diff --git a/docs/workflows/index.html b/docs/workflows/index.html index d252848..772af08 100644 --- a/docs/workflows/index.html +++ b/docs/workflows/index.html @@ -1212,7 +1212,7 @@

Workflow Files Descriptions

fitted polynomial, and to find the optimal weight of each past sample for the fitting process.

Workflow Files Dependency Graph

- + Workflow outline and the dependencies between each notebook files. Blue color (rectangles) denotes an interactive notebook file, green color (diamond) denotes intermediate outputs between different files, and the global input is in yellow color (circle). Dotted line denote optional dependencies. diff --git a/reference/wtracker/eval/vlc/index.html b/reference/wtracker/eval/vlc/index.html index 397e5ec..f89ca64 100644 --- a/reference/wtracker/eval/vlc/index.html +++ b/reference/wtracker/eval/vlc/index.html @@ -2854,7 +2854,7 @@

create_trackbar

name: str, val: int, maxval: int, - onChange=<function StreamViewer.<lambda> at 0x7f88625f0160> + onChange=<function StreamViewer.<lambda> at 0x7f9303198160> ) diff --git a/reference/wtracker/utils/path_utils/index.html b/reference/wtracker/utils/path_utils/index.html index a37c231..fbeaac3 100644 --- a/reference/wtracker/utils/path_utils/index.html +++ b/reference/wtracker/utils/path_utils/index.html @@ -2035,7 +2035,7 @@

Files

extension: 'str' = '', scan_dirs: 'bool' = False, return_full_path: 'bool' = True, - sorting_key: 'Callable[[str], Union[int, str]]' = <function Files.<lambda> at 0x7f894dc4e290> + sorting_key: 'Callable[[str], Union[int, str]]' = <function Files.<lambda> at 0x7f93ee7f2290> ) diff --git a/search/search_index.json b/search/search_index.json index 3ff4534..9177a43 100644 --- a/search/search_index.json +++ b/search/search_index.json @@ -1 +1 @@ -{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"WTracker Description This library provides tools for worm detection and movement prediction, training predictors, and analyzing the results. It includes support for YOLO-based prediction and various simulation controllers. Features Real-time Worm detection and movement prediction Logging and analysis tools CSV, logging, and YOLO controllers Documentation There is an Official Documentation website availabe of the entire API. The library is fully documented within the code base. Workflow files have elaborate documentation for usage. Installation Download the Repository Download the project repository (by clicking on code -> download zip) and extract the files in the desired location. Environment Installation Step 1 - Install mamba: Install 'Miniforge' from this link , make sure to download the right version (that match the OS and CPU of the computer). If asked during installation: add to PATH. * if unsure, use this link to download mamba. Step 2 - verify that mamba is installed correctly: 1. Navigate to the folder into which the library was download. 2. Open terminal/command prompt. 3. Enter - 'mamba -h'. if no error is encountered then mamba is installed correctly. Step 3 - create a new environment: 1. Enter the following command - \"mamba create -n bio-proj python=3.12\". 2. Enter the command - 'mamba init'. * You can choose another name (not only 'bio-proj'). If you do , you will need to change the name field in the 'requirements.yaml' file as well. Step 4 - Activate the environment: 1. Enter the command - 'mamba activate bio-proj'. * If you used another name for the environment, replace 'bio-proj' with the name you have chosen. Step 5 - Installing Pytorch: 1. Head to the pytorch website here , there you will find a table to select your configuration, select the following: 1. PyTorch Build = stable 2. OS - the operating system of the computer [Windows/Linux] 3. Package - the package manager [conda] 4. Language - the programming language [Python] 5. Compute Platform - if the computer has GPU select the newest version of CUDA [12.1 was tested], otherwise select CPU. 2. Copy the command below the table and enter it in the terminal/command prompt 3. Wait till the installation is complate. That might take a while. Step 6 - Install the rest of the libraries: Enter the command - 'mamba env update -f requirements.yaml -n bio-proj' Install the Development Environment To run the project we recommend 'Visual Studio Code' (also referred as VS Code), a free IDE. Basic usage videos and documentation can be found here . You can download and install VS Code from here . To set up VS Code for the project you need to install several extensions. Follow this link to learn how to install extensions. The extensions needed are: - Jupyter - Python - Pylance * Some extensions may be already installed by default. Usage Refer to the variouse '.ipynb' files for usage for each workflow. License The code is licensed under the GPL v3.0 License . TL;DR Here's what the license entails: 1. Anyone can copy, modify and distribute this software. 2. You have to include the license and copyright notice with each and every distribution. 3. You can use this software privately. 4. You can use this software for commercial purposes. 5. If you dare build your business solely from this code, you risk open-sourcing the whole code base. 6. If you modify it, you have to indicate changes made to the code. 7. Any modifications of this code base MUST be distributed with the same license, GPLv3. 8. This software is provided without warranty. 9. The software author or license can not be held liable for any damages inflicted by the software. For more details see the license file. Contact Please open an issue in the GitHub repository if you have any questions or feedback.","title":"Home"},{"location":"#wtracker","text":"","title":"WTracker"},{"location":"#description","text":"This library provides tools for worm detection and movement prediction, training predictors, and analyzing the results. It includes support for YOLO-based prediction and various simulation controllers.","title":"Description"},{"location":"#features","text":"Real-time Worm detection and movement prediction Logging and analysis tools CSV, logging, and YOLO controllers","title":"Features"},{"location":"#documentation","text":"There is an Official Documentation website availabe of the entire API. The library is fully documented within the code base. Workflow files have elaborate documentation for usage.","title":"Documentation"},{"location":"#installation","text":"","title":"Installation"},{"location":"#download-the-repository","text":"Download the project repository (by clicking on code -> download zip) and extract the files in the desired location.","title":"Download the Repository"},{"location":"#environment-installation","text":"Step 1 - Install mamba: Install 'Miniforge' from this link , make sure to download the right version (that match the OS and CPU of the computer). If asked during installation: add to PATH. * if unsure, use this link to download mamba. Step 2 - verify that mamba is installed correctly: 1. Navigate to the folder into which the library was download. 2. Open terminal/command prompt. 3. Enter - 'mamba -h'. if no error is encountered then mamba is installed correctly. Step 3 - create a new environment: 1. Enter the following command - \"mamba create -n bio-proj python=3.12\". 2. Enter the command - 'mamba init'. * You can choose another name (not only 'bio-proj'). If you do , you will need to change the name field in the 'requirements.yaml' file as well. Step 4 - Activate the environment: 1. Enter the command - 'mamba activate bio-proj'. * If you used another name for the environment, replace 'bio-proj' with the name you have chosen. Step 5 - Installing Pytorch: 1. Head to the pytorch website here , there you will find a table to select your configuration, select the following: 1. PyTorch Build = stable 2. OS - the operating system of the computer [Windows/Linux] 3. Package - the package manager [conda] 4. Language - the programming language [Python] 5. Compute Platform - if the computer has GPU select the newest version of CUDA [12.1 was tested], otherwise select CPU. 2. Copy the command below the table and enter it in the terminal/command prompt 3. Wait till the installation is complate. That might take a while. Step 6 - Install the rest of the libraries: Enter the command - 'mamba env update -f requirements.yaml -n bio-proj'","title":"Environment Installation"},{"location":"#install-the-development-environment","text":"To run the project we recommend 'Visual Studio Code' (also referred as VS Code), a free IDE. Basic usage videos and documentation can be found here . You can download and install VS Code from here . To set up VS Code for the project you need to install several extensions. Follow this link to learn how to install extensions. The extensions needed are: - Jupyter - Python - Pylance * Some extensions may be already installed by default.","title":"Install the Development Environment"},{"location":"#usage","text":"Refer to the variouse '.ipynb' files for usage for each workflow.","title":"Usage"},{"location":"#license","text":"The code is licensed under the GPL v3.0 License . TL;DR Here's what the license entails: 1. Anyone can copy, modify and distribute this software. 2. You have to include the license and copyright notice with each and every distribution. 3. You can use this software privately. 4. You can use this software for commercial purposes. 5. If you dare build your business solely from this code, you risk open-sourcing the whole code base. 6. If you modify it, you have to indicate changes made to the code. 7. Any modifications of this code base MUST be distributed with the same license, GPLv3. 8. This software is provided without warranty. 9. The software author or license can not be held liable for any damages inflicted by the software. For more details see the license file.","title":"License"},{"location":"#contact","text":"Please open an issue in the GitHub repository if you have any questions or feedback.","title":"Contact"},{"location":"docs/workflows/","text":"General workflows Here we will go over the steps to do some of the main tasks, from training a YOLO model on custom data to running simulations with different configurations. All of the main Workflows have a dedicated, interactive notebook (.ipynb file) ready to use with explanations for each step. All of the workflow notebooks are located in a dedicated folder called \"workflows\". Workflow Files Descriptions create_yolo_images.ipynb - Prepares raw frames of some experiment for the process of training YOLO model on them. This step entails detecting the worm in selected frames and cropping a region of pre-defined size around the worms. yolo_training.ipynb - Used to train a YOLO model on a given dataset. The training dataset was prepared by annotating 3 the images which were extracted using the notebook create_yolo_images. The annotation process can be done with RoboFlow, which is an online dataset creation and annotation tool. initialize_experiment.ipynb - In order to run system simulations on a new experiment, first it\u2019s essential to initialize the experiment. The initialization step runs the YOLO predictor on the raw experiment, detects worm\u2019s head position in each frame and saves the detection results into a log. That log would be later used for simulating different control algorithms on the experiment. In addition, the background image and worm images are extracted from the raw frames. These can be used later during analysis, to calculate the segmentation based error. This log is useful since in the future the simulator can simply read worm head positions from the log, instead of using YOLO to predict worm\u2019s head position in every frame of interest (which is much slower, especially on computers without a dedicated graphics card). simulate.ipynb - Run a full system simulation on some previously initialized experiment. The simulation is ran by reading an experiment log produced by the initialization process - in each frame, worm\u2019s head position is retrieved from the log. In this notebook it is possible to simulate the system with any controller and any configuration parameters, not only the ones of used for the initial experiment log. Similar to the initialization process, the simulation produces a log, which would be later used to analyze system\u2019s performance and its behavior. analysis.ipynb - This notebook is used to analyze the performance of a control algorithm (controller). A log which was produced by running simulate is read and analyzed, and different plots and statistics are presented. In addition, there is an option to calculate segmentation evaluation-error, by counting how many pixels of the worm are outside of the microscope view. To this end, we use the background and worm images which were extracted during the run of intialize_experiment notebook for this experiment. visualize.ipynb - Given a system log which was produced by simulate, this notebook is able to visually recreate the simulator\u2019s behavior. At each frame, the position of worm\u2019s head is drawn, the position of the microscope FOV, and also the camera FOV. This notebook is used to visually assess the performance and the behavior of the simulator, and to visually investigate what causes the system to misbehave. predictor_training - Used to train a specific simulation control algorithm. The MLPController *is an algorithm that uses a neural network (NN) to predict worm\u2019s future position. Since this algorithm is NN based, it requires training. That script is responsible to train that NN from experiment log files, which were produced by either running initialize or simulate (doesn\u2019t matter). polyfit_optimizer.ipynb - This notebook is used to tune the parameters of a specific simulation control algorithm. The PolyfitController is an algorithm that uses polynomial-fitting to predict worm\u2019s future position. A polynomial is fitted from past observations at previous time stamps, and afterwards sampled in the future time to predict worm\u2019s position. This notebook is used to determine the optimal degree of the fitted polynomial, and to find the optimal weight of each past sample for the fitting process. Workflow Files Dependency Graph Workflow outline and the dependencies between each notebook files. Blue color (rectangles) denotes an interactive notebook file, green color (diamond) denotes intermediate outputs between different files, and the global input is in yellow color (circle). Dotted line denote optional dependencies. Complete Workflows Conducting an Experiment Here we explain how to properly capture the footage of an experiment for the simulator. Decide on the frame rate (FPS) of the experiment. There are two distinct scenarios: If the sole reason for the experiment footage is to be used for YOLO training, a very low FPS can be used (can be as low as 1 FPS or even lower if possible). Ideally, a single frame would be captured every few seconds. If a simulation to be run on the experiment footage, a high FPS should be used, preferably at least 60 FPS. Note, that the the chosen frame rate should be the same frame rate on which the platform control algorithms were calibrated. The camera should be completely stationary from which the entire arena is visible. The footage should be captured as distinct images for each frame, not as a continious video. We recommend to use \"bmp\" image format, which is lossless, and is able to save single channel if the image is grayscale. Make sure that the distinct frames are saved as images with the frame number appearing in their name, such that it's possible to read them in-order. If you want to run a system simulation on the experiment, follow the steps in the initialize_experiment notebook. YOLO Model Training Below is the workflow to train a YOLO model to detect worm's head position: Conduct a new experiment and capture the footage, as explained in the previous section. Determine the image size for the YOLO model, this size should match the desired input image size during a simulation run. At the time of writing, it is possible to pass images of different sizes to YOLO models, but they are scaled to the closes multiple of 32. This means you should use and train YOLO models on images with sizes that are a multiple of 32 when possible. A YOLO model should be trained on images with the same size, it is not expected to work well on images of different sizes without special attention. Be careful of a Distribution Shift, this means that the training data is different (not representative) of the real world. For example: In the training data, are the worms always in a similar position? Is the background lighting consistent with the one on the system? Is the size of each pixel the same as in the system? 4.Are the worms in the dataset representative of the worm population? Create a set of images for annotation - Follow the instructions in the create_yolo_images python notebook. Make sure to provide the correct image size, which was determined in the previous step. Annotate the data - The annotation process is the process of marking the actual worm head position in the extracted images. To do so, we recommend using the website RoboFlow , which provides easy-to-use tools to annotate the data, and create a dataset for YOLO models. Create a YOLO dataset - If you used Roboflow to create the dataset - on the dataset page you can click on 'export dataset' and then 'show download code'. You can copy the code snippet to the appropriate place in the notebook of step 6 to download the dataset to the computer and use it for training. Follow the instructions in the yolo_training notebook and train the YOLO model. There are two approaches to tackle the challenges of distribution shift, mentioned earlier. The first approach is to carefully train the YOLO model on very similar conditions as of the final system. The resulting model will function well, but if conditions change then models performance will likely degrade. The other approach is to train the model on wide variety of settings (e.g. different lighting conditions or different magnification levels), leading to a more robust model. The benefit of this approach is that the model is more robust to changes, but a disadvantage is that such models usually require more data, and may perform slightly worse than models carefully trained on some very specific conditions. Perform System Simulation Below is the workflow of performing a full system simulation on some experiment, and analyzing the results. If the experiment was not initialized yet, make sure to follow the instructions in the initialize_experiment notebook. Decide on the platform control algorithm to be used. If MLPController algorithm is chosen: the MLP controller works by a neural network that predicts future positions of the worm. If that network needs training then first run predictor_training notebook. Note, that the neural network should be trained only once. Once the network is trained, there is no need to perform this step anymore. If PolyfitController algorithm is chosen, and the hyper-parameters of the controller should be tuned then first run the polyfit_optimizer notebook. Note, that the hyper-parameters of this controller should be tuned only once. Once they were tuned, there is no need to perform this step anymore. Follow the steps in the simulate notebook. The result of running this notebook is a log file containing the full simulation log. To visualize the simulation run visualize notebook, and to analyze the performance of the control algorithm, and general statistics of the conducted experiment run the analyze notebook. Both of these notebooks analyze the log produced by simulate .","title":"Workflows"},{"location":"docs/workflows/#general-workflows","text":"Here we will go over the steps to do some of the main tasks, from training a YOLO model on custom data to running simulations with different configurations. All of the main Workflows have a dedicated, interactive notebook (.ipynb file) ready to use with explanations for each step. All of the workflow notebooks are located in a dedicated folder called \"workflows\".","title":"General workflows"},{"location":"docs/workflows/#workflow-files-descriptions","text":"create_yolo_images.ipynb - Prepares raw frames of some experiment for the process of training YOLO model on them. This step entails detecting the worm in selected frames and cropping a region of pre-defined size around the worms. yolo_training.ipynb - Used to train a YOLO model on a given dataset. The training dataset was prepared by annotating 3 the images which were extracted using the notebook create_yolo_images. The annotation process can be done with RoboFlow, which is an online dataset creation and annotation tool. initialize_experiment.ipynb - In order to run system simulations on a new experiment, first it\u2019s essential to initialize the experiment. The initialization step runs the YOLO predictor on the raw experiment, detects worm\u2019s head position in each frame and saves the detection results into a log. That log would be later used for simulating different control algorithms on the experiment. In addition, the background image and worm images are extracted from the raw frames. These can be used later during analysis, to calculate the segmentation based error. This log is useful since in the future the simulator can simply read worm head positions from the log, instead of using YOLO to predict worm\u2019s head position in every frame of interest (which is much slower, especially on computers without a dedicated graphics card). simulate.ipynb - Run a full system simulation on some previously initialized experiment. The simulation is ran by reading an experiment log produced by the initialization process - in each frame, worm\u2019s head position is retrieved from the log. In this notebook it is possible to simulate the system with any controller and any configuration parameters, not only the ones of used for the initial experiment log. Similar to the initialization process, the simulation produces a log, which would be later used to analyze system\u2019s performance and its behavior. analysis.ipynb - This notebook is used to analyze the performance of a control algorithm (controller). A log which was produced by running simulate is read and analyzed, and different plots and statistics are presented. In addition, there is an option to calculate segmentation evaluation-error, by counting how many pixels of the worm are outside of the microscope view. To this end, we use the background and worm images which were extracted during the run of intialize_experiment notebook for this experiment. visualize.ipynb - Given a system log which was produced by simulate, this notebook is able to visually recreate the simulator\u2019s behavior. At each frame, the position of worm\u2019s head is drawn, the position of the microscope FOV, and also the camera FOV. This notebook is used to visually assess the performance and the behavior of the simulator, and to visually investigate what causes the system to misbehave. predictor_training - Used to train a specific simulation control algorithm. The MLPController *is an algorithm that uses a neural network (NN) to predict worm\u2019s future position. Since this algorithm is NN based, it requires training. That script is responsible to train that NN from experiment log files, which were produced by either running initialize or simulate (doesn\u2019t matter). polyfit_optimizer.ipynb - This notebook is used to tune the parameters of a specific simulation control algorithm. The PolyfitController is an algorithm that uses polynomial-fitting to predict worm\u2019s future position. A polynomial is fitted from past observations at previous time stamps, and afterwards sampled in the future time to predict worm\u2019s position. This notebook is used to determine the optimal degree of the fitted polynomial, and to find the optimal weight of each past sample for the fitting process.","title":"Workflow Files Descriptions"},{"location":"docs/workflows/#workflow-files-dependency-graph","text":"Workflow outline and the dependencies between each notebook files. Blue color (rectangles) denotes an interactive notebook file, green color (diamond) denotes intermediate outputs between different files, and the global input is in yellow color (circle). Dotted line denote optional dependencies.","title":"Workflow Files Dependency Graph"},{"location":"docs/workflows/#complete-workflows","text":"","title":"Complete Workflows"},{"location":"docs/workflows/#conducting-an-experiment","text":"Here we explain how to properly capture the footage of an experiment for the simulator. Decide on the frame rate (FPS) of the experiment. There are two distinct scenarios: If the sole reason for the experiment footage is to be used for YOLO training, a very low FPS can be used (can be as low as 1 FPS or even lower if possible). Ideally, a single frame would be captured every few seconds. If a simulation to be run on the experiment footage, a high FPS should be used, preferably at least 60 FPS. Note, that the the chosen frame rate should be the same frame rate on which the platform control algorithms were calibrated. The camera should be completely stationary from which the entire arena is visible. The footage should be captured as distinct images for each frame, not as a continious video. We recommend to use \"bmp\" image format, which is lossless, and is able to save single channel if the image is grayscale. Make sure that the distinct frames are saved as images with the frame number appearing in their name, such that it's possible to read them in-order. If you want to run a system simulation on the experiment, follow the steps in the initialize_experiment notebook.","title":"Conducting an Experiment"},{"location":"docs/workflows/#yolo-model-training","text":"Below is the workflow to train a YOLO model to detect worm's head position: Conduct a new experiment and capture the footage, as explained in the previous section. Determine the image size for the YOLO model, this size should match the desired input image size during a simulation run. At the time of writing, it is possible to pass images of different sizes to YOLO models, but they are scaled to the closes multiple of 32. This means you should use and train YOLO models on images with sizes that are a multiple of 32 when possible. A YOLO model should be trained on images with the same size, it is not expected to work well on images of different sizes without special attention. Be careful of a Distribution Shift, this means that the training data is different (not representative) of the real world. For example: In the training data, are the worms always in a similar position? Is the background lighting consistent with the one on the system? Is the size of each pixel the same as in the system? 4.Are the worms in the dataset representative of the worm population? Create a set of images for annotation - Follow the instructions in the create_yolo_images python notebook. Make sure to provide the correct image size, which was determined in the previous step. Annotate the data - The annotation process is the process of marking the actual worm head position in the extracted images. To do so, we recommend using the website RoboFlow , which provides easy-to-use tools to annotate the data, and create a dataset for YOLO models. Create a YOLO dataset - If you used Roboflow to create the dataset - on the dataset page you can click on 'export dataset' and then 'show download code'. You can copy the code snippet to the appropriate place in the notebook of step 6 to download the dataset to the computer and use it for training. Follow the instructions in the yolo_training notebook and train the YOLO model. There are two approaches to tackle the challenges of distribution shift, mentioned earlier. The first approach is to carefully train the YOLO model on very similar conditions as of the final system. The resulting model will function well, but if conditions change then models performance will likely degrade. The other approach is to train the model on wide variety of settings (e.g. different lighting conditions or different magnification levels), leading to a more robust model. The benefit of this approach is that the model is more robust to changes, but a disadvantage is that such models usually require more data, and may perform slightly worse than models carefully trained on some very specific conditions.","title":"YOLO Model Training"},{"location":"docs/workflows/#perform-system-simulation","text":"Below is the workflow of performing a full system simulation on some experiment, and analyzing the results. If the experiment was not initialized yet, make sure to follow the instructions in the initialize_experiment notebook. Decide on the platform control algorithm to be used. If MLPController algorithm is chosen: the MLP controller works by a neural network that predicts future positions of the worm. If that network needs training then first run predictor_training notebook. Note, that the neural network should be trained only once. Once the network is trained, there is no need to perform this step anymore. If PolyfitController algorithm is chosen, and the hyper-parameters of the controller should be tuned then first run the polyfit_optimizer notebook. Note, that the hyper-parameters of this controller should be tuned only once. Once they were tuned, there is no need to perform this step anymore. Follow the steps in the simulate notebook. The result of running this notebook is a log file containing the full simulation log. To visualize the simulation run visualize notebook, and to analyze the performance of the control algorithm, and general statistics of the conducted experiment run the analyze notebook. Both of these notebooks analyze the log produced by simulate .","title":"Perform System Simulation"},{"location":"reference/wtracker/dataset/","text":"Module wtracker.dataset View Source from wtracker.dataset.sample_extractor import SampleExtractor from wtracker.dataset.box_calculator import BoxCalculator from wtracker.dataset.bg_extractor import BGExtractor Sub-modules wtracker.dataset.bg_extractor wtracker.dataset.box_calculator wtracker.dataset.sample_extractor","title":"Index"},{"location":"reference/wtracker/dataset/#module-wtrackerdataset","text":"View Source from wtracker.dataset.sample_extractor import SampleExtractor from wtracker.dataset.box_calculator import BoxCalculator from wtracker.dataset.bg_extractor import BGExtractor","title":"Module wtracker.dataset"},{"location":"reference/wtracker/dataset/#sub-modules","text":"wtracker.dataset.bg_extractor wtracker.dataset.box_calculator wtracker.dataset.sample_extractor","title":"Sub-modules"},{"location":"reference/wtracker/dataset/bg_extractor/","text":"Module wtracker.dataset.bg_extractor View Source import numpy as np from tqdm.auto import tqdm from wtracker.utils.frame_reader import FrameReader class BGExtractor : \"\"\" A class for extracting the background from a given sequence of frames, provided by a FrameReader. Args: reader (FrameReader): The FrameReader object holding the frames to extract the background from. \"\"\" def __init__ ( self , reader : FrameReader ): self . reader = reader def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \"random\" or \"uniform\". \"uniform\" will select frames uniformly spaced from the FrameReader. \"random\" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \"median\" or \"mean\". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\" , \"uniform\" ] assert method in [ \"median\" , \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg def _calc_background_mean ( self , frame_ids : np . ndarray ) -> np . ndarray : sum = np . zeros ( self . reader . frame_shape , dtype = np . float64 ) # read frames for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ): frame = self . reader [ frame_id ] sum += frame mean = sum / len ( frame_ids ) return mean . astype ( np . uint8 , copy = False ) def _calc_background_median ( self , frame_ids : np . ndarray ) -> np . ndarray : # get frames extracted_list = [] for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ): frame = self . reader [ frame_id ] extracted_list . append ( frame ) # calculate the median along the time axis extracted = np . stack ( extracted_list , axis = 0 ) median = np . median ( extracted , axis = 0 ) . astype ( np . uint8 , copy = False ) return median Classes BGExtractor class BGExtractor ( reader : wtracker . utils . frame_reader . FrameReader ) A class for extracting the background from a given sequence of frames, provided by a FrameReader. Attributes Name Type Description Default reader FrameReader The FrameReader object holding the frames to extract the background from. None View Source class BGExtractor : \"\"\" A class for extracting the background from a given sequence of frames, provided by a FrameReader. Args: reader (FrameReader): The FrameReader object holding the frames to extract the background from. \"\"\" def __init__ ( self , reader : FrameReader ) : self . reader = reader def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \" random \" or \" uniform \". \" uniform \" will select frames uniformly spaced from the FrameReader. \" random \" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \" median \" or \" mean \". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\", \"uniform\" ] assert method in [ \"median\", \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg def _calc_background_mean ( self , frame_ids : np . ndarray ) -> np . ndarray : sum = np . zeros ( self . reader . frame_shape , dtype = np . float64 ) # read frames for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ) : frame = self . reader [ frame_id ] sum += frame mean = sum / len ( frame_ids ) return mean . astype ( np . uint8 , copy = False ) def _calc_background_median ( self , frame_ids : np . ndarray ) -> np . ndarray : # get frames extracted_list = [] for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ) : frame = self . reader [ frame_id ] extracted_list . append ( frame ) # calculate the median along the time axis extracted = np . stack ( extracted_list , axis = 0 ) median = np . median ( extracted , axis = 0 ). astype ( np . uint8 , copy = False ) return median Methods calc_background def calc_background ( self , num_probes : int , sampling : str = 'uniform' , method : str = 'median' ) -> numpy . ndarray Calculate the background of the dataset. Parameters: Name Type Description Default num_probes int The number of probes to sample for background calculation. None sampling str The sampling method for selecting probes. Can be \"random\" or \"uniform\". \"uniform\" will select frames uniformly spaced from the FrameReader. \"random\" will select frames randomly from the FrameReader. None method str The method for calculating the background. Can be \"median\" or \"mean\". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. None Returns: Type Description np.ndarray The calculated background as a numpy array. View Source def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \" random \" or \" uniform \". \" uniform \" will select frames uniformly spaced from the FrameReader. \" random \" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \" median \" or \" mean \". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\" , \"uniform\" ] assert method in [ \"median\" , \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg","title":"Bg Extractor"},{"location":"reference/wtracker/dataset/bg_extractor/#module-wtrackerdatasetbg_extractor","text":"View Source import numpy as np from tqdm.auto import tqdm from wtracker.utils.frame_reader import FrameReader class BGExtractor : \"\"\" A class for extracting the background from a given sequence of frames, provided by a FrameReader. Args: reader (FrameReader): The FrameReader object holding the frames to extract the background from. \"\"\" def __init__ ( self , reader : FrameReader ): self . reader = reader def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \"random\" or \"uniform\". \"uniform\" will select frames uniformly spaced from the FrameReader. \"random\" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \"median\" or \"mean\". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\" , \"uniform\" ] assert method in [ \"median\" , \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg def _calc_background_mean ( self , frame_ids : np . ndarray ) -> np . ndarray : sum = np . zeros ( self . reader . frame_shape , dtype = np . float64 ) # read frames for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ): frame = self . reader [ frame_id ] sum += frame mean = sum / len ( frame_ids ) return mean . astype ( np . uint8 , copy = False ) def _calc_background_median ( self , frame_ids : np . ndarray ) -> np . ndarray : # get frames extracted_list = [] for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ): frame = self . reader [ frame_id ] extracted_list . append ( frame ) # calculate the median along the time axis extracted = np . stack ( extracted_list , axis = 0 ) median = np . median ( extracted , axis = 0 ) . astype ( np . uint8 , copy = False ) return median","title":"Module wtracker.dataset.bg_extractor"},{"location":"reference/wtracker/dataset/bg_extractor/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/dataset/bg_extractor/#bgextractor","text":"class BGExtractor ( reader : wtracker . utils . frame_reader . FrameReader ) A class for extracting the background from a given sequence of frames, provided by a FrameReader.","title":"BGExtractor"},{"location":"reference/wtracker/dataset/bg_extractor/#attributes","text":"Name Type Description Default reader FrameReader The FrameReader object holding the frames to extract the background from. None View Source class BGExtractor : \"\"\" A class for extracting the background from a given sequence of frames, provided by a FrameReader. Args: reader (FrameReader): The FrameReader object holding the frames to extract the background from. \"\"\" def __init__ ( self , reader : FrameReader ) : self . reader = reader def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \" random \" or \" uniform \". \" uniform \" will select frames uniformly spaced from the FrameReader. \" random \" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \" median \" or \" mean \". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\", \"uniform\" ] assert method in [ \"median\", \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg def _calc_background_mean ( self , frame_ids : np . ndarray ) -> np . ndarray : sum = np . zeros ( self . reader . frame_shape , dtype = np . float64 ) # read frames for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ) : frame = self . reader [ frame_id ] sum += frame mean = sum / len ( frame_ids ) return mean . astype ( np . uint8 , copy = False ) def _calc_background_median ( self , frame_ids : np . ndarray ) -> np . ndarray : # get frames extracted_list = [] for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ) : frame = self . reader [ frame_id ] extracted_list . append ( frame ) # calculate the median along the time axis extracted = np . stack ( extracted_list , axis = 0 ) median = np . median ( extracted , axis = 0 ). astype ( np . uint8 , copy = False ) return median","title":"Attributes"},{"location":"reference/wtracker/dataset/bg_extractor/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/dataset/bg_extractor/#calc_background","text":"def calc_background ( self , num_probes : int , sampling : str = 'uniform' , method : str = 'median' ) -> numpy . ndarray Calculate the background of the dataset. Parameters: Name Type Description Default num_probes int The number of probes to sample for background calculation. None sampling str The sampling method for selecting probes. Can be \"random\" or \"uniform\". \"uniform\" will select frames uniformly spaced from the FrameReader. \"random\" will select frames randomly from the FrameReader. None method str The method for calculating the background. Can be \"median\" or \"mean\". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. None Returns: Type Description np.ndarray The calculated background as a numpy array. View Source def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \" random \" or \" uniform \". \" uniform \" will select frames uniformly spaced from the FrameReader. \" random \" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \" median \" or \" mean \". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\" , \"uniform\" ] assert method in [ \"median\" , \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg","title":"calc_background"},{"location":"reference/wtracker/dataset/box_calculator/","text":"Module wtracker.dataset.box_calculator View Source import cv2 as cv import numpy as np from typing import Collection from tqdm.auto import tqdm from tqdm.contrib import concurrent from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import adjust_num_workers class BoxCalculator : \"\"\" A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Args: frame_reader (FrameReader): The frame reader object holing the relevant frames. background (np.ndarray): The background image of the frames in the `frame_reader` argument. diff_thresh (int, optional): Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. \"\"\" def __init__ ( self , frame_reader : FrameReader , background : np . ndarray , diff_thresh : int = 20 , ) -> None : assert diff_thresh > 0 , \"Difference threshold must be greater than 0.\" assert frame_reader . frame_shape == background . shape , \"Background shape must match frame shape.\" # convert background to grayscale if needed if background . ndim == 3 and background . shape [ 2 ] == 3 : background = cv . cvtColor ( background , cv . COLOR_BGR2GRAY ) if background . ndim != 2 : raise ValueError ( \"background must be either a gray or a color image.\" ) self . _frame_reader = frame_reader self . _background = background self . _diff_thresh = diff_thresh self . _all_bboxes = np . full (( len ( frame_reader ), 4 ), - 1 , dtype = int ) def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn't calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox def _calc_bounding_box ( self , frame_idx : int ) -> np . ndarray : # get mask according to the threshold value frame = self . _frame_reader [ frame_idx ] # convert to grayscale if needed if frame . ndim == 3 and frame . shape [ 2 ] == 3 : frame = cv . cvtColor ( frame , cv . COLOR_BGR2GRAY ) diff = cv . absdiff ( frame , self . _background ) _ , mask = cv . threshold ( diff , self . _diff_thresh , 255 , cv . THRESH_BINARY ) # apply morphological ops to the mask mask = cv . morphologyEx ( mask , cv . MORPH_OPEN , np . ones (( 5 , 5 ), np . uint8 )) mask = cv . dilate ( mask , np . ones (( 11 , 11 ), np . uint8 )) # extract contours and bbox contours , _ = cv . findContours ( mask , cv . RETR_EXTERNAL , cv . CHAIN_APPROX_NONE ) if not contours : zero_bbox = np . array ([ 0 , 0 , 0 , 0 ]) self . _all_bboxes [ frame_idx ] = zero_bbox return zero_bbox largest_contour = max ( contours , key = cv . contourArea ) largest_bbox = cv . boundingRect ( largest_contour ) largest_bbox = np . asanyarray ( largest_bbox , dtype = int ) return largest_bbox def calc_specified_boxes ( self , frame_indices : Collection [ int ], num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ): self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ): self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices , :] return bboxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size ) Classes BoxCalculator class BoxCalculator ( frame_reader : wtracker . utils . frame_reader . FrameReader , background : numpy . ndarray , diff_thresh : int = 20 ) A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Attributes Name Type Description Default frame_reader FrameReader The frame reader object holing the relevant frames. None background np.ndarray The background image of the frames in the frame_reader argument. None diff_thresh int Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. None View Source class BoxCalculator : \"\"\" A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Args: frame_reader (FrameReader): The frame reader object holing the relevant frames. background (np.ndarray): The background image of the frames in the `frame_reader` argument. diff_thresh (int, optional): Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. \"\"\" def __init__ ( self , frame_reader : FrameReader , background : np . ndarray , diff_thresh : int = 20 , ) -> None : assert diff_thresh > 0 , \"Difference threshold must be greater than 0.\" assert frame_reader . frame_shape == background . shape , \"Background shape must match frame shape.\" # convert background to grayscale if needed if background . ndim == 3 and background . shape [ 2 ] == 3 : background = cv . cvtColor ( background , cv . COLOR_BGR2GRAY ) if background . ndim != 2 : raise ValueError ( \"background must be either a gray or a color image.\" ) self . _frame_reader = frame_reader self . _background = background self . _diff_thresh = diff_thresh self . _all_bboxes = np . full (( len ( frame_reader ), 4 ), - 1 , dtype = int ) def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn ' t calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox def _calc_bounding_box ( self , frame_idx : int ) -> np . ndarray : # get mask according to the threshold value frame = self . _frame_reader [ frame_idx ] # convert to grayscale if needed if frame . ndim == 3 and frame . shape [ 2 ] == 3 : frame = cv . cvtColor ( frame , cv . COLOR_BGR2GRAY ) diff = cv . absdiff ( frame , self . _background ) _ , mask = cv . threshold ( diff , self . _diff_thresh , 255 , cv . THRESH_BINARY ) # apply morphological ops to the mask mask = cv . morphologyEx ( mask , cv . MORPH_OPEN , np . ones (( 5 , 5 ), np . uint8 )) mask = cv . dilate ( mask , np . ones (( 11 , 11 ), np . uint8 )) # extract contours and bbox contours , _ = cv . findContours ( mask , cv . RETR_EXTERNAL , cv . CHAIN_APPROX_NONE ) if not contours : zero_bbox = np . array ( [ 0, 0, 0, 0 ] ) self . _all_bboxes [ frame_idx ] = zero_bbox return zero_bbox largest_contour = max ( contours , key = cv . contourArea ) largest_bbox = cv . boundingRect ( largest_contour ) largest_bbox = np . asanyarray ( largest_bbox , dtype = int ) return largest_bbox def calc_specified_boxes ( self , frame_indices : Collection [ int ] , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ) : self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ) : self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices, : ] return bboxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size ) Methods all_bboxes def all_bboxes ( self ) -> numpy . ndarray Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: Type Description np.ndarray Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). View Source def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes calc_all_boxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 ) -> numpy . ndarray Calculate bounding boxes for all frames. Parameters: Name Type Description Default num_workers int Number of workers for parallel processing. If None is provided then number of workers is determined automatically. None chunk_size int Size of each chunk for parallel processing. None Returns: Type Description np.ndarray Array of bounding boxes for all frames. View Source def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size ) calc_specified_boxes def calc_specified_boxes ( self , frame_indices : Collection [ int ], num_workers : int = None , chunk_size : int = 50 ) -> numpy . ndarray Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. View Source def calc_specified_boxes ( self , frame_indices : Collection [ int ] , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ) : self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ) : self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices, : ] return bboxes get_bbox def get_bbox ( self , frame_idx : int ) -> numpy . ndarray Returns the bounding box for a given frame index. Parameters: Name Type Description Default frame_idx int The index of the frame from which to extract the bounding box. None Returns: Type Description np.ndarray The bounding box coordinates as a numpy array, in format (x, y, w, h). View Source def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn ' t calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox","title":"Box Calculator"},{"location":"reference/wtracker/dataset/box_calculator/#module-wtrackerdatasetbox_calculator","text":"View Source import cv2 as cv import numpy as np from typing import Collection from tqdm.auto import tqdm from tqdm.contrib import concurrent from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import adjust_num_workers class BoxCalculator : \"\"\" A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Args: frame_reader (FrameReader): The frame reader object holing the relevant frames. background (np.ndarray): The background image of the frames in the `frame_reader` argument. diff_thresh (int, optional): Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. \"\"\" def __init__ ( self , frame_reader : FrameReader , background : np . ndarray , diff_thresh : int = 20 , ) -> None : assert diff_thresh > 0 , \"Difference threshold must be greater than 0.\" assert frame_reader . frame_shape == background . shape , \"Background shape must match frame shape.\" # convert background to grayscale if needed if background . ndim == 3 and background . shape [ 2 ] == 3 : background = cv . cvtColor ( background , cv . COLOR_BGR2GRAY ) if background . ndim != 2 : raise ValueError ( \"background must be either a gray or a color image.\" ) self . _frame_reader = frame_reader self . _background = background self . _diff_thresh = diff_thresh self . _all_bboxes = np . full (( len ( frame_reader ), 4 ), - 1 , dtype = int ) def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn't calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox def _calc_bounding_box ( self , frame_idx : int ) -> np . ndarray : # get mask according to the threshold value frame = self . _frame_reader [ frame_idx ] # convert to grayscale if needed if frame . ndim == 3 and frame . shape [ 2 ] == 3 : frame = cv . cvtColor ( frame , cv . COLOR_BGR2GRAY ) diff = cv . absdiff ( frame , self . _background ) _ , mask = cv . threshold ( diff , self . _diff_thresh , 255 , cv . THRESH_BINARY ) # apply morphological ops to the mask mask = cv . morphologyEx ( mask , cv . MORPH_OPEN , np . ones (( 5 , 5 ), np . uint8 )) mask = cv . dilate ( mask , np . ones (( 11 , 11 ), np . uint8 )) # extract contours and bbox contours , _ = cv . findContours ( mask , cv . RETR_EXTERNAL , cv . CHAIN_APPROX_NONE ) if not contours : zero_bbox = np . array ([ 0 , 0 , 0 , 0 ]) self . _all_bboxes [ frame_idx ] = zero_bbox return zero_bbox largest_contour = max ( contours , key = cv . contourArea ) largest_bbox = cv . boundingRect ( largest_contour ) largest_bbox = np . asanyarray ( largest_bbox , dtype = int ) return largest_bbox def calc_specified_boxes ( self , frame_indices : Collection [ int ], num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ): self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ): self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices , :] return bboxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size )","title":"Module wtracker.dataset.box_calculator"},{"location":"reference/wtracker/dataset/box_calculator/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/dataset/box_calculator/#boxcalculator","text":"class BoxCalculator ( frame_reader : wtracker . utils . frame_reader . FrameReader , background : numpy . ndarray , diff_thresh : int = 20 ) A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box.","title":"BoxCalculator"},{"location":"reference/wtracker/dataset/box_calculator/#attributes","text":"Name Type Description Default frame_reader FrameReader The frame reader object holing the relevant frames. None background np.ndarray The background image of the frames in the frame_reader argument. None diff_thresh int Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. None View Source class BoxCalculator : \"\"\" A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Args: frame_reader (FrameReader): The frame reader object holing the relevant frames. background (np.ndarray): The background image of the frames in the `frame_reader` argument. diff_thresh (int, optional): Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. \"\"\" def __init__ ( self , frame_reader : FrameReader , background : np . ndarray , diff_thresh : int = 20 , ) -> None : assert diff_thresh > 0 , \"Difference threshold must be greater than 0.\" assert frame_reader . frame_shape == background . shape , \"Background shape must match frame shape.\" # convert background to grayscale if needed if background . ndim == 3 and background . shape [ 2 ] == 3 : background = cv . cvtColor ( background , cv . COLOR_BGR2GRAY ) if background . ndim != 2 : raise ValueError ( \"background must be either a gray or a color image.\" ) self . _frame_reader = frame_reader self . _background = background self . _diff_thresh = diff_thresh self . _all_bboxes = np . full (( len ( frame_reader ), 4 ), - 1 , dtype = int ) def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn ' t calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox def _calc_bounding_box ( self , frame_idx : int ) -> np . ndarray : # get mask according to the threshold value frame = self . _frame_reader [ frame_idx ] # convert to grayscale if needed if frame . ndim == 3 and frame . shape [ 2 ] == 3 : frame = cv . cvtColor ( frame , cv . COLOR_BGR2GRAY ) diff = cv . absdiff ( frame , self . _background ) _ , mask = cv . threshold ( diff , self . _diff_thresh , 255 , cv . THRESH_BINARY ) # apply morphological ops to the mask mask = cv . morphologyEx ( mask , cv . MORPH_OPEN , np . ones (( 5 , 5 ), np . uint8 )) mask = cv . dilate ( mask , np . ones (( 11 , 11 ), np . uint8 )) # extract contours and bbox contours , _ = cv . findContours ( mask , cv . RETR_EXTERNAL , cv . CHAIN_APPROX_NONE ) if not contours : zero_bbox = np . array ( [ 0, 0, 0, 0 ] ) self . _all_bboxes [ frame_idx ] = zero_bbox return zero_bbox largest_contour = max ( contours , key = cv . contourArea ) largest_bbox = cv . boundingRect ( largest_contour ) largest_bbox = np . asanyarray ( largest_bbox , dtype = int ) return largest_bbox def calc_specified_boxes ( self , frame_indices : Collection [ int ] , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ) : self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ) : self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices, : ] return bboxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size )","title":"Attributes"},{"location":"reference/wtracker/dataset/box_calculator/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/dataset/box_calculator/#all_bboxes","text":"def all_bboxes ( self ) -> numpy . ndarray Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: Type Description np.ndarray Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). View Source def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes","title":"all_bboxes"},{"location":"reference/wtracker/dataset/box_calculator/#calc_all_boxes","text":"def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 ) -> numpy . ndarray Calculate bounding boxes for all frames. Parameters: Name Type Description Default num_workers int Number of workers for parallel processing. If None is provided then number of workers is determined automatically. None chunk_size int Size of each chunk for parallel processing. None Returns: Type Description np.ndarray Array of bounding boxes for all frames. View Source def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size )","title":"calc_all_boxes"},{"location":"reference/wtracker/dataset/box_calculator/#calc_specified_boxes","text":"def calc_specified_boxes ( self , frame_indices : Collection [ int ], num_workers : int = None , chunk_size : int = 50 ) -> numpy . ndarray Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. View Source def calc_specified_boxes ( self , frame_indices : Collection [ int ] , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ) : self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ) : self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices, : ] return bboxes","title":"calc_specified_boxes"},{"location":"reference/wtracker/dataset/box_calculator/#get_bbox","text":"def get_bbox ( self , frame_idx : int ) -> numpy . ndarray Returns the bounding box for a given frame index. Parameters: Name Type Description Default frame_idx int The index of the frame from which to extract the bounding box. None Returns: Type Description np.ndarray The bounding box coordinates as a numpy array, in format (x, y, w, h). View Source def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn ' t calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox","title":"get_bbox"},{"location":"reference/wtracker/dataset/sample_extractor/","text":"Module wtracker.dataset.sample_extractor View Source import numpy as np from typing import Collection from wtracker.dataset.box_calculator import BoxCalculator from wtracker.utils.bbox_utils import BoxUtils from wtracker.utils.io_utils import FrameSaver class SampleExtractor : \"\"\" A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Args: bbox_calculator (BoxCalculator): An instance of the BoxCalculator class. \"\"\" def __init__ ( self , bbox_calculator : BoxCalculator ): self . _bbox_calculator = bbox_calculator self . _frame_reader = bbox_calculator . _frame_reader def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int , int ]) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ): raise ValueError () if np . any ( x + w > frame_size [ 0 ]) or np . any ( y + h > frame_size [ 1 ]): raise ValueError () return BoxUtils . pack ( x , y , w , h ) def create_specified_samples ( self , frame_indices : Collection [ int ], target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ]) h = np . full_like ( x , target_size [ 1 ]) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # (h, w) -> (w, h) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ): saver . schedule_save ( i , bbox , name_format . format ( i )) def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) def create_all_samples ( self , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range ( 0 , len ( self . _frame_reader )) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) Classes SampleExtractor class SampleExtractor ( bbox_calculator : wtracker . dataset . box_calculator . BoxCalculator ) A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Attributes Name Type Description Default bbox_calculator BoxCalculator An instance of the BoxCalculator class. None View Source class SampleExtractor : \"\"\" A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Args: bbox_calculator (BoxCalculator): An instance of the BoxCalculator class. \"\"\" def __init__ ( self , bbox_calculator : BoxCalculator ) : self . _bbox_calculator = bbox_calculator self . _frame_reader = bbox_calculator . _frame_reader def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int, int ] ) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ) : raise ValueError () if np . any ( x + w > frame_size [ 0 ] ) or np . any ( y + h > frame_size [ 1 ] ) : raise ValueError () return BoxUtils . pack ( x , y , w , h ) def create_specified_samples ( self , frame_indices : Collection [ int ] , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ] ) h = np . full_like ( x , target_size [ 1 ] ) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # ( h , w ) -> ( w , h ) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ) : saver . schedule_save ( i , bbox , name_format . format ( i )) def create_samples ( self , count : int , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) def create_all_samples ( self , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range ( 0 , len ( self . _frame_reader )) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) Methods create_all_samples def create_all_samples ( self , target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates samples for all frames. Parameters: Name Type Description Default target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def create_all_samples( self, target_size: tuple[int, int], save_folder: str, name_format: str = \"img_{:09d}.png\", num_workers: int = None, chunk_size: int = 50, ): \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range(0, len(self._frame_reader)) self.create_specified_samples(frame_indices, target_size, save_folder, name_format, num_workers, chunk_size) create_samples def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates random samples based on a specified count. Parameters: Name Type Description Default count int The number of samples to create. None target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk sent to each worker. None View Source def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) create_specified_samples def create_specified_samples ( self , frame_indices : Collection [ int ], target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates specified samples based on the given frame indices. Parameters: Name Type Description Default frame_indices Collection[int] The indices of the frames to extract samples from. None target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def create_specified_samples ( self , frame_indices : Collection [ int ] , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ] ) h = np . full_like ( x , target_size [ 1 ] ) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # ( h , w ) -> ( w , h ) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ) : saver . schedule_save ( i , bbox , name_format . format ( i )) move_bboxes_into_bounds def move_bboxes_into_bounds ( self , bboxes : numpy . ndarray , frame_size : tuple [ int , int ] ) -> numpy . ndarray Moves the bounding boxes into the bounds of the frame. Parameters: Name Type Description Default bboxes np.ndarray The bounding boxes to be moved. None frame_size tuple[int, int] The size of the frame in the format (w, h). None Returns: Type Description np.ndarray The updated bounding boxes. Raises: Type Description ValueError If exists a bounding box which cannot be moved into the provided bounds without resizing it. View Source def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int, int ] ) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ) : raise ValueError () if np . any ( x + w > frame_size [ 0 ] ) or np . any ( y + h > frame_size [ 1 ] ) : raise ValueError () return BoxUtils . pack ( x , y , w , h )","title":"Sample Extractor"},{"location":"reference/wtracker/dataset/sample_extractor/#module-wtrackerdatasetsample_extractor","text":"View Source import numpy as np from typing import Collection from wtracker.dataset.box_calculator import BoxCalculator from wtracker.utils.bbox_utils import BoxUtils from wtracker.utils.io_utils import FrameSaver class SampleExtractor : \"\"\" A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Args: bbox_calculator (BoxCalculator): An instance of the BoxCalculator class. \"\"\" def __init__ ( self , bbox_calculator : BoxCalculator ): self . _bbox_calculator = bbox_calculator self . _frame_reader = bbox_calculator . _frame_reader def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int , int ]) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ): raise ValueError () if np . any ( x + w > frame_size [ 0 ]) or np . any ( y + h > frame_size [ 1 ]): raise ValueError () return BoxUtils . pack ( x , y , w , h ) def create_specified_samples ( self , frame_indices : Collection [ int ], target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ]) h = np . full_like ( x , target_size [ 1 ]) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # (h, w) -> (w, h) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ): saver . schedule_save ( i , bbox , name_format . format ( i )) def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) def create_all_samples ( self , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range ( 0 , len ( self . _frame_reader )) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size )","title":"Module wtracker.dataset.sample_extractor"},{"location":"reference/wtracker/dataset/sample_extractor/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/dataset/sample_extractor/#sampleextractor","text":"class SampleExtractor ( bbox_calculator : wtracker . dataset . box_calculator . BoxCalculator ) A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models.","title":"SampleExtractor"},{"location":"reference/wtracker/dataset/sample_extractor/#attributes","text":"Name Type Description Default bbox_calculator BoxCalculator An instance of the BoxCalculator class. None View Source class SampleExtractor : \"\"\" A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Args: bbox_calculator (BoxCalculator): An instance of the BoxCalculator class. \"\"\" def __init__ ( self , bbox_calculator : BoxCalculator ) : self . _bbox_calculator = bbox_calculator self . _frame_reader = bbox_calculator . _frame_reader def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int, int ] ) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ) : raise ValueError () if np . any ( x + w > frame_size [ 0 ] ) or np . any ( y + h > frame_size [ 1 ] ) : raise ValueError () return BoxUtils . pack ( x , y , w , h ) def create_specified_samples ( self , frame_indices : Collection [ int ] , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ] ) h = np . full_like ( x , target_size [ 1 ] ) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # ( h , w ) -> ( w , h ) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ) : saver . schedule_save ( i , bbox , name_format . format ( i )) def create_samples ( self , count : int , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) def create_all_samples ( self , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range ( 0 , len ( self . _frame_reader )) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size )","title":"Attributes"},{"location":"reference/wtracker/dataset/sample_extractor/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/dataset/sample_extractor/#create_all_samples","text":"def create_all_samples ( self , target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates samples for all frames. Parameters: Name Type Description Default target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def create_all_samples( self, target_size: tuple[int, int], save_folder: str, name_format: str = \"img_{:09d}.png\", num_workers: int = None, chunk_size: int = 50, ): \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range(0, len(self._frame_reader)) self.create_specified_samples(frame_indices, target_size, save_folder, name_format, num_workers, chunk_size)","title":"create_all_samples"},{"location":"reference/wtracker/dataset/sample_extractor/#create_samples","text":"def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates random samples based on a specified count. Parameters: Name Type Description Default count int The number of samples to create. None target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk sent to each worker. None View Source def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size )","title":"create_samples"},{"location":"reference/wtracker/dataset/sample_extractor/#create_specified_samples","text":"def create_specified_samples ( self , frame_indices : Collection [ int ], target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates specified samples based on the given frame indices. Parameters: Name Type Description Default frame_indices Collection[int] The indices of the frames to extract samples from. None target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def create_specified_samples ( self , frame_indices : Collection [ int ] , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ] ) h = np . full_like ( x , target_size [ 1 ] ) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # ( h , w ) -> ( w , h ) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ) : saver . schedule_save ( i , bbox , name_format . format ( i ))","title":"create_specified_samples"},{"location":"reference/wtracker/dataset/sample_extractor/#move_bboxes_into_bounds","text":"def move_bboxes_into_bounds ( self , bboxes : numpy . ndarray , frame_size : tuple [ int , int ] ) -> numpy . ndarray Moves the bounding boxes into the bounds of the frame. Parameters: Name Type Description Default bboxes np.ndarray The bounding boxes to be moved. None frame_size tuple[int, int] The size of the frame in the format (w, h). None Returns: Type Description np.ndarray The updated bounding boxes. Raises: Type Description ValueError If exists a bounding box which cannot be moved into the provided bounds without resizing it. View Source def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int, int ] ) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ) : raise ValueError () if np . any ( x + w > frame_size [ 0 ] ) or np . any ( y + h > frame_size [ 1 ] ) : raise ValueError () return BoxUtils . pack ( x , y , w , h )","title":"move_bboxes_into_bounds"},{"location":"reference/wtracker/eval/","text":"Module wtracker.eval View Source from wtracker.eval.plotter import Plotter from wtracker.eval.data_analyzer import DataAnalyzer from wtracker.eval.error_calculator import ErrorCalculator from wtracker.eval.vlc import VLC , StreamViewer , HotKey Sub-modules wtracker.eval.data_analyzer wtracker.eval.error_calculator wtracker.eval.plotter wtracker.eval.vlc","title":"Index"},{"location":"reference/wtracker/eval/#module-wtrackereval","text":"View Source from wtracker.eval.plotter import Plotter from wtracker.eval.data_analyzer import DataAnalyzer from wtracker.eval.error_calculator import ErrorCalculator from wtracker.eval.vlc import VLC , StreamViewer , HotKey","title":"Module wtracker.eval"},{"location":"reference/wtracker/eval/#sub-modules","text":"wtracker.eval.data_analyzer wtracker.eval.error_calculator wtracker.eval.plotter wtracker.eval.vlc","title":"Sub-modules"},{"location":"reference/wtracker/eval/data_analyzer/","text":"Module wtracker.eval.data_analyzer View Source from __future__ import annotations import pandas as pd import numpy as np import tqdm.contrib.concurrent as concurrent from wtracker.sim.config import TimingConfig from wtracker.eval.error_calculator import ErrorCalculator from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import adjust_num_workers class DataAnalyzer : \"\"\" A class for analyzing simulation log. Args: time_config (TimingConfig): The timing configuration. log_path (pd.DataFrame): Dataframe containing the simulation log data. \"\"\" def __init__ ( self , time_config : TimingConfig , log_data : pd . DataFrame , ): self . time_config = time_config self . data = log_data . copy () self . _orig_data = log_data self . _unit = \"frame\" @property def unit ( self ) -> str : return self . _unit def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False ) @staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) def initialize ( self , period : int = 10 ): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self . _orig_data data [ \"time\" ] = data [ \"frame\" ] data [ \"cycle_step\" ] = data [ \"frame\" ] % self . time_config . cycle_frame_num data = DataAnalyzer . _calc_centers ( data ) data = DataAnalyzer . _calc_speed ( data , period ) data = DataAnalyzer . _calc_worm_deviation ( data ) data = DataAnalyzer . _calc_errors ( data ) data = data . round ( 5 ) self . _orig_data = data self . data = self . _orig_data . copy () @staticmethod def _calc_centers ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"wrm_center_x\" ] = data [ \"wrm_x\" ] + data [ \"wrm_w\" ] / 2 data [ \"wrm_center_y\" ] = data [ \"wrm_y\" ] + data [ \"wrm_h\" ] / 2 data [ \"mic_center_x\" ] = data [ \"mic_x\" ] + data [ \"mic_w\" ] / 2 data [ \"mic_center_y\" ] = data [ \"mic_y\" ] + data [ \"mic_h\" ] / 2 return data @staticmethod def _calc_speed ( data : pd . DataFrame , n : int ) -> pd . DataFrame : diff = data [ \"time\" ] . diff ( n ) . to_numpy () data [ \"wrm_speed_x\" ] = data [ \"wrm_center_x\" ] . diff ( n ) / diff data [ \"wrm_speed_y\" ] = data [ \"wrm_center_y\" ] . diff ( n ) / diff data [ \"wrm_speed\" ] = np . sqrt ( data [ \"wrm_speed_x\" ] ** 2 + data [ \"wrm_speed_y\" ] ** 2 ) return data @staticmethod def _calc_worm_deviation ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"worm_deviation_x\" ] = data [ \"wrm_center_x\" ] - data [ \"mic_center_x\" ] data [ \"worm_deviation_y\" ] = data [ \"wrm_center_y\" ] - data [ \"mic_center_y\" ] data [ \"worm_deviation\" ] = np . sqrt ( data [ \"worm_deviation_x\" ] ** 2 + data [ \"worm_deviation_y\" ] ** 2 ) return data @staticmethod def _calc_errors ( data : pd . DataFrame ) -> pd . DataFrame : wrm_bboxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () bbox_error = ErrorCalculator . calculate_bbox_error ( wrm_bboxes , mic_bboxes ) data [ \"bbox_error\" ] = bbox_error data [ \"precise_error\" ] = np . nan return data def remove_cycle ( self , cycles : int | list [ int ]): \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ): cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~ mask ] def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float , float , float , float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ]) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ]) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ]) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ]) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ]) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ]) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ]) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ]) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data def reset_changes ( self ): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self . data = self . _orig_data . copy () self . _unit = \"frame\" def column_names ( self ) -> list [ str ]: \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () def change_unit ( self , unit : str ): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [ \"frame\" , \"sec\" ] if self . _unit == unit : return data = self . data if unit == \"sec\" : # frame -> sec dist_factor = self . time_config . mm_per_px * 1000 time_factor = self . time_config . ms_per_frame / 1000 if unit == \"frame\" : # sec -> frame dist_factor = self . time_config . px_per_mm / 1000 time_factor = self . time_config . frames_per_sec data [ \"time\" ] *= time_factor data [[ \"plt_x\" , \"plt_y\" ]] *= dist_factor data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] *= dist_factor data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] *= dist_factor data [[ \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" ]] *= dist_factor data [[ \"wrm_center_x\" , \"wrm_center_y\" ]] *= dist_factor data [[ \"mic_center_x\" , \"mic_center_y\" ]] *= dist_factor data [[ \"worm_deviation_x\" , \"worm_deviation_y\" , \"worm_deviation\" ]] *= dist_factor data [[ \"wrm_speed_x\" , \"wrm_speed_y\" , \"wrm_speed\" ]] *= dist_factor / time_factor self . _unit = unit self . data = data # TODO: TEST # TODO: MAYBE REMOVE, THE non-multithreaded version works very fast for me for some reason # perhaps SSD is required for fast analysis. def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( int , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ) . all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO: add non-multithreaded case whenever num_workers=0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ], mic_bboxes = mic_bboxes_list [ idx ], frame_nums = frames_list [ idx ], worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~ mask ] return anomalies def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 : - 1 ] return self . data [ columns ] . describe ( percentiles ) def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: { num_removed } ( { round ( 100 * num_removed / len ( self . _orig_data . index ), 3 ) } %)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: { no_preds } ( { round ( 100 * no_preds / len ( self . data . index ), 3 ) } %)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: { num_cycles } \" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: { round ( 100 * non_perfect , 3 ) } %\" ) Classes DataAnalyzer class DataAnalyzer ( time_config : 'TimingConfig' , log_data : 'pd.DataFrame' ) A class for analyzing simulation log. Attributes Name Type Description Default time_config TimingConfig The timing configuration. None log_path pd.DataFrame Dataframe containing the simulation log data. None View Source class DataAnalyzer : \"\"\" A class for analyzing simulation log. Args: time_config (TimingConfig): The timing configuration. log_path (pd.DataFrame): Dataframe containing the simulation log data. \"\"\" def __init__ ( self , time_config : TimingConfig , log_data : pd . DataFrame , ): self . time_config = time_config self . data = log_data . copy () self . _orig_data = log_data self . _unit = \"frame\" @ property def unit ( self ) -> str : return self . _unit def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False ) @ staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) def initialize ( self , period : int = 10 ): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self . _orig_data data [ \"time\" ] = data [ \"frame\" ] data [ \"cycle_step\" ] = data [ \"frame\" ] % self . time_config . cycle_frame_num data = DataAnalyzer . _calc_centers ( data ) data = DataAnalyzer . _calc_speed ( data , period ) data = DataAnalyzer . _calc_worm_deviation ( data ) data = DataAnalyzer . _calc_errors ( data ) data = data . round ( 5 ) self . _orig_data = data self . data = self . _orig_data . copy () @ staticmethod def _calc_centers ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"wrm_center_x\" ] = data [ \"wrm_x\" ] + data [ \"wrm_w\" ] / 2 data [ \"wrm_center_y\" ] = data [ \"wrm_y\" ] + data [ \"wrm_h\" ] / 2 data [ \"mic_center_x\" ] = data [ \"mic_x\" ] + data [ \"mic_w\" ] / 2 data [ \"mic_center_y\" ] = data [ \"mic_y\" ] + data [ \"mic_h\" ] / 2 return data @ staticmethod def _calc_speed ( data : pd . DataFrame , n : int ) -> pd . DataFrame : diff = data [ \"time\" ] . diff ( n ) . to_numpy () data [ \"wrm_speed_x\" ] = data [ \"wrm_center_x\" ] . diff ( n ) / diff data [ \"wrm_speed_y\" ] = data [ \"wrm_center_y\" ] . diff ( n ) / diff data [ \"wrm_speed\" ] = np . sqrt ( data [ \"wrm_speed_x\" ] ** 2 + data [ \"wrm_speed_y\" ] ** 2 ) return data @ staticmethod def _calc_worm_deviation ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"worm_deviation_x\" ] = data [ \"wrm_center_x\" ] - data [ \"mic_center_x\" ] data [ \"worm_deviation_y\" ] = data [ \"wrm_center_y\" ] - data [ \"mic_center_y\" ] data [ \"worm_deviation\" ] = np . sqrt ( data [ \"worm_deviation_x\" ] ** 2 + data [ \"worm_deviation_y\" ] ** 2 ) return data @ staticmethod def _calc_errors ( data : pd . DataFrame ) -> pd . DataFrame : wrm_bboxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () bbox_error = ErrorCalculator . calculate_bbox_error ( wrm_bboxes , mic_bboxes ) data [ \"bbox_error\" ] = bbox_error data [ \"precise_error\" ] = np . nan return data def remove_cycle ( self , cycles : int | list [ int ]): \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ): cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~ mask ] def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float , float , float , float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ]) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ]) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ]) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ]) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ]) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ]) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ]) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ]) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data def reset_changes ( self ): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self . data = self . _orig_data . copy () self . _unit = \"frame\" def column_names ( self ) -> list [ str ]: \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () def change_unit ( self , unit : str ): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [ \"frame\" , \"sec\" ] if self . _unit == unit : return data = self . data if unit == \"sec\" : # frame -> sec dist_factor = self . time_config . mm_per_px * 1000 time_factor = self . time_config . ms_per_frame / 1000 if unit == \"frame\" : # sec -> frame dist_factor = self . time_config . px_per_mm / 1000 time_factor = self . time_config . frames_per_sec data [ \"time\" ] *= time_factor data [[ \"plt_x\" , \"plt_y\" ]] *= dist_factor data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] *= dist_factor data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] *= dist_factor data [[ \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" ]] *= dist_factor data [[ \"wrm_center_x\" , \"wrm_center_y\" ]] *= dist_factor data [[ \"mic_center_x\" , \"mic_center_y\" ]] *= dist_factor data [[ \"worm_deviation_x\" , \"worm_deviation_y\" , \"worm_deviation\" ]] *= dist_factor data [[ \"wrm_speed_x\" , \"wrm_speed_y\" , \"wrm_speed\" ]] *= dist_factor / time_factor self . _unit = unit self . data = data # TODO: TEST # TODO: MAYBE REMOVE, THE non-multithreaded version works very fast for me for some reason # perhaps SSD is required for fast analysis. def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( int , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ) . all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO: add non-multithreaded case whenever num_workers=0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ], mic_bboxes = mic_bboxes_list [ idx ], frame_nums = frames_list [ idx ], worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~ mask ] return anomalies def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 : - 1 ] return self . data [ columns ] . describe ( percentiles ) def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: {num_removed} ({round(100 * num_removed / len(self._orig_data.index), 3)}%)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: {no_preds} ({round(100 * no_preds / len(self.data.index), 3)}%)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: {num_cycles}\" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: {round(100 * non_perfect, 3)}%\" ) Static methods load def load ( time_config : 'TimingConfig' , csv_path : 'str' ) -> 'DataAnalyzer' Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Parameters: Name Type Description Default time_config TimingConfig The timing configuration. None csv_path str Path to the csv file containing the experiment data. None View Source @ staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) Instance variables unit Methods calc_anomalies def calc_anomalies ( self , no_preds : 'bool' = True , min_bbox_error : 'float' = inf , min_dist_error : 'float' = inf , min_speed : 'float' = inf , min_size : 'float' = inf , remove_anomalies : 'bool' = False ) -> 'pd.DataFrame' Calculate anomalies in the data based on specified criteria. Parameters: Name Type Description Default no_preds bool Flag indicating whether to consider instances with missing predictions. None min_bbox_error float Minimum bounding box error threshold to consider as anomaly. None min_dist_error float Minimum distance error threshold to consider as anomaly. None min_speed float Minimum speed threshold to consider as anomaly. None min_size float Minimum size threshold to consider as anomaly. None remove_anomalies bool Flag indicating whether to remove the anomalies from the data. None Returns: Type Description pd.DataFrame DataFrame containing the anomalies found in the data. View Source def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ()). all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~mask ] return anomalies calc_precise_error def calc_precise_error ( self , worm_reader : 'FrameReader' , background : 'np.ndarray' , diff_thresh = 20 ) -> 'None' Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Parameters: Name Type Description Default worm_reader FrameReader Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. None background np.ndarray The background image of the entire experiment. None diff_thresh int Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None View Source def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy (). astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy () mic_bboxes = self . _orig_data [ [\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\" ] ] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] calc_precise_error_experimental def calc_precise_error_experimental ( self , worm_reader : 'FrameReader' , background : 'np.ndarray' , diff_thresh = 20 , num_workers : 'int' = None , chunk_size : 'int' = 2000 ) -> 'None' Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Parameters: Name Type Description Default worm_reader FrameReader Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. None background np.ndarray The background image of the entire experiment. None diff_thresh int Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy (). astype ( int , copy = False ) wrm_bboxes = self . _orig_data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy () mic_bboxes = self . _orig_data [ [\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\" ] ] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ). all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO : add non - multithreaded case whenever num_workers = 0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ] , mic_bboxes = mic_bboxes_list [ idx ] , frame_nums = frames_list [ idx ] , worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] change_unit def change_unit ( self , unit : 'str' ) Changes the unit of time and distance in the data. Parameters: Name Type Description Default unit str The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. None View Source def change_unit(self, unit: str): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [\"frame\", \"sec\"] if self._unit == unit: return data = self.data if unit == \"sec\": # frame -> sec dist_factor = self.time_config.mm_per_px * 1000 time_factor = self.time_config.ms_per_frame / 1000 if unit == \"frame\": # sec -> frame dist_factor = self.time_config.px_per_mm / 1000 time_factor = self.time_config.frames_per_sec data[\"time\"] * = time_factor data[[\"plt_x\", \"plt_y\"]] *= dist_factor data[[\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\"]] * = dist_factor data[[\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\"]] *= dist_factor data[[\"cam_x\", \"cam_y\", \"cam_w\", \"cam_h\"]] * = dist_factor data[[\"wrm_center_x\", \"wrm_center_y\"]] *= dist_factor data[[\"mic_center_x\", \"mic_center_y\"]] * = dist_factor data[[\"worm_deviation_x\", \"worm_deviation_y\", \"worm_deviation\"]] *= dist_factor data[[\"wrm_speed_x\", \"wrm_speed_y\", \"wrm_speed\"]] * = dist_factor / time_factor self._unit = unit self.data = data clean def clean ( self , trim_cycles : 'bool' = False , imaging_only : 'bool' = False , bounds : 'tuple[float, float, float, float]' = None ) -> 'None' Clean the data by the provided parameters. Parameters: Name Type Description Default trim_cycles bool whether to remove the first and the last cycles from the data. None imaging_only bool Flag indicating whether to include only imaging phases in the analysis. None legal_bounds tuple[float, float, float, float] The legal bounds for worm movement. None View Source def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float, float, float, float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ()). all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ] ) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ] ) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ] ) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ] ) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ] ) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ] ) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ] ) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ] ) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data column_names def column_names ( self ) -> 'list[str]' Returns a list of all column names in the analyzed data. Returns: Type Description list[str] A list of column names. View Source def column_names ( self ) -> list [ str ] : \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () describe def describe ( self , columns : 'list[str]' = None , num : 'int' = 3 , percentiles : 'list[float]' = None ) -> 'pd.DataFrame' Generate descriptive statistics of the specified columns in the table containing the data. Parameters: Name Type Description Default columns list[str] List of column names to include in the analysis. If None, all columns will be included. None num int Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. None percentiles list[float] List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. None Returns: Type Description pd.DataFrame A DataFrame containing the descriptive statistics of the specified columns. View Source def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 :- 1 ] return self . data [ columns ]. describe ( percentiles ) initialize def initialize ( self , period : 'int' = 10 ) Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Parameters: Name Type Description Default period int The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. None View Source def initialize(self, period: int = 10): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self._orig_data data[\"time\"] = data[\"frame\"] data[\"cycle_step\"] = data[\"frame\"] % self.time_config.cycle_frame_num data = DataAnalyzer._calc_centers(data) data = DataAnalyzer._calc_speed(data, period) data = DataAnalyzer._calc_worm_deviation(data) data = DataAnalyzer._calc_errors(data) data = data.round(5) self._orig_data = data self.data = self._orig_data.copy() print_stats def print_stats ( self ) -> 'None' Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. View Source def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: {num_removed} ({round(100 * num_removed / len(self._orig_data.index), 3)}%)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: {no_preds} ({round(100 * no_preds / len(self.data.index), 3)}%)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: {num_cycles}\" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: {round(100 * non_perfect, 3)}%\" ) remove_cycle def remove_cycle ( self , cycles : 'int | list[int]' ) Remove the specified cycles from the data. Parameters: Name Type Description Default cycles int list[int] The cycle(s) to remove from the data. View Source def remove_cycle ( self , cycles : int | list [ int ] ) : \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ) : cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~mask ] reset_changes def reset_changes ( self ) Reset the data to its original state. Note, that this method will not reset the unit of time and distance. View Source def reset_changes(self): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self.data = self._orig_data.copy() self._unit = \"frame\" save def save ( self , path : 'str' ) -> 'None' Save the full analyzed data to a csv file. View Source def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False )","title":"Data Analyzer"},{"location":"reference/wtracker/eval/data_analyzer/#module-wtrackerevaldata_analyzer","text":"View Source from __future__ import annotations import pandas as pd import numpy as np import tqdm.contrib.concurrent as concurrent from wtracker.sim.config import TimingConfig from wtracker.eval.error_calculator import ErrorCalculator from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import adjust_num_workers class DataAnalyzer : \"\"\" A class for analyzing simulation log. Args: time_config (TimingConfig): The timing configuration. log_path (pd.DataFrame): Dataframe containing the simulation log data. \"\"\" def __init__ ( self , time_config : TimingConfig , log_data : pd . DataFrame , ): self . time_config = time_config self . data = log_data . copy () self . _orig_data = log_data self . _unit = \"frame\" @property def unit ( self ) -> str : return self . _unit def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False ) @staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) def initialize ( self , period : int = 10 ): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self . _orig_data data [ \"time\" ] = data [ \"frame\" ] data [ \"cycle_step\" ] = data [ \"frame\" ] % self . time_config . cycle_frame_num data = DataAnalyzer . _calc_centers ( data ) data = DataAnalyzer . _calc_speed ( data , period ) data = DataAnalyzer . _calc_worm_deviation ( data ) data = DataAnalyzer . _calc_errors ( data ) data = data . round ( 5 ) self . _orig_data = data self . data = self . _orig_data . copy () @staticmethod def _calc_centers ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"wrm_center_x\" ] = data [ \"wrm_x\" ] + data [ \"wrm_w\" ] / 2 data [ \"wrm_center_y\" ] = data [ \"wrm_y\" ] + data [ \"wrm_h\" ] / 2 data [ \"mic_center_x\" ] = data [ \"mic_x\" ] + data [ \"mic_w\" ] / 2 data [ \"mic_center_y\" ] = data [ \"mic_y\" ] + data [ \"mic_h\" ] / 2 return data @staticmethod def _calc_speed ( data : pd . DataFrame , n : int ) -> pd . DataFrame : diff = data [ \"time\" ] . diff ( n ) . to_numpy () data [ \"wrm_speed_x\" ] = data [ \"wrm_center_x\" ] . diff ( n ) / diff data [ \"wrm_speed_y\" ] = data [ \"wrm_center_y\" ] . diff ( n ) / diff data [ \"wrm_speed\" ] = np . sqrt ( data [ \"wrm_speed_x\" ] ** 2 + data [ \"wrm_speed_y\" ] ** 2 ) return data @staticmethod def _calc_worm_deviation ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"worm_deviation_x\" ] = data [ \"wrm_center_x\" ] - data [ \"mic_center_x\" ] data [ \"worm_deviation_y\" ] = data [ \"wrm_center_y\" ] - data [ \"mic_center_y\" ] data [ \"worm_deviation\" ] = np . sqrt ( data [ \"worm_deviation_x\" ] ** 2 + data [ \"worm_deviation_y\" ] ** 2 ) return data @staticmethod def _calc_errors ( data : pd . DataFrame ) -> pd . DataFrame : wrm_bboxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () bbox_error = ErrorCalculator . calculate_bbox_error ( wrm_bboxes , mic_bboxes ) data [ \"bbox_error\" ] = bbox_error data [ \"precise_error\" ] = np . nan return data def remove_cycle ( self , cycles : int | list [ int ]): \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ): cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~ mask ] def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float , float , float , float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ]) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ]) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ]) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ]) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ]) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ]) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ]) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ]) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data def reset_changes ( self ): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self . data = self . _orig_data . copy () self . _unit = \"frame\" def column_names ( self ) -> list [ str ]: \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () def change_unit ( self , unit : str ): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [ \"frame\" , \"sec\" ] if self . _unit == unit : return data = self . data if unit == \"sec\" : # frame -> sec dist_factor = self . time_config . mm_per_px * 1000 time_factor = self . time_config . ms_per_frame / 1000 if unit == \"frame\" : # sec -> frame dist_factor = self . time_config . px_per_mm / 1000 time_factor = self . time_config . frames_per_sec data [ \"time\" ] *= time_factor data [[ \"plt_x\" , \"plt_y\" ]] *= dist_factor data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] *= dist_factor data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] *= dist_factor data [[ \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" ]] *= dist_factor data [[ \"wrm_center_x\" , \"wrm_center_y\" ]] *= dist_factor data [[ \"mic_center_x\" , \"mic_center_y\" ]] *= dist_factor data [[ \"worm_deviation_x\" , \"worm_deviation_y\" , \"worm_deviation\" ]] *= dist_factor data [[ \"wrm_speed_x\" , \"wrm_speed_y\" , \"wrm_speed\" ]] *= dist_factor / time_factor self . _unit = unit self . data = data # TODO: TEST # TODO: MAYBE REMOVE, THE non-multithreaded version works very fast for me for some reason # perhaps SSD is required for fast analysis. def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( int , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ) . all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO: add non-multithreaded case whenever num_workers=0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ], mic_bboxes = mic_bboxes_list [ idx ], frame_nums = frames_list [ idx ], worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~ mask ] return anomalies def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 : - 1 ] return self . data [ columns ] . describe ( percentiles ) def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: { num_removed } ( { round ( 100 * num_removed / len ( self . _orig_data . index ), 3 ) } %)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: { no_preds } ( { round ( 100 * no_preds / len ( self . data . index ), 3 ) } %)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: { num_cycles } \" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: { round ( 100 * non_perfect , 3 ) } %\" )","title":"Module wtracker.eval.data_analyzer"},{"location":"reference/wtracker/eval/data_analyzer/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/eval/data_analyzer/#dataanalyzer","text":"class DataAnalyzer ( time_config : 'TimingConfig' , log_data : 'pd.DataFrame' ) A class for analyzing simulation log.","title":"DataAnalyzer"},{"location":"reference/wtracker/eval/data_analyzer/#attributes","text":"Name Type Description Default time_config TimingConfig The timing configuration. None log_path pd.DataFrame Dataframe containing the simulation log data. None View Source class DataAnalyzer : \"\"\" A class for analyzing simulation log. Args: time_config (TimingConfig): The timing configuration. log_path (pd.DataFrame): Dataframe containing the simulation log data. \"\"\" def __init__ ( self , time_config : TimingConfig , log_data : pd . DataFrame , ): self . time_config = time_config self . data = log_data . copy () self . _orig_data = log_data self . _unit = \"frame\" @ property def unit ( self ) -> str : return self . _unit def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False ) @ staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) def initialize ( self , period : int = 10 ): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self . _orig_data data [ \"time\" ] = data [ \"frame\" ] data [ \"cycle_step\" ] = data [ \"frame\" ] % self . time_config . cycle_frame_num data = DataAnalyzer . _calc_centers ( data ) data = DataAnalyzer . _calc_speed ( data , period ) data = DataAnalyzer . _calc_worm_deviation ( data ) data = DataAnalyzer . _calc_errors ( data ) data = data . round ( 5 ) self . _orig_data = data self . data = self . _orig_data . copy () @ staticmethod def _calc_centers ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"wrm_center_x\" ] = data [ \"wrm_x\" ] + data [ \"wrm_w\" ] / 2 data [ \"wrm_center_y\" ] = data [ \"wrm_y\" ] + data [ \"wrm_h\" ] / 2 data [ \"mic_center_x\" ] = data [ \"mic_x\" ] + data [ \"mic_w\" ] / 2 data [ \"mic_center_y\" ] = data [ \"mic_y\" ] + data [ \"mic_h\" ] / 2 return data @ staticmethod def _calc_speed ( data : pd . DataFrame , n : int ) -> pd . DataFrame : diff = data [ \"time\" ] . diff ( n ) . to_numpy () data [ \"wrm_speed_x\" ] = data [ \"wrm_center_x\" ] . diff ( n ) / diff data [ \"wrm_speed_y\" ] = data [ \"wrm_center_y\" ] . diff ( n ) / diff data [ \"wrm_speed\" ] = np . sqrt ( data [ \"wrm_speed_x\" ] ** 2 + data [ \"wrm_speed_y\" ] ** 2 ) return data @ staticmethod def _calc_worm_deviation ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"worm_deviation_x\" ] = data [ \"wrm_center_x\" ] - data [ \"mic_center_x\" ] data [ \"worm_deviation_y\" ] = data [ \"wrm_center_y\" ] - data [ \"mic_center_y\" ] data [ \"worm_deviation\" ] = np . sqrt ( data [ \"worm_deviation_x\" ] ** 2 + data [ \"worm_deviation_y\" ] ** 2 ) return data @ staticmethod def _calc_errors ( data : pd . DataFrame ) -> pd . DataFrame : wrm_bboxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () bbox_error = ErrorCalculator . calculate_bbox_error ( wrm_bboxes , mic_bboxes ) data [ \"bbox_error\" ] = bbox_error data [ \"precise_error\" ] = np . nan return data def remove_cycle ( self , cycles : int | list [ int ]): \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ): cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~ mask ] def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float , float , float , float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ]) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ]) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ]) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ]) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ]) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ]) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ]) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ]) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data def reset_changes ( self ): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self . data = self . _orig_data . copy () self . _unit = \"frame\" def column_names ( self ) -> list [ str ]: \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () def change_unit ( self , unit : str ): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [ \"frame\" , \"sec\" ] if self . _unit == unit : return data = self . data if unit == \"sec\" : # frame -> sec dist_factor = self . time_config . mm_per_px * 1000 time_factor = self . time_config . ms_per_frame / 1000 if unit == \"frame\" : # sec -> frame dist_factor = self . time_config . px_per_mm / 1000 time_factor = self . time_config . frames_per_sec data [ \"time\" ] *= time_factor data [[ \"plt_x\" , \"plt_y\" ]] *= dist_factor data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] *= dist_factor data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] *= dist_factor data [[ \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" ]] *= dist_factor data [[ \"wrm_center_x\" , \"wrm_center_y\" ]] *= dist_factor data [[ \"mic_center_x\" , \"mic_center_y\" ]] *= dist_factor data [[ \"worm_deviation_x\" , \"worm_deviation_y\" , \"worm_deviation\" ]] *= dist_factor data [[ \"wrm_speed_x\" , \"wrm_speed_y\" , \"wrm_speed\" ]] *= dist_factor / time_factor self . _unit = unit self . data = data # TODO: TEST # TODO: MAYBE REMOVE, THE non-multithreaded version works very fast for me for some reason # perhaps SSD is required for fast analysis. def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( int , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ) . all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO: add non-multithreaded case whenever num_workers=0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ], mic_bboxes = mic_bboxes_list [ idx ], frame_nums = frames_list [ idx ], worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~ mask ] return anomalies def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 : - 1 ] return self . data [ columns ] . describe ( percentiles ) def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: {num_removed} ({round(100 * num_removed / len(self._orig_data.index), 3)}%)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: {no_preds} ({round(100 * no_preds / len(self.data.index), 3)}%)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: {num_cycles}\" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: {round(100 * non_perfect, 3)}%\" )","title":"Attributes"},{"location":"reference/wtracker/eval/data_analyzer/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/eval/data_analyzer/#load","text":"def load ( time_config : 'TimingConfig' , csv_path : 'str' ) -> 'DataAnalyzer' Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Parameters: Name Type Description Default time_config TimingConfig The timing configuration. None csv_path str Path to the csv file containing the experiment data. None View Source @ staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data )","title":"load"},{"location":"reference/wtracker/eval/data_analyzer/#instance-variables","text":"unit","title":"Instance variables"},{"location":"reference/wtracker/eval/data_analyzer/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/eval/data_analyzer/#calc_anomalies","text":"def calc_anomalies ( self , no_preds : 'bool' = True , min_bbox_error : 'float' = inf , min_dist_error : 'float' = inf , min_speed : 'float' = inf , min_size : 'float' = inf , remove_anomalies : 'bool' = False ) -> 'pd.DataFrame' Calculate anomalies in the data based on specified criteria. Parameters: Name Type Description Default no_preds bool Flag indicating whether to consider instances with missing predictions. None min_bbox_error float Minimum bounding box error threshold to consider as anomaly. None min_dist_error float Minimum distance error threshold to consider as anomaly. None min_speed float Minimum speed threshold to consider as anomaly. None min_size float Minimum size threshold to consider as anomaly. None remove_anomalies bool Flag indicating whether to remove the anomalies from the data. None Returns: Type Description pd.DataFrame DataFrame containing the anomalies found in the data. View Source def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ()). all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~mask ] return anomalies","title":"calc_anomalies"},{"location":"reference/wtracker/eval/data_analyzer/#calc_precise_error","text":"def calc_precise_error ( self , worm_reader : 'FrameReader' , background : 'np.ndarray' , diff_thresh = 20 ) -> 'None' Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Parameters: Name Type Description Default worm_reader FrameReader Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. None background np.ndarray The background image of the entire experiment. None diff_thresh int Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None View Source def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy (). astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy () mic_bboxes = self . _orig_data [ [\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\" ] ] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ]","title":"calc_precise_error"},{"location":"reference/wtracker/eval/data_analyzer/#calc_precise_error_experimental","text":"def calc_precise_error_experimental ( self , worm_reader : 'FrameReader' , background : 'np.ndarray' , diff_thresh = 20 , num_workers : 'int' = None , chunk_size : 'int' = 2000 ) -> 'None' Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Parameters: Name Type Description Default worm_reader FrameReader Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. None background np.ndarray The background image of the entire experiment. None diff_thresh int Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy (). astype ( int , copy = False ) wrm_bboxes = self . _orig_data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy () mic_bboxes = self . _orig_data [ [\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\" ] ] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ). all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO : add non - multithreaded case whenever num_workers = 0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ] , mic_bboxes = mic_bboxes_list [ idx ] , frame_nums = frames_list [ idx ] , worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ]","title":"calc_precise_error_experimental"},{"location":"reference/wtracker/eval/data_analyzer/#change_unit","text":"def change_unit ( self , unit : 'str' ) Changes the unit of time and distance in the data. Parameters: Name Type Description Default unit str The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. None View Source def change_unit(self, unit: str): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [\"frame\", \"sec\"] if self._unit == unit: return data = self.data if unit == \"sec\": # frame -> sec dist_factor = self.time_config.mm_per_px * 1000 time_factor = self.time_config.ms_per_frame / 1000 if unit == \"frame\": # sec -> frame dist_factor = self.time_config.px_per_mm / 1000 time_factor = self.time_config.frames_per_sec data[\"time\"] * = time_factor data[[\"plt_x\", \"plt_y\"]] *= dist_factor data[[\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\"]] * = dist_factor data[[\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\"]] *= dist_factor data[[\"cam_x\", \"cam_y\", \"cam_w\", \"cam_h\"]] * = dist_factor data[[\"wrm_center_x\", \"wrm_center_y\"]] *= dist_factor data[[\"mic_center_x\", \"mic_center_y\"]] * = dist_factor data[[\"worm_deviation_x\", \"worm_deviation_y\", \"worm_deviation\"]] *= dist_factor data[[\"wrm_speed_x\", \"wrm_speed_y\", \"wrm_speed\"]] * = dist_factor / time_factor self._unit = unit self.data = data","title":"change_unit"},{"location":"reference/wtracker/eval/data_analyzer/#clean","text":"def clean ( self , trim_cycles : 'bool' = False , imaging_only : 'bool' = False , bounds : 'tuple[float, float, float, float]' = None ) -> 'None' Clean the data by the provided parameters. Parameters: Name Type Description Default trim_cycles bool whether to remove the first and the last cycles from the data. None imaging_only bool Flag indicating whether to include only imaging phases in the analysis. None legal_bounds tuple[float, float, float, float] The legal bounds for worm movement. None View Source def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float, float, float, float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ()). all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ] ) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ] ) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ] ) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ] ) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ] ) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ] ) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ] ) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ] ) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data","title":"clean"},{"location":"reference/wtracker/eval/data_analyzer/#column_names","text":"def column_names ( self ) -> 'list[str]' Returns a list of all column names in the analyzed data. Returns: Type Description list[str] A list of column names. View Source def column_names ( self ) -> list [ str ] : \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list ()","title":"column_names"},{"location":"reference/wtracker/eval/data_analyzer/#describe","text":"def describe ( self , columns : 'list[str]' = None , num : 'int' = 3 , percentiles : 'list[float]' = None ) -> 'pd.DataFrame' Generate descriptive statistics of the specified columns in the table containing the data. Parameters: Name Type Description Default columns list[str] List of column names to include in the analysis. If None, all columns will be included. None num int Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. None percentiles list[float] List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. None Returns: Type Description pd.DataFrame A DataFrame containing the descriptive statistics of the specified columns. View Source def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 :- 1 ] return self . data [ columns ]. describe ( percentiles )","title":"describe"},{"location":"reference/wtracker/eval/data_analyzer/#initialize","text":"def initialize ( self , period : 'int' = 10 ) Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Parameters: Name Type Description Default period int The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. None View Source def initialize(self, period: int = 10): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self._orig_data data[\"time\"] = data[\"frame\"] data[\"cycle_step\"] = data[\"frame\"] % self.time_config.cycle_frame_num data = DataAnalyzer._calc_centers(data) data = DataAnalyzer._calc_speed(data, period) data = DataAnalyzer._calc_worm_deviation(data) data = DataAnalyzer._calc_errors(data) data = data.round(5) self._orig_data = data self.data = self._orig_data.copy()","title":"initialize"},{"location":"reference/wtracker/eval/data_analyzer/#print_stats","text":"def print_stats ( self ) -> 'None' Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. View Source def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: {num_removed} ({round(100 * num_removed / len(self._orig_data.index), 3)}%)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: {no_preds} ({round(100 * no_preds / len(self.data.index), 3)}%)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: {num_cycles}\" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: {round(100 * non_perfect, 3)}%\" )","title":"print_stats"},{"location":"reference/wtracker/eval/data_analyzer/#remove_cycle","text":"def remove_cycle ( self , cycles : 'int | list[int]' ) Remove the specified cycles from the data. Parameters: Name Type Description Default cycles int list[int] The cycle(s) to remove from the data. View Source def remove_cycle ( self , cycles : int | list [ int ] ) : \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ) : cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~mask ]","title":"remove_cycle"},{"location":"reference/wtracker/eval/data_analyzer/#reset_changes","text":"def reset_changes ( self ) Reset the data to its original state. Note, that this method will not reset the unit of time and distance. View Source def reset_changes(self): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self.data = self._orig_data.copy() self._unit = \"frame\"","title":"reset_changes"},{"location":"reference/wtracker/eval/data_analyzer/#save","text":"def save ( self , path : 'str' ) -> 'None' Save the full analyzed data to a csv file. View Source def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False )","title":"save"},{"location":"reference/wtracker/eval/error_calculator/","text":"Module wtracker.eval.error_calculator View Source from typing import Collection import numpy as np import cv2 as cv from tqdm.auto import tqdm from typing import Callable from wtracker.utils.frame_reader import FrameReader from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat class ErrorCalculator : \"\"\" The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view. \"\"\" # TODO: Kinda a weird solution, but it works for now. Maybe find a better way to do this. probe_hook : Callable [[ np . ndarray , np . ndarray ], None ] = None # takes mask and view for testing @staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image. Args: bbox (np.ndarray): The bounding box of the image, in the format (x, y, w, h). image (np.ndarray): The image to calculate segmentation from. background (np.ndarray): The background image. diff_thresh (float): The difference threshold to distinguish foreground and background objects from. Returns: np.ndarray: The segmentation mask. Raises: ValueError: If the image is not grayscale or color. \"\"\" x , y , w , h = bbox assert image . shape [: 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )) . astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm # TODO: VERY FAST FOR ME, INVESTIGATE WHY IT'S SLOW IN THE LAB # TODO: swap the FrameReader to another type. The only requirement is that accessing frame index returns the correct frame. # we should probably use something like ImageLoader, which is implemented in the analysis_experimental. @staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Args: background (np.ndarray): The background image. worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). frame_nums (np.ndarray): An array of frame numbers to calculate the error for. worm_reader (FrameReader): A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. diff_thresh (float, optional): The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. Returns: np.ndarray: Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: AssertionError: If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [: 2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes, indicting no prediction or bad prediction. errors [ ~ is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ): wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox [ 1 ] : int_bbox [ 1 ] + int_bbox [ 3 ], int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ]] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ) . sum () error = 1.0 - intersection / total errors [ i ] = error return errors @staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors @staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors Classes ErrorCalculator class ErrorCalculator ( / , * args , ** kwargs ) The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view. View Source class ErrorCalculator : \"\"\" The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view . \"\"\" # TODO: Kinda a weird solution, but it works for now. Maybe find a better way to do this. probe_hook : Callable [[ np . ndarray , np . ndarray ], None ] = None # takes mask and view for testing @ staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image . Args : bbox ( np . ndarray ) : The bounding box of the image , in the format ( x , y , w , h ). image ( np . ndarray ) : The image to calculate segmentation from . background ( np . ndarray ) : The background image . diff_thresh ( float ) : The difference threshold to distinguish foreground and background objects from . Returns : np . ndarray : The segmentation mask . Raises : ValueError : If the image is not grayscale or color . \"\"\" x , y , w , h = bbox assert image . shape [ : 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )). astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm # TODO: VERY FAST FOR ME, INVESTIGATE WHY IT'S SLOW IN THE LAB # TODO: swap the FrameReader to another type. The only requirement is that accessing frame index returns the correct frame. # we should probably use something like ImageLoader, which is implemented in the analysis_experimental. @ staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence . This error is based on precise segmentation of the worm object from the frame , and determining the exact proportion of worm ' s body outside the microscope view . Args : background ( np . ndarray ) : The background image . worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). frame_nums ( np . ndarray ) : An array of frame numbers to calculate the error for . worm_reader ( FrameReader ) : A frame reader containing segmented worm images for each frame . These worm images should match the shape of the worm bounding boxes . Frames passed in frame_nums are read from this reader by index . diff_thresh ( float , optional ) : The difference threshold to distinguish foreground and background objects from . A foreground object is detected if the pixel value difference with the background is greater than this threshold . Returns : np . ndarray : Array of errors of shape ( N ,) representing the precise segmentation error for each frame . Raises : AssertionError : If the length of frame_nums , worm_bboxes , and mic_bboxes do not match . \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [ : 2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes, indicting no prediction or bad prediction. errors [ ~ is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ) : wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox [ 1 ] : int_bbox [ 1 ] + int_bbox [ 3 ], int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ]] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ). sum () error = 1.0 - intersection / total errors [ i ] = error return errors @ staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes . This error calculates the proportion of the worm bounding box that is outside the microscope bounding box . Args : worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). Returns : np . ndarray : Array of errors of shape ( N ,) representing the bounding box error for each pair of worm and microscope bounding boxes . \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors @ staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error ( MSE ) error between the centers of worm bounding boxes and microscope bounding boxes . Args : worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). Returns : np . ndarray : Array of errors of shape ( N ,) representing the MSE error for each pair of worm and microscope bounding boxes . \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors Class variables probe_hook Static methods calculate_bbox_error def calculate_bbox_error ( worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray ) -> numpy . ndarray Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Parameters: Name Type Description Default worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. View Source @staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors calculate_mse_error def calculate_mse_error ( worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray ) -> numpy . ndarray Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Parameters: Name Type Description Default worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. View Source @staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors calculate_precise def calculate_precise ( background : numpy . ndarray , worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray , frame_nums : numpy . ndarray , worm_reader : wtracker . utils . frame_reader . FrameReader , diff_thresh : float = 10 ) -> numpy . ndarray Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Parameters: Name Type Description Default background np.ndarray The background image. None worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None frame_nums np.ndarray An array of frame numbers to calculate the error for. None worm_reader FrameReader A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. None diff_thresh float The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: Type Description AssertionError If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. View Source @staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Args: background (np.ndarray): The background image. worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). frame_nums (np.ndarray): An array of frame numbers to calculate the error for. worm_reader (FrameReader): A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. diff_thresh (float, optional): The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. Returns: np.ndarray: Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: AssertionError: If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [ :2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes , indicting no prediction or bad prediction . errors [ ~is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ) : wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox[1 ] : int_bbox [ 1 ] + int_bbox [ 3 ] , int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ] ] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ). sum () error = 1.0 - intersection / total errors [ i ] = error return errors calculate_segmentation def calculate_segmentation ( bbox : numpy . ndarray , image : numpy . ndarray , background : numpy . ndarray , diff_thresh : float ) -> numpy . ndarray Calculates the segmentation error between a view and background image. Parameters: Name Type Description Default bbox np.ndarray The bounding box of the image, in the format (x, y, w, h). None image np.ndarray The image to calculate segmentation from. None background np.ndarray The background image. None diff_thresh float The difference threshold to distinguish foreground and background objects from. None Returns: Type Description np.ndarray The segmentation mask. Raises: Type Description ValueError If the image is not grayscale or color. View Source @ staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image . Args : bbox ( np . ndarray ) : The bounding box of the image , in the format ( x , y , w , h ). image ( np . ndarray ) : The image to calculate segmentation from . background ( np . ndarray ) : The background image . diff_thresh ( float ) : The difference threshold to distinguish foreground and background objects from . Returns : np . ndarray : The segmentation mask . Raises : ValueError : If the image is not grayscale or color . \"\"\" x , y , w , h = bbox assert image . shape [ : 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )). astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm","title":"Error Calculator"},{"location":"reference/wtracker/eval/error_calculator/#module-wtrackerevalerror_calculator","text":"View Source from typing import Collection import numpy as np import cv2 as cv from tqdm.auto import tqdm from typing import Callable from wtracker.utils.frame_reader import FrameReader from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat class ErrorCalculator : \"\"\" The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view. \"\"\" # TODO: Kinda a weird solution, but it works for now. Maybe find a better way to do this. probe_hook : Callable [[ np . ndarray , np . ndarray ], None ] = None # takes mask and view for testing @staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image. Args: bbox (np.ndarray): The bounding box of the image, in the format (x, y, w, h). image (np.ndarray): The image to calculate segmentation from. background (np.ndarray): The background image. diff_thresh (float): The difference threshold to distinguish foreground and background objects from. Returns: np.ndarray: The segmentation mask. Raises: ValueError: If the image is not grayscale or color. \"\"\" x , y , w , h = bbox assert image . shape [: 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )) . astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm # TODO: VERY FAST FOR ME, INVESTIGATE WHY IT'S SLOW IN THE LAB # TODO: swap the FrameReader to another type. The only requirement is that accessing frame index returns the correct frame. # we should probably use something like ImageLoader, which is implemented in the analysis_experimental. @staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Args: background (np.ndarray): The background image. worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). frame_nums (np.ndarray): An array of frame numbers to calculate the error for. worm_reader (FrameReader): A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. diff_thresh (float, optional): The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. Returns: np.ndarray: Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: AssertionError: If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [: 2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes, indicting no prediction or bad prediction. errors [ ~ is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ): wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox [ 1 ] : int_bbox [ 1 ] + int_bbox [ 3 ], int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ]] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ) . sum () error = 1.0 - intersection / total errors [ i ] = error return errors @staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors @staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors","title":"Module wtracker.eval.error_calculator"},{"location":"reference/wtracker/eval/error_calculator/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/eval/error_calculator/#errorcalculator","text":"class ErrorCalculator ( / , * args , ** kwargs ) The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view. View Source class ErrorCalculator : \"\"\" The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view . \"\"\" # TODO: Kinda a weird solution, but it works for now. Maybe find a better way to do this. probe_hook : Callable [[ np . ndarray , np . ndarray ], None ] = None # takes mask and view for testing @ staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image . Args : bbox ( np . ndarray ) : The bounding box of the image , in the format ( x , y , w , h ). image ( np . ndarray ) : The image to calculate segmentation from . background ( np . ndarray ) : The background image . diff_thresh ( float ) : The difference threshold to distinguish foreground and background objects from . Returns : np . ndarray : The segmentation mask . Raises : ValueError : If the image is not grayscale or color . \"\"\" x , y , w , h = bbox assert image . shape [ : 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )). astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm # TODO: VERY FAST FOR ME, INVESTIGATE WHY IT'S SLOW IN THE LAB # TODO: swap the FrameReader to another type. The only requirement is that accessing frame index returns the correct frame. # we should probably use something like ImageLoader, which is implemented in the analysis_experimental. @ staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence . This error is based on precise segmentation of the worm object from the frame , and determining the exact proportion of worm ' s body outside the microscope view . Args : background ( np . ndarray ) : The background image . worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). frame_nums ( np . ndarray ) : An array of frame numbers to calculate the error for . worm_reader ( FrameReader ) : A frame reader containing segmented worm images for each frame . These worm images should match the shape of the worm bounding boxes . Frames passed in frame_nums are read from this reader by index . diff_thresh ( float , optional ) : The difference threshold to distinguish foreground and background objects from . A foreground object is detected if the pixel value difference with the background is greater than this threshold . Returns : np . ndarray : Array of errors of shape ( N ,) representing the precise segmentation error for each frame . Raises : AssertionError : If the length of frame_nums , worm_bboxes , and mic_bboxes do not match . \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [ : 2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes, indicting no prediction or bad prediction. errors [ ~ is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ) : wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox [ 1 ] : int_bbox [ 1 ] + int_bbox [ 3 ], int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ]] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ). sum () error = 1.0 - intersection / total errors [ i ] = error return errors @ staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes . This error calculates the proportion of the worm bounding box that is outside the microscope bounding box . Args : worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). Returns : np . ndarray : Array of errors of shape ( N ,) representing the bounding box error for each pair of worm and microscope bounding boxes . \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors @ staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error ( MSE ) error between the centers of worm bounding boxes and microscope bounding boxes . Args : worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). Returns : np . ndarray : Array of errors of shape ( N ,) representing the MSE error for each pair of worm and microscope bounding boxes . \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors","title":"ErrorCalculator"},{"location":"reference/wtracker/eval/error_calculator/#class-variables","text":"probe_hook","title":"Class variables"},{"location":"reference/wtracker/eval/error_calculator/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/eval/error_calculator/#calculate_bbox_error","text":"def calculate_bbox_error ( worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray ) -> numpy . ndarray Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Parameters: Name Type Description Default worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. View Source @staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors","title":"calculate_bbox_error"},{"location":"reference/wtracker/eval/error_calculator/#calculate_mse_error","text":"def calculate_mse_error ( worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray ) -> numpy . ndarray Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Parameters: Name Type Description Default worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. View Source @staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors","title":"calculate_mse_error"},{"location":"reference/wtracker/eval/error_calculator/#calculate_precise","text":"def calculate_precise ( background : numpy . ndarray , worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray , frame_nums : numpy . ndarray , worm_reader : wtracker . utils . frame_reader . FrameReader , diff_thresh : float = 10 ) -> numpy . ndarray Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Parameters: Name Type Description Default background np.ndarray The background image. None worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None frame_nums np.ndarray An array of frame numbers to calculate the error for. None worm_reader FrameReader A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. None diff_thresh float The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: Type Description AssertionError If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. View Source @staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Args: background (np.ndarray): The background image. worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). frame_nums (np.ndarray): An array of frame numbers to calculate the error for. worm_reader (FrameReader): A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. diff_thresh (float, optional): The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. Returns: np.ndarray: Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: AssertionError: If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [ :2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes , indicting no prediction or bad prediction . errors [ ~is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ) : wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox[1 ] : int_bbox [ 1 ] + int_bbox [ 3 ] , int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ] ] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ). sum () error = 1.0 - intersection / total errors [ i ] = error return errors","title":"calculate_precise"},{"location":"reference/wtracker/eval/error_calculator/#calculate_segmentation","text":"def calculate_segmentation ( bbox : numpy . ndarray , image : numpy . ndarray , background : numpy . ndarray , diff_thresh : float ) -> numpy . ndarray Calculates the segmentation error between a view and background image. Parameters: Name Type Description Default bbox np.ndarray The bounding box of the image, in the format (x, y, w, h). None image np.ndarray The image to calculate segmentation from. None background np.ndarray The background image. None diff_thresh float The difference threshold to distinguish foreground and background objects from. None Returns: Type Description np.ndarray The segmentation mask. Raises: Type Description ValueError If the image is not grayscale or color. View Source @ staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image . Args : bbox ( np . ndarray ) : The bounding box of the image , in the format ( x , y , w , h ). image ( np . ndarray ) : The image to calculate segmentation from . background ( np . ndarray ) : The background image . diff_thresh ( float ) : The difference threshold to distinguish foreground and background objects from . Returns : np . ndarray : The segmentation mask . Raises : ValueError : If the image is not grayscale or color . \"\"\" x , y , w , h = bbox assert image . shape [ : 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )). astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm","title":"calculate_segmentation"},{"location":"reference/wtracker/eval/plotter/","text":"Module wtracker.eval.plotter View Source from __future__ import annotations import pandas as pd import seaborn as sns from typing import Callable class Plotter : \"\"\" A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Args: data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log. plot_height (int, optional): The height of the plot. palette (str, optional): The color palette to use for the plots. \"\"\" def __init__ ( self , data_list : list [ pd . DataFrame ], plot_height : int = 7 , palette : str = \"viridis\" , ) -> None : self . plot_height = plot_height self . palette = palette for i , data in enumerate ( data_list ): data [ \"log_num\" ] = i self . data = pd . concat ([ d for d in data_list ], ignore_index = True ) def _get_error_column ( self , error_kind : str ) -> str : if error_kind == \"bbox\" : return \"bbox_error\" elif error_kind == \"dist\" : return \"worm_deviation\" elif error_kind == \"precise\" : return \"precise_error\" else : raise ValueError ( f \"Invalid error kind: { error_kind } \" ) def plot_speed ( self , log_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self . create_distplot ( x_col = \"wrm_speed\" , x_label = \"speed\" , title = \"Worm Speed Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , kde = True , ** kwargs , ) def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ([ \"log_num\" , \"cycle\" ])[ error_col ] . max () . reset_index () return self . create_distplot ( x_col = error_col , x_label = f \" { error_kind } error\" , title = f \" { error_kind } Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) def plot_cycle_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"boxen\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) return self . create_catplot ( x_col = \"cycle_step\" , y_col = error_col , x_label = \"cycle step\" , y_label = f \" { error_kind } error\" , title = f \" { error_kind } error as function of cycle step\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , ** kwargs , ) def plot_speed_vs_error ( self , error_kind : str = \"bbox\" , cycle_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the speed of the worm vs the error. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = ( self . data . groupby ([ \"log_num\" , \"cycle\" ])[[ error_col , \"wrm_speed\" ]] . aggregate ({ error_col : \"max\" , \"wrm_speed\" : \"mean\" }) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \" { error_kind } error\" , title = f \"Speed vs { error_kind } Error\" , condition = condition , data = data , ** kwargs , ) def plot_trajectory ( self , hue_col = \"log_num\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self . create_jointplot ( x_col = \"wrm_center_x\" , y_col = \"wrm_center_y\" , x_label = \"X\" , y_label = \"Y\" , title = \"Worm Trajectory\" , hue_col = hue_col , plot_kind = \"scatter\" , alpha = 1 , linewidth = 0 , condition = condition , ** kwargs , ) plot . ax_marg_x . remove () plot . ax_marg_y . remove () plot . ax_joint . invert_yaxis () return plot def plot_head_size ( self , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the size of the worm head. Args: condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\" , \"kde\" , \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{ col_name }} :: { title . title () } \" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\" , \"box\" , \"violin\" , \"boxen\" , \"bar\" , \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{ col_name }} :: { title . title () } \" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot Classes Plotter class Plotter ( data_list : 'list[pd.DataFrame]' , plot_height : 'int' = 7 , palette : 'str' = 'viridis' ) A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Attributes Name Type Description Default data_list list[pd.DataFrame] A list of dataframes, each holding the data of a single experiment log. None plot_height int The height of the plot. None palette str The color palette to use for the plots. None View Source class Plotter : \"\"\" A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Args: data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log. plot_height (int, optional): The height of the plot. palette (str, optional): The color palette to use for the plots. \"\"\" def __init__ ( self , data_list : list [ pd.DataFrame ] , plot_height : int = 7 , palette : str = \"viridis\" , ) -> None : self . plot_height = plot_height self . palette = palette for i , data in enumerate ( data_list ) : data [ \"log_num\" ] = i self . data = pd . concat ( [ d for d in data_list ] , ignore_index = True ) def _get_error_column ( self , error_kind : str ) -> str : if error_kind == \"bbox\" : return \"bbox_error\" elif error_kind == \"dist\" : return \"worm_deviation\" elif error_kind == \"precise\" : return \"precise_error\" else : raise ValueError ( f \"Invalid error kind: {error_kind}\" ) def plot_speed ( self , log_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self . create_distplot ( x_col = \"wrm_speed\" , x_label = \"speed\" , title = \"Worm Speed Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , kde = True , ** kwargs , ) def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ error_col ] . max (). reset_index () return self . create_distplot ( x_col = error_col , x_label = f \"{error_kind} error\" , title = f \"{error_kind} Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) def plot_cycle_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"boxen\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" strip \", \" box \", \" violin \", \" boxen \", \" bar \", or \" count \". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) return self . create_catplot ( x_col = \"cycle_step\" , y_col = error_col , x_label = \"cycle step\" , y_label = f \"{error_kind} error\" , title = f \"{error_kind} error as function of cycle step\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , ** kwargs , ) def plot_speed_vs_error ( self , error_kind : str = \"bbox\" , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the speed of the worm vs the error. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = ( self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ [error_col, \"wrm_speed\" ] ] . aggregate ( { error_col : \"max\" , \"wrm_speed\" : \"mean\" } ) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \"{error_kind} error\" , title = f \"Speed vs {error_kind} Error\" , condition = condition , data = data , ** kwargs , ) def plot_trajectory ( self , hue_col = \"log_num\" , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self . create_jointplot ( x_col = \"wrm_center_x\" , y_col = \"wrm_center_y\" , x_label = \"X\" , y_label = \"Y\" , title = \"Worm Trajectory\" , hue_col = hue_col , plot_kind = \"scatter\" , alpha = 1 , linewidth = 0 , condition = condition , ** kwargs , ) plot . ax_marg_x . remove () plot . ax_marg_y . remove () plot . ax_joint . invert_yaxis () return plot def plot_head_size ( self , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the size of the worm head. Args: condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\", \"kde\", \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \" strip \", \" box \", \" violin \", \" boxen \", \" bar \", or \" count \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot Methods create_catplot def create_catplot ( self , x_col : 'str' , y_col : 'str' = None , hue_col : 'str' = None , log_wise : 'bool' = False , plot_kind : 'str' = 'strip' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str | None' = None , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.FacetGrid' Create a categorical plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None log_wise bool Whether to plot each log separately. None plot_kind str The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.catplot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\" , \"box\" , \"violin\" , \"boxen\" , \"bar\" , \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot create_distplot def create_distplot ( self , x_col : 'str' , y_col : 'str' = None , hue_col : 'str' = None , log_wise : 'bool' = False , plot_kind : 'str' = 'hist' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str | None' = None , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.FacetGrid' Create a distribution plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None log_wise bool Whether to plot each log separately. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.displot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\" , \"kde\" , \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot create_jointplot def create_jointplot ( self , x_col : 'str' , y_col : 'str' , hue_col : 'str' = None , plot_kind : 'str' = 'scatter' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str' = '' , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.JointGrid' Create a joint plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None plot_kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot plot_cycle_error def plot_cycle_error ( self , error_kind : 'str' = 'bbox' , log_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'boxen' , ** kwargs ) -> 'sns.JointGrid' Plot the error as a function of the cycle step. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None log_wise bool Whether to plot each log separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". None **kwargs None Additional keyword arguments to pass the Plotter.create_catplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_cycle_error( self, error_kind: str = \"bbox\", log_wise: bool = False, condition: Callable[[pd.DataFrame], pd.DataFrame] = None, plot_kind: str = \"boxen\", **kwargs, ) -> sns.JointGrid: \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self._get_error_column(error_kind) return self.create_catplot( x_col=\"cycle_step\", y_col=error_col, x_label=\"cycle step\", y_label=f\"{error_kind} error\", title=f\"{error_kind} error as function of cycle step\", plot_kind=plot_kind, log_wise=log_wise, condition=condition, **kwargs, ) plot_error def plot_error ( self , error_kind : 'str' = 'bbox' , log_wise : 'bool' = False , cycle_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.FacetGrid' Plot the error distribution. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None log_wise bool Whether to plot each log separately. None cycle_wise bool Whether to plot each cycle separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None **kwargs None Additional keyword arguments to pass the Plotter.create_distplot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ error_col ] . max (). reset_index () return self . create_distplot ( x_col = error_col , x_label = f \"{error_kind} error\" , title = f \"{error_kind} Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) plot_head_size def plot_head_size ( self , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.JointGrid' Plot the size of the worm head. Parameters: Name Type Description Default condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_head_size ( self , condition: Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind: str = \"hist\" , ** kwargs , ) -> sns . JointGrid: \"\"\" Plot the size of the worm head . Args: condition ( Callable [[ pd . DataFrame ], pd . DataFrame ], optional ) : A function to filter the data . plot_kind ( str , optional ) : The kind of plot to create . Can be \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , or \"resid\" . ** kwargs: Additional keyword arguments to pass the `Plotter . create_jointplot ` function . Returns: sns . JointGrid: The plot object . \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) plot_speed def plot_speed ( self , log_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.FacetGrid' Plot the speed distribution of the worm. Parameters: Name Type Description Default log_wise bool Whether to plot each log separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None **kwargs None Additional keyword arguments to pass the Plotter.create_distplot function. None View Source def plot_speed( self, log_wise: bool = False, condition: Callable[[pd.DataFrame], pd.DataFrame] = None, plot_kind: str = \"hist\", **kwargs, ) -> sns.FacetGrid: \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self.create_distplot( x_col=\"wrm_speed\", x_label=\"speed\", title=\"Worm Speed Distribution\", plot_kind=plot_kind, log_wise=log_wise, condition=condition, kde=True, **kwargs, ) plot_speed_vs_error def plot_speed_vs_error ( self , error_kind : 'str' = 'bbox' , cycle_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , kind : 'str' = 'hist' , ** kwargs ) -> 'sns.JointGrid' Plot the speed of the worm vs the error. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None cycle_wise bool Whether to plot each cycle separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_speed_vs_error ( self , error_kind: str = \"bbox\" , cycle_wise: bool = False , condition: Callable [[ pd . DataFrame ], pd . DataFrame ] = None , kind: str = \"hist\" , ** kwargs , ) -> sns . JointGrid: \"\"\" Plot the speed of the worm vs the error . Args: error_kind ( str , optional ) : The kind of error to plot . Can be \"bbox\" , \"dist\" , or \"precise\" . cycle_wise ( bool , optional ) : Whether to plot each cycle separately . condition ( Callable [[ pd . DataFrame ], pd . DataFrame ], optional ) : A function to filter the data . kind ( str , optional ) : The kind of plot to create . Can be \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , or \"resid\" . ** kwargs: Additional keyword arguments to pass the `Plotter . create_jointplot ` function . Returns: sns . JointGrid: The plot object . \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise: data = ( self . data . groupby ([ \"log_num\" , \"cycle\" ])[[ error_col , \"wrm_speed\" ]] . aggregate ({ error_col: \"max\" , \"wrm_speed\" : \"mean\" }) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \"{error_kind} error\" , title = f \"Speed vs {error_kind} Error\" , condition = condition , data = data , ** kwargs , ) plot_trajectory def plot_trajectory ( self , hue_col = 'log_num' , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , ** kwargs ) -> 'sns.JointGrid' Plot the trajectory of the worm. Parameters: Name Type Description Default hue_col str The column to use for coloring the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_trajectory( self, hue_col=\"log_num\", condition: Callable[[pd.DataFrame], pd.DataFrame] = None, **kwargs, ) -> sns.JointGrid: \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self.create_jointplot( x_col=\"wrm_center_x\", y_col=\"wrm_center_y\", x_label=\"X\", y_label=\"Y\", title=\"Worm Trajectory\", hue_col=hue_col, plot_kind=\"scatter\", alpha=1, linewidth=0, condition=condition, **kwargs, ) plot.ax_marg_x.remove() plot.ax_marg_y.remove() plot.ax_joint.invert_yaxis() return plot","title":"Plotter"},{"location":"reference/wtracker/eval/plotter/#module-wtrackerevalplotter","text":"View Source from __future__ import annotations import pandas as pd import seaborn as sns from typing import Callable class Plotter : \"\"\" A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Args: data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log. plot_height (int, optional): The height of the plot. palette (str, optional): The color palette to use for the plots. \"\"\" def __init__ ( self , data_list : list [ pd . DataFrame ], plot_height : int = 7 , palette : str = \"viridis\" , ) -> None : self . plot_height = plot_height self . palette = palette for i , data in enumerate ( data_list ): data [ \"log_num\" ] = i self . data = pd . concat ([ d for d in data_list ], ignore_index = True ) def _get_error_column ( self , error_kind : str ) -> str : if error_kind == \"bbox\" : return \"bbox_error\" elif error_kind == \"dist\" : return \"worm_deviation\" elif error_kind == \"precise\" : return \"precise_error\" else : raise ValueError ( f \"Invalid error kind: { error_kind } \" ) def plot_speed ( self , log_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self . create_distplot ( x_col = \"wrm_speed\" , x_label = \"speed\" , title = \"Worm Speed Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , kde = True , ** kwargs , ) def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ([ \"log_num\" , \"cycle\" ])[ error_col ] . max () . reset_index () return self . create_distplot ( x_col = error_col , x_label = f \" { error_kind } error\" , title = f \" { error_kind } Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) def plot_cycle_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"boxen\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) return self . create_catplot ( x_col = \"cycle_step\" , y_col = error_col , x_label = \"cycle step\" , y_label = f \" { error_kind } error\" , title = f \" { error_kind } error as function of cycle step\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , ** kwargs , ) def plot_speed_vs_error ( self , error_kind : str = \"bbox\" , cycle_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the speed of the worm vs the error. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = ( self . data . groupby ([ \"log_num\" , \"cycle\" ])[[ error_col , \"wrm_speed\" ]] . aggregate ({ error_col : \"max\" , \"wrm_speed\" : \"mean\" }) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \" { error_kind } error\" , title = f \"Speed vs { error_kind } Error\" , condition = condition , data = data , ** kwargs , ) def plot_trajectory ( self , hue_col = \"log_num\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self . create_jointplot ( x_col = \"wrm_center_x\" , y_col = \"wrm_center_y\" , x_label = \"X\" , y_label = \"Y\" , title = \"Worm Trajectory\" , hue_col = hue_col , plot_kind = \"scatter\" , alpha = 1 , linewidth = 0 , condition = condition , ** kwargs , ) plot . ax_marg_x . remove () plot . ax_marg_y . remove () plot . ax_joint . invert_yaxis () return plot def plot_head_size ( self , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the size of the worm head. Args: condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\" , \"kde\" , \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{ col_name }} :: { title . title () } \" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\" , \"box\" , \"violin\" , \"boxen\" , \"bar\" , \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{ col_name }} :: { title . title () } \" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot","title":"Module wtracker.eval.plotter"},{"location":"reference/wtracker/eval/plotter/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/eval/plotter/#plotter","text":"class Plotter ( data_list : 'list[pd.DataFrame]' , plot_height : 'int' = 7 , palette : 'str' = 'viridis' ) A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once.","title":"Plotter"},{"location":"reference/wtracker/eval/plotter/#attributes","text":"Name Type Description Default data_list list[pd.DataFrame] A list of dataframes, each holding the data of a single experiment log. None plot_height int The height of the plot. None palette str The color palette to use for the plots. None View Source class Plotter : \"\"\" A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Args: data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log. plot_height (int, optional): The height of the plot. palette (str, optional): The color palette to use for the plots. \"\"\" def __init__ ( self , data_list : list [ pd.DataFrame ] , plot_height : int = 7 , palette : str = \"viridis\" , ) -> None : self . plot_height = plot_height self . palette = palette for i , data in enumerate ( data_list ) : data [ \"log_num\" ] = i self . data = pd . concat ( [ d for d in data_list ] , ignore_index = True ) def _get_error_column ( self , error_kind : str ) -> str : if error_kind == \"bbox\" : return \"bbox_error\" elif error_kind == \"dist\" : return \"worm_deviation\" elif error_kind == \"precise\" : return \"precise_error\" else : raise ValueError ( f \"Invalid error kind: {error_kind}\" ) def plot_speed ( self , log_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self . create_distplot ( x_col = \"wrm_speed\" , x_label = \"speed\" , title = \"Worm Speed Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , kde = True , ** kwargs , ) def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ error_col ] . max (). reset_index () return self . create_distplot ( x_col = error_col , x_label = f \"{error_kind} error\" , title = f \"{error_kind} Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) def plot_cycle_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"boxen\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" strip \", \" box \", \" violin \", \" boxen \", \" bar \", or \" count \". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) return self . create_catplot ( x_col = \"cycle_step\" , y_col = error_col , x_label = \"cycle step\" , y_label = f \"{error_kind} error\" , title = f \"{error_kind} error as function of cycle step\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , ** kwargs , ) def plot_speed_vs_error ( self , error_kind : str = \"bbox\" , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the speed of the worm vs the error. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = ( self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ [error_col, \"wrm_speed\" ] ] . aggregate ( { error_col : \"max\" , \"wrm_speed\" : \"mean\" } ) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \"{error_kind} error\" , title = f \"Speed vs {error_kind} Error\" , condition = condition , data = data , ** kwargs , ) def plot_trajectory ( self , hue_col = \"log_num\" , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self . create_jointplot ( x_col = \"wrm_center_x\" , y_col = \"wrm_center_y\" , x_label = \"X\" , y_label = \"Y\" , title = \"Worm Trajectory\" , hue_col = hue_col , plot_kind = \"scatter\" , alpha = 1 , linewidth = 0 , condition = condition , ** kwargs , ) plot . ax_marg_x . remove () plot . ax_marg_y . remove () plot . ax_joint . invert_yaxis () return plot def plot_head_size ( self , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the size of the worm head. Args: condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\", \"kde\", \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \" strip \", \" box \", \" violin \", \" boxen \", \" bar \", or \" count \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot","title":"Attributes"},{"location":"reference/wtracker/eval/plotter/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/eval/plotter/#create_catplot","text":"def create_catplot ( self , x_col : 'str' , y_col : 'str' = None , hue_col : 'str' = None , log_wise : 'bool' = False , plot_kind : 'str' = 'strip' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str | None' = None , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.FacetGrid' Create a categorical plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None log_wise bool Whether to plot each log separately. None plot_kind str The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.catplot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\" , \"box\" , \"violin\" , \"boxen\" , \"bar\" , \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot","title":"create_catplot"},{"location":"reference/wtracker/eval/plotter/#create_distplot","text":"def create_distplot ( self , x_col : 'str' , y_col : 'str' = None , hue_col : 'str' = None , log_wise : 'bool' = False , plot_kind : 'str' = 'hist' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str | None' = None , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.FacetGrid' Create a distribution plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None log_wise bool Whether to plot each log separately. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.displot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\" , \"kde\" , \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot","title":"create_distplot"},{"location":"reference/wtracker/eval/plotter/#create_jointplot","text":"def create_jointplot ( self , x_col : 'str' , y_col : 'str' , hue_col : 'str' = None , plot_kind : 'str' = 'scatter' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str' = '' , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.JointGrid' Create a joint plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None plot_kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot","title":"create_jointplot"},{"location":"reference/wtracker/eval/plotter/#plot_cycle_error","text":"def plot_cycle_error ( self , error_kind : 'str' = 'bbox' , log_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'boxen' , ** kwargs ) -> 'sns.JointGrid' Plot the error as a function of the cycle step. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None log_wise bool Whether to plot each log separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". None **kwargs None Additional keyword arguments to pass the Plotter.create_catplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_cycle_error( self, error_kind: str = \"bbox\", log_wise: bool = False, condition: Callable[[pd.DataFrame], pd.DataFrame] = None, plot_kind: str = \"boxen\", **kwargs, ) -> sns.JointGrid: \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self._get_error_column(error_kind) return self.create_catplot( x_col=\"cycle_step\", y_col=error_col, x_label=\"cycle step\", y_label=f\"{error_kind} error\", title=f\"{error_kind} error as function of cycle step\", plot_kind=plot_kind, log_wise=log_wise, condition=condition, **kwargs, )","title":"plot_cycle_error"},{"location":"reference/wtracker/eval/plotter/#plot_error","text":"def plot_error ( self , error_kind : 'str' = 'bbox' , log_wise : 'bool' = False , cycle_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.FacetGrid' Plot the error distribution. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None log_wise bool Whether to plot each log separately. None cycle_wise bool Whether to plot each cycle separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None **kwargs None Additional keyword arguments to pass the Plotter.create_distplot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ error_col ] . max (). reset_index () return self . create_distplot ( x_col = error_col , x_label = f \"{error_kind} error\" , title = f \"{error_kind} Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , )","title":"plot_error"},{"location":"reference/wtracker/eval/plotter/#plot_head_size","text":"def plot_head_size ( self , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.JointGrid' Plot the size of the worm head. Parameters: Name Type Description Default condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_head_size ( self , condition: Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind: str = \"hist\" , ** kwargs , ) -> sns . JointGrid: \"\"\" Plot the size of the worm head . Args: condition ( Callable [[ pd . DataFrame ], pd . DataFrame ], optional ) : A function to filter the data . plot_kind ( str , optional ) : The kind of plot to create . Can be \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , or \"resid\" . ** kwargs: Additional keyword arguments to pass the `Plotter . create_jointplot ` function . Returns: sns . JointGrid: The plot object . \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , )","title":"plot_head_size"},{"location":"reference/wtracker/eval/plotter/#plot_speed","text":"def plot_speed ( self , log_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.FacetGrid' Plot the speed distribution of the worm. Parameters: Name Type Description Default log_wise bool Whether to plot each log separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None **kwargs None Additional keyword arguments to pass the Plotter.create_distplot function. None View Source def plot_speed( self, log_wise: bool = False, condition: Callable[[pd.DataFrame], pd.DataFrame] = None, plot_kind: str = \"hist\", **kwargs, ) -> sns.FacetGrid: \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self.create_distplot( x_col=\"wrm_speed\", x_label=\"speed\", title=\"Worm Speed Distribution\", plot_kind=plot_kind, log_wise=log_wise, condition=condition, kde=True, **kwargs, )","title":"plot_speed"},{"location":"reference/wtracker/eval/plotter/#plot_speed_vs_error","text":"def plot_speed_vs_error ( self , error_kind : 'str' = 'bbox' , cycle_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , kind : 'str' = 'hist' , ** kwargs ) -> 'sns.JointGrid' Plot the speed of the worm vs the error. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None cycle_wise bool Whether to plot each cycle separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_speed_vs_error ( self , error_kind: str = \"bbox\" , cycle_wise: bool = False , condition: Callable [[ pd . DataFrame ], pd . DataFrame ] = None , kind: str = \"hist\" , ** kwargs , ) -> sns . JointGrid: \"\"\" Plot the speed of the worm vs the error . Args: error_kind ( str , optional ) : The kind of error to plot . Can be \"bbox\" , \"dist\" , or \"precise\" . cycle_wise ( bool , optional ) : Whether to plot each cycle separately . condition ( Callable [[ pd . DataFrame ], pd . DataFrame ], optional ) : A function to filter the data . kind ( str , optional ) : The kind of plot to create . Can be \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , or \"resid\" . ** kwargs: Additional keyword arguments to pass the `Plotter . create_jointplot ` function . Returns: sns . JointGrid: The plot object . \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise: data = ( self . data . groupby ([ \"log_num\" , \"cycle\" ])[[ error_col , \"wrm_speed\" ]] . aggregate ({ error_col: \"max\" , \"wrm_speed\" : \"mean\" }) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \"{error_kind} error\" , title = f \"Speed vs {error_kind} Error\" , condition = condition , data = data , ** kwargs , )","title":"plot_speed_vs_error"},{"location":"reference/wtracker/eval/plotter/#plot_trajectory","text":"def plot_trajectory ( self , hue_col = 'log_num' , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , ** kwargs ) -> 'sns.JointGrid' Plot the trajectory of the worm. Parameters: Name Type Description Default hue_col str The column to use for coloring the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_trajectory( self, hue_col=\"log_num\", condition: Callable[[pd.DataFrame], pd.DataFrame] = None, **kwargs, ) -> sns.JointGrid: \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self.create_jointplot( x_col=\"wrm_center_x\", y_col=\"wrm_center_y\", x_label=\"X\", y_label=\"Y\", title=\"Worm Trajectory\", hue_col=hue_col, plot_kind=\"scatter\", alpha=1, linewidth=0, condition=condition, **kwargs, ) plot.ax_marg_x.remove() plot.ax_marg_y.remove() plot.ax_joint.invert_yaxis() return plot","title":"plot_trajectory"},{"location":"reference/wtracker/eval/vlc/","text":"Module wtracker.eval.vlc View Source import pandas as pd import numpy as np from math import ceil , floor import os import cv2 as cv from typing import Callable from dataclasses import dataclass , field import matplotlib matplotlib . use ( \"QTAgg\" ) from wtracker.utils.path_utils import Files , create_directory , join_paths from wtracker.utils.io_utils import ImageSaver from wtracker.utils.frame_reader import FrameReader , DummyReader from wtracker.sim.config import TimingConfig @dataclass class HotKey : \"\"\" Represents a hotkey that can be used to trigger a specific function. Attributes: key (str): The key for the hotkey. func (Callable[[str], None]): The function to be called when the hotkey is triggered. description (str): The description of the hotkey (optional). \"\"\" key : str func : Callable [[ str ], None ] description : str = field ( default = \"\" ) def __post_init__ ( self ): self . key = self . key . lower () class StreamViewer : \"\"\" A class for viewing and interacting with photos and video streams. Args: window_name (str, optional): The name of the window. Example: with StreamViewer() as streamer: streamer.imshow(image) streamer.waitKey() \"\"\" def __init__ ( self , window_name : str = \"streamer\" ) -> None : self . window_name = window_name self . window = None self . hotkeys : list [ HotKey ] = [] self . register_hotkey ( HotKey ( \"q\" , self . close , \"close the window\" )) def register_hotkey ( self , hotkey : HotKey ): \"\"\" Registers a hotkey. Args: hotkey (HotKey): The hotkey to register. \"\"\" self . hotkeys . append ( hotkey ) def create_trackbar ( self , name : str , val : int , maxval : int , onChange = lambda x : x ): \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv . createTrackbar ( name , self . window_name , val , maxval , onChange ) def update_trackbar ( self , name : str , val : int ): \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv . setTrackbarPos ( name , self . window_name , val ) def set_title ( self , title : str ): \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv . setWindowTitle ( self . window_name , title ) def __enter__ ( self ): \"\"\" Enters the context manager. \"\"\" self . open () return self def __exit__ ( self , exc_type , exc_value , traceback ): \"\"\" Exits the context manager. \"\"\" self . close () def __del__ ( self ): \"\"\" Destructor method. \"\"\" self . close () def update ( self , image : np . ndarray , wait : int = 1 ): \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv . imshow ( self . window_name , image ) self . waitKey ( wait ) def waitKey ( self , timeout : int = 0 ): \"\"\" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\"\" key = cv . waitKey ( timeout ) if key <= 0 : return key key = chr ( key ) . lower () for hotkey in self . hotkeys : if key in hotkey . key : hotkey . func ( key ) return key def open ( self ): \"\"\" Opens the window. \"\"\" self . close () self . window = cv . namedWindow ( self . window_name , flags = cv . WINDOW_GUI_EXPANDED ) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) self . set_title ( self . window_name ) def close ( self , key : str = \"q\" ): \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self . window is not None : cv . destroyWindow ( self . window_name ) self . window = None def imshow ( self , image : np . ndarray , title : str = \"image\" ): \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self . update ( image , wait = 0 ) self . set_title ( title ) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) class VLC : \"\"\" The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Args: files (Files): The files to read frames from. If None, the video player will present the log data (simulation) on a white background. config (TimingConfig): The timing configuration of the system. log_path (str): The path to the log file. cam_type (str): The type of camera. This should match the prefix of the corresponding columns in the log file. show_pred (bool, optional): Whether to show the prediction box. show_micro (bool, optional): Whether to show the microscope box. show_cam (bool, optional): Whether to show the camera box. \"\"\" def __init__ ( self , files : Files | None , config : TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False , ) -> None : self . streamer = StreamViewer ( window_name = \"VLC\" ) self . index = 0 self . _curr_row = None self . exit = False self . delay = 0 self . play = False self . show_pred = show_pred self . show_micro = show_micro self . show_cam = show_cam self . cam_type : str = cam_type self . config : TimingConfig = config self . log : pd . DataFrame = self . _load_log ( log_path ) self . reader : FrameReader = self . _create_reader ( files ) def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () def _load_log ( self , log_path : str ) -> pd . DataFrame : if log_path is None : return None log = pd . read_csv ( log_path , index_col = \"frame\" ) if self . cam_type == \"plt\" : log [ \"plt_x\" ] = 0 log [ \"plt_y\" ] = 0 log [ \"plt_h\" ] = max ( log [ \"cam_y\" ]) + max ( log [ \"cam_h\" ]) log [ \"plt_w\" ] = max ( log [ \"cam_x\" ]) + max ( log [ \"cam_w\" ]) # assert len(log.index) == len(self.reader) self . _curr_row = log . iloc [ self . index ] return log def _init_hotkeys ( self ) -> None : self . streamer . register_hotkey ( HotKey ( \"q\" , self . close , \"close VLC\" )) self . streamer . register_hotkey ( HotKey ( \"d\" , self . next , \"next frame\" )) self . streamer . register_hotkey ( HotKey ( \"a\" , self . prev , \"previous frame\" )) self . streamer . register_hotkey ( HotKey ( \"p\" , self . toggle_play , \"play/pause\" )) self . streamer . register_hotkey ( HotKey ( \"h\" , self . toggle_pred , \"toggle prediction box\" )) self . streamer . register_hotkey ( HotKey ( \"m\" , self . toggle_micro , \"toggle microscope box\" )) self . streamer . register_hotkey ( HotKey ( \"c\" , self . toggle_cam , \"toggle camera box\" )) def print_hotkeys ( self ): print ( \"Hotkeys:\" ) for hotkey in self . streamer . hotkeys : print ( f \" - { hotkey . key } : { hotkey . description } \" ) def _create_window ( self ): self . streamer . open () self . streamer . create_trackbar ( \"delay\" , 0 , 250 , self . set_delay ) self . streamer . create_trackbar ( \"#frame\" , 0 , len ( self . reader ), self . seek ) def _create_reader ( self , files : Files ) -> FrameReader : if files is None : frame_num = len ( self . log . index ) frame_size = ( self . get_attribute ( self . cam_type + \"_h\" ), self . get_attribute ( self . cam_type + \"_w\" ), ) return DummyReader ( frame_num , frame_size ) filenames = [ f for f in files ] reader = FrameReader ( files . root , filenames ) return reader def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . streamer . close () def _get_title ( self ): curr_phase = self . get_attribute ( \"phase\" ) phase_title = f \"Action: { curr_phase } \" cycle_len = self . config . imaging_frame_num + self . config . moving_frame_num cycle_progress = 1 + self . index % cycle_len cycle_title = ( f \"cycle progress [ { cycle_progress } / { cycle_len } ]: \" + cycle_progress * \"#\" + ( cycle_len - cycle_progress ) * \"_\" ) title = f \" { phase_title } :: { cycle_title } \" return title def get_attribute ( self , col_name : str ): return self . _curr_row [ col_name ] def update_curr_row ( self ): self . _curr_row = self . log . iloc [ self . index ] def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo def seek ( self , pos : int ): self . index = ( pos ) % len ( self . reader ) self . update_curr_row () self . streamer . update ( self . get_photo ()) self . streamer . set_title ( self . _get_title ()) def next ( self , key = None ): self . index = ( self . index + 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def prev ( self , key = None ): self . index = ( self . index - 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def close ( self , key = None ): self . exit = True def set_delay ( self , delay : int ): self . delay = delay def toggle_play ( self , key : str = None ): self . play = not self . play def toggle_pred ( self , key : str = None ): self . show_pred = not self . show_pred def toggle_micro ( self , key : str = None ): self . show_micro = not self . show_micro def toggle_cam ( self , key : str = None ): self . show_cam = not self . show_cam def mainloop ( self ): \"\"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\"\" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ]: x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ) . all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) def draw_center ( self , photo : np . ndarray ): x , y , w , h = self . get_bbox ( \"mic\" ) center = ( x + w // 2 , y + h // 2 ) cv . drawMarker ( photo , center , ( 0 , 0 , 255 ), cv . MARKER_CROSS , 7 , 1 ) def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \" { self . cam_type } _\" + \" {:07d} .png\" with ImageSaver ( folder_path , tqdm_kwargs = { \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ) . replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate { fps } -start_number 0 -i { join_paths ( folder_path , img_name_format ) } -c:v copy { join_paths ( output_dir , 'video.mp4' ) } \" print ( command ) os . system ( command ) Classes HotKey class HotKey ( key : str , func : Callable [[ str ], NoneType ], description : str = '' ) Represents a hotkey that can be used to trigger a specific function. Attributes Name Type Description Default key str The key for the hotkey. None func Callable[[str], None] The function to be called when the hotkey is triggered. None description str The description of the hotkey (optional). None View Source @dataclass class HotKey : \"\"\" Represents a hotkey that can be used to trigger a specific function. Attributes: key (str): The key for the hotkey. func (Callable[[str], None]): The function to be called when the hotkey is triggered. description (str): The description of the hotkey (optional). \"\"\" key : str func : Callable [ [str ] , None ] description : str = field ( default = \"\" ) def __post_init__ ( self ) : self . key = self . key . lower () Class variables description StreamViewer class StreamViewer ( window_name : str = 'streamer' ) A class for viewing and interacting with photos and video streams. Attributes Name Type Description Default window_name str The name of the window. None View Source class StreamViewer : \"\"\" A class for viewing and interacting with photos and video streams. Args: window_name (str, optional): The name of the window. Example: with StreamViewer() as streamer: streamer.imshow(image) streamer.waitKey() \"\"\" def __init__ ( self , window_name : str = \"streamer\" ) -> None : self . window_name = window_name self . window = None self . hotkeys : list [ HotKey ] = [] self . register_hotkey ( HotKey ( \"q\" , self . close , \"close the window\" )) def register_hotkey ( self , hotkey : HotKey ) : \"\"\" Registers a hotkey. Args: hotkey (HotKey): The hotkey to register. \"\"\" self . hotkeys . append ( hotkey ) def create_trackbar ( self , name : str , val : int , maxval : int , onChange = lambda x : x ) : \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv . createTrackbar ( name , self . window_name , val , maxval , onChange ) def update_trackbar ( self , name : str , val : int ) : \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv . setTrackbarPos ( name , self . window_name , val ) def set_title ( self , title : str ) : \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv . setWindowTitle ( self . window_name , title ) def __enter__ ( self ) : \"\"\" Enters the context manager. \"\"\" self . open () return self def __exit__ ( self , exc_type , exc_value , traceback ) : \"\"\" Exits the context manager. \"\"\" self . close () def __del__ ( self ) : \"\"\" Destructor method. \"\"\" self . close () def update ( self , image : np . ndarray , wait : int = 1 ) : \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv . imshow ( self . window_name , image ) self . waitKey ( wait ) def waitKey ( self , timeout : int = 0 ) : \"\"\" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\"\" key = cv . waitKey ( timeout ) if key <= 0 : return key key = chr ( key ). lower () for hotkey in self . hotkeys : if key in hotkey . key : hotkey . func ( key ) return key def open ( self ) : \"\"\" Opens the window. \"\"\" self . close () self . window = cv . namedWindow ( self . window_name , flags = cv . WINDOW_GUI_EXPANDED ) # cv . setWindowProperty ( self . window_name , cv . WND_PROP_TOPMOST , 1 ) self . set_title ( self . window_name ) def close ( self , key : str = \"q\" ) : \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self . window is not None : cv . destroyWindow ( self . window_name ) self . window = None def imshow ( self , image : np . ndarray , title : str = \"image\" ) : \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self . update ( image , wait = 0 ) self . set_title ( title ) # cv . setWindowProperty ( self . window_name , cv . WND_PROP_TOPMOST , 1 ) Methods close def close ( self , key : str = 'q' ) Closes the window. Parameters: Name Type Description Default key str The key to close the window. None View Source def close(self, key: str = \"q\"): \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self.window is not None: cv.destroyWindow(self.window_name) self.window = None create_trackbar def create_trackbar ( self , name : str , val : int , maxval : int , onChange =< function StreamViewer .< lambda > at 0x7f88625f0160 > ) Creates a trackbar. Parameters: Name Type Description Default name str The name of the trackbar. None val int The initial value of the trackbar. None maxval int The maximum value of the trackbar. None onChange function The function to call when the trackbar value changes. None View Source def create_trackbar(self, name: str, val: int, maxval: int, onChange=lambda x: x): \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv.createTrackbar(name, self.window_name, val, maxval, onChange) imshow def imshow ( self , image : numpy . ndarray , title : str = 'image' ) Displays an image in the window. Parameters: Name Type Description Default image np.ndarray The image to display. None title str The title of the image. None View Source def imshow(self, image: np.ndarray, title: str = \"image\"): \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self.update(image, wait=0) self.set_title(title) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) open def open ( self ) Opens the window. View Source def open(self): \"\"\" Opens the window. \"\"\" self.close() self.window = cv.namedWindow(self.window_name, flags=cv.WINDOW_GUI_EXPANDED) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) self.set_title(self.window_name) register_hotkey def register_hotkey ( self , hotkey : wtracker . eval . vlc . HotKey ) Registers a hotkey. Parameters: Name Type Description Default hotkey HotKey The hotkey to register. None View Source def register_hotkey ( self , hotkey: HotKey ) : \"\"\" Registers a hotkey . Args: hotkey ( HotKey ) : The hotkey to register . \"\"\" self . hotkeys . append ( hotkey ) set_title def set_title ( self , title : str ) Sets the title of the window. Parameters: Name Type Description Default title str The new title of the window. None View Source def set_title(self, title: str): \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv.setWindowTitle(self.window_name, title) update def update ( self , image : numpy . ndarray , wait : int = 1 ) Updates the window with a new image. Parameters: Name Type Description Default image np.ndarray The image to display. None wait int The delay in milliseconds before updating the window. None View Source def update(self, image: np.ndarray, wait: int = 1): \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv.imshow(self.window_name, image) self.waitKey(wait) update_trackbar def update_trackbar ( self , name : str , val : int ) Updates the value of a trackbar. Parameters: Name Type Description Default name str The name of the trackbar. None val int The new value of the trackbar. None View Source def update_trackbar(self, name: str, val: int): \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv.setTrackbarPos(name, self.window_name, val) waitKey def waitKey ( self , timeout : int = 0 ) Waits for a key press. This Function also triggers the hotkeys. Parameters: Name Type Description Default timeout int The timeout in milliseconds. None Returns: Type Description str The key that was pressed. View Source def waitKey ( self , timeout : int = 0 ) : \"\" \" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\" \" key = cv.waitKey(timeout) if key <= 0: return key key = chr(key).lower() for hotkey in self.hotkeys: if key in hotkey.key: hotkey.func(key) return key VLC class VLC ( files : wtracker . utils . path_utils . Files | None , config : wtracker . sim . config . TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False ) The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Attributes Name Type Description Default files Files The files to read frames from. If None, the video player will present the log data (simulation) on a white background. None config TimingConfig The timing configuration of the system. None log_path str The path to the log file. None cam_type str The type of camera. This should match the prefix of the corresponding columns in the log file. None show_pred bool Whether to show the prediction box. None show_micro bool Whether to show the microscope box. None show_cam bool Whether to show the camera box. None View Source class VLC : \"\"\" The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Args: files (Files): The files to read frames from. If None, the video player will present the log data (simulation) on a white background. config (TimingConfig): The timing configuration of the system. log_path (str): The path to the log file. cam_type (str): The type of camera. This should match the prefix of the corresponding columns in the log file. show_pred (bool, optional): Whether to show the prediction box. show_micro (bool, optional): Whether to show the microscope box. show_cam (bool, optional): Whether to show the camera box. \"\"\" def __init__ ( self , files : Files | None , config : TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False , ) -> None : self . streamer = StreamViewer ( window_name = \"VLC\" ) self . index = 0 self . _curr_row = None self . exit = False self . delay = 0 self . play = False self . show_pred = show_pred self . show_micro = show_micro self . show_cam = show_cam self . cam_type : str = cam_type self . config : TimingConfig = config self . log : pd . DataFrame = self . _load_log ( log_path ) self . reader : FrameReader = self . _create_reader ( files ) def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () def _load_log ( self , log_path : str ) -> pd . DataFrame : if log_path is None : return None log = pd . read_csv ( log_path , index_col = \"frame\" ) if self . cam_type == \"plt\" : log [ \"plt_x\" ] = 0 log [ \"plt_y\" ] = 0 log [ \"plt_h\" ] = max ( log [ \"cam_y\" ]) + max ( log [ \"cam_h\" ]) log [ \"plt_w\" ] = max ( log [ \"cam_x\" ]) + max ( log [ \"cam_w\" ]) # assert len(log.index) == len(self.reader) self . _curr_row = log . iloc [ self . index ] return log def _init_hotkeys ( self ) -> None : self . streamer . register_hotkey ( HotKey ( \"q\" , self . close , \"close VLC\" )) self . streamer . register_hotkey ( HotKey ( \"d\" , self . next , \"next frame\" )) self . streamer . register_hotkey ( HotKey ( \"a\" , self . prev , \"previous frame\" )) self . streamer . register_hotkey ( HotKey ( \"p\" , self . toggle_play , \"play/pause\" )) self . streamer . register_hotkey ( HotKey ( \"h\" , self . toggle_pred , \"toggle prediction box\" )) self . streamer . register_hotkey ( HotKey ( \"m\" , self . toggle_micro , \"toggle microscope box\" )) self . streamer . register_hotkey ( HotKey ( \"c\" , self . toggle_cam , \"toggle camera box\" )) def print_hotkeys ( self ): print ( \"Hotkeys:\" ) for hotkey in self . streamer . hotkeys : print ( f \" - {hotkey.key} : {hotkey.description}\" ) def _create_window ( self ): self . streamer . open () self . streamer . create_trackbar ( \"delay\" , 0 , 250 , self . set_delay ) self . streamer . create_trackbar ( \"#frame\" , 0 , len ( self . reader ), self . seek ) def _create_reader ( self , files : Files ) -> FrameReader : if files is None : frame_num = len ( self . log . index ) frame_size = ( self . get_attribute ( self . cam_type + \"_h\" ), self . get_attribute ( self . cam_type + \"_w\" ), ) return DummyReader ( frame_num , frame_size ) filenames = [ f for f in files ] reader = FrameReader ( files . root , filenames ) return reader def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . streamer . close () def _get_title ( self ): curr_phase = self . get_attribute ( \"phase\" ) phase_title = f \"Action: {curr_phase}\" cycle_len = self . config . imaging_frame_num + self . config . moving_frame_num cycle_progress = 1 + self . index % cycle_len cycle_title = ( f \"cycle progress [{cycle_progress}/{cycle_len}]: \" + cycle_progress * \"#\" + ( cycle_len - cycle_progress ) * \"_\" ) title = f \"{phase_title} :: {cycle_title}\" return title def get_attribute ( self , col_name : str ): return self . _curr_row [ col_name ] def update_curr_row ( self ): self . _curr_row = self . log . iloc [ self . index ] def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo def seek ( self , pos : int ): self . index = ( pos ) % len ( self . reader ) self . update_curr_row () self . streamer . update ( self . get_photo ()) self . streamer . set_title ( self . _get_title ()) def next ( self , key = None ): self . index = ( self . index + 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def prev ( self , key = None ): self . index = ( self . index - 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def close ( self , key = None ): self . exit = True def set_delay ( self , delay : int ): self . delay = delay def toggle_play ( self , key : str = None ): self . play = not self . play def toggle_pred ( self , key : str = None ): self . show_pred = not self . show_pred def toggle_micro ( self , key : str = None ): self . show_micro = not self . show_micro def toggle_cam ( self , key : str = None ): self . show_cam = not self . show_cam def mainloop ( self ): \"\"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\"\" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ]: x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ) . all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) def draw_center ( self , photo : np . ndarray ): x , y , w , h = self . get_bbox ( \"mic\" ) center = ( x + w // 2 , y + h // 2 ) cv . drawMarker ( photo , center , ( 0 , 0 , 255 ), cv . MARKER_CROSS , 7 , 1 ) def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \"{self.cam_type}_\" + \"{:07d}.png\" with ImageSaver ( folder_path , tqdm_kwargs = { \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ) . replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate {fps} -start_number 0 -i {join_paths(folder_path, img_name_format)} -c:v copy {join_paths(output_dir, 'video.mp4')}\" print ( command ) os . system ( command ) Methods add_cam_box def add_cam_box ( self , photo : numpy . ndarray ) -> None View Source def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) add_micro_box def add_micro_box ( self , photo : numpy . ndarray ) -> None View Source def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) add_pred def add_pred ( self , photo : numpy . ndarray ) -> None View Source def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) close def close ( self , key = None ) View Source def close ( self , key = None ) : self . exit = True draw_box def draw_box ( self , photo : numpy . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 ) -> None View Source def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ). all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) draw_center def draw_center ( self , photo : numpy . ndarray ) View Source def draw_center(self, photo: np.ndarray): x, y, w, h = self.get_bbox(\"mic\") center = (x + w // 2, y + h // 2) cv.drawMarker(photo, center, (0, 0, 255), cv.MARKER_CROSS, 7, 1) draw_marker def draw_marker ( self , photo : numpy . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = 0 , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 ) -> None View Source def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) get_attribute def get_attribute ( self , col_name : str ) View Source def get_attribute ( self , col_name : str ) : return self . _curr_row [ col_name ] get_bbox def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ] View Source def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ] : x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) get_photo def get_photo ( self ) -> numpy . ndarray View Source def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo initialize def initialize ( self ) -> None Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. View Source def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () mainloop def mainloop ( self ) Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the exit flag is set to True (by self.close() (called by an hotkey)). It checks the play flag to determine if the player should continue playing or pause. The delay variable is used to control the delay between each iteration of the loop and is set to 0 to pause. View Source def mainloop ( self ) : \" \"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\" \" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) make_vid def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None View Source def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate {fps} -start_number 0 -i {join_paths(folder_path, img_name_format)} -c:v copy {join_paths(output_dir, ' video . mp4 ')}\" print ( command ) os . system ( command ) next def next ( self , key = None ) View Source def next(self, key=None): self.index = (self.index + 1) % len(self.reader) self.streamer.update_trackbar(\"#frame\", self.index) prev def prev ( self , key = None ) View Source def prev(self, key=None): self.index = (self.index - 1) % len(self.reader) self.streamer.update_trackbar(\"#frame\", self.index) print_hotkeys def print_hotkeys ( self ) View Source def print_hotkeys(self): print(\"Hotkeys:\") for hotkey in self.streamer.hotkeys: print(f\" - {hotkey.key} : {hotkey.description}\") save_stream def save_stream ( self , folder_path : str ) -> None View Source def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \"{self.cam_type}_\" + \"{:07d}.png\" with ImageSaver ( folder_path , tqdm_kwargs ={ \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ). replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) seek def seek ( self , pos : int ) View Source def seek(self, pos: int): self.index = (pos) % len(self.reader) self.update_curr_row() self.streamer.update(self.get_photo()) self.streamer.set_title(self._get_title()) set_delay def set_delay ( self , delay : int ) View Source def set_delay(self, delay: int): self.delay = delay toggle_cam def toggle_cam ( self , key : str = None ) View Source def toggle_cam(self, key: str = None): self.show_cam = not self.show_cam toggle_micro def toggle_micro ( self , key : str = None ) View Source def toggle_micro(self, key: str = None): self.show_micro = not self.show_micro toggle_play def toggle_play ( self , key : str = None ) View Source def toggle_play(self, key: str = None): self.play = not self.play toggle_pred def toggle_pred ( self , key : str = None ) View Source def toggle_pred(self, key: str = None): self.show_pred = not self.show_pred update_curr_row def update_curr_row ( self ) View Source def update_curr_row(self): self._curr_row = self.log.iloc[self.index]","title":"Vlc"},{"location":"reference/wtracker/eval/vlc/#module-wtrackerevalvlc","text":"View Source import pandas as pd import numpy as np from math import ceil , floor import os import cv2 as cv from typing import Callable from dataclasses import dataclass , field import matplotlib matplotlib . use ( \"QTAgg\" ) from wtracker.utils.path_utils import Files , create_directory , join_paths from wtracker.utils.io_utils import ImageSaver from wtracker.utils.frame_reader import FrameReader , DummyReader from wtracker.sim.config import TimingConfig @dataclass class HotKey : \"\"\" Represents a hotkey that can be used to trigger a specific function. Attributes: key (str): The key for the hotkey. func (Callable[[str], None]): The function to be called when the hotkey is triggered. description (str): The description of the hotkey (optional). \"\"\" key : str func : Callable [[ str ], None ] description : str = field ( default = \"\" ) def __post_init__ ( self ): self . key = self . key . lower () class StreamViewer : \"\"\" A class for viewing and interacting with photos and video streams. Args: window_name (str, optional): The name of the window. Example: with StreamViewer() as streamer: streamer.imshow(image) streamer.waitKey() \"\"\" def __init__ ( self , window_name : str = \"streamer\" ) -> None : self . window_name = window_name self . window = None self . hotkeys : list [ HotKey ] = [] self . register_hotkey ( HotKey ( \"q\" , self . close , \"close the window\" )) def register_hotkey ( self , hotkey : HotKey ): \"\"\" Registers a hotkey. Args: hotkey (HotKey): The hotkey to register. \"\"\" self . hotkeys . append ( hotkey ) def create_trackbar ( self , name : str , val : int , maxval : int , onChange = lambda x : x ): \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv . createTrackbar ( name , self . window_name , val , maxval , onChange ) def update_trackbar ( self , name : str , val : int ): \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv . setTrackbarPos ( name , self . window_name , val ) def set_title ( self , title : str ): \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv . setWindowTitle ( self . window_name , title ) def __enter__ ( self ): \"\"\" Enters the context manager. \"\"\" self . open () return self def __exit__ ( self , exc_type , exc_value , traceback ): \"\"\" Exits the context manager. \"\"\" self . close () def __del__ ( self ): \"\"\" Destructor method. \"\"\" self . close () def update ( self , image : np . ndarray , wait : int = 1 ): \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv . imshow ( self . window_name , image ) self . waitKey ( wait ) def waitKey ( self , timeout : int = 0 ): \"\"\" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\"\" key = cv . waitKey ( timeout ) if key <= 0 : return key key = chr ( key ) . lower () for hotkey in self . hotkeys : if key in hotkey . key : hotkey . func ( key ) return key def open ( self ): \"\"\" Opens the window. \"\"\" self . close () self . window = cv . namedWindow ( self . window_name , flags = cv . WINDOW_GUI_EXPANDED ) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) self . set_title ( self . window_name ) def close ( self , key : str = \"q\" ): \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self . window is not None : cv . destroyWindow ( self . window_name ) self . window = None def imshow ( self , image : np . ndarray , title : str = \"image\" ): \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self . update ( image , wait = 0 ) self . set_title ( title ) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) class VLC : \"\"\" The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Args: files (Files): The files to read frames from. If None, the video player will present the log data (simulation) on a white background. config (TimingConfig): The timing configuration of the system. log_path (str): The path to the log file. cam_type (str): The type of camera. This should match the prefix of the corresponding columns in the log file. show_pred (bool, optional): Whether to show the prediction box. show_micro (bool, optional): Whether to show the microscope box. show_cam (bool, optional): Whether to show the camera box. \"\"\" def __init__ ( self , files : Files | None , config : TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False , ) -> None : self . streamer = StreamViewer ( window_name = \"VLC\" ) self . index = 0 self . _curr_row = None self . exit = False self . delay = 0 self . play = False self . show_pred = show_pred self . show_micro = show_micro self . show_cam = show_cam self . cam_type : str = cam_type self . config : TimingConfig = config self . log : pd . DataFrame = self . _load_log ( log_path ) self . reader : FrameReader = self . _create_reader ( files ) def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () def _load_log ( self , log_path : str ) -> pd . DataFrame : if log_path is None : return None log = pd . read_csv ( log_path , index_col = \"frame\" ) if self . cam_type == \"plt\" : log [ \"plt_x\" ] = 0 log [ \"plt_y\" ] = 0 log [ \"plt_h\" ] = max ( log [ \"cam_y\" ]) + max ( log [ \"cam_h\" ]) log [ \"plt_w\" ] = max ( log [ \"cam_x\" ]) + max ( log [ \"cam_w\" ]) # assert len(log.index) == len(self.reader) self . _curr_row = log . iloc [ self . index ] return log def _init_hotkeys ( self ) -> None : self . streamer . register_hotkey ( HotKey ( \"q\" , self . close , \"close VLC\" )) self . streamer . register_hotkey ( HotKey ( \"d\" , self . next , \"next frame\" )) self . streamer . register_hotkey ( HotKey ( \"a\" , self . prev , \"previous frame\" )) self . streamer . register_hotkey ( HotKey ( \"p\" , self . toggle_play , \"play/pause\" )) self . streamer . register_hotkey ( HotKey ( \"h\" , self . toggle_pred , \"toggle prediction box\" )) self . streamer . register_hotkey ( HotKey ( \"m\" , self . toggle_micro , \"toggle microscope box\" )) self . streamer . register_hotkey ( HotKey ( \"c\" , self . toggle_cam , \"toggle camera box\" )) def print_hotkeys ( self ): print ( \"Hotkeys:\" ) for hotkey in self . streamer . hotkeys : print ( f \" - { hotkey . key } : { hotkey . description } \" ) def _create_window ( self ): self . streamer . open () self . streamer . create_trackbar ( \"delay\" , 0 , 250 , self . set_delay ) self . streamer . create_trackbar ( \"#frame\" , 0 , len ( self . reader ), self . seek ) def _create_reader ( self , files : Files ) -> FrameReader : if files is None : frame_num = len ( self . log . index ) frame_size = ( self . get_attribute ( self . cam_type + \"_h\" ), self . get_attribute ( self . cam_type + \"_w\" ), ) return DummyReader ( frame_num , frame_size ) filenames = [ f for f in files ] reader = FrameReader ( files . root , filenames ) return reader def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . streamer . close () def _get_title ( self ): curr_phase = self . get_attribute ( \"phase\" ) phase_title = f \"Action: { curr_phase } \" cycle_len = self . config . imaging_frame_num + self . config . moving_frame_num cycle_progress = 1 + self . index % cycle_len cycle_title = ( f \"cycle progress [ { cycle_progress } / { cycle_len } ]: \" + cycle_progress * \"#\" + ( cycle_len - cycle_progress ) * \"_\" ) title = f \" { phase_title } :: { cycle_title } \" return title def get_attribute ( self , col_name : str ): return self . _curr_row [ col_name ] def update_curr_row ( self ): self . _curr_row = self . log . iloc [ self . index ] def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo def seek ( self , pos : int ): self . index = ( pos ) % len ( self . reader ) self . update_curr_row () self . streamer . update ( self . get_photo ()) self . streamer . set_title ( self . _get_title ()) def next ( self , key = None ): self . index = ( self . index + 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def prev ( self , key = None ): self . index = ( self . index - 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def close ( self , key = None ): self . exit = True def set_delay ( self , delay : int ): self . delay = delay def toggle_play ( self , key : str = None ): self . play = not self . play def toggle_pred ( self , key : str = None ): self . show_pred = not self . show_pred def toggle_micro ( self , key : str = None ): self . show_micro = not self . show_micro def toggle_cam ( self , key : str = None ): self . show_cam = not self . show_cam def mainloop ( self ): \"\"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\"\" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ]: x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ) . all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) def draw_center ( self , photo : np . ndarray ): x , y , w , h = self . get_bbox ( \"mic\" ) center = ( x + w // 2 , y + h // 2 ) cv . drawMarker ( photo , center , ( 0 , 0 , 255 ), cv . MARKER_CROSS , 7 , 1 ) def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \" { self . cam_type } _\" + \" {:07d} .png\" with ImageSaver ( folder_path , tqdm_kwargs = { \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ) . replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate { fps } -start_number 0 -i { join_paths ( folder_path , img_name_format ) } -c:v copy { join_paths ( output_dir , 'video.mp4' ) } \" print ( command ) os . system ( command )","title":"Module wtracker.eval.vlc"},{"location":"reference/wtracker/eval/vlc/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/eval/vlc/#hotkey","text":"class HotKey ( key : str , func : Callable [[ str ], NoneType ], description : str = '' ) Represents a hotkey that can be used to trigger a specific function.","title":"HotKey"},{"location":"reference/wtracker/eval/vlc/#attributes","text":"Name Type Description Default key str The key for the hotkey. None func Callable[[str], None] The function to be called when the hotkey is triggered. None description str The description of the hotkey (optional). None View Source @dataclass class HotKey : \"\"\" Represents a hotkey that can be used to trigger a specific function. Attributes: key (str): The key for the hotkey. func (Callable[[str], None]): The function to be called when the hotkey is triggered. description (str): The description of the hotkey (optional). \"\"\" key : str func : Callable [ [str ] , None ] description : str = field ( default = \"\" ) def __post_init__ ( self ) : self . key = self . key . lower ()","title":"Attributes"},{"location":"reference/wtracker/eval/vlc/#class-variables","text":"description","title":"Class variables"},{"location":"reference/wtracker/eval/vlc/#streamviewer","text":"class StreamViewer ( window_name : str = 'streamer' ) A class for viewing and interacting with photos and video streams.","title":"StreamViewer"},{"location":"reference/wtracker/eval/vlc/#attributes_1","text":"Name Type Description Default window_name str The name of the window. None View Source class StreamViewer : \"\"\" A class for viewing and interacting with photos and video streams. Args: window_name (str, optional): The name of the window. Example: with StreamViewer() as streamer: streamer.imshow(image) streamer.waitKey() \"\"\" def __init__ ( self , window_name : str = \"streamer\" ) -> None : self . window_name = window_name self . window = None self . hotkeys : list [ HotKey ] = [] self . register_hotkey ( HotKey ( \"q\" , self . close , \"close the window\" )) def register_hotkey ( self , hotkey : HotKey ) : \"\"\" Registers a hotkey. Args: hotkey (HotKey): The hotkey to register. \"\"\" self . hotkeys . append ( hotkey ) def create_trackbar ( self , name : str , val : int , maxval : int , onChange = lambda x : x ) : \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv . createTrackbar ( name , self . window_name , val , maxval , onChange ) def update_trackbar ( self , name : str , val : int ) : \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv . setTrackbarPos ( name , self . window_name , val ) def set_title ( self , title : str ) : \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv . setWindowTitle ( self . window_name , title ) def __enter__ ( self ) : \"\"\" Enters the context manager. \"\"\" self . open () return self def __exit__ ( self , exc_type , exc_value , traceback ) : \"\"\" Exits the context manager. \"\"\" self . close () def __del__ ( self ) : \"\"\" Destructor method. \"\"\" self . close () def update ( self , image : np . ndarray , wait : int = 1 ) : \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv . imshow ( self . window_name , image ) self . waitKey ( wait ) def waitKey ( self , timeout : int = 0 ) : \"\"\" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\"\" key = cv . waitKey ( timeout ) if key <= 0 : return key key = chr ( key ). lower () for hotkey in self . hotkeys : if key in hotkey . key : hotkey . func ( key ) return key def open ( self ) : \"\"\" Opens the window. \"\"\" self . close () self . window = cv . namedWindow ( self . window_name , flags = cv . WINDOW_GUI_EXPANDED ) # cv . setWindowProperty ( self . window_name , cv . WND_PROP_TOPMOST , 1 ) self . set_title ( self . window_name ) def close ( self , key : str = \"q\" ) : \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self . window is not None : cv . destroyWindow ( self . window_name ) self . window = None def imshow ( self , image : np . ndarray , title : str = \"image\" ) : \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self . update ( image , wait = 0 ) self . set_title ( title ) # cv . setWindowProperty ( self . window_name , cv . WND_PROP_TOPMOST , 1 )","title":"Attributes"},{"location":"reference/wtracker/eval/vlc/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/eval/vlc/#close","text":"def close ( self , key : str = 'q' ) Closes the window. Parameters: Name Type Description Default key str The key to close the window. None View Source def close(self, key: str = \"q\"): \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self.window is not None: cv.destroyWindow(self.window_name) self.window = None","title":"close"},{"location":"reference/wtracker/eval/vlc/#create_trackbar","text":"def create_trackbar ( self , name : str , val : int , maxval : int , onChange =< function StreamViewer .< lambda > at 0x7f88625f0160 > ) Creates a trackbar. Parameters: Name Type Description Default name str The name of the trackbar. None val int The initial value of the trackbar. None maxval int The maximum value of the trackbar. None onChange function The function to call when the trackbar value changes. None View Source def create_trackbar(self, name: str, val: int, maxval: int, onChange=lambda x: x): \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv.createTrackbar(name, self.window_name, val, maxval, onChange)","title":"create_trackbar"},{"location":"reference/wtracker/eval/vlc/#imshow","text":"def imshow ( self , image : numpy . ndarray , title : str = 'image' ) Displays an image in the window. Parameters: Name Type Description Default image np.ndarray The image to display. None title str The title of the image. None View Source def imshow(self, image: np.ndarray, title: str = \"image\"): \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self.update(image, wait=0) self.set_title(title) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1)","title":"imshow"},{"location":"reference/wtracker/eval/vlc/#open","text":"def open ( self ) Opens the window. View Source def open(self): \"\"\" Opens the window. \"\"\" self.close() self.window = cv.namedWindow(self.window_name, flags=cv.WINDOW_GUI_EXPANDED) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) self.set_title(self.window_name)","title":"open"},{"location":"reference/wtracker/eval/vlc/#register_hotkey","text":"def register_hotkey ( self , hotkey : wtracker . eval . vlc . HotKey ) Registers a hotkey. Parameters: Name Type Description Default hotkey HotKey The hotkey to register. None View Source def register_hotkey ( self , hotkey: HotKey ) : \"\"\" Registers a hotkey . Args: hotkey ( HotKey ) : The hotkey to register . \"\"\" self . hotkeys . append ( hotkey )","title":"register_hotkey"},{"location":"reference/wtracker/eval/vlc/#set_title","text":"def set_title ( self , title : str ) Sets the title of the window. Parameters: Name Type Description Default title str The new title of the window. None View Source def set_title(self, title: str): \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv.setWindowTitle(self.window_name, title)","title":"set_title"},{"location":"reference/wtracker/eval/vlc/#update","text":"def update ( self , image : numpy . ndarray , wait : int = 1 ) Updates the window with a new image. Parameters: Name Type Description Default image np.ndarray The image to display. None wait int The delay in milliseconds before updating the window. None View Source def update(self, image: np.ndarray, wait: int = 1): \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv.imshow(self.window_name, image) self.waitKey(wait)","title":"update"},{"location":"reference/wtracker/eval/vlc/#update_trackbar","text":"def update_trackbar ( self , name : str , val : int ) Updates the value of a trackbar. Parameters: Name Type Description Default name str The name of the trackbar. None val int The new value of the trackbar. None View Source def update_trackbar(self, name: str, val: int): \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv.setTrackbarPos(name, self.window_name, val)","title":"update_trackbar"},{"location":"reference/wtracker/eval/vlc/#waitkey","text":"def waitKey ( self , timeout : int = 0 ) Waits for a key press. This Function also triggers the hotkeys. Parameters: Name Type Description Default timeout int The timeout in milliseconds. None Returns: Type Description str The key that was pressed. View Source def waitKey ( self , timeout : int = 0 ) : \"\" \" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\" \" key = cv.waitKey(timeout) if key <= 0: return key key = chr(key).lower() for hotkey in self.hotkeys: if key in hotkey.key: hotkey.func(key) return key","title":"waitKey"},{"location":"reference/wtracker/eval/vlc/#vlc","text":"class VLC ( files : wtracker . utils . path_utils . Files | None , config : wtracker . sim . config . TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False ) The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well.","title":"VLC"},{"location":"reference/wtracker/eval/vlc/#attributes_2","text":"Name Type Description Default files Files The files to read frames from. If None, the video player will present the log data (simulation) on a white background. None config TimingConfig The timing configuration of the system. None log_path str The path to the log file. None cam_type str The type of camera. This should match the prefix of the corresponding columns in the log file. None show_pred bool Whether to show the prediction box. None show_micro bool Whether to show the microscope box. None show_cam bool Whether to show the camera box. None View Source class VLC : \"\"\" The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Args: files (Files): The files to read frames from. If None, the video player will present the log data (simulation) on a white background. config (TimingConfig): The timing configuration of the system. log_path (str): The path to the log file. cam_type (str): The type of camera. This should match the prefix of the corresponding columns in the log file. show_pred (bool, optional): Whether to show the prediction box. show_micro (bool, optional): Whether to show the microscope box. show_cam (bool, optional): Whether to show the camera box. \"\"\" def __init__ ( self , files : Files | None , config : TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False , ) -> None : self . streamer = StreamViewer ( window_name = \"VLC\" ) self . index = 0 self . _curr_row = None self . exit = False self . delay = 0 self . play = False self . show_pred = show_pred self . show_micro = show_micro self . show_cam = show_cam self . cam_type : str = cam_type self . config : TimingConfig = config self . log : pd . DataFrame = self . _load_log ( log_path ) self . reader : FrameReader = self . _create_reader ( files ) def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () def _load_log ( self , log_path : str ) -> pd . DataFrame : if log_path is None : return None log = pd . read_csv ( log_path , index_col = \"frame\" ) if self . cam_type == \"plt\" : log [ \"plt_x\" ] = 0 log [ \"plt_y\" ] = 0 log [ \"plt_h\" ] = max ( log [ \"cam_y\" ]) + max ( log [ \"cam_h\" ]) log [ \"plt_w\" ] = max ( log [ \"cam_x\" ]) + max ( log [ \"cam_w\" ]) # assert len(log.index) == len(self.reader) self . _curr_row = log . iloc [ self . index ] return log def _init_hotkeys ( self ) -> None : self . streamer . register_hotkey ( HotKey ( \"q\" , self . close , \"close VLC\" )) self . streamer . register_hotkey ( HotKey ( \"d\" , self . next , \"next frame\" )) self . streamer . register_hotkey ( HotKey ( \"a\" , self . prev , \"previous frame\" )) self . streamer . register_hotkey ( HotKey ( \"p\" , self . toggle_play , \"play/pause\" )) self . streamer . register_hotkey ( HotKey ( \"h\" , self . toggle_pred , \"toggle prediction box\" )) self . streamer . register_hotkey ( HotKey ( \"m\" , self . toggle_micro , \"toggle microscope box\" )) self . streamer . register_hotkey ( HotKey ( \"c\" , self . toggle_cam , \"toggle camera box\" )) def print_hotkeys ( self ): print ( \"Hotkeys:\" ) for hotkey in self . streamer . hotkeys : print ( f \" - {hotkey.key} : {hotkey.description}\" ) def _create_window ( self ): self . streamer . open () self . streamer . create_trackbar ( \"delay\" , 0 , 250 , self . set_delay ) self . streamer . create_trackbar ( \"#frame\" , 0 , len ( self . reader ), self . seek ) def _create_reader ( self , files : Files ) -> FrameReader : if files is None : frame_num = len ( self . log . index ) frame_size = ( self . get_attribute ( self . cam_type + \"_h\" ), self . get_attribute ( self . cam_type + \"_w\" ), ) return DummyReader ( frame_num , frame_size ) filenames = [ f for f in files ] reader = FrameReader ( files . root , filenames ) return reader def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . streamer . close () def _get_title ( self ): curr_phase = self . get_attribute ( \"phase\" ) phase_title = f \"Action: {curr_phase}\" cycle_len = self . config . imaging_frame_num + self . config . moving_frame_num cycle_progress = 1 + self . index % cycle_len cycle_title = ( f \"cycle progress [{cycle_progress}/{cycle_len}]: \" + cycle_progress * \"#\" + ( cycle_len - cycle_progress ) * \"_\" ) title = f \"{phase_title} :: {cycle_title}\" return title def get_attribute ( self , col_name : str ): return self . _curr_row [ col_name ] def update_curr_row ( self ): self . _curr_row = self . log . iloc [ self . index ] def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo def seek ( self , pos : int ): self . index = ( pos ) % len ( self . reader ) self . update_curr_row () self . streamer . update ( self . get_photo ()) self . streamer . set_title ( self . _get_title ()) def next ( self , key = None ): self . index = ( self . index + 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def prev ( self , key = None ): self . index = ( self . index - 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def close ( self , key = None ): self . exit = True def set_delay ( self , delay : int ): self . delay = delay def toggle_play ( self , key : str = None ): self . play = not self . play def toggle_pred ( self , key : str = None ): self . show_pred = not self . show_pred def toggle_micro ( self , key : str = None ): self . show_micro = not self . show_micro def toggle_cam ( self , key : str = None ): self . show_cam = not self . show_cam def mainloop ( self ): \"\"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\"\" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ]: x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ) . all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) def draw_center ( self , photo : np . ndarray ): x , y , w , h = self . get_bbox ( \"mic\" ) center = ( x + w // 2 , y + h // 2 ) cv . drawMarker ( photo , center , ( 0 , 0 , 255 ), cv . MARKER_CROSS , 7 , 1 ) def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \"{self.cam_type}_\" + \"{:07d}.png\" with ImageSaver ( folder_path , tqdm_kwargs = { \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ) . replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate {fps} -start_number 0 -i {join_paths(folder_path, img_name_format)} -c:v copy {join_paths(output_dir, 'video.mp4')}\" print ( command ) os . system ( command )","title":"Attributes"},{"location":"reference/wtracker/eval/vlc/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/eval/vlc/#add_cam_box","text":"def add_cam_box ( self , photo : numpy . ndarray ) -> None View Source def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 )","title":"add_cam_box"},{"location":"reference/wtracker/eval/vlc/#add_micro_box","text":"def add_micro_box ( self , photo : numpy . ndarray ) -> None View Source def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 )","title":"add_micro_box"},{"location":"reference/wtracker/eval/vlc/#add_pred","text":"def add_pred ( self , photo : numpy . ndarray ) -> None View Source def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 )","title":"add_pred"},{"location":"reference/wtracker/eval/vlc/#close_1","text":"def close ( self , key = None ) View Source def close ( self , key = None ) : self . exit = True","title":"close"},{"location":"reference/wtracker/eval/vlc/#draw_box","text":"def draw_box ( self , photo : numpy . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 ) -> None View Source def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ). all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width )","title":"draw_box"},{"location":"reference/wtracker/eval/vlc/#draw_center","text":"def draw_center ( self , photo : numpy . ndarray ) View Source def draw_center(self, photo: np.ndarray): x, y, w, h = self.get_bbox(\"mic\") center = (x + w // 2, y + h // 2) cv.drawMarker(photo, center, (0, 0, 255), cv.MARKER_CROSS, 7, 1)","title":"draw_center"},{"location":"reference/wtracker/eval/vlc/#draw_marker","text":"def draw_marker ( self , photo : numpy . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = 0 , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 ) -> None View Source def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness )","title":"draw_marker"},{"location":"reference/wtracker/eval/vlc/#get_attribute","text":"def get_attribute ( self , col_name : str ) View Source def get_attribute ( self , col_name : str ) : return self . _curr_row [ col_name ]","title":"get_attribute"},{"location":"reference/wtracker/eval/vlc/#get_bbox","text":"def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ] View Source def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ] : x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h )","title":"get_bbox"},{"location":"reference/wtracker/eval/vlc/#get_photo","text":"def get_photo ( self ) -> numpy . ndarray View Source def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo","title":"get_photo"},{"location":"reference/wtracker/eval/vlc/#initialize","text":"def initialize ( self ) -> None Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. View Source def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys ()","title":"initialize"},{"location":"reference/wtracker/eval/vlc/#mainloop","text":"def mainloop ( self ) Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the exit flag is set to True (by self.close() (called by an hotkey)). It checks the play flag to determine if the player should continue playing or pause. The delay variable is used to control the delay between each iteration of the loop and is set to 0 to pause. View Source def mainloop ( self ) : \" \"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\" \" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay )","title":"mainloop"},{"location":"reference/wtracker/eval/vlc/#make_vid","text":"def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None View Source def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate {fps} -start_number 0 -i {join_paths(folder_path, img_name_format)} -c:v copy {join_paths(output_dir, ' video . mp4 ')}\" print ( command ) os . system ( command )","title":"make_vid"},{"location":"reference/wtracker/eval/vlc/#next","text":"def next ( self , key = None ) View Source def next(self, key=None): self.index = (self.index + 1) % len(self.reader) self.streamer.update_trackbar(\"#frame\", self.index)","title":"next"},{"location":"reference/wtracker/eval/vlc/#prev","text":"def prev ( self , key = None ) View Source def prev(self, key=None): self.index = (self.index - 1) % len(self.reader) self.streamer.update_trackbar(\"#frame\", self.index)","title":"prev"},{"location":"reference/wtracker/eval/vlc/#print_hotkeys","text":"def print_hotkeys ( self ) View Source def print_hotkeys(self): print(\"Hotkeys:\") for hotkey in self.streamer.hotkeys: print(f\" - {hotkey.key} : {hotkey.description}\")","title":"print_hotkeys"},{"location":"reference/wtracker/eval/vlc/#save_stream","text":"def save_stream ( self , folder_path : str ) -> None View Source def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \"{self.cam_type}_\" + \"{:07d}.png\" with ImageSaver ( folder_path , tqdm_kwargs ={ \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ). replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path )","title":"save_stream"},{"location":"reference/wtracker/eval/vlc/#seek","text":"def seek ( self , pos : int ) View Source def seek(self, pos: int): self.index = (pos) % len(self.reader) self.update_curr_row() self.streamer.update(self.get_photo()) self.streamer.set_title(self._get_title())","title":"seek"},{"location":"reference/wtracker/eval/vlc/#set_delay","text":"def set_delay ( self , delay : int ) View Source def set_delay(self, delay: int): self.delay = delay","title":"set_delay"},{"location":"reference/wtracker/eval/vlc/#toggle_cam","text":"def toggle_cam ( self , key : str = None ) View Source def toggle_cam(self, key: str = None): self.show_cam = not self.show_cam","title":"toggle_cam"},{"location":"reference/wtracker/eval/vlc/#toggle_micro","text":"def toggle_micro ( self , key : str = None ) View Source def toggle_micro(self, key: str = None): self.show_micro = not self.show_micro","title":"toggle_micro"},{"location":"reference/wtracker/eval/vlc/#toggle_play","text":"def toggle_play ( self , key : str = None ) View Source def toggle_play(self, key: str = None): self.play = not self.play","title":"toggle_play"},{"location":"reference/wtracker/eval/vlc/#toggle_pred","text":"def toggle_pred ( self , key : str = None ) View Source def toggle_pred(self, key: str = None): self.show_pred = not self.show_pred","title":"toggle_pred"},{"location":"reference/wtracker/eval/vlc/#update_curr_row","text":"def update_curr_row ( self ) View Source def update_curr_row(self): self._curr_row = self.log.iloc[self.index]","title":"update_curr_row"},{"location":"reference/wtracker/neural/","text":"Namespace wtracker.neural Sub-modules wtracker.neural.config wtracker.neural.dataset wtracker.neural.mlp wtracker.neural.train_results wtracker.neural.training","title":"Index"},{"location":"reference/wtracker/neural/#namespace-wtrackerneural","text":"","title":"Namespace wtracker.neural"},{"location":"reference/wtracker/neural/#sub-modules","text":"wtracker.neural.config wtracker.neural.dataset wtracker.neural.mlp wtracker.neural.train_results wtracker.neural.training","title":"Sub-modules"},{"location":"reference/wtracker/neural/config/","text":"Module wtracker.neural.config View Source from __future__ import annotations import torch from torch import nn from torch.optim import Optimizer from torch.utils.data import Dataset , DataLoader , random_split from dataclasses import dataclass , field from wtracker.utils.config_base import ConfigBase @dataclass class DatasetConfig ( ConfigBase ): input_frames : list [ int ] # The frames to use as input for the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). pred_frames : list [ int ] # The frames to predict. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). log_path : list [ str ] # The path to the log file containing the worm head predictions (by YOLO). def __post_init__ ( self ) -> None : if self . input_frames [ 0 ] != 0 : print ( \"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters.\" ) @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path ) OPTIMIZERS = { \"adam\" : torch . optim . Adam , \"sgd\" : torch . optim . SGD , \"rmsprop\" : torch . optim . RMSprop , \"adamw\" : torch . optim . AdamW , } LOSSES = { \"mse\" : nn . MSELoss , \"l1\" : nn . L1Loss , } @dataclass class TrainConfig ( ConfigBase ): # general parameters seed : int = field ( default = 42 , kw_only = True ) # Random seed for reproducibility dataset : DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is) # trainer parameters model : nn . Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk) loss_fn : str # The loss function to use, can be any of the keys in the LOSSES dict optimizer : str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict device : str = \"cuda\" # 'cuda' for training on GPU or 'cpu' otherwise log : bool = False # Whether to log and save the training process with tensorboard # training parameters num_epochs : int = 100 # Number of times to iterate over the dataset checkpoints : str = None # Path to save model checkpoints, influding the checkpoint name. early_stopping : int = None # Number of epochs to wait before stopping training if no improvement was made print_every : int = 5 # How often (#epochs) to print training progress # optimizer parameters learning_rate : float = 0.001 # Learning rate for the optimizer weight_decay : float = ( 1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger) ) # dataloader parameters batch_size : int = 256 # Number of samples in each batch shuffle : bool = True # Whether to shuffle the dataset at the beginning of each epoch num_workers : int = 0 # Number of subprocesses to use for data loading train_test_split : float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing dl_train : DataLoader = field ( init = False ) dl_test : DataLoader = field ( init = False ) @dataclass class IOConfig ( ConfigBase ): \"\"\" Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. \"\"\" input_frames : list [ int ] pred_frames : list [ int ] in_dim : int = field ( init = False ) out_dim : int = field ( init = False ) def __post_init__ ( self ): if 0 not in self . input_frames : print ( \"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters.\" ) self . in_dim = len ( self . input_frames ) * 4 self . out_dim = len ( self . pred_frames ) * 2 @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames ) Variables LOSSES OPTIMIZERS Classes DatasetConfig class DatasetConfig ( input_frames : 'list[int]' , pred_frames : 'list[int]' , log_path : 'list[str]' ) DatasetConfig(input_frames: 'list[int]', pred_frames: 'list[int]', log_path: 'list[str]') View Source @dataclass class DatasetConfig ( ConfigBase ) : input_frames : list [ int ] # The frames to use as input for the network . The frames are in the format of the number of frames before ( negative ) or after ( positive ) the prediction frame ( 0 ). pred_frames : list [ int ] # The frames to predict . The frames are in the format of the number of frames before ( negative ) or after ( positive ) the prediction frame ( 0 ). log_path : list [ str ] # The path to the log file containing the worm head predictions ( by YOLO ). def __post_init__ ( self ) -> None : if self . input_frames [ 0 ] != 0 : print ( \"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters.\" ) @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Static methods from_io_config def from_io_config ( io : 'IOConfig' , log_path : 'str' ) -> 'DatasetConfig' View Source @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path ) load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) IOConfig class IOConfig ( input_frames : 'list[int]' , pred_frames : 'list[int]' ) Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. View Source @dataclass class IOConfig ( ConfigBase ) : \"\"\" Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. \"\"\" input_frames : list [ int ] pred_frames : list [ int ] in_dim : int = field ( init = False ) out_dim : int = field ( init = False ) def __post_init__ ( self ) : if 0 not in self . input_frames : print ( \"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters.\" ) self . in_dim = len ( self . input_frames ) * 4 self . out_dim = len ( self . pred_frames ) * 2 @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Static methods from_datasetConfig def from_datasetConfig ( config : 'DatasetConfig' ) -> 'IOConfig' View Source @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames ) load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) TrainConfig class TrainConfig ( dataset : 'DatasetConfig' , model : 'nn.Module | str' , loss_fn : 'str' , optimizer : 'str' , device : 'str' = 'cuda' , log : 'bool' = False , num_epochs : 'int' = 100 , checkpoints : 'str' = None , early_stopping : 'int' = None , print_every : 'int' = 5 , learning_rate : 'float' = 0.001 , weight_decay : 'float' = 1e-05 , batch_size : 'int' = 256 , shuffle : 'bool' = True , num_workers : 'int' = 0 , train_test_split : 'float' = 0.8 , * , seed : 'int' = 42 ) TrainConfig(dataset: 'DatasetConfig', model: 'nn.Module | str', loss_fn: 'str', optimizer: 'str', device: 'str' = 'cuda', log: 'bool' = False, num_epochs: 'int' = 100, checkpoints: 'str' = None, early_stopping: 'int' = None, print_every: 'int' = 5, learning_rate: 'float' = 0.001, weight_decay: 'float' = 1e-05, batch_size: 'int' = 256, shuffle: 'bool' = True, num_workers: 'int' = 0, train_test_split: 'float' = 0.8, *, seed: 'int' = 42) View Source @ dataclass class TrainConfig ( ConfigBase ): # general parameters seed : int = field ( default = 42 , kw_only = True ) # Random seed for reproducibility dataset : DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is) # trainer parameters model : nn . Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk) loss_fn : str # The loss function to use, can be any of the keys in the LOSSES dict optimizer : str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict device : str = \"cuda\" # 'cuda' for training on GPU or 'cpu' otherwise log : bool = False # Whether to log and save the training process with tensorboard # training parameters num_epochs : int = 100 # Number of times to iterate over the dataset checkpoints : str = None # Path to save model checkpoints, influding the checkpoint name. early_stopping : int = None # Number of epochs to wait before stopping training if no improvement was made print_every : int = 5 # How often (#epochs) to print training progress # optimizer parameters learning_rate : float = 0.001 # Learning rate for the optimizer weight_decay : float = ( 1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger) ) # dataloader parameters batch_size : int = 256 # Number of samples in each batch shuffle : bool = True # Whether to shuffle the dataset at the beginning of each epoch num_workers : int = 0 # Number of subprocesses to use for data loading train_test_split : float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing dl_train : DataLoader = field ( init = False ) dl_test : DataLoader = field ( init = False ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables batch_size checkpoints device early_stopping learning_rate log num_epochs num_workers print_every seed shuffle train_test_split weight_decay Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"Config"},{"location":"reference/wtracker/neural/config/#module-wtrackerneuralconfig","text":"View Source from __future__ import annotations import torch from torch import nn from torch.optim import Optimizer from torch.utils.data import Dataset , DataLoader , random_split from dataclasses import dataclass , field from wtracker.utils.config_base import ConfigBase @dataclass class DatasetConfig ( ConfigBase ): input_frames : list [ int ] # The frames to use as input for the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). pred_frames : list [ int ] # The frames to predict. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). log_path : list [ str ] # The path to the log file containing the worm head predictions (by YOLO). def __post_init__ ( self ) -> None : if self . input_frames [ 0 ] != 0 : print ( \"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters.\" ) @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path ) OPTIMIZERS = { \"adam\" : torch . optim . Adam , \"sgd\" : torch . optim . SGD , \"rmsprop\" : torch . optim . RMSprop , \"adamw\" : torch . optim . AdamW , } LOSSES = { \"mse\" : nn . MSELoss , \"l1\" : nn . L1Loss , } @dataclass class TrainConfig ( ConfigBase ): # general parameters seed : int = field ( default = 42 , kw_only = True ) # Random seed for reproducibility dataset : DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is) # trainer parameters model : nn . Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk) loss_fn : str # The loss function to use, can be any of the keys in the LOSSES dict optimizer : str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict device : str = \"cuda\" # 'cuda' for training on GPU or 'cpu' otherwise log : bool = False # Whether to log and save the training process with tensorboard # training parameters num_epochs : int = 100 # Number of times to iterate over the dataset checkpoints : str = None # Path to save model checkpoints, influding the checkpoint name. early_stopping : int = None # Number of epochs to wait before stopping training if no improvement was made print_every : int = 5 # How often (#epochs) to print training progress # optimizer parameters learning_rate : float = 0.001 # Learning rate for the optimizer weight_decay : float = ( 1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger) ) # dataloader parameters batch_size : int = 256 # Number of samples in each batch shuffle : bool = True # Whether to shuffle the dataset at the beginning of each epoch num_workers : int = 0 # Number of subprocesses to use for data loading train_test_split : float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing dl_train : DataLoader = field ( init = False ) dl_test : DataLoader = field ( init = False ) @dataclass class IOConfig ( ConfigBase ): \"\"\" Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. \"\"\" input_frames : list [ int ] pred_frames : list [ int ] in_dim : int = field ( init = False ) out_dim : int = field ( init = False ) def __post_init__ ( self ): if 0 not in self . input_frames : print ( \"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters.\" ) self . in_dim = len ( self . input_frames ) * 4 self . out_dim = len ( self . pred_frames ) * 2 @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames )","title":"Module wtracker.neural.config"},{"location":"reference/wtracker/neural/config/#variables","text":"LOSSES OPTIMIZERS","title":"Variables"},{"location":"reference/wtracker/neural/config/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/config/#datasetconfig","text":"class DatasetConfig ( input_frames : 'list[int]' , pred_frames : 'list[int]' , log_path : 'list[str]' ) DatasetConfig(input_frames: 'list[int]', pred_frames: 'list[int]', log_path: 'list[str]') View Source @dataclass class DatasetConfig ( ConfigBase ) : input_frames : list [ int ] # The frames to use as input for the network . The frames are in the format of the number of frames before ( negative ) or after ( positive ) the prediction frame ( 0 ). pred_frames : list [ int ] # The frames to predict . The frames are in the format of the number of frames before ( negative ) or after ( positive ) the prediction frame ( 0 ). log_path : list [ str ] # The path to the log file containing the worm head predictions ( by YOLO ). def __post_init__ ( self ) -> None : if self . input_frames [ 0 ] != 0 : print ( \"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters.\" ) @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path )","title":"DatasetConfig"},{"location":"reference/wtracker/neural/config/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/config/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/neural/config/#from_io_config","text":"def from_io_config ( io : 'IOConfig' , log_path : 'str' ) -> 'DatasetConfig' View Source @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path )","title":"from_io_config"},{"location":"reference/wtracker/neural/config/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/neural/config/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/neural/config/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/config/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/neural/config/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/neural/config/#ioconfig","text":"class IOConfig ( input_frames : 'list[int]' , pred_frames : 'list[int]' ) Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. View Source @dataclass class IOConfig ( ConfigBase ) : \"\"\" Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. \"\"\" input_frames : list [ int ] pred_frames : list [ int ] in_dim : int = field ( init = False ) out_dim : int = field ( init = False ) def __post_init__ ( self ) : if 0 not in self . input_frames : print ( \"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters.\" ) self . in_dim = len ( self . input_frames ) * 4 self . out_dim = len ( self . pred_frames ) * 2 @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames )","title":"IOConfig"},{"location":"reference/wtracker/neural/config/#ancestors-in-mro_1","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/config/#static-methods_1","text":"","title":"Static methods"},{"location":"reference/wtracker/neural/config/#from_datasetconfig","text":"def from_datasetConfig ( config : 'DatasetConfig' ) -> 'IOConfig' View Source @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames )","title":"from_datasetConfig"},{"location":"reference/wtracker/neural/config/#load_json_1","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/neural/config/#load_pickle_1","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/neural/config/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/neural/config/#save_json_1","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/neural/config/#save_pickle_1","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/neural/config/#trainconfig","text":"class TrainConfig ( dataset : 'DatasetConfig' , model : 'nn.Module | str' , loss_fn : 'str' , optimizer : 'str' , device : 'str' = 'cuda' , log : 'bool' = False , num_epochs : 'int' = 100 , checkpoints : 'str' = None , early_stopping : 'int' = None , print_every : 'int' = 5 , learning_rate : 'float' = 0.001 , weight_decay : 'float' = 1e-05 , batch_size : 'int' = 256 , shuffle : 'bool' = True , num_workers : 'int' = 0 , train_test_split : 'float' = 0.8 , * , seed : 'int' = 42 ) TrainConfig(dataset: 'DatasetConfig', model: 'nn.Module | str', loss_fn: 'str', optimizer: 'str', device: 'str' = 'cuda', log: 'bool' = False, num_epochs: 'int' = 100, checkpoints: 'str' = None, early_stopping: 'int' = None, print_every: 'int' = 5, learning_rate: 'float' = 0.001, weight_decay: 'float' = 1e-05, batch_size: 'int' = 256, shuffle: 'bool' = True, num_workers: 'int' = 0, train_test_split: 'float' = 0.8, *, seed: 'int' = 42) View Source @ dataclass class TrainConfig ( ConfigBase ): # general parameters seed : int = field ( default = 42 , kw_only = True ) # Random seed for reproducibility dataset : DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is) # trainer parameters model : nn . Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk) loss_fn : str # The loss function to use, can be any of the keys in the LOSSES dict optimizer : str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict device : str = \"cuda\" # 'cuda' for training on GPU or 'cpu' otherwise log : bool = False # Whether to log and save the training process with tensorboard # training parameters num_epochs : int = 100 # Number of times to iterate over the dataset checkpoints : str = None # Path to save model checkpoints, influding the checkpoint name. early_stopping : int = None # Number of epochs to wait before stopping training if no improvement was made print_every : int = 5 # How often (#epochs) to print training progress # optimizer parameters learning_rate : float = 0.001 # Learning rate for the optimizer weight_decay : float = ( 1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger) ) # dataloader parameters batch_size : int = 256 # Number of samples in each batch shuffle : bool = True # Whether to shuffle the dataset at the beginning of each epoch num_workers : int = 0 # Number of subprocesses to use for data loading train_test_split : float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing dl_train : DataLoader = field ( init = False ) dl_test : DataLoader = field ( init = False )","title":"TrainConfig"},{"location":"reference/wtracker/neural/config/#ancestors-in-mro_2","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/config/#class-variables","text":"batch_size checkpoints device early_stopping learning_rate log num_epochs num_workers print_every seed shuffle train_test_split weight_decay","title":"Class variables"},{"location":"reference/wtracker/neural/config/#static-methods_2","text":"","title":"Static methods"},{"location":"reference/wtracker/neural/config/#load_json_2","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/neural/config/#load_pickle_2","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/neural/config/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/neural/config/#save_json_2","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/neural/config/#save_pickle_2","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/neural/dataset/","text":"Module wtracker.neural.dataset View Source from __future__ import annotations from torch.utils.data import Dataset from torch import Tensor import torch import pandas as pd import numpy as np from wtracker.neural.config import DatasetConfig from wtracker.utils.bbox_utils import BoxUtils class NumpyDataset ( Dataset ): \"\"\" A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Args: X (np.ndarray): The input data as a numpy array. y (np.ndarray): The output data as a numpy array. config (DatasetConfig, optional): The configuration object for the dataset. \"\"\" def __init__ ( self , X : np . ndarray , y : np . ndarray , config : DatasetConfig = None ): self . config = config self . X = Tensor ( X ) self . y = Tensor ( y ) def __len__ ( self ): return self . X . shape [ 0 ] def __getitem__ ( self , idx ): return self . X [ idx , :], self . y [ idx , :] def save ( self , path : str ) -> None : torch . save ( self , path ) @staticmethod def load ( path : str ) -> None : return torch . load ( path ) @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\" , \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str ( i ) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str ( i ) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ): X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna () . any ( axis = 1 ), y . isna () . any ( axis = 1 )) X = X . loc [ ~ na_mask ] y = y . loc [ ~ na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [:, 0 ] . reshape ( - 1 , 1 ) y_cords = X [:, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ]) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ]) % 2 == 1 y [:, x_cord_mask ] -= x_cords y [:, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ]) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ]) % 4 == 1 X [:, x_cord_mask ] -= x_cords # X [:, y_cord_mask ] -= y_cords # .reshape(-1, 1) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset Classes NumpyDataset class NumpyDataset ( X : 'np.ndarray' , y : 'np.ndarray' , config : 'DatasetConfig' = None ) A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Attributes Name Type Description Default X np.ndarray The input data as a numpy array. None y np.ndarray The output data as a numpy array. None config DatasetConfig The configuration object for the dataset. None View Source class NumpyDataset ( Dataset ) : \"\"\" A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Args: X (np.ndarray): The input data as a numpy array. y (np.ndarray): The output data as a numpy array. config (DatasetConfig, optional): The configuration object for the dataset. \"\"\" def __init__ ( self , X : np . ndarray , y : np . ndarray , config : DatasetConfig = None ) : self . config = config self . X = Tensor ( X ) self . y = Tensor ( y ) def __len__ ( self ) : return self . X . shape [ 0 ] def __getitem__ ( self , idx ) : return self . X [ idx, : ] , self . y [ idx, : ] def save ( self , path : str ) -> None : torch . save ( self , path ) @staticmethod def load ( path : str ) -> None : return torch . load ( path ) @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\", \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str(i) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str(i) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ) : X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna (). any ( axis = 1 ), y . isna (). any ( axis = 1 )) X = X . loc [ ~na_mask ] y = y . loc [ ~na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [ :, 0 ] . reshape ( - 1 , 1 ) y_cords = X [ :, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 1 y [ :, x_cord_mask ] -= x_cords y [ :, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 1 X [ :, x_cord_mask ] -= x_cords # X [ :, y_cord_mask ] -= y_cords # . reshape ( - 1 , 1 ) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset Ancestors (in MRO) torch.utils.data.dataset.Dataset typing.Generic Static methods create_from_config def create_from_config ( config : 'DatasetConfig' , save_path : 'str | None' = None ) -> 'NumpyDataset' View Source @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\", \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str(i) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str(i) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ) : X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna (). any ( axis = 1 ), y . isna (). any ( axis = 1 )) X = X . loc [ ~na_mask ] y = y . loc [ ~na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [ :, 0 ] . reshape ( - 1 , 1 ) y_cords = X [ :, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 1 y [ :, x_cord_mask ] -= x_cords y [ :, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 1 X [ :, x_cord_mask ] -= x_cords # X [ :, y_cord_mask ] -= y_cords # . reshape ( - 1 , 1 ) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset load def load ( path : 'str' ) -> 'None' View Source @ staticmethod def load ( path : str ) -> None : return torch . load ( path ) Methods save def save ( self , path : 'str' ) -> 'None' View Source def save ( self , path : str ) -> None : torch . save ( self , path )","title":"Dataset"},{"location":"reference/wtracker/neural/dataset/#module-wtrackerneuraldataset","text":"View Source from __future__ import annotations from torch.utils.data import Dataset from torch import Tensor import torch import pandas as pd import numpy as np from wtracker.neural.config import DatasetConfig from wtracker.utils.bbox_utils import BoxUtils class NumpyDataset ( Dataset ): \"\"\" A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Args: X (np.ndarray): The input data as a numpy array. y (np.ndarray): The output data as a numpy array. config (DatasetConfig, optional): The configuration object for the dataset. \"\"\" def __init__ ( self , X : np . ndarray , y : np . ndarray , config : DatasetConfig = None ): self . config = config self . X = Tensor ( X ) self . y = Tensor ( y ) def __len__ ( self ): return self . X . shape [ 0 ] def __getitem__ ( self , idx ): return self . X [ idx , :], self . y [ idx , :] def save ( self , path : str ) -> None : torch . save ( self , path ) @staticmethod def load ( path : str ) -> None : return torch . load ( path ) @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\" , \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str ( i ) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str ( i ) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ): X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna () . any ( axis = 1 ), y . isna () . any ( axis = 1 )) X = X . loc [ ~ na_mask ] y = y . loc [ ~ na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [:, 0 ] . reshape ( - 1 , 1 ) y_cords = X [:, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ]) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ]) % 2 == 1 y [:, x_cord_mask ] -= x_cords y [:, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ]) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ]) % 4 == 1 X [:, x_cord_mask ] -= x_cords # X [:, y_cord_mask ] -= y_cords # .reshape(-1, 1) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset","title":"Module wtracker.neural.dataset"},{"location":"reference/wtracker/neural/dataset/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/dataset/#numpydataset","text":"class NumpyDataset ( X : 'np.ndarray' , y : 'np.ndarray' , config : 'DatasetConfig' = None ) A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically.","title":"NumpyDataset"},{"location":"reference/wtracker/neural/dataset/#attributes","text":"Name Type Description Default X np.ndarray The input data as a numpy array. None y np.ndarray The output data as a numpy array. None config DatasetConfig The configuration object for the dataset. None View Source class NumpyDataset ( Dataset ) : \"\"\" A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Args: X (np.ndarray): The input data as a numpy array. y (np.ndarray): The output data as a numpy array. config (DatasetConfig, optional): The configuration object for the dataset. \"\"\" def __init__ ( self , X : np . ndarray , y : np . ndarray , config : DatasetConfig = None ) : self . config = config self . X = Tensor ( X ) self . y = Tensor ( y ) def __len__ ( self ) : return self . X . shape [ 0 ] def __getitem__ ( self , idx ) : return self . X [ idx, : ] , self . y [ idx, : ] def save ( self , path : str ) -> None : torch . save ( self , path ) @staticmethod def load ( path : str ) -> None : return torch . load ( path ) @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\", \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str(i) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str(i) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ) : X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna (). any ( axis = 1 ), y . isna (). any ( axis = 1 )) X = X . loc [ ~na_mask ] y = y . loc [ ~na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [ :, 0 ] . reshape ( - 1 , 1 ) y_cords = X [ :, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 1 y [ :, x_cord_mask ] -= x_cords y [ :, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 1 X [ :, x_cord_mask ] -= x_cords # X [ :, y_cord_mask ] -= y_cords # . reshape ( - 1 , 1 ) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset","title":"Attributes"},{"location":"reference/wtracker/neural/dataset/#ancestors-in-mro","text":"torch.utils.data.dataset.Dataset typing.Generic","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/dataset/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/neural/dataset/#create_from_config","text":"def create_from_config ( config : 'DatasetConfig' , save_path : 'str | None' = None ) -> 'NumpyDataset' View Source @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\", \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str(i) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str(i) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ) : X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna (). any ( axis = 1 ), y . isna (). any ( axis = 1 )) X = X . loc [ ~na_mask ] y = y . loc [ ~na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [ :, 0 ] . reshape ( - 1 , 1 ) y_cords = X [ :, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 1 y [ :, x_cord_mask ] -= x_cords y [ :, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 1 X [ :, x_cord_mask ] -= x_cords # X [ :, y_cord_mask ] -= y_cords # . reshape ( - 1 , 1 ) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset","title":"create_from_config"},{"location":"reference/wtracker/neural/dataset/#load","text":"def load ( path : 'str' ) -> 'None' View Source @ staticmethod def load ( path : str ) -> None : return torch . load ( path )","title":"load"},{"location":"reference/wtracker/neural/dataset/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/dataset/#save","text":"def save ( self , path : 'str' ) -> 'None' View Source def save ( self , path : str ) -> None : torch . save ( self , path )","title":"save"},{"location":"reference/wtracker/neural/mlp/","text":"Module wtracker.neural.mlp View Source from torch import Tensor , nn from typing import Union , Sequence from collections import defaultdict from wtracker.neural.config import IOConfig ACTIVATIONS = { \"relu\" : nn . ReLU , \"tanh\" : nn . Tanh , \"sigmoid\" : nn . Sigmoid , \"softmax\" : nn . Softmax , \"logsoftmax\" : nn . LogSoftmax , \"lrelu\" : nn . LeakyReLU , \"none\" : nn . Identity , None : nn . Identity , } # Default keyword arguments to pass to activation class constructors, e.g. # activation_cls(**ACTIVATION_DEFAULT_KWARGS[name]) ACTIVATION_DEFAULT_KWARGS = defaultdict ( dict , { ### \"softmax\" : dict ( dim = 1 ), \"logsoftmax\" : dict ( dim = 1 ), }, ) class WormPredictor ( nn . Module ): \"\"\" A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes: model: The neural network model that predicts worm behavior. io_config: The IOConfig object of the model. \"\"\" def __init__ ( self , model : nn . Module , io_config : IOConfig ): super () . __init__ () self . io_config : IOConfig = io_config self . model : nn . Module = model def forward ( self , x : Tensor ) -> Tensor : return self . model ( x ) class MLPLayer ( nn . Module ): \"\"\" A single layer perceptron, that can hold a bach-norm and activation layers as well. \"\"\" def __init__ ( self , in_dim : int , out_dim : Sequence [ int ], nonlin : Union [ str , nn . Module ], batch_norm : bool = True , ) -> None : super () . __init__ () layers = [] layers . append ( nn . Linear ( in_dim , out_dim )) in_dim = out_dim if batch_norm and nonlin not in [ \"none\" , None ]: layers . append ( nn . BatchNorm1d ( out_dim )) layers . append ( self . _make_activation ( nonlin )) self . mlp_layer = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 )) class MlpBlock ( nn . Module ): \"\"\" A general-purpose MLP. Args: in_dim: Input dimension. dims: Hidden dimensions, including output dimension. nonlins: Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. \"\"\" def __init__ ( self , in_dim : int , dims : Sequence [ int ], nonlins : Sequence [ Union [ str , nn . Module ]], batch_norm : bool = True , ): assert len ( nonlins ) == len ( dims ) self . in_dim = in_dim self . out_dim = dims [ - 1 ] self . dims = dims self . nonlins = nonlins super () . __init__ () layers = [] for i , out_dim in enumerate ( self . dims ): layers . append ( MLPLayer ( in_dim , out_dim , nonlins [ i ], batch_norm )) in_dim = out_dim self . sequence = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 )) class RMLP ( nn . Module ): def __init__ ( self , block_in_dim : int , block_dims : Sequence [ int ], block_nonlins : Sequence [ Union [ str , nn . Module ]], n_blocks : int , out_dim : int , in_dim : int = None , # if in_dim is an int, then a first layer will be made batch_norm : bool = True , ) -> None : super () . __init__ () # Create first layer if in_dim is not None self . input = nn . Identity () if in_dim is not None : self . input = MLPLayer ( in_dim , block_in_dim , block_nonlins [ 0 ], batch_norm ) # Create blocks layers = [] for i in range ( n_blocks ): layers . append ( MlpBlock ( block_in_dim , block_dims , block_nonlins , batch_norm )) self . blocks = nn . ModuleList ( layers ) # Create output layer self . output = nn . Linear ( block_dims [ - 1 ], out_dim ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x ) Variables ACTIVATIONS ACTIVATION_DEFAULT_KWARGS Classes MLPLayer class MLPLayer ( in_dim : int , out_dim : Sequence [ int ], nonlin : Union [ str , torch . nn . modules . module . Module ], batch_norm : bool = True ) A single layer perceptron, that can hold a bach-norm and activation layers as well. View Source class MLPLayer ( nn . Module ) : \"\"\" A single layer perceptron, that can hold a bach-norm and activation layers as well. \"\"\" def __init__ ( self , in_dim : int , out_dim : Sequence [ int ] , nonlin : Union [ str, nn.Module ] , batch_norm : bool = True , ) -> None : super (). __init__ () layers = [] layers . append ( nn . Linear ( in_dim , out_dim )) in_dim = out_dim if batch_norm and nonlin not in [ \"none\", None ] : layers . append ( nn . BatchNorm1d ( out_dim )) layers . append ( self . _make_activation ( nonlin )) self . mlp_layer = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 )) Ancestors (in MRO) torch.nn.modules.module.Module Class variables T_destination call_super_init dump_patches Methods add_module def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module apply def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self bfloat16 def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t ) buffers def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf children def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module compile def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs ) cpu def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ()) cuda def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device )) double def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t ) eval def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False ) extra_repr def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return '' float def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t ) forward def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 )) get_buffer def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer get_extra_state def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) get_parameter def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param get_submodule def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod half def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t ) ipu def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device )) load_state_dict def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys ) modules def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module named_buffers def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen named_children def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module named_modules def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate ) named_parameters def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen parameters def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param register_backward_hook def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle register_buffer def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name ) register_forward_hook def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_forward_pre_hook def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_hook def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_pre_hook def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_load_state_dict_post_hook def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle register_module def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module ) register_parameter def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param register_state_dict_pre_hook def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle requires_grad_ def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self set_extra_state def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) share_memory def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ()) state_dict def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination to def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert ) to_empty def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse ) train def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self type def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type )) xpu def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device )) zero_grad def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() MlpBlock class MlpBlock ( in_dim : int , dims : Sequence [ int ], nonlins : Sequence [ Union [ str , torch . nn . modules . module . Module ]], batch_norm : bool = True ) A general-purpose MLP. Attributes Name Type Description Default in_dim None Input dimension. None dims None Hidden dimensions, including output dimension. None nonlins None Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. None View Source class MlpBlock ( nn . Module ) : \"\"\" A general-purpose MLP. Args: in_dim: Input dimension. dims: Hidden dimensions, including output dimension. nonlins: Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. \"\"\" def __init__ ( self , in_dim : int , dims : Sequence [ int ] , nonlins : Sequence [ Union[str, nn.Module ] ] , batch_norm : bool = True , ) : assert len ( nonlins ) == len ( dims ) self . in_dim = in_dim self . out_dim = dims [ -1 ] self . dims = dims self . nonlins = nonlins super (). __init__ () layers = [] for i , out_dim in enumerate ( self . dims ) : layers . append ( MLPLayer ( in_dim , out_dim , nonlins [ i ] , batch_norm )) in_dim = out_dim self . sequence = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 )) Ancestors (in MRO) torch.nn.modules.module.Module Class variables T_destination call_super_init dump_patches Methods add_module def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module apply def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self bfloat16 def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t ) buffers def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf children def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module compile def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs ) cpu def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ()) cuda def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device )) double def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t ) eval def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False ) extra_repr def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return '' float def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t ) forward def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 )) get_buffer def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer get_extra_state def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) get_parameter def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param get_submodule def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod half def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t ) ipu def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device )) load_state_dict def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys ) modules def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module named_buffers def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen named_children def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module named_modules def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate ) named_parameters def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen parameters def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param register_backward_hook def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle register_buffer def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name ) register_forward_hook def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_forward_pre_hook def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_hook def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_pre_hook def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_load_state_dict_post_hook def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle register_module def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module ) register_parameter def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param register_state_dict_pre_hook def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle requires_grad_ def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self set_extra_state def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) share_memory def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ()) state_dict def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination to def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert ) to_empty def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse ) train def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self type def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type )) xpu def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device )) zero_grad def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() RMLP class RMLP ( block_in_dim : int , block_dims : Sequence [ int ], block_nonlins : Sequence [ Union [ str , torch . nn . modules . module . Module ]], n_blocks : int , out_dim : int , in_dim : int = None , batch_norm : bool = True ) Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model ( nn . Module ): def __init__ ( self ): super () . __init__ () self . conv1 = nn . Conv2d ( 1 , 20 , 5 ) self . conv2 = nn . Conv2d ( 20 , 20 , 5 ) def forward ( self , x ): x = F . relu ( self . conv1 ( x )) return F . relu ( self . conv2 ( x )) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth: to , etc. .. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child. View Source class RMLP ( nn . Module ) : def __init__ ( self , block_in_dim : int , block_dims : Sequence [ int ] , block_nonlins : Sequence [ Union[str, nn.Module ] ] , n_blocks : int , out_dim : int , in_dim : int = None , # if in_dim is an int , then a first layer will be made batch_norm : bool = True , ) -> None : super (). __init__ () # Create first layer if in_dim is not None self . input = nn . Identity () if in_dim is not None : self . input = MLPLayer ( in_dim , block_in_dim , block_nonlins [ 0 ] , batch_norm ) # Create blocks layers = [] for i in range ( n_blocks ) : layers . append ( MlpBlock ( block_in_dim , block_dims , block_nonlins , batch_norm )) self . blocks = nn . ModuleList ( layers ) # Create output layer self . output = nn . Linear ( block_dims [ -1 ] , out_dim ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x ) Ancestors (in MRO) torch.nn.modules.module.Module Class variables T_destination call_super_init dump_patches Methods add_module def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module apply def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self bfloat16 def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t ) buffers def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf children def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module compile def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs ) cpu def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ()) cuda def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device )) double def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t ) eval def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False ) extra_repr def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return '' float def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t ) forward def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x ) get_buffer def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer get_extra_state def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) get_parameter def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param get_submodule def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod half def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t ) ipu def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device )) load_state_dict def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys ) modules def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module named_buffers def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen named_children def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module named_modules def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate ) named_parameters def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen parameters def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param register_backward_hook def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle register_buffer def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name ) register_forward_hook def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_forward_pre_hook def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_hook def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_pre_hook def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_load_state_dict_post_hook def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle register_module def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module ) register_parameter def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param register_state_dict_pre_hook def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle requires_grad_ def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self set_extra_state def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) share_memory def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ()) state_dict def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination to def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert ) to_empty def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse ) train def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self type def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type )) xpu def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device )) zero_grad def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() WormPredictor class WormPredictor ( model : torch . nn . modules . module . Module , io_config : wtracker . neural . config . IOConfig ) A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes Name Type Description Default model None The neural network model that predicts worm behavior. None io_config None The IOConfig object of the model. None View Source class WormPredictor ( nn . Module ): \"\"\" A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes: model: The neural network model that predicts worm behavior. io_config: The IOConfig object of the model. \"\"\" def __init__ ( self , model: nn . Module , io_config: IOConfig ): super (). __init__ () self . io_config: IOConfig = io_config self . model: nn . Module = model def forward ( self , x : Tensor ) -> Tensor: return self . model ( x ) Ancestors (in MRO) torch.nn.modules.module.Module Class variables T_destination call_super_init dump_patches Methods add_module def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module apply def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self bfloat16 def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t ) buffers def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf children def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module compile def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs ) cpu def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ()) cuda def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device )) double def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t ) eval def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False ) extra_repr def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return '' float def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t ) forward def forward ( self , x : torch . Tensor ) -> torch . Tensor Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class: Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. View Source def forward ( self , x : Tensor ) -> Tensor : return self . model ( x ) get_buffer def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer get_extra_state def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) get_parameter def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param get_submodule def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod half def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t ) ipu def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device )) load_state_dict def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys ) modules def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module named_buffers def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen named_children def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module named_modules def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate ) named_parameters def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen parameters def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param register_backward_hook def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle register_buffer def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name ) register_forward_hook def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_forward_pre_hook def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_hook def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_pre_hook def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_load_state_dict_post_hook def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle register_module def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module ) register_parameter def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param register_state_dict_pre_hook def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle requires_grad_ def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self set_extra_state def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) share_memory def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ()) state_dict def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination to def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert ) to_empty def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse ) train def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self type def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type )) xpu def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device )) zero_grad def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"Mlp"},{"location":"reference/wtracker/neural/mlp/#module-wtrackerneuralmlp","text":"View Source from torch import Tensor , nn from typing import Union , Sequence from collections import defaultdict from wtracker.neural.config import IOConfig ACTIVATIONS = { \"relu\" : nn . ReLU , \"tanh\" : nn . Tanh , \"sigmoid\" : nn . Sigmoid , \"softmax\" : nn . Softmax , \"logsoftmax\" : nn . LogSoftmax , \"lrelu\" : nn . LeakyReLU , \"none\" : nn . Identity , None : nn . Identity , } # Default keyword arguments to pass to activation class constructors, e.g. # activation_cls(**ACTIVATION_DEFAULT_KWARGS[name]) ACTIVATION_DEFAULT_KWARGS = defaultdict ( dict , { ### \"softmax\" : dict ( dim = 1 ), \"logsoftmax\" : dict ( dim = 1 ), }, ) class WormPredictor ( nn . Module ): \"\"\" A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes: model: The neural network model that predicts worm behavior. io_config: The IOConfig object of the model. \"\"\" def __init__ ( self , model : nn . Module , io_config : IOConfig ): super () . __init__ () self . io_config : IOConfig = io_config self . model : nn . Module = model def forward ( self , x : Tensor ) -> Tensor : return self . model ( x ) class MLPLayer ( nn . Module ): \"\"\" A single layer perceptron, that can hold a bach-norm and activation layers as well. \"\"\" def __init__ ( self , in_dim : int , out_dim : Sequence [ int ], nonlin : Union [ str , nn . Module ], batch_norm : bool = True , ) -> None : super () . __init__ () layers = [] layers . append ( nn . Linear ( in_dim , out_dim )) in_dim = out_dim if batch_norm and nonlin not in [ \"none\" , None ]: layers . append ( nn . BatchNorm1d ( out_dim )) layers . append ( self . _make_activation ( nonlin )) self . mlp_layer = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 )) class MlpBlock ( nn . Module ): \"\"\" A general-purpose MLP. Args: in_dim: Input dimension. dims: Hidden dimensions, including output dimension. nonlins: Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. \"\"\" def __init__ ( self , in_dim : int , dims : Sequence [ int ], nonlins : Sequence [ Union [ str , nn . Module ]], batch_norm : bool = True , ): assert len ( nonlins ) == len ( dims ) self . in_dim = in_dim self . out_dim = dims [ - 1 ] self . dims = dims self . nonlins = nonlins super () . __init__ () layers = [] for i , out_dim in enumerate ( self . dims ): layers . append ( MLPLayer ( in_dim , out_dim , nonlins [ i ], batch_norm )) in_dim = out_dim self . sequence = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 )) class RMLP ( nn . Module ): def __init__ ( self , block_in_dim : int , block_dims : Sequence [ int ], block_nonlins : Sequence [ Union [ str , nn . Module ]], n_blocks : int , out_dim : int , in_dim : int = None , # if in_dim is an int, then a first layer will be made batch_norm : bool = True , ) -> None : super () . __init__ () # Create first layer if in_dim is not None self . input = nn . Identity () if in_dim is not None : self . input = MLPLayer ( in_dim , block_in_dim , block_nonlins [ 0 ], batch_norm ) # Create blocks layers = [] for i in range ( n_blocks ): layers . append ( MlpBlock ( block_in_dim , block_dims , block_nonlins , batch_norm )) self . blocks = nn . ModuleList ( layers ) # Create output layer self . output = nn . Linear ( block_dims [ - 1 ], out_dim ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x )","title":"Module wtracker.neural.mlp"},{"location":"reference/wtracker/neural/mlp/#variables","text":"ACTIVATIONS ACTIVATION_DEFAULT_KWARGS","title":"Variables"},{"location":"reference/wtracker/neural/mlp/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/mlp/#mlplayer","text":"class MLPLayer ( in_dim : int , out_dim : Sequence [ int ], nonlin : Union [ str , torch . nn . modules . module . Module ], batch_norm : bool = True ) A single layer perceptron, that can hold a bach-norm and activation layers as well. View Source class MLPLayer ( nn . Module ) : \"\"\" A single layer perceptron, that can hold a bach-norm and activation layers as well. \"\"\" def __init__ ( self , in_dim : int , out_dim : Sequence [ int ] , nonlin : Union [ str, nn.Module ] , batch_norm : bool = True , ) -> None : super (). __init__ () layers = [] layers . append ( nn . Linear ( in_dim , out_dim )) in_dim = out_dim if batch_norm and nonlin not in [ \"none\", None ] : layers . append ( nn . BatchNorm1d ( out_dim )) layers . append ( self . _make_activation ( nonlin )) self . mlp_layer = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 ))","title":"MLPLayer"},{"location":"reference/wtracker/neural/mlp/#ancestors-in-mro","text":"torch.nn.modules.module.Module","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/mlp/#class-variables","text":"T_destination call_super_init dump_patches","title":"Class variables"},{"location":"reference/wtracker/neural/mlp/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/mlp/#add_module","text":"def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module","title":"add_module"},{"location":"reference/wtracker/neural/mlp/#apply","text":"def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self","title":"apply"},{"location":"reference/wtracker/neural/mlp/#bfloat16","text":"def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t )","title":"bfloat16"},{"location":"reference/wtracker/neural/mlp/#buffers","text":"def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf","title":"buffers"},{"location":"reference/wtracker/neural/mlp/#children","text":"def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module","title":"children"},{"location":"reference/wtracker/neural/mlp/#compile","text":"def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs )","title":"compile"},{"location":"reference/wtracker/neural/mlp/#cpu","text":"def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ())","title":"cpu"},{"location":"reference/wtracker/neural/mlp/#cuda","text":"def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device ))","title":"cuda"},{"location":"reference/wtracker/neural/mlp/#double","text":"def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t )","title":"double"},{"location":"reference/wtracker/neural/mlp/#eval","text":"def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False )","title":"eval"},{"location":"reference/wtracker/neural/mlp/#extra_repr","text":"def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return ''","title":"extra_repr"},{"location":"reference/wtracker/neural/mlp/#float","text":"def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t )","title":"float"},{"location":"reference/wtracker/neural/mlp/#forward","text":"def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 ))","title":"forward"},{"location":"reference/wtracker/neural/mlp/#get_buffer","text":"def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer","title":"get_buffer"},{"location":"reference/wtracker/neural/mlp/#get_extra_state","text":"def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"get_extra_state"},{"location":"reference/wtracker/neural/mlp/#get_parameter","text":"def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param","title":"get_parameter"},{"location":"reference/wtracker/neural/mlp/#get_submodule","text":"def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod","title":"get_submodule"},{"location":"reference/wtracker/neural/mlp/#half","text":"def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t )","title":"half"},{"location":"reference/wtracker/neural/mlp/#ipu","text":"def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device ))","title":"ipu"},{"location":"reference/wtracker/neural/mlp/#load_state_dict","text":"def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys )","title":"load_state_dict"},{"location":"reference/wtracker/neural/mlp/#modules","text":"def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module","title":"modules"},{"location":"reference/wtracker/neural/mlp/#named_buffers","text":"def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_buffers"},{"location":"reference/wtracker/neural/mlp/#named_children","text":"def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module","title":"named_children"},{"location":"reference/wtracker/neural/mlp/#named_modules","text":"def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate )","title":"named_modules"},{"location":"reference/wtracker/neural/mlp/#named_parameters","text":"def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_parameters"},{"location":"reference/wtracker/neural/mlp/#parameters","text":"def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param","title":"parameters"},{"location":"reference/wtracker/neural/mlp/#register_backward_hook","text":"def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle","title":"register_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_buffer","text":"def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name )","title":"register_buffer"},{"location":"reference/wtracker/neural/mlp/#register_forward_hook","text":"def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_hook"},{"location":"reference/wtracker/neural/mlp/#register_forward_pre_hook","text":"def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_hook","text":"def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_pre_hook","text":"def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_load_state_dict_post_hook","text":"def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle","title":"register_load_state_dict_post_hook"},{"location":"reference/wtracker/neural/mlp/#register_module","text":"def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module )","title":"register_module"},{"location":"reference/wtracker/neural/mlp/#register_parameter","text":"def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param","title":"register_parameter"},{"location":"reference/wtracker/neural/mlp/#register_state_dict_pre_hook","text":"def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle","title":"register_state_dict_pre_hook"},{"location":"reference/wtracker/neural/mlp/#requires_grad_","text":"def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self","title":"requires_grad_"},{"location":"reference/wtracker/neural/mlp/#set_extra_state","text":"def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"set_extra_state"},{"location":"reference/wtracker/neural/mlp/#share_memory","text":"def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ())","title":"share_memory"},{"location":"reference/wtracker/neural/mlp/#state_dict","text":"def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination","title":"state_dict"},{"location":"reference/wtracker/neural/mlp/#to","text":"def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert )","title":"to"},{"location":"reference/wtracker/neural/mlp/#to_empty","text":"def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse )","title":"to_empty"},{"location":"reference/wtracker/neural/mlp/#train","text":"def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self","title":"train"},{"location":"reference/wtracker/neural/mlp/#type","text":"def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type ))","title":"type"},{"location":"reference/wtracker/neural/mlp/#xpu","text":"def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device ))","title":"xpu"},{"location":"reference/wtracker/neural/mlp/#zero_grad","text":"def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"zero_grad"},{"location":"reference/wtracker/neural/mlp/#mlpblock","text":"class MlpBlock ( in_dim : int , dims : Sequence [ int ], nonlins : Sequence [ Union [ str , torch . nn . modules . module . Module ]], batch_norm : bool = True ) A general-purpose MLP.","title":"MlpBlock"},{"location":"reference/wtracker/neural/mlp/#attributes","text":"Name Type Description Default in_dim None Input dimension. None dims None Hidden dimensions, including output dimension. None nonlins None Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. None View Source class MlpBlock ( nn . Module ) : \"\"\" A general-purpose MLP. Args: in_dim: Input dimension. dims: Hidden dimensions, including output dimension. nonlins: Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. \"\"\" def __init__ ( self , in_dim : int , dims : Sequence [ int ] , nonlins : Sequence [ Union[str, nn.Module ] ] , batch_norm : bool = True , ) : assert len ( nonlins ) == len ( dims ) self . in_dim = in_dim self . out_dim = dims [ -1 ] self . dims = dims self . nonlins = nonlins super (). __init__ () layers = [] for i , out_dim in enumerate ( self . dims ) : layers . append ( MLPLayer ( in_dim , out_dim , nonlins [ i ] , batch_norm )) in_dim = out_dim self . sequence = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 ))","title":"Attributes"},{"location":"reference/wtracker/neural/mlp/#ancestors-in-mro_1","text":"torch.nn.modules.module.Module","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/mlp/#class-variables_1","text":"T_destination call_super_init dump_patches","title":"Class variables"},{"location":"reference/wtracker/neural/mlp/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/neural/mlp/#add_module_1","text":"def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module","title":"add_module"},{"location":"reference/wtracker/neural/mlp/#apply_1","text":"def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self","title":"apply"},{"location":"reference/wtracker/neural/mlp/#bfloat16_1","text":"def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t )","title":"bfloat16"},{"location":"reference/wtracker/neural/mlp/#buffers_1","text":"def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf","title":"buffers"},{"location":"reference/wtracker/neural/mlp/#children_1","text":"def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module","title":"children"},{"location":"reference/wtracker/neural/mlp/#compile_1","text":"def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs )","title":"compile"},{"location":"reference/wtracker/neural/mlp/#cpu_1","text":"def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ())","title":"cpu"},{"location":"reference/wtracker/neural/mlp/#cuda_1","text":"def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device ))","title":"cuda"},{"location":"reference/wtracker/neural/mlp/#double_1","text":"def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t )","title":"double"},{"location":"reference/wtracker/neural/mlp/#eval_1","text":"def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False )","title":"eval"},{"location":"reference/wtracker/neural/mlp/#extra_repr_1","text":"def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return ''","title":"extra_repr"},{"location":"reference/wtracker/neural/mlp/#float_1","text":"def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t )","title":"float"},{"location":"reference/wtracker/neural/mlp/#forward_1","text":"def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 ))","title":"forward"},{"location":"reference/wtracker/neural/mlp/#get_buffer_1","text":"def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer","title":"get_buffer"},{"location":"reference/wtracker/neural/mlp/#get_extra_state_1","text":"def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"get_extra_state"},{"location":"reference/wtracker/neural/mlp/#get_parameter_1","text":"def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param","title":"get_parameter"},{"location":"reference/wtracker/neural/mlp/#get_submodule_1","text":"def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod","title":"get_submodule"},{"location":"reference/wtracker/neural/mlp/#half_1","text":"def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t )","title":"half"},{"location":"reference/wtracker/neural/mlp/#ipu_1","text":"def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device ))","title":"ipu"},{"location":"reference/wtracker/neural/mlp/#load_state_dict_1","text":"def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys )","title":"load_state_dict"},{"location":"reference/wtracker/neural/mlp/#modules_1","text":"def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module","title":"modules"},{"location":"reference/wtracker/neural/mlp/#named_buffers_1","text":"def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_buffers"},{"location":"reference/wtracker/neural/mlp/#named_children_1","text":"def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module","title":"named_children"},{"location":"reference/wtracker/neural/mlp/#named_modules_1","text":"def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate )","title":"named_modules"},{"location":"reference/wtracker/neural/mlp/#named_parameters_1","text":"def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_parameters"},{"location":"reference/wtracker/neural/mlp/#parameters_1","text":"def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param","title":"parameters"},{"location":"reference/wtracker/neural/mlp/#register_backward_hook_1","text":"def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle","title":"register_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_buffer_1","text":"def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name )","title":"register_buffer"},{"location":"reference/wtracker/neural/mlp/#register_forward_hook_1","text":"def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_hook"},{"location":"reference/wtracker/neural/mlp/#register_forward_pre_hook_1","text":"def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_hook_1","text":"def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_pre_hook_1","text":"def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_load_state_dict_post_hook_1","text":"def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle","title":"register_load_state_dict_post_hook"},{"location":"reference/wtracker/neural/mlp/#register_module_1","text":"def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module )","title":"register_module"},{"location":"reference/wtracker/neural/mlp/#register_parameter_1","text":"def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param","title":"register_parameter"},{"location":"reference/wtracker/neural/mlp/#register_state_dict_pre_hook_1","text":"def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle","title":"register_state_dict_pre_hook"},{"location":"reference/wtracker/neural/mlp/#requires_grad__1","text":"def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self","title":"requires_grad_"},{"location":"reference/wtracker/neural/mlp/#set_extra_state_1","text":"def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"set_extra_state"},{"location":"reference/wtracker/neural/mlp/#share_memory_1","text":"def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ())","title":"share_memory"},{"location":"reference/wtracker/neural/mlp/#state_dict_1","text":"def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination","title":"state_dict"},{"location":"reference/wtracker/neural/mlp/#to_1","text":"def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert )","title":"to"},{"location":"reference/wtracker/neural/mlp/#to_empty_1","text":"def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse )","title":"to_empty"},{"location":"reference/wtracker/neural/mlp/#train_1","text":"def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self","title":"train"},{"location":"reference/wtracker/neural/mlp/#type_1","text":"def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type ))","title":"type"},{"location":"reference/wtracker/neural/mlp/#xpu_1","text":"def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device ))","title":"xpu"},{"location":"reference/wtracker/neural/mlp/#zero_grad_1","text":"def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"zero_grad"},{"location":"reference/wtracker/neural/mlp/#rmlp","text":"class RMLP ( block_in_dim : int , block_dims : Sequence [ int ], block_nonlins : Sequence [ Union [ str , torch . nn . modules . module . Module ]], n_blocks : int , out_dim : int , in_dim : int = None , batch_norm : bool = True ) Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model ( nn . Module ): def __init__ ( self ): super () . __init__ () self . conv1 = nn . Conv2d ( 1 , 20 , 5 ) self . conv2 = nn . Conv2d ( 20 , 20 , 5 ) def forward ( self , x ): x = F . relu ( self . conv1 ( x )) return F . relu ( self . conv2 ( x )) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth: to , etc. .. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child. View Source class RMLP ( nn . Module ) : def __init__ ( self , block_in_dim : int , block_dims : Sequence [ int ] , block_nonlins : Sequence [ Union[str, nn.Module ] ] , n_blocks : int , out_dim : int , in_dim : int = None , # if in_dim is an int , then a first layer will be made batch_norm : bool = True , ) -> None : super (). __init__ () # Create first layer if in_dim is not None self . input = nn . Identity () if in_dim is not None : self . input = MLPLayer ( in_dim , block_in_dim , block_nonlins [ 0 ] , batch_norm ) # Create blocks layers = [] for i in range ( n_blocks ) : layers . append ( MlpBlock ( block_in_dim , block_dims , block_nonlins , batch_norm )) self . blocks = nn . ModuleList ( layers ) # Create output layer self . output = nn . Linear ( block_dims [ -1 ] , out_dim ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x )","title":"RMLP"},{"location":"reference/wtracker/neural/mlp/#ancestors-in-mro_2","text":"torch.nn.modules.module.Module","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/mlp/#class-variables_2","text":"T_destination call_super_init dump_patches","title":"Class variables"},{"location":"reference/wtracker/neural/mlp/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/neural/mlp/#add_module_2","text":"def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module","title":"add_module"},{"location":"reference/wtracker/neural/mlp/#apply_2","text":"def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self","title":"apply"},{"location":"reference/wtracker/neural/mlp/#bfloat16_2","text":"def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t )","title":"bfloat16"},{"location":"reference/wtracker/neural/mlp/#buffers_2","text":"def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf","title":"buffers"},{"location":"reference/wtracker/neural/mlp/#children_2","text":"def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module","title":"children"},{"location":"reference/wtracker/neural/mlp/#compile_2","text":"def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs )","title":"compile"},{"location":"reference/wtracker/neural/mlp/#cpu_2","text":"def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ())","title":"cpu"},{"location":"reference/wtracker/neural/mlp/#cuda_2","text":"def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device ))","title":"cuda"},{"location":"reference/wtracker/neural/mlp/#double_2","text":"def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t )","title":"double"},{"location":"reference/wtracker/neural/mlp/#eval_2","text":"def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False )","title":"eval"},{"location":"reference/wtracker/neural/mlp/#extra_repr_2","text":"def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return ''","title":"extra_repr"},{"location":"reference/wtracker/neural/mlp/#float_2","text":"def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t )","title":"float"},{"location":"reference/wtracker/neural/mlp/#forward_2","text":"def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x )","title":"forward"},{"location":"reference/wtracker/neural/mlp/#get_buffer_2","text":"def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer","title":"get_buffer"},{"location":"reference/wtracker/neural/mlp/#get_extra_state_2","text":"def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"get_extra_state"},{"location":"reference/wtracker/neural/mlp/#get_parameter_2","text":"def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param","title":"get_parameter"},{"location":"reference/wtracker/neural/mlp/#get_submodule_2","text":"def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod","title":"get_submodule"},{"location":"reference/wtracker/neural/mlp/#half_2","text":"def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t )","title":"half"},{"location":"reference/wtracker/neural/mlp/#ipu_2","text":"def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device ))","title":"ipu"},{"location":"reference/wtracker/neural/mlp/#load_state_dict_2","text":"def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys )","title":"load_state_dict"},{"location":"reference/wtracker/neural/mlp/#modules_2","text":"def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module","title":"modules"},{"location":"reference/wtracker/neural/mlp/#named_buffers_2","text":"def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_buffers"},{"location":"reference/wtracker/neural/mlp/#named_children_2","text":"def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module","title":"named_children"},{"location":"reference/wtracker/neural/mlp/#named_modules_2","text":"def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate )","title":"named_modules"},{"location":"reference/wtracker/neural/mlp/#named_parameters_2","text":"def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_parameters"},{"location":"reference/wtracker/neural/mlp/#parameters_2","text":"def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param","title":"parameters"},{"location":"reference/wtracker/neural/mlp/#register_backward_hook_2","text":"def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle","title":"register_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_buffer_2","text":"def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name )","title":"register_buffer"},{"location":"reference/wtracker/neural/mlp/#register_forward_hook_2","text":"def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_hook"},{"location":"reference/wtracker/neural/mlp/#register_forward_pre_hook_2","text":"def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_hook_2","text":"def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_pre_hook_2","text":"def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_load_state_dict_post_hook_2","text":"def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle","title":"register_load_state_dict_post_hook"},{"location":"reference/wtracker/neural/mlp/#register_module_2","text":"def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module )","title":"register_module"},{"location":"reference/wtracker/neural/mlp/#register_parameter_2","text":"def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param","title":"register_parameter"},{"location":"reference/wtracker/neural/mlp/#register_state_dict_pre_hook_2","text":"def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle","title":"register_state_dict_pre_hook"},{"location":"reference/wtracker/neural/mlp/#requires_grad__2","text":"def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self","title":"requires_grad_"},{"location":"reference/wtracker/neural/mlp/#set_extra_state_2","text":"def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"set_extra_state"},{"location":"reference/wtracker/neural/mlp/#share_memory_2","text":"def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ())","title":"share_memory"},{"location":"reference/wtracker/neural/mlp/#state_dict_2","text":"def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination","title":"state_dict"},{"location":"reference/wtracker/neural/mlp/#to_2","text":"def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert )","title":"to"},{"location":"reference/wtracker/neural/mlp/#to_empty_2","text":"def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse )","title":"to_empty"},{"location":"reference/wtracker/neural/mlp/#train_2","text":"def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self","title":"train"},{"location":"reference/wtracker/neural/mlp/#type_2","text":"def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type ))","title":"type"},{"location":"reference/wtracker/neural/mlp/#xpu_2","text":"def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device ))","title":"xpu"},{"location":"reference/wtracker/neural/mlp/#zero_grad_2","text":"def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"zero_grad"},{"location":"reference/wtracker/neural/mlp/#wormpredictor","text":"class WormPredictor ( model : torch . nn . modules . module . Module , io_config : wtracker . neural . config . IOConfig ) A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output.","title":"WormPredictor"},{"location":"reference/wtracker/neural/mlp/#attributes_1","text":"Name Type Description Default model None The neural network model that predicts worm behavior. None io_config None The IOConfig object of the model. None View Source class WormPredictor ( nn . Module ): \"\"\" A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes: model: The neural network model that predicts worm behavior. io_config: The IOConfig object of the model. \"\"\" def __init__ ( self , model: nn . Module , io_config: IOConfig ): super (). __init__ () self . io_config: IOConfig = io_config self . model: nn . Module = model def forward ( self , x : Tensor ) -> Tensor: return self . model ( x )","title":"Attributes"},{"location":"reference/wtracker/neural/mlp/#ancestors-in-mro_3","text":"torch.nn.modules.module.Module","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/mlp/#class-variables_3","text":"T_destination call_super_init dump_patches","title":"Class variables"},{"location":"reference/wtracker/neural/mlp/#methods_3","text":"","title":"Methods"},{"location":"reference/wtracker/neural/mlp/#add_module_3","text":"def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module","title":"add_module"},{"location":"reference/wtracker/neural/mlp/#apply_3","text":"def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self","title":"apply"},{"location":"reference/wtracker/neural/mlp/#bfloat16_3","text":"def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t )","title":"bfloat16"},{"location":"reference/wtracker/neural/mlp/#buffers_3","text":"def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf","title":"buffers"},{"location":"reference/wtracker/neural/mlp/#children_3","text":"def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module","title":"children"},{"location":"reference/wtracker/neural/mlp/#compile_3","text":"def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs )","title":"compile"},{"location":"reference/wtracker/neural/mlp/#cpu_3","text":"def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ())","title":"cpu"},{"location":"reference/wtracker/neural/mlp/#cuda_3","text":"def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device ))","title":"cuda"},{"location":"reference/wtracker/neural/mlp/#double_3","text":"def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t )","title":"double"},{"location":"reference/wtracker/neural/mlp/#eval_3","text":"def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False )","title":"eval"},{"location":"reference/wtracker/neural/mlp/#extra_repr_3","text":"def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return ''","title":"extra_repr"},{"location":"reference/wtracker/neural/mlp/#float_3","text":"def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t )","title":"float"},{"location":"reference/wtracker/neural/mlp/#forward_3","text":"def forward ( self , x : torch . Tensor ) -> torch . Tensor Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class: Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. View Source def forward ( self , x : Tensor ) -> Tensor : return self . model ( x )","title":"forward"},{"location":"reference/wtracker/neural/mlp/#get_buffer_3","text":"def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer","title":"get_buffer"},{"location":"reference/wtracker/neural/mlp/#get_extra_state_3","text":"def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"get_extra_state"},{"location":"reference/wtracker/neural/mlp/#get_parameter_3","text":"def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param","title":"get_parameter"},{"location":"reference/wtracker/neural/mlp/#get_submodule_3","text":"def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod","title":"get_submodule"},{"location":"reference/wtracker/neural/mlp/#half_3","text":"def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t )","title":"half"},{"location":"reference/wtracker/neural/mlp/#ipu_3","text":"def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device ))","title":"ipu"},{"location":"reference/wtracker/neural/mlp/#load_state_dict_3","text":"def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys )","title":"load_state_dict"},{"location":"reference/wtracker/neural/mlp/#modules_3","text":"def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module","title":"modules"},{"location":"reference/wtracker/neural/mlp/#named_buffers_3","text":"def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_buffers"},{"location":"reference/wtracker/neural/mlp/#named_children_3","text":"def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module","title":"named_children"},{"location":"reference/wtracker/neural/mlp/#named_modules_3","text":"def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate )","title":"named_modules"},{"location":"reference/wtracker/neural/mlp/#named_parameters_3","text":"def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_parameters"},{"location":"reference/wtracker/neural/mlp/#parameters_3","text":"def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param","title":"parameters"},{"location":"reference/wtracker/neural/mlp/#register_backward_hook_3","text":"def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle","title":"register_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_buffer_3","text":"def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name )","title":"register_buffer"},{"location":"reference/wtracker/neural/mlp/#register_forward_hook_3","text":"def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_hook"},{"location":"reference/wtracker/neural/mlp/#register_forward_pre_hook_3","text":"def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_hook_3","text":"def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_pre_hook_3","text":"def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_load_state_dict_post_hook_3","text":"def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle","title":"register_load_state_dict_post_hook"},{"location":"reference/wtracker/neural/mlp/#register_module_3","text":"def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module )","title":"register_module"},{"location":"reference/wtracker/neural/mlp/#register_parameter_3","text":"def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param","title":"register_parameter"},{"location":"reference/wtracker/neural/mlp/#register_state_dict_pre_hook_3","text":"def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle","title":"register_state_dict_pre_hook"},{"location":"reference/wtracker/neural/mlp/#requires_grad__3","text":"def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self","title":"requires_grad_"},{"location":"reference/wtracker/neural/mlp/#set_extra_state_3","text":"def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"set_extra_state"},{"location":"reference/wtracker/neural/mlp/#share_memory_3","text":"def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ())","title":"share_memory"},{"location":"reference/wtracker/neural/mlp/#state_dict_3","text":"def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination","title":"state_dict"},{"location":"reference/wtracker/neural/mlp/#to_3","text":"def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert )","title":"to"},{"location":"reference/wtracker/neural/mlp/#to_empty_3","text":"def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse )","title":"to_empty"},{"location":"reference/wtracker/neural/mlp/#train_3","text":"def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self","title":"train"},{"location":"reference/wtracker/neural/mlp/#type_3","text":"def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type ))","title":"type"},{"location":"reference/wtracker/neural/mlp/#xpu_3","text":"def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device ))","title":"xpu"},{"location":"reference/wtracker/neural/mlp/#zero_grad_3","text":"def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"zero_grad"},{"location":"reference/wtracker/neural/train_results/","text":"Module wtracker.neural.train_results View Source from typing import List , NamedTuple class BatchResult ( NamedTuple ): \"\"\" Represents the result of training for a single batch: the loss and number of correct classifications. \"\"\" loss : float num_correct : int class EpochResult ( NamedTuple ): \"\"\" Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). \"\"\" losses : List [ float ] accuracy : float class FitResult ( NamedTuple ): \"\"\" Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. \"\"\" num_epochs : int train_loss : List [ float ] train_acc : List [ float ] test_loss : List [ float ] test_acc : List [ float ] Classes BatchResult class BatchResult ( / , * args , ** kwargs ) Represents the result of training for a single batch: the loss and number of correct classifications. View Source class BatchResult ( NamedTuple ): \"\"\" Represents the result of training for a single batch: the loss and number of correct classifications. \"\"\" loss: float num_correct: int Ancestors (in MRO) builtins.tuple Class variables loss num_correct Methods count def count ( self , value , / ) Return number of occurrences of value. index def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present. EpochResult class EpochResult ( / , * args , ** kwargs ) Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). View Source class EpochResult ( NamedTuple ) : \"\"\" Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). \"\"\" losses : List [ float ] accuracy : float Ancestors (in MRO) builtins.tuple Class variables accuracy losses Methods count def count ( self , value , / ) Return number of occurrences of value. index def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present. FitResult class FitResult ( / , * args , ** kwargs ) Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. View Source class FitResult ( NamedTuple ) : \"\"\" Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. \"\"\" num_epochs : int train_loss : List [ float ] train_acc : List [ float ] test_loss : List [ float ] test_acc : List [ float ] Ancestors (in MRO) builtins.tuple Class variables num_epochs test_acc test_loss train_acc train_loss Methods count def count ( self , value , / ) Return number of occurrences of value. index def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present.","title":"Train Results"},{"location":"reference/wtracker/neural/train_results/#module-wtrackerneuraltrain_results","text":"View Source from typing import List , NamedTuple class BatchResult ( NamedTuple ): \"\"\" Represents the result of training for a single batch: the loss and number of correct classifications. \"\"\" loss : float num_correct : int class EpochResult ( NamedTuple ): \"\"\" Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). \"\"\" losses : List [ float ] accuracy : float class FitResult ( NamedTuple ): \"\"\" Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. \"\"\" num_epochs : int train_loss : List [ float ] train_acc : List [ float ] test_loss : List [ float ] test_acc : List [ float ]","title":"Module wtracker.neural.train_results"},{"location":"reference/wtracker/neural/train_results/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/train_results/#batchresult","text":"class BatchResult ( / , * args , ** kwargs ) Represents the result of training for a single batch: the loss and number of correct classifications. View Source class BatchResult ( NamedTuple ): \"\"\" Represents the result of training for a single batch: the loss and number of correct classifications. \"\"\" loss: float num_correct: int","title":"BatchResult"},{"location":"reference/wtracker/neural/train_results/#ancestors-in-mro","text":"builtins.tuple","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/train_results/#class-variables","text":"loss num_correct","title":"Class variables"},{"location":"reference/wtracker/neural/train_results/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/train_results/#count","text":"def count ( self , value , / ) Return number of occurrences of value.","title":"count"},{"location":"reference/wtracker/neural/train_results/#index","text":"def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present.","title":"index"},{"location":"reference/wtracker/neural/train_results/#epochresult","text":"class EpochResult ( / , * args , ** kwargs ) Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). View Source class EpochResult ( NamedTuple ) : \"\"\" Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). \"\"\" losses : List [ float ] accuracy : float","title":"EpochResult"},{"location":"reference/wtracker/neural/train_results/#ancestors-in-mro_1","text":"builtins.tuple","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/train_results/#class-variables_1","text":"accuracy losses","title":"Class variables"},{"location":"reference/wtracker/neural/train_results/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/neural/train_results/#count_1","text":"def count ( self , value , / ) Return number of occurrences of value.","title":"count"},{"location":"reference/wtracker/neural/train_results/#index_1","text":"def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present.","title":"index"},{"location":"reference/wtracker/neural/train_results/#fitresult","text":"class FitResult ( / , * args , ** kwargs ) Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. View Source class FitResult ( NamedTuple ) : \"\"\" Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. \"\"\" num_epochs : int train_loss : List [ float ] train_acc : List [ float ] test_loss : List [ float ] test_acc : List [ float ]","title":"FitResult"},{"location":"reference/wtracker/neural/train_results/#ancestors-in-mro_2","text":"builtins.tuple","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/train_results/#class-variables_2","text":"num_epochs test_acc test_loss train_acc train_loss","title":"Class variables"},{"location":"reference/wtracker/neural/train_results/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/neural/train_results/#count_2","text":"def count ( self , value , / ) Return number of occurrences of value.","title":"count"},{"location":"reference/wtracker/neural/train_results/#index_2","text":"def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present.","title":"index"},{"location":"reference/wtracker/neural/training/","text":"Module wtracker.neural.training View Source import os import abc import sys import torch import torch.nn as nn import torch.nn.functional import tqdm.auto from torch import Tensor from typing import Any , Tuple , Callable , Optional from torch.optim import Optimizer from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from wtracker.neural.train_results import FitResult , BatchResult , EpochResult class Trainer ( abc . ABC ): \"\"\" A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Args: model (nn.Module): The model to train. device (Optional[torch.device], optional): The device to run training on (CPU or GPU). log (bool, optional): Whether to log training progress with tensorboard. \"\"\" def __init__ ( self , model : nn . Module , device : Optional [ torch . device ] = None , log : bool = False , ): self . model = model self . device = device self . logger = None if not log else SummaryWriter () if self . logger is not None : self . logger . add_hparams ({ \"model\" : model . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"device\" : str ( device )}, {}, run_name = \"hparams\" ) if self . device : model . to ( self . device ) def _make_batch_result ( self , loss , num_correct ) -> BatchResult : loss = loss . item () if isinstance ( loss , Tensor ) else loss num_correct = num_correct . item () if isinstance ( num_correct , Tensor ) else num_correct return BatchResult ( float ( loss ), int ( num_correct )) def _make_fit_result ( self , num_epochs , train_losses , train_acc , test_losses , test_acc ) -> FitResult : num_epochs = num_epochs . item () if isinstance ( num_epochs , Tensor ) else num_epochs train_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in train_losses ] train_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in train_acc ] test_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in test_losses ] test_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in test_acc ] return FitResult ( int ( num_epochs ), train_losses , train_acc , test_losses , test_acc ) def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH { epoch + 1 } / { num_epochs } ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \" { self . logger . log_dir } / { checkpoint_filename } \" torch . save ( self . model , checkpoint_filename ) print ( f \" \\n *** Saved checkpoint { checkpoint_filename } :: val_loss= { loss : .3f } \" ) def train_epoch ( self , dl_train : DataLoader , ** kw ) -> EpochResult : \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( True ) # set train mode return self . _foreach_batch ( dl_train , self . train_batch , ** kw ) def test_epoch ( self , dl_test : DataLoader , ** kw ) -> EpochResult : \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( False ) # set evaluation (test) mode return self . _foreach_batch ( dl_test , self . test_batch , ** kw ) @abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @staticmethod def _print ( message , verbose = True ): \"\"\"Simple wrapper around print to make it conditional\"\"\" if verbose : print ( message ) @staticmethod def _foreach_batch ( dl : DataLoader , forward_fn : Callable [[ Any ], BatchResult ], verbose = True , max_batches = None , ) -> EpochResult : \"\"\" Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. \"\"\" losses = [] num_correct = 0 num_samples = len ( dl . sampler ) num_batches = len ( dl . batch_sampler ) if max_batches is not None : if max_batches < num_batches : num_batches = max_batches num_samples = num_batches * dl . batch_size if verbose : pbar_fn = tqdm . auto . tqdm pbar_file = sys . stdout else : pbar_fn = tqdm . tqdm pbar_file = open ( os . devnull , \"w\" ) pbar_name = forward_fn . __name__ with pbar_fn ( desc = pbar_name , total = num_batches , file = pbar_file ) as pbar : dl_iter = iter ( dl ) for batch_idx in range ( num_batches ): data = next ( dl_iter ) batch_res = forward_fn ( data ) pbar . set_description ( f \" { pbar_name } ( { batch_res . loss : .3f } )\" ) pbar . update () losses . append ( batch_res . loss ) num_correct += batch_res . num_correct avg_loss = sum ( losses ) / num_batches accuracy = 100.0 * num_correct / num_samples pbar . set_description ( f \" { pbar_name } \" f \"(Avg. Loss { avg_loss : .3f } , \" f \"Accuracy { accuracy : .2f } %)\" ) if not verbose : pbar_file . close () return EpochResult ( losses = losses , accuracy = accuracy ) def log_hparam ( self , hparam_dict : dict [ str , Any ], metric_dict : dict [ str , Any ] = {}, run_name : str = \"hparams\" ): if self . logger is not None : self . logger . add_hparams ( hparam_dict , metric_dict , run_name = run_name ) class MLPTrainer ( Trainer ): \"\"\" The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models. Args: model (nn.Module): The MLP model to be trained. loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. device (Optional[torch.device], optional): The device on which the model and data should be loaded. log (bool, optional): Whether to log training progress with tensorboard. Attributes: loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. \"\"\" def __init__ ( self , model : nn . Module , loss_fn : nn . Module , optimizer : Optimizer , device : Optional [ torch . device ] = None , log : bool = False , ): super () . __init__ ( model , device , log = log ) self . loss_fn = loss_fn self . optimizer = optimizer if self . logger is not None : self . logger . add_hparams ({ \"loss_fn\" : loss_fn . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"optimizer\" : optimizer . __class__ . __name__ }, {}, run_name = \"hparams\" ) optimizer_params = {} for key , val in optimizer . param_groups [ 0 ] . items (): optimizer_params [ key ] = str ( val ) optimizer_params . update ({ \"params\" : \"\" }) self . logger . add_hparams ( optimizer_params , {}, run_name = \"hparams\" ) def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) @torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct ) Classes MLPTrainer class MLPTrainer ( model : torch . nn . modules . module . Module , loss_fn : torch . nn . modules . module . Module , optimizer : torch . optim . optimizer . Optimizer , device : Optional [ torch . device ] = None , log : bool = False ) The MLPTrainer class is responsible for training and testing a multi-layer perceptron (MLP) models. Attributes Name Type Description Default model nn.Module The MLP model to be trained. None loss_fn nn.Module The loss function used for training. None optimizer Optimizer The optimizer used for updating the model's parameters. None device Optional[torch.device] The device on which the model and data should be loaded. None log bool Whether to log training progress with tensorboard. None loss_fn nn.Module The loss function used for training. None optimizer Optimizer The optimizer used for updating the model's parameters. None View Source class MLPTrainer ( Trainer ): \"\"\" The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models. Args: model (nn.Module): The MLP model to be trained. loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. device (Optional[torch.device], optional): The device on which the model and data should be loaded. log (bool, optional): Whether to log training progress with tensorboard. Attributes: loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. \"\"\" def __init__ ( self , model : nn . Module , loss_fn : nn . Module , optimizer : Optimizer , device : Optional [ torch . device ] = None , log : bool = False , ): super () . __init__ ( model , device , log = log ) self . loss_fn = loss_fn self . optimizer = optimizer if self . logger is not None : self . logger . add_hparams ({ \"loss_fn\" : loss_fn . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"optimizer\" : optimizer . __class__ . __name__ }, {}, run_name = \"hparams\" ) optimizer_params = {} for key , val in optimizer . param_groups [ 0 ] . items (): optimizer_params [ key ] = str ( val ) optimizer_params . update ({ \"params\" : \"\" }) self . logger . add_hparams ( optimizer_params , {}, run_name = \"hparams\" ) def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) @ torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct ) Ancestors (in MRO) wtracker.neural.training.Trainer abc.ABC Methods fit def fit ( self , dl_train : torch . utils . data . dataloader . DataLoader , dl_test : torch . utils . data . dataloader . DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw ) -> wtracker . neural . train_results . FitResult Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Parameters: Name Type Description Default dl_train DataLoader Dataloader for the training set. None dl_test DataLoader Dataloader for the test set. None num_epochs int Number of epochs to train for. None checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None print_every int Print progress every this number of epochs. None Returns: Type Description FitResult A FitResult object containing train and test losses per epoch. View Source def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) log_hparam def log_hparam ( self , hparam_dict : dict [ str , typing . Any ], metric_dict : dict [ str , typing . Any ] = {}, run_name : str = 'hparams' ) View Source def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = \"hparams\"): if self.logger is not None: self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name) save_checkpoint def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None Saves the model in it's current state to a file with the given name (treated as a relative path). Parameters: Name Type Description Default checkpoint_filename str File name or relative path to save to. None View Source def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \"\\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" ) test_batch def test_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model and calculates loss. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ). norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct ) test_epoch def test_epoch ( self , dl_test : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Evaluate model once over a test set (single epoch). Parameters: Name Type Description Default dl_test DataLoader DataLoader for the test set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(False) # set evaluation (test) mode return self._foreach_batch(dl_test, self.test_batch, **kw) train_batch def train_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ). norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) train_epoch def train_epoch ( self , dl_train : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Train once over a training set (single epoch). Parameters: Name Type Description Default dl_train DataLoader DataLoader for the training set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(True) # set train mode return self._foreach_batch(dl_train, self.train_batch, **kw) Trainer class Trainer ( model : torch . nn . modules . module . Module , device : Optional [ torch . device ] = None , log : bool = False ) A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Attributes Name Type Description Default model nn.Module The model to train. None device Optional[torch.device] The device to run training on (CPU or GPU). None log bool Whether to log training progress with tensorboard. None View Source class Trainer ( abc . ABC ): \"\"\" A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Args: model (nn.Module): The model to train. device (Optional[torch.device], optional): The device to run training on (CPU or GPU). log (bool, optional): Whether to log training progress with tensorboard. \"\"\" def __init__ ( self , model : nn . Module , device : Optional [ torch . device ] = None , log : bool = False , ): self . model = model self . device = device self . logger = None if not log else SummaryWriter () if self . logger is not None : self . logger . add_hparams ({ \"model\" : model . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"device\" : str ( device )}, {}, run_name = \"hparams\" ) if self . device : model . to ( self . device ) def _make_batch_result ( self , loss , num_correct ) -> BatchResult : loss = loss . item () if isinstance ( loss , Tensor ) else loss num_correct = num_correct . item () if isinstance ( num_correct , Tensor ) else num_correct return BatchResult ( float ( loss ), int ( num_correct )) def _make_fit_result ( self , num_epochs , train_losses , train_acc , test_losses , test_acc ) -> FitResult : num_epochs = num_epochs . item () if isinstance ( num_epochs , Tensor ) else num_epochs train_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in train_losses ] train_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in train_acc ] test_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in test_losses ] test_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in test_acc ] return FitResult ( int ( num_epochs ), train_losses , train_acc , test_losses , test_acc ) def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \" \\n *** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" ) def train_epoch ( self , dl_train : DataLoader , ** kw ) -> EpochResult : \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( True ) # set train mode return self . _foreach_batch ( dl_train , self . train_batch , ** kw ) def test_epoch ( self , dl_test : DataLoader , ** kw ) -> EpochResult : \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( False ) # set evaluation (test) mode return self . _foreach_batch ( dl_test , self . test_batch , ** kw ) @ abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @ abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @ staticmethod def _print ( message , verbose = True ): \"\"\"Simple wrapper around print to make it conditional\"\"\" if verbose : print ( message ) @ staticmethod def _foreach_batch ( dl : DataLoader , forward_fn : Callable [[ Any ], BatchResult ], verbose = True , max_batches = None , ) -> EpochResult : \"\"\" Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. \"\"\" losses = [] num_correct = 0 num_samples = len ( dl . sampler ) num_batches = len ( dl . batch_sampler ) if max_batches is not None : if max_batches < num_batches : num_batches = max_batches num_samples = num_batches * dl . batch_size if verbose : pbar_fn = tqdm . auto . tqdm pbar_file = sys . stdout else : pbar_fn = tqdm . tqdm pbar_file = open ( os . devnull , \"w\" ) pbar_name = forward_fn . __name__ with pbar_fn ( desc = pbar_name , total = num_batches , file = pbar_file ) as pbar : dl_iter = iter ( dl ) for batch_idx in range ( num_batches ): data = next ( dl_iter ) batch_res = forward_fn ( data ) pbar . set_description ( f \"{pbar_name} ({batch_res.loss:.3f})\" ) pbar . update () losses . append ( batch_res . loss ) num_correct += batch_res . num_correct avg_loss = sum ( losses ) / num_batches accuracy = 100.0 * num_correct / num_samples pbar . set_description ( f \"{pbar_name} \" f \"(Avg. Loss {avg_loss:.3f}, \" f \"Accuracy {accuracy:.2f}%)\" ) if not verbose : pbar_file . close () return EpochResult ( losses = losses , accuracy = accuracy ) def log_hparam ( self , hparam_dict : dict [ str , Any ], metric_dict : dict [ str , Any ] = {}, run_name : str = \"hparams\" ): if self . logger is not None : self . logger . add_hparams ( hparam_dict , metric_dict , run_name = run_name ) Ancestors (in MRO) abc.ABC Descendants wtracker.neural.training.MLPTrainer Methods fit def fit ( self , dl_train : torch . utils . data . dataloader . DataLoader , dl_test : torch . utils . data . dataloader . DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw ) -> wtracker . neural . train_results . FitResult Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Parameters: Name Type Description Default dl_train DataLoader Dataloader for the training set. None dl_test DataLoader Dataloader for the test set. None num_epochs int Number of epochs to train for. None checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None print_every int Print progress every this number of epochs. None Returns: Type Description FitResult A FitResult object containing train and test losses per epoch. View Source def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) log_hparam def log_hparam ( self , hparam_dict : dict [ str , typing . Any ], metric_dict : dict [ str , typing . Any ] = {}, run_name : str = 'hparams' ) View Source def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = \"hparams\"): if self.logger is not None: self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name) save_checkpoint def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None Saves the model in it's current state to a file with the given name (treated as a relative path). Parameters: Name Type Description Default checkpoint_filename str File name or relative path to save to. None View Source def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \"\\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" ) test_batch def test_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model and calculates loss. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @ abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () test_epoch def test_epoch ( self , dl_test : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Evaluate model once over a test set (single epoch). Parameters: Name Type Description Default dl_test DataLoader DataLoader for the test set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(False) # set evaluation (test) mode return self._foreach_batch(dl_test, self.test_batch, **kw) train_batch def train_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @ abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () train_epoch def train_epoch ( self , dl_train : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Train once over a training set (single epoch). Parameters: Name Type Description Default dl_train DataLoader DataLoader for the training set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(True) # set train mode return self._foreach_batch(dl_train, self.train_batch, **kw)","title":"Training"},{"location":"reference/wtracker/neural/training/#module-wtrackerneuraltraining","text":"View Source import os import abc import sys import torch import torch.nn as nn import torch.nn.functional import tqdm.auto from torch import Tensor from typing import Any , Tuple , Callable , Optional from torch.optim import Optimizer from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from wtracker.neural.train_results import FitResult , BatchResult , EpochResult class Trainer ( abc . ABC ): \"\"\" A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Args: model (nn.Module): The model to train. device (Optional[torch.device], optional): The device to run training on (CPU or GPU). log (bool, optional): Whether to log training progress with tensorboard. \"\"\" def __init__ ( self , model : nn . Module , device : Optional [ torch . device ] = None , log : bool = False , ): self . model = model self . device = device self . logger = None if not log else SummaryWriter () if self . logger is not None : self . logger . add_hparams ({ \"model\" : model . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"device\" : str ( device )}, {}, run_name = \"hparams\" ) if self . device : model . to ( self . device ) def _make_batch_result ( self , loss , num_correct ) -> BatchResult : loss = loss . item () if isinstance ( loss , Tensor ) else loss num_correct = num_correct . item () if isinstance ( num_correct , Tensor ) else num_correct return BatchResult ( float ( loss ), int ( num_correct )) def _make_fit_result ( self , num_epochs , train_losses , train_acc , test_losses , test_acc ) -> FitResult : num_epochs = num_epochs . item () if isinstance ( num_epochs , Tensor ) else num_epochs train_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in train_losses ] train_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in train_acc ] test_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in test_losses ] test_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in test_acc ] return FitResult ( int ( num_epochs ), train_losses , train_acc , test_losses , test_acc ) def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH { epoch + 1 } / { num_epochs } ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \" { self . logger . log_dir } / { checkpoint_filename } \" torch . save ( self . model , checkpoint_filename ) print ( f \" \\n *** Saved checkpoint { checkpoint_filename } :: val_loss= { loss : .3f } \" ) def train_epoch ( self , dl_train : DataLoader , ** kw ) -> EpochResult : \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( True ) # set train mode return self . _foreach_batch ( dl_train , self . train_batch , ** kw ) def test_epoch ( self , dl_test : DataLoader , ** kw ) -> EpochResult : \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( False ) # set evaluation (test) mode return self . _foreach_batch ( dl_test , self . test_batch , ** kw ) @abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @staticmethod def _print ( message , verbose = True ): \"\"\"Simple wrapper around print to make it conditional\"\"\" if verbose : print ( message ) @staticmethod def _foreach_batch ( dl : DataLoader , forward_fn : Callable [[ Any ], BatchResult ], verbose = True , max_batches = None , ) -> EpochResult : \"\"\" Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. \"\"\" losses = [] num_correct = 0 num_samples = len ( dl . sampler ) num_batches = len ( dl . batch_sampler ) if max_batches is not None : if max_batches < num_batches : num_batches = max_batches num_samples = num_batches * dl . batch_size if verbose : pbar_fn = tqdm . auto . tqdm pbar_file = sys . stdout else : pbar_fn = tqdm . tqdm pbar_file = open ( os . devnull , \"w\" ) pbar_name = forward_fn . __name__ with pbar_fn ( desc = pbar_name , total = num_batches , file = pbar_file ) as pbar : dl_iter = iter ( dl ) for batch_idx in range ( num_batches ): data = next ( dl_iter ) batch_res = forward_fn ( data ) pbar . set_description ( f \" { pbar_name } ( { batch_res . loss : .3f } )\" ) pbar . update () losses . append ( batch_res . loss ) num_correct += batch_res . num_correct avg_loss = sum ( losses ) / num_batches accuracy = 100.0 * num_correct / num_samples pbar . set_description ( f \" { pbar_name } \" f \"(Avg. Loss { avg_loss : .3f } , \" f \"Accuracy { accuracy : .2f } %)\" ) if not verbose : pbar_file . close () return EpochResult ( losses = losses , accuracy = accuracy ) def log_hparam ( self , hparam_dict : dict [ str , Any ], metric_dict : dict [ str , Any ] = {}, run_name : str = \"hparams\" ): if self . logger is not None : self . logger . add_hparams ( hparam_dict , metric_dict , run_name = run_name ) class MLPTrainer ( Trainer ): \"\"\" The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models. Args: model (nn.Module): The MLP model to be trained. loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. device (Optional[torch.device], optional): The device on which the model and data should be loaded. log (bool, optional): Whether to log training progress with tensorboard. Attributes: loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. \"\"\" def __init__ ( self , model : nn . Module , loss_fn : nn . Module , optimizer : Optimizer , device : Optional [ torch . device ] = None , log : bool = False , ): super () . __init__ ( model , device , log = log ) self . loss_fn = loss_fn self . optimizer = optimizer if self . logger is not None : self . logger . add_hparams ({ \"loss_fn\" : loss_fn . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"optimizer\" : optimizer . __class__ . __name__ }, {}, run_name = \"hparams\" ) optimizer_params = {} for key , val in optimizer . param_groups [ 0 ] . items (): optimizer_params [ key ] = str ( val ) optimizer_params . update ({ \"params\" : \"\" }) self . logger . add_hparams ( optimizer_params , {}, run_name = \"hparams\" ) def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) @torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct )","title":"Module wtracker.neural.training"},{"location":"reference/wtracker/neural/training/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/training/#mlptrainer","text":"class MLPTrainer ( model : torch . nn . modules . module . Module , loss_fn : torch . nn . modules . module . Module , optimizer : torch . optim . optimizer . Optimizer , device : Optional [ torch . device ] = None , log : bool = False ) The MLPTrainer class is responsible for training and testing a multi-layer perceptron (MLP) models.","title":"MLPTrainer"},{"location":"reference/wtracker/neural/training/#attributes","text":"Name Type Description Default model nn.Module The MLP model to be trained. None loss_fn nn.Module The loss function used for training. None optimizer Optimizer The optimizer used for updating the model's parameters. None device Optional[torch.device] The device on which the model and data should be loaded. None log bool Whether to log training progress with tensorboard. None loss_fn nn.Module The loss function used for training. None optimizer Optimizer The optimizer used for updating the model's parameters. None View Source class MLPTrainer ( Trainer ): \"\"\" The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models. Args: model (nn.Module): The MLP model to be trained. loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. device (Optional[torch.device], optional): The device on which the model and data should be loaded. log (bool, optional): Whether to log training progress with tensorboard. Attributes: loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. \"\"\" def __init__ ( self , model : nn . Module , loss_fn : nn . Module , optimizer : Optimizer , device : Optional [ torch . device ] = None , log : bool = False , ): super () . __init__ ( model , device , log = log ) self . loss_fn = loss_fn self . optimizer = optimizer if self . logger is not None : self . logger . add_hparams ({ \"loss_fn\" : loss_fn . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"optimizer\" : optimizer . __class__ . __name__ }, {}, run_name = \"hparams\" ) optimizer_params = {} for key , val in optimizer . param_groups [ 0 ] . items (): optimizer_params [ key ] = str ( val ) optimizer_params . update ({ \"params\" : \"\" }) self . logger . add_hparams ( optimizer_params , {}, run_name = \"hparams\" ) def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) @ torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct )","title":"Attributes"},{"location":"reference/wtracker/neural/training/#ancestors-in-mro","text":"wtracker.neural.training.Trainer abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/training/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/training/#fit","text":"def fit ( self , dl_train : torch . utils . data . dataloader . DataLoader , dl_test : torch . utils . data . dataloader . DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw ) -> wtracker . neural . train_results . FitResult Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Parameters: Name Type Description Default dl_train DataLoader Dataloader for the training set. None dl_test DataLoader Dataloader for the test set. None num_epochs int Number of epochs to train for. None checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None print_every int Print progress every this number of epochs. None Returns: Type Description FitResult A FitResult object containing train and test losses per epoch. View Source def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc )","title":"fit"},{"location":"reference/wtracker/neural/training/#log_hparam","text":"def log_hparam ( self , hparam_dict : dict [ str , typing . Any ], metric_dict : dict [ str , typing . Any ] = {}, run_name : str = 'hparams' ) View Source def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = \"hparams\"): if self.logger is not None: self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name)","title":"log_hparam"},{"location":"reference/wtracker/neural/training/#save_checkpoint","text":"def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None Saves the model in it's current state to a file with the given name (treated as a relative path). Parameters: Name Type Description Default checkpoint_filename str File name or relative path to save to. None View Source def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \"\\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" )","title":"save_checkpoint"},{"location":"reference/wtracker/neural/training/#test_batch","text":"def test_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model and calculates loss. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ). norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct )","title":"test_batch"},{"location":"reference/wtracker/neural/training/#test_epoch","text":"def test_epoch ( self , dl_test : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Evaluate model once over a test set (single epoch). Parameters: Name Type Description Default dl_test DataLoader DataLoader for the test set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(False) # set evaluation (test) mode return self._foreach_batch(dl_test, self.test_batch, **kw)","title":"test_epoch"},{"location":"reference/wtracker/neural/training/#train_batch","text":"def train_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ). norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct )","title":"train_batch"},{"location":"reference/wtracker/neural/training/#train_epoch","text":"def train_epoch ( self , dl_train : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Train once over a training set (single epoch). Parameters: Name Type Description Default dl_train DataLoader DataLoader for the training set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(True) # set train mode return self._foreach_batch(dl_train, self.train_batch, **kw)","title":"train_epoch"},{"location":"reference/wtracker/neural/training/#trainer","text":"class Trainer ( model : torch . nn . modules . module . Module , device : Optional [ torch . device ] = None , log : bool = False ) A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch)","title":"Trainer"},{"location":"reference/wtracker/neural/training/#attributes_1","text":"Name Type Description Default model nn.Module The model to train. None device Optional[torch.device] The device to run training on (CPU or GPU). None log bool Whether to log training progress with tensorboard. None View Source class Trainer ( abc . ABC ): \"\"\" A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Args: model (nn.Module): The model to train. device (Optional[torch.device], optional): The device to run training on (CPU or GPU). log (bool, optional): Whether to log training progress with tensorboard. \"\"\" def __init__ ( self , model : nn . Module , device : Optional [ torch . device ] = None , log : bool = False , ): self . model = model self . device = device self . logger = None if not log else SummaryWriter () if self . logger is not None : self . logger . add_hparams ({ \"model\" : model . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"device\" : str ( device )}, {}, run_name = \"hparams\" ) if self . device : model . to ( self . device ) def _make_batch_result ( self , loss , num_correct ) -> BatchResult : loss = loss . item () if isinstance ( loss , Tensor ) else loss num_correct = num_correct . item () if isinstance ( num_correct , Tensor ) else num_correct return BatchResult ( float ( loss ), int ( num_correct )) def _make_fit_result ( self , num_epochs , train_losses , train_acc , test_losses , test_acc ) -> FitResult : num_epochs = num_epochs . item () if isinstance ( num_epochs , Tensor ) else num_epochs train_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in train_losses ] train_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in train_acc ] test_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in test_losses ] test_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in test_acc ] return FitResult ( int ( num_epochs ), train_losses , train_acc , test_losses , test_acc ) def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \" \\n *** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" ) def train_epoch ( self , dl_train : DataLoader , ** kw ) -> EpochResult : \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( True ) # set train mode return self . _foreach_batch ( dl_train , self . train_batch , ** kw ) def test_epoch ( self , dl_test : DataLoader , ** kw ) -> EpochResult : \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( False ) # set evaluation (test) mode return self . _foreach_batch ( dl_test , self . test_batch , ** kw ) @ abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @ abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @ staticmethod def _print ( message , verbose = True ): \"\"\"Simple wrapper around print to make it conditional\"\"\" if verbose : print ( message ) @ staticmethod def _foreach_batch ( dl : DataLoader , forward_fn : Callable [[ Any ], BatchResult ], verbose = True , max_batches = None , ) -> EpochResult : \"\"\" Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. \"\"\" losses = [] num_correct = 0 num_samples = len ( dl . sampler ) num_batches = len ( dl . batch_sampler ) if max_batches is not None : if max_batches < num_batches : num_batches = max_batches num_samples = num_batches * dl . batch_size if verbose : pbar_fn = tqdm . auto . tqdm pbar_file = sys . stdout else : pbar_fn = tqdm . tqdm pbar_file = open ( os . devnull , \"w\" ) pbar_name = forward_fn . __name__ with pbar_fn ( desc = pbar_name , total = num_batches , file = pbar_file ) as pbar : dl_iter = iter ( dl ) for batch_idx in range ( num_batches ): data = next ( dl_iter ) batch_res = forward_fn ( data ) pbar . set_description ( f \"{pbar_name} ({batch_res.loss:.3f})\" ) pbar . update () losses . append ( batch_res . loss ) num_correct += batch_res . num_correct avg_loss = sum ( losses ) / num_batches accuracy = 100.0 * num_correct / num_samples pbar . set_description ( f \"{pbar_name} \" f \"(Avg. Loss {avg_loss:.3f}, \" f \"Accuracy {accuracy:.2f}%)\" ) if not verbose : pbar_file . close () return EpochResult ( losses = losses , accuracy = accuracy ) def log_hparam ( self , hparam_dict : dict [ str , Any ], metric_dict : dict [ str , Any ] = {}, run_name : str = \"hparams\" ): if self . logger is not None : self . logger . add_hparams ( hparam_dict , metric_dict , run_name = run_name )","title":"Attributes"},{"location":"reference/wtracker/neural/training/#ancestors-in-mro_1","text":"abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/training/#descendants","text":"wtracker.neural.training.MLPTrainer","title":"Descendants"},{"location":"reference/wtracker/neural/training/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/neural/training/#fit_1","text":"def fit ( self , dl_train : torch . utils . data . dataloader . DataLoader , dl_test : torch . utils . data . dataloader . DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw ) -> wtracker . neural . train_results . FitResult Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Parameters: Name Type Description Default dl_train DataLoader Dataloader for the training set. None dl_test DataLoader Dataloader for the test set. None num_epochs int Number of epochs to train for. None checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None print_every int Print progress every this number of epochs. None Returns: Type Description FitResult A FitResult object containing train and test losses per epoch. View Source def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc )","title":"fit"},{"location":"reference/wtracker/neural/training/#log_hparam_1","text":"def log_hparam ( self , hparam_dict : dict [ str , typing . Any ], metric_dict : dict [ str , typing . Any ] = {}, run_name : str = 'hparams' ) View Source def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = \"hparams\"): if self.logger is not None: self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name)","title":"log_hparam"},{"location":"reference/wtracker/neural/training/#save_checkpoint_1","text":"def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None Saves the model in it's current state to a file with the given name (treated as a relative path). Parameters: Name Type Description Default checkpoint_filename str File name or relative path to save to. None View Source def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \"\\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" )","title":"save_checkpoint"},{"location":"reference/wtracker/neural/training/#test_batch_1","text":"def test_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model and calculates loss. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @ abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError ()","title":"test_batch"},{"location":"reference/wtracker/neural/training/#test_epoch_1","text":"def test_epoch ( self , dl_test : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Evaluate model once over a test set (single epoch). Parameters: Name Type Description Default dl_test DataLoader DataLoader for the test set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(False) # set evaluation (test) mode return self._foreach_batch(dl_test, self.test_batch, **kw)","title":"test_epoch"},{"location":"reference/wtracker/neural/training/#train_batch_1","text":"def train_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @ abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError ()","title":"train_batch"},{"location":"reference/wtracker/neural/training/#train_epoch_1","text":"def train_epoch ( self , dl_train : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Train once over a training set (single epoch). Parameters: Name Type Description Default dl_train DataLoader DataLoader for the training set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(True) # set train mode return self._foreach_batch(dl_train, self.train_batch, **kw)","title":"train_epoch"},{"location":"reference/wtracker/sim/","text":"Module wtracker.sim View Source from wtracker.sim.config import TimingConfig , ExperimentConfig from wtracker.sim.motor_controllers import MotorController , StepMotorController , SineMotorController from wtracker.sim.simulator import Simulator , SimController from wtracker.sim.view_controller import ViewController Sub-modules wtracker.sim.config wtracker.sim.motor_controllers wtracker.sim.sim_controllers wtracker.sim.simulator wtracker.sim.view_controller","title":"Index"},{"location":"reference/wtracker/sim/#module-wtrackersim","text":"View Source from wtracker.sim.config import TimingConfig , ExperimentConfig from wtracker.sim.motor_controllers import MotorController , StepMotorController , SineMotorController from wtracker.sim.simulator import Simulator , SimController from wtracker.sim.view_controller import ViewController","title":"Module wtracker.sim"},{"location":"reference/wtracker/sim/#sub-modules","text":"wtracker.sim.config wtracker.sim.motor_controllers wtracker.sim.sim_controllers wtracker.sim.simulator wtracker.sim.view_controller","title":"Sub-modules"},{"location":"reference/wtracker/sim/config/","text":"Module wtracker.sim.config View Source from __future__ import annotations from dataclasses import dataclass , field import math from wtracker.utils.config_base import ConfigBase from wtracker.utils.frame_reader import FrameReader @dataclass class TimingConfig ( ConfigBase ): \"\"\" Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. \"\"\" experiment_config : ExperimentConfig = field ( repr = False ) \"\"\"The configuration of the experiment parameters.\"\"\" px_per_mm : int = field ( init = False ) mm_per_px : float = field ( init = False ) frames_per_sec : int = field ( init = False ) ms_per_frame : float = field ( init = False ) imaging_time_ms : float imaging_frame_num : int = field ( init = False ) pred_time_ms : float pred_frame_num : int = field ( init = False ) moving_time_ms : float moving_frame_num : int = field ( init = False ) camera_size_mm : tuple [ float , float ] camera_size_px : tuple [ int , int ] = field ( init = False ) micro_size_mm : tuple [ float , float ] micro_size_px : tuple [ int , int ] = field ( init = False ) def __post_init__ ( self ): self . frames_per_sec = self . experiment_config . frames_per_sec self . ms_per_frame = self . experiment_config . ms_per_frame self . imaging_frame_num = math . ceil ( self . imaging_time_ms / self . ms_per_frame ) self . pred_frame_num = math . ceil ( self . pred_time_ms / self . ms_per_frame ) self . moving_frame_num = math . ceil ( self . moving_time_ms / self . ms_per_frame ) self . mm_per_px = self . experiment_config . mm_per_px self . px_per_mm = self . experiment_config . px_per_mm self . camera_size_px = ( round ( self . px_per_mm * self . camera_size_mm [ 0 ]), round ( self . px_per_mm * self . camera_size_mm [ 1 ]), ) self . micro_size_px = ( round ( self . px_per_mm * self . micro_size_mm [ 0 ]), round ( self . px_per_mm * self . micro_size_mm [ 1 ]), ) del self . experiment_config # experiment_config was temporary, only for the constructor @property def cycle_frame_num ( self ) -> int : return self . imaging_frame_num + self . moving_frame_num @property def cycle_time_ms ( self ) -> float : return self . cycle_frame_num * self . ms_per_frame @dataclass class ExperimentConfig ( ConfigBase ): \"\"\" Configuration for the experiment parameters. These parameters can change between different experiments. \"\"\" name : str \"\"\"Experiment name\"\"\" num_frames : int \"\"\"total number of frames of the experiment\"\"\" frames_per_sec : float \"\"\"Number of frames per second that the experiment was recorded at\"\"\" orig_resolution : tuple [ int , int ] \"\"\"Original resolution of the frames in pixels, in format (h, w)\"\"\" px_per_mm : float \"\"\"Number of pixels in a single millimeter\"\"\" init_position : tuple [ int , int ] \"\"\"The initial position of the center of the platform, in pixels (x, y) format. Platform's initial position should point to the worm, or close to it.\"\"\" comments : str = \"\" \"\"\"Additional comments about the experiment\"\"\" mm_per_px : float = field ( init = False ) \"\"\"Number of millimeters in a single pixel\"\"\" ms_per_frame : float = field ( init = False ) \"\"\"Number of milliseconds per frame\"\"\" def __post_init__ ( self ): self . ms_per_frame = 1000 / self . frames_per_sec self . mm_per_px = 1 / self . px_per_mm @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int , int ], ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , ) Classes ExperimentConfig class ExperimentConfig ( name : 'str' , num_frames : 'int' , frames_per_sec : 'float' , orig_resolution : 'tuple[int, int]' , px_per_mm : 'float' , init_position : 'tuple[int, int]' , comments : 'str' = '' ) Configuration for the experiment parameters. These parameters can change between different experiments. View Source @dataclass class ExperimentConfig ( ConfigBase ) : \"\"\" Configuration for the experiment parameters. These parameters can change between different experiments. \"\"\" name : str \"\"\"Experiment name\"\"\" num_frames : int \"\"\"total number of frames of the experiment\"\"\" frames_per_sec : float \"\"\"Number of frames per second that the experiment was recorded at\"\"\" orig_resolution : tuple [ int, int ] \"\"\"Original resolution of the frames in pixels, in format (h, w)\"\"\" px_per_mm : float \"\"\"Number of pixels in a single millimeter\"\"\" init_position : tuple [ int, int ] \"\"\"The initial position of the center of the platform, in pixels (x, y) format. Platform's initial position should point to the worm, or close to it.\"\"\" comments : str = \"\" \"\"\"Additional comments about the experiment\"\"\" mm_per_px : float = field ( init = False ) \"\"\"Number of millimeters in a single pixel\"\"\" ms_per_frame : float = field ( init = False ) \"\"\"Number of milliseconds per frame\"\"\" def __post_init__ ( self ) : self . ms_per_frame = 1000 / self . frames_per_sec self . mm_per_px = 1 / self . px_per_mm @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int, int ] , ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables comments Static methods from_frame_reader def from_frame_reader ( reader : 'FrameReader' , name : 'str' , frames_per_sec : 'int' , px_per_mm : 'float' , init_position : 'tuple[int, int]' ) -> 'ExperimentConfig' View Source @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int, int ] , ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , ) load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) TimingConfig class TimingConfig ( experiment_config : 'ExperimentConfig' , imaging_time_ms : 'float' , pred_time_ms : 'float' , moving_time_ms : 'float' , camera_size_mm : 'tuple[float, float]' , micro_size_mm : 'tuple[float, float]' ) Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. View Source @ dataclass class TimingConfig ( ConfigBase ): \"\"\" Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. \"\"\" experiment_config : ExperimentConfig = field ( repr = False ) \"\"\"The configuration of the experiment parameters.\"\"\" px_per_mm : int = field ( init = False ) mm_per_px : float = field ( init = False ) frames_per_sec : int = field ( init = False ) ms_per_frame : float = field ( init = False ) imaging_time_ms : float imaging_frame_num : int = field ( init = False ) pred_time_ms : float pred_frame_num : int = field ( init = False ) moving_time_ms : float moving_frame_num : int = field ( init = False ) camera_size_mm : tuple [ float , float ] camera_size_px : tuple [ int , int ] = field ( init = False ) micro_size_mm : tuple [ float , float ] micro_size_px : tuple [ int , int ] = field ( init = False ) def __post_init__ ( self ): self . frames_per_sec = self . experiment_config . frames_per_sec self . ms_per_frame = self . experiment_config . ms_per_frame self . imaging_frame_num = math . ceil ( self . imaging_time_ms / self . ms_per_frame ) self . pred_frame_num = math . ceil ( self . pred_time_ms / self . ms_per_frame ) self . moving_frame_num = math . ceil ( self . moving_time_ms / self . ms_per_frame ) self . mm_per_px = self . experiment_config . mm_per_px self . px_per_mm = self . experiment_config . px_per_mm self . camera_size_px = ( round ( self . px_per_mm * self . camera_size_mm [ 0 ]), round ( self . px_per_mm * self . camera_size_mm [ 1 ]), ) self . micro_size_px = ( round ( self . px_per_mm * self . micro_size_mm [ 0 ]), round ( self . px_per_mm * self . micro_size_mm [ 1 ]), ) del self . experiment_config # experiment_config was temporary, only for the constructor @ property def cycle_frame_num ( self ) -> int : return self . imaging_frame_num + self . moving_frame_num @ property def cycle_time_ms ( self ) -> float : return self . cycle_frame_num * self . ms_per_frame Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Instance variables cycle_frame_num cycle_time_ms Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"Config"},{"location":"reference/wtracker/sim/config/#module-wtrackersimconfig","text":"View Source from __future__ import annotations from dataclasses import dataclass , field import math from wtracker.utils.config_base import ConfigBase from wtracker.utils.frame_reader import FrameReader @dataclass class TimingConfig ( ConfigBase ): \"\"\" Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. \"\"\" experiment_config : ExperimentConfig = field ( repr = False ) \"\"\"The configuration of the experiment parameters.\"\"\" px_per_mm : int = field ( init = False ) mm_per_px : float = field ( init = False ) frames_per_sec : int = field ( init = False ) ms_per_frame : float = field ( init = False ) imaging_time_ms : float imaging_frame_num : int = field ( init = False ) pred_time_ms : float pred_frame_num : int = field ( init = False ) moving_time_ms : float moving_frame_num : int = field ( init = False ) camera_size_mm : tuple [ float , float ] camera_size_px : tuple [ int , int ] = field ( init = False ) micro_size_mm : tuple [ float , float ] micro_size_px : tuple [ int , int ] = field ( init = False ) def __post_init__ ( self ): self . frames_per_sec = self . experiment_config . frames_per_sec self . ms_per_frame = self . experiment_config . ms_per_frame self . imaging_frame_num = math . ceil ( self . imaging_time_ms / self . ms_per_frame ) self . pred_frame_num = math . ceil ( self . pred_time_ms / self . ms_per_frame ) self . moving_frame_num = math . ceil ( self . moving_time_ms / self . ms_per_frame ) self . mm_per_px = self . experiment_config . mm_per_px self . px_per_mm = self . experiment_config . px_per_mm self . camera_size_px = ( round ( self . px_per_mm * self . camera_size_mm [ 0 ]), round ( self . px_per_mm * self . camera_size_mm [ 1 ]), ) self . micro_size_px = ( round ( self . px_per_mm * self . micro_size_mm [ 0 ]), round ( self . px_per_mm * self . micro_size_mm [ 1 ]), ) del self . experiment_config # experiment_config was temporary, only for the constructor @property def cycle_frame_num ( self ) -> int : return self . imaging_frame_num + self . moving_frame_num @property def cycle_time_ms ( self ) -> float : return self . cycle_frame_num * self . ms_per_frame @dataclass class ExperimentConfig ( ConfigBase ): \"\"\" Configuration for the experiment parameters. These parameters can change between different experiments. \"\"\" name : str \"\"\"Experiment name\"\"\" num_frames : int \"\"\"total number of frames of the experiment\"\"\" frames_per_sec : float \"\"\"Number of frames per second that the experiment was recorded at\"\"\" orig_resolution : tuple [ int , int ] \"\"\"Original resolution of the frames in pixels, in format (h, w)\"\"\" px_per_mm : float \"\"\"Number of pixels in a single millimeter\"\"\" init_position : tuple [ int , int ] \"\"\"The initial position of the center of the platform, in pixels (x, y) format. Platform's initial position should point to the worm, or close to it.\"\"\" comments : str = \"\" \"\"\"Additional comments about the experiment\"\"\" mm_per_px : float = field ( init = False ) \"\"\"Number of millimeters in a single pixel\"\"\" ms_per_frame : float = field ( init = False ) \"\"\"Number of milliseconds per frame\"\"\" def __post_init__ ( self ): self . ms_per_frame = 1000 / self . frames_per_sec self . mm_per_px = 1 / self . px_per_mm @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int , int ], ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , )","title":"Module wtracker.sim.config"},{"location":"reference/wtracker/sim/config/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/config/#experimentconfig","text":"class ExperimentConfig ( name : 'str' , num_frames : 'int' , frames_per_sec : 'float' , orig_resolution : 'tuple[int, int]' , px_per_mm : 'float' , init_position : 'tuple[int, int]' , comments : 'str' = '' ) Configuration for the experiment parameters. These parameters can change between different experiments. View Source @dataclass class ExperimentConfig ( ConfigBase ) : \"\"\" Configuration for the experiment parameters. These parameters can change between different experiments. \"\"\" name : str \"\"\"Experiment name\"\"\" num_frames : int \"\"\"total number of frames of the experiment\"\"\" frames_per_sec : float \"\"\"Number of frames per second that the experiment was recorded at\"\"\" orig_resolution : tuple [ int, int ] \"\"\"Original resolution of the frames in pixels, in format (h, w)\"\"\" px_per_mm : float \"\"\"Number of pixels in a single millimeter\"\"\" init_position : tuple [ int, int ] \"\"\"The initial position of the center of the platform, in pixels (x, y) format. Platform's initial position should point to the worm, or close to it.\"\"\" comments : str = \"\" \"\"\"Additional comments about the experiment\"\"\" mm_per_px : float = field ( init = False ) \"\"\"Number of millimeters in a single pixel\"\"\" ms_per_frame : float = field ( init = False ) \"\"\"Number of milliseconds per frame\"\"\" def __post_init__ ( self ) : self . ms_per_frame = 1000 / self . frames_per_sec self . mm_per_px = 1 / self . px_per_mm @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int, int ] , ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , )","title":"ExperimentConfig"},{"location":"reference/wtracker/sim/config/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/config/#class-variables","text":"comments","title":"Class variables"},{"location":"reference/wtracker/sim/config/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/config/#from_frame_reader","text":"def from_frame_reader ( reader : 'FrameReader' , name : 'str' , frames_per_sec : 'int' , px_per_mm : 'float' , init_position : 'tuple[int, int]' ) -> 'ExperimentConfig' View Source @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int, int ] , ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , )","title":"from_frame_reader"},{"location":"reference/wtracker/sim/config/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/config/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/config/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/config/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/config/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/config/#timingconfig","text":"class TimingConfig ( experiment_config : 'ExperimentConfig' , imaging_time_ms : 'float' , pred_time_ms : 'float' , moving_time_ms : 'float' , camera_size_mm : 'tuple[float, float]' , micro_size_mm : 'tuple[float, float]' ) Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. View Source @ dataclass class TimingConfig ( ConfigBase ): \"\"\" Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. \"\"\" experiment_config : ExperimentConfig = field ( repr = False ) \"\"\"The configuration of the experiment parameters.\"\"\" px_per_mm : int = field ( init = False ) mm_per_px : float = field ( init = False ) frames_per_sec : int = field ( init = False ) ms_per_frame : float = field ( init = False ) imaging_time_ms : float imaging_frame_num : int = field ( init = False ) pred_time_ms : float pred_frame_num : int = field ( init = False ) moving_time_ms : float moving_frame_num : int = field ( init = False ) camera_size_mm : tuple [ float , float ] camera_size_px : tuple [ int , int ] = field ( init = False ) micro_size_mm : tuple [ float , float ] micro_size_px : tuple [ int , int ] = field ( init = False ) def __post_init__ ( self ): self . frames_per_sec = self . experiment_config . frames_per_sec self . ms_per_frame = self . experiment_config . ms_per_frame self . imaging_frame_num = math . ceil ( self . imaging_time_ms / self . ms_per_frame ) self . pred_frame_num = math . ceil ( self . pred_time_ms / self . ms_per_frame ) self . moving_frame_num = math . ceil ( self . moving_time_ms / self . ms_per_frame ) self . mm_per_px = self . experiment_config . mm_per_px self . px_per_mm = self . experiment_config . px_per_mm self . camera_size_px = ( round ( self . px_per_mm * self . camera_size_mm [ 0 ]), round ( self . px_per_mm * self . camera_size_mm [ 1 ]), ) self . micro_size_px = ( round ( self . px_per_mm * self . micro_size_mm [ 0 ]), round ( self . px_per_mm * self . micro_size_mm [ 1 ]), ) del self . experiment_config # experiment_config was temporary, only for the constructor @ property def cycle_frame_num ( self ) -> int : return self . imaging_frame_num + self . moving_frame_num @ property def cycle_time_ms ( self ) -> float : return self . cycle_frame_num * self . ms_per_frame","title":"TimingConfig"},{"location":"reference/wtracker/sim/config/#ancestors-in-mro_1","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/config/#static-methods_1","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/config/#load_json_1","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/config/#load_pickle_1","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/config/#instance-variables","text":"cycle_frame_num cycle_time_ms","title":"Instance variables"},{"location":"reference/wtracker/sim/config/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/config/#save_json_1","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/config/#save_pickle_1","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/motor_controllers/","text":"Module wtracker.sim.motor_controllers View Source import abc import numpy as np from wtracker.sim.config import TimingConfig class MotorController ( abc . ABC ): \"\"\" Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Args: timing_config (TimingConfig): The timing configuration of the simulation. Attributes: timing_config (TimingConfig): The timing configuration for the motor controller. movement_steps (int): The number of movement steps (in units of frames) based on the timing configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig ): self . timing_config = timing_config self . movement_steps = self . timing_config . moving_frame_num @abc . abstractmethod def register_move ( self , dx : int , dy : int ): pass @abc . abstractmethod def step ( self ) -> tuple [ int , int ]: pass class StepMotorController ( MotorController ): \"\"\" A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Args: timing_config (TimingConfig): The timing configuration of the simulation. move_after_ratio (float, optional): The ratio of movement steps after which the motor should move. \"\"\" def __init__ ( self , timing_config : TimingConfig , move_after_ratio : float = 0.5 ): assert 0 <= move_after_ratio <= 1 super () . __init__ ( timing_config ) self . queue : list = [] self . move_at_step = round ( self . movement_steps * move_after_ratio ) def register_move ( self , dx : int , dy : int ): for _ in range ( self . movement_steps - 1 ): self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) def step ( self ) -> tuple [ int , int ]: return self . queue . pop ( 0 ) class SineMotorController ( MotorController ): \"\"\" A motor controller that generates sinusoidal movements. Args: timing_config (TimingConfig): The timing configuration of the simulation. \"\"\" def __init__ ( self , timing_config : TimingConfig ): super () . __init__ ( timing_config ) self . queue : list = [] def register_move ( self , dx : int , dy : int ) -> None : assert len ( self . queue ) == 0 for i in range ( self . movement_steps ): step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) def step ( self ) -> tuple [ int , int ]: dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue : self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy ) Classes MotorController class MotorController ( timing_config : wtracker . sim . config . TimingConfig ) Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Attributes Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None timing_config TimingConfig The timing configuration for the motor controller. None movement_steps int The number of movement steps (in units of frames) based on the timing configuration. None View Source class MotorController ( abc . ABC ) : \"\"\" Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Args: timing_config (TimingConfig): The timing configuration of the simulation. Attributes: timing_config (TimingConfig): The timing configuration for the motor controller. movement_steps (int): The number of movement steps (in units of frames) based on the timing configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig ) : self . timing_config = timing_config self . movement_steps = self . timing_config . moving_frame_num @abc . abstractmethod def register_move ( self , dx : int , dy : int ) : pass @abc . abstractmethod def step ( self ) -> tuple [ int, int ] : pass Ancestors (in MRO) abc.ABC Descendants wtracker.sim.motor_controllers.StepMotorController wtracker.sim.motor_controllers.SineMotorController Methods register_move def register_move ( self , dx : int , dy : int ) View Source @abc . abstractmethod def register_move ( self , dx : int , dy : int ) : pass step def step ( self ) -> tuple [ int , int ] View Source @abc . abstractmethod def step ( self ) -> tuple [ int, int ] : pass SineMotorController class SineMotorController ( timing_config : wtracker . sim . config . TimingConfig ) A motor controller that generates sinusoidal movements. Attributes Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None View Source class SineMotorController ( MotorController ) : \"\"\" A motor controller that generates sinusoidal movements . Args: timing_config ( TimingConfig ) : The timing configuration of the simulation . \"\"\" def __init__ ( self , timing_config: TimingConfig ) : super (). __init__ ( timing_config ) self . queue: list = [] def register_move ( self , dx: int , dy: int ) -> None: assert len ( self . queue ) == 0 for i in range ( self . movement_steps ) : step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) def step ( self ) -> tuple [ int , int ] : dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue: self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy ) Ancestors (in MRO) wtracker.sim.motor_controllers.MotorController abc.ABC Methods register_move def register_move ( self , dx : int , dy : int ) -> None View Source def register_move ( self , dx: int , dy: int ) -> None: assert len ( self . queue ) == 0 for i in range ( self . movement_steps ) : step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) step def step ( self ) -> tuple [ int , int ] View Source def step ( self ) -> tuple [ int , int ] : dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue : self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy ) StepMotorController class StepMotorController ( timing_config : wtracker . sim . config . TimingConfig , move_after_ratio : float = 0.5 ) A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Attributes Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None move_after_ratio float The ratio of movement steps after which the motor should move. None View Source class StepMotorController ( MotorController ): \"\"\" A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Args: timing_config (TimingConfig): The timing configuration of the simulation. move_after_ratio (float, optional): The ratio of movement steps after which the motor should move. \"\"\" def __init__ ( self , timing_config : TimingConfig , move_after_ratio : float = 0.5 ): assert 0 <= move_after_ratio <= 1 super (). __init__ ( timing_config ) self . queue : list = [] self . move_at_step = round ( self . movement_steps * move_after_ratio ) def register_move ( self , dx : int , dy : int ): for _ in range ( self . movement_steps - 1 ): self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) def step ( self ) -> tuple [ int , int ]: return self . queue . pop ( 0 ) Ancestors (in MRO) wtracker.sim.motor_controllers.MotorController abc.ABC Methods register_move def register_move ( self , dx : int , dy : int ) View Source def register_move ( self , dx: int , dy: int ) : for _ in range ( self . movement_steps - 1 ) : self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) step def step ( self ) -> tuple [ int , int ] View Source def step ( self ) -> tuple [ int , int ] : return self . queue . pop ( 0 )","title":"Motor Controllers"},{"location":"reference/wtracker/sim/motor_controllers/#module-wtrackersimmotor_controllers","text":"View Source import abc import numpy as np from wtracker.sim.config import TimingConfig class MotorController ( abc . ABC ): \"\"\" Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Args: timing_config (TimingConfig): The timing configuration of the simulation. Attributes: timing_config (TimingConfig): The timing configuration for the motor controller. movement_steps (int): The number of movement steps (in units of frames) based on the timing configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig ): self . timing_config = timing_config self . movement_steps = self . timing_config . moving_frame_num @abc . abstractmethod def register_move ( self , dx : int , dy : int ): pass @abc . abstractmethod def step ( self ) -> tuple [ int , int ]: pass class StepMotorController ( MotorController ): \"\"\" A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Args: timing_config (TimingConfig): The timing configuration of the simulation. move_after_ratio (float, optional): The ratio of movement steps after which the motor should move. \"\"\" def __init__ ( self , timing_config : TimingConfig , move_after_ratio : float = 0.5 ): assert 0 <= move_after_ratio <= 1 super () . __init__ ( timing_config ) self . queue : list = [] self . move_at_step = round ( self . movement_steps * move_after_ratio ) def register_move ( self , dx : int , dy : int ): for _ in range ( self . movement_steps - 1 ): self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) def step ( self ) -> tuple [ int , int ]: return self . queue . pop ( 0 ) class SineMotorController ( MotorController ): \"\"\" A motor controller that generates sinusoidal movements. Args: timing_config (TimingConfig): The timing configuration of the simulation. \"\"\" def __init__ ( self , timing_config : TimingConfig ): super () . __init__ ( timing_config ) self . queue : list = [] def register_move ( self , dx : int , dy : int ) -> None : assert len ( self . queue ) == 0 for i in range ( self . movement_steps ): step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) def step ( self ) -> tuple [ int , int ]: dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue : self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy )","title":"Module wtracker.sim.motor_controllers"},{"location":"reference/wtracker/sim/motor_controllers/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/motor_controllers/#motorcontroller","text":"class MotorController ( timing_config : wtracker . sim . config . TimingConfig ) Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform.","title":"MotorController"},{"location":"reference/wtracker/sim/motor_controllers/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None timing_config TimingConfig The timing configuration for the motor controller. None movement_steps int The number of movement steps (in units of frames) based on the timing configuration. None View Source class MotorController ( abc . ABC ) : \"\"\" Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Args: timing_config (TimingConfig): The timing configuration of the simulation. Attributes: timing_config (TimingConfig): The timing configuration for the motor controller. movement_steps (int): The number of movement steps (in units of frames) based on the timing configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig ) : self . timing_config = timing_config self . movement_steps = self . timing_config . moving_frame_num @abc . abstractmethod def register_move ( self , dx : int , dy : int ) : pass @abc . abstractmethod def step ( self ) -> tuple [ int, int ] : pass","title":"Attributes"},{"location":"reference/wtracker/sim/motor_controllers/#ancestors-in-mro","text":"abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/motor_controllers/#descendants","text":"wtracker.sim.motor_controllers.StepMotorController wtracker.sim.motor_controllers.SineMotorController","title":"Descendants"},{"location":"reference/wtracker/sim/motor_controllers/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/motor_controllers/#register_move","text":"def register_move ( self , dx : int , dy : int ) View Source @abc . abstractmethod def register_move ( self , dx : int , dy : int ) : pass","title":"register_move"},{"location":"reference/wtracker/sim/motor_controllers/#step","text":"def step ( self ) -> tuple [ int , int ] View Source @abc . abstractmethod def step ( self ) -> tuple [ int, int ] : pass","title":"step"},{"location":"reference/wtracker/sim/motor_controllers/#sinemotorcontroller","text":"class SineMotorController ( timing_config : wtracker . sim . config . TimingConfig ) A motor controller that generates sinusoidal movements.","title":"SineMotorController"},{"location":"reference/wtracker/sim/motor_controllers/#attributes_1","text":"Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None View Source class SineMotorController ( MotorController ) : \"\"\" A motor controller that generates sinusoidal movements . Args: timing_config ( TimingConfig ) : The timing configuration of the simulation . \"\"\" def __init__ ( self , timing_config: TimingConfig ) : super (). __init__ ( timing_config ) self . queue: list = [] def register_move ( self , dx: int , dy: int ) -> None: assert len ( self . queue ) == 0 for i in range ( self . movement_steps ) : step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) def step ( self ) -> tuple [ int , int ] : dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue: self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy )","title":"Attributes"},{"location":"reference/wtracker/sim/motor_controllers/#ancestors-in-mro_1","text":"wtracker.sim.motor_controllers.MotorController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/motor_controllers/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/motor_controllers/#register_move_1","text":"def register_move ( self , dx : int , dy : int ) -> None View Source def register_move ( self , dx: int , dy: int ) -> None: assert len ( self . queue ) == 0 for i in range ( self . movement_steps ) : step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step )","title":"register_move"},{"location":"reference/wtracker/sim/motor_controllers/#step_1","text":"def step ( self ) -> tuple [ int , int ] View Source def step ( self ) -> tuple [ int , int ] : dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue : self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy )","title":"step"},{"location":"reference/wtracker/sim/motor_controllers/#stepmotorcontroller","text":"class StepMotorController ( timing_config : wtracker . sim . config . TimingConfig , move_after_ratio : float = 0.5 ) A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed.","title":"StepMotorController"},{"location":"reference/wtracker/sim/motor_controllers/#attributes_2","text":"Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None move_after_ratio float The ratio of movement steps after which the motor should move. None View Source class StepMotorController ( MotorController ): \"\"\" A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Args: timing_config (TimingConfig): The timing configuration of the simulation. move_after_ratio (float, optional): The ratio of movement steps after which the motor should move. \"\"\" def __init__ ( self , timing_config : TimingConfig , move_after_ratio : float = 0.5 ): assert 0 <= move_after_ratio <= 1 super (). __init__ ( timing_config ) self . queue : list = [] self . move_at_step = round ( self . movement_steps * move_after_ratio ) def register_move ( self , dx : int , dy : int ): for _ in range ( self . movement_steps - 1 ): self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) def step ( self ) -> tuple [ int , int ]: return self . queue . pop ( 0 )","title":"Attributes"},{"location":"reference/wtracker/sim/motor_controllers/#ancestors-in-mro_2","text":"wtracker.sim.motor_controllers.MotorController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/motor_controllers/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/sim/motor_controllers/#register_move_2","text":"def register_move ( self , dx : int , dy : int ) View Source def register_move ( self , dx: int , dy: int ) : for _ in range ( self . movement_steps - 1 ) : self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy ))","title":"register_move"},{"location":"reference/wtracker/sim/motor_controllers/#step_2","text":"def step ( self ) -> tuple [ int , int ] View Source def step ( self ) -> tuple [ int , int ] : return self . queue . pop ( 0 )","title":"step"},{"location":"reference/wtracker/sim/simulator/","text":"Module wtracker.sim.simulator View Source from __future__ import annotations import numpy as np import abc from tqdm.auto import tqdm from wtracker.sim.view_controller import ViewController from wtracker.sim.config import TimingConfig , ExperimentConfig from wtracker.sim.motor_controllers import MotorController , SineMotorController from wtracker.utils.frame_reader import FrameReader , DummyReader class Simulator : \"\"\" A class representing a simulator for a biological experiment. Args: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. sim_controller (SimController): The simulation controller. reader (FrameReader, optional): The frame reader. motor_controller (MotorController, optional): The motor controller. Attributes: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig , experiment_config : ExperimentConfig , sim_controller : SimController , reader : FrameReader = None , motor_controller : MotorController = None , ) -> None : self . timing_config = timing_config self . experiment_config = experiment_config self . _sim_controller = sim_controller if reader is None : num_frames = experiment_config . num_frames padding_size = ( timing_config . camera_size_px [ 0 ] // 2 * 2 , timing_config . camera_size_px [ 1 ] // 2 * 2 ) resolution = tuple ([ sum ( x ) for x in zip ( experiment_config . orig_resolution , padding_size )]) reader = DummyReader ( num_frames , resolution , colored = True ) if motor_controller is None : motor_controller = SineMotorController ( timing_config ) self . _motor_controller = motor_controller self . _view = ViewController ( frame_reader = reader , camera_size = timing_config . camera_size_px , micro_size = timing_config . micro_size_px , init_position = experiment_config . init_position , ) @property def view ( self ) -> ViewController : \"\"\" Get the view controller. Returns: ViewController: The view controller. \"\"\" return self . _view @property def position ( self ) -> tuple [ int , int ]: \"\"\" Get the current position. Returns: tuple[int, int]: The current position. \"\"\" return self . _view . position @property def cycle_number ( self ) -> int : \"\"\" Get the current cycle number. Returns: int: The current cycle number. \"\"\" return self . _view . index // self . timing_config . cycle_frame_num @property def frame_number ( self ) -> int : \"\"\" Get the current frame number. Returns: int: The current frame number. \"\"\" return self . _view . index @property def cycle_step ( self ) -> int : \"\"\" Get the current cycle step. Returns: int: The current cycle step. \"\"\" return self . _view . index % self . timing_config . cycle_frame_num def camera_view ( self ) -> np . ndarray : \"\"\" Get the view that the camera sees. Returns: np.ndarray: The camera view. \"\"\" return self . _view . camera_view () def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () def _reset ( self ): \"\"\" Reset the simulator. \"\"\" self . view . reset () self . view . set_position ( * self . experiment_config . init_position ) def run ( self , visualize : bool = False , wait_key : bool = False ): \"\"\" Run the simulation. Args: visualize (bool, optional): Whether to visualize the simulation. wait_key (bool, optional): Whether to wait for a key press to advance the simulation during visualization. \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config . cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress (): if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num : self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num : self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num : self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num : dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize : self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close () class SimController ( abc . ABC ): \"\"\" Abstract base class for simulator controllers. Attributes: timing_config (TimingConfig): The timing configuration for the simulator. \"\"\" def __init__ ( self , timing_config : TimingConfig ): self . timing_config = timing_config def on_sim_start ( self , sim : Simulator ): \"\"\" Called when the simulation starts. \"\"\" pass def on_sim_end ( self , sim : Simulator ): \"\"\" Called when the simulation ends. \"\"\" pass def on_cycle_start ( self , sim : Simulator ): \"\"\" Called when a new cycle starts. \"\"\" pass def on_cycle_end ( self , sim : Simulator ): \"\"\" Called when a cycle ends. \"\"\" pass def on_camera_frame ( self , sim : Simulator ): \"\"\" Called when a camera frame is captured. Happens every frame. \"\"\" pass def on_imaging_start ( self , sim : Simulator ): \"\"\" Called when imaging phase starts. \"\"\" pass def on_micro_frame ( self , sim : Simulator ): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass def on_imaging_end ( self , sim : Simulator ): \"\"\" Called when imaging phase ends. \"\"\" pass def on_movement_start ( self , sim : Simulator ): \"\"\" Called when movement phase starts. \"\"\" pass def on_movement_end ( self , sim : Simulator ): \"\"\" Called when movement phase ends. \"\"\" pass @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () @abc . abstractmethod def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : \"\"\" Returns a list of bbox predictions of the worm, for each frame of the current cycle. If a prediction is not available, return None for that frame. Used internally for logging. \"\"\" raise NotImplementedError () Classes SimController class SimController ( timing_config : 'TimingConfig' ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class SimController ( abc . ABC ) : \"\"\" Abstract base class for simulator controllers. Attributes: timing_config (TimingConfig): The timing configuration for the simulator. \"\"\" def __init__ ( self , timing_config : TimingConfig ) : self . timing_config = timing_config def on_sim_start ( self , sim : Simulator ) : \"\"\" Called when the simulation starts. \"\"\" pass def on_sim_end ( self , sim : Simulator ) : \"\"\" Called when the simulation ends. \"\"\" pass def on_cycle_start ( self , sim : Simulator ) : \"\"\" Called when a new cycle starts. \"\"\" pass def on_cycle_end ( self , sim : Simulator ) : \"\"\" Called when a cycle ends. \"\"\" pass def on_camera_frame ( self , sim : Simulator ) : \"\"\" Called when a camera frame is captured. Happens every frame. \"\"\" pass def on_imaging_start ( self , sim : Simulator ) : \"\"\" Called when imaging phase starts. \"\"\" pass def on_micro_frame ( self , sim : Simulator ) : \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass def on_imaging_end ( self , sim : Simulator ) : \"\"\" Called when imaging phase ends. \"\"\" pass def on_movement_start ( self , sim : Simulator ) : \"\"\" Called when movement phase starts. \"\"\" pass def on_movement_end ( self , sim : Simulator ) : \"\"\" Called when movement phase ends. \"\"\" pass @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () @abc . abstractmethod def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : \"\"\" Returns a list of bbox predictions of the worm, for each frame of the current cycle. If a prediction is not available, return None for that frame. Used internally for logging. \"\"\" raise NotImplementedError () Ancestors (in MRO) abc.ABC Descendants wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.sim_controllers.logging_controller.LoggingController wtracker.sim.sim_controllers.yolo_controller.YoloController Methods begin_movement_prediction def begin_movement_prediction ( self , sim : 'Simulator' ) -> 'None' Called when the movement prediction begins. View Source @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () on_camera_frame def on_camera_frame ( self , sim : 'Simulator' ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : \"\" \" Called when a camera frame is captured. Happens every frame. \" \"\" pass on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : 'Simulator' ) Called when the simulation starts. View Source def on_sim_start(self, sim: Simulator): \"\"\" Called when the simulation starts. \"\"\" pass provide_movement_vector def provide_movement_vector ( self , sim : 'Simulator' ) -> 'tuple[int, int]' Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () Simulator class Simulator ( timing_config : 'TimingConfig' , experiment_config : 'ExperimentConfig' , sim_controller : 'SimController' , reader : 'FrameReader' = None , motor_controller : 'MotorController' = None ) A class representing a simulator for a biological experiment. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the experiment. None experiment_config ExperimentConfig The experiment configuration. None sim_controller SimController The simulation controller. None reader FrameReader The frame reader. None motor_controller MotorController The motor controller. None timing_config TimingConfig The timing configuration for the experiment. None experiment_config ExperimentConfig The experiment configuration. None View Source class Simulator : \"\"\" A class representing a simulator for a biological experiment. Args: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. sim_controller (SimController): The simulation controller. reader (FrameReader, optional): The frame reader. motor_controller (MotorController, optional): The motor controller. Attributes: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig , experiment_config : ExperimentConfig , sim_controller : SimController , reader : FrameReader = None , motor_controller : MotorController = None , ) -> None : self . timing_config = timing_config self . experiment_config = experiment_config self . _sim_controller = sim_controller if reader is None : num_frames = experiment_config . num_frames padding_size = ( timing_config . camera_size_px [ 0 ] // 2 * 2 , timing_config . camera_size_px [ 1 ] // 2 * 2 ) resolution = tuple ( [ sum(x) for x in zip(experiment_config.orig_resolution, padding_size) ] ) reader = DummyReader ( num_frames , resolution , colored = True ) if motor_controller is None : motor_controller = SineMotorController ( timing_config ) self . _motor_controller = motor_controller self . _view = ViewController ( frame_reader = reader , camera_size = timing_config . camera_size_px , micro_size = timing_config . micro_size_px , init_position = experiment_config . init_position , ) @property def view ( self ) -> ViewController : \"\"\" Get the view controller. Returns: ViewController: The view controller. \"\"\" return self . _view @property def position ( self ) -> tuple [ int, int ] : \"\"\" Get the current position. Returns: tuple[int, int]: The current position. \"\"\" return self . _view . position @property def cycle_number ( self ) -> int : \"\"\" Get the current cycle number. Returns: int: The current cycle number. \"\"\" return self . _view . index // self . timing_config . cycle_frame_num @property def frame_number ( self ) -> int : \"\"\" Get the current frame number. Returns: int: The current frame number. \"\"\" return self . _view . index @property def cycle_step ( self ) -> int : \"\"\" Get the current cycle step. Returns: int: The current cycle step. \"\"\" return self . _view . index % self . timing_config . cycle_frame_num def camera_view ( self ) -> np . ndarray : \"\"\" Get the view that the camera sees. Returns: np.ndarray: The camera view. \"\"\" return self . _view . camera_view () def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () def _reset ( self ) : \"\"\" Reset the simulator. \"\"\" self . view . reset () self . view . set_position ( * self . experiment_config . init_position ) def run ( self , visualize : bool = False , wait_key : bool = False ) : \"\"\" Run the simulation. Args: visualize (bool, optional): Whether to visualize the simulation. wait_key (bool, optional): Whether to wait for a key press to advance the simulation during visualization. \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config . cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress () : if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num : self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num : self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num : self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num : dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize : self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close () Instance variables cycle_number Get the current cycle number. cycle_step Get the current cycle step. frame_number Get the current frame number. position Get the current position. view Get the view controller. Methods camera_view def camera_view ( self ) -> 'np.ndarray' Get the view that the camera sees. Returns: Type Description np.ndarray The camera view. View Source def camera_view ( self ) - > np . ndarray : \"\" \" Get the view that the camera sees. Returns: np.ndarray: The camera view. \" \"\" return self . _view . camera_view () micro_view def micro_view ( self ) -> 'np.ndarray' Get the view that the microscope sees. Returns: Type Description np.ndarray The micro view. View Source def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () run def run ( self , visualize : 'bool' = False , wait_key : 'bool' = False ) Run the simulation. Parameters: Name Type Description Default visualize bool Whether to visualize the simulation. None wait_key bool Whether to wait for a key press to advance the simulation during visualization. None View Source def run ( self , visualize: bool = False , wait_key: bool = False ) : \"\"\" Run the simulation . Args: visualize ( bool , optional ) : Whether to visualize the simulation . wait_key ( bool , optional ) : Whether to wait for a key press to advance the simulation during visualization . \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config.cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress () : if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num: self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num: self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num: self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num: dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize: self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close ()","title":"Simulator"},{"location":"reference/wtracker/sim/simulator/#module-wtrackersimsimulator","text":"View Source from __future__ import annotations import numpy as np import abc from tqdm.auto import tqdm from wtracker.sim.view_controller import ViewController from wtracker.sim.config import TimingConfig , ExperimentConfig from wtracker.sim.motor_controllers import MotorController , SineMotorController from wtracker.utils.frame_reader import FrameReader , DummyReader class Simulator : \"\"\" A class representing a simulator for a biological experiment. Args: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. sim_controller (SimController): The simulation controller. reader (FrameReader, optional): The frame reader. motor_controller (MotorController, optional): The motor controller. Attributes: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig , experiment_config : ExperimentConfig , sim_controller : SimController , reader : FrameReader = None , motor_controller : MotorController = None , ) -> None : self . timing_config = timing_config self . experiment_config = experiment_config self . _sim_controller = sim_controller if reader is None : num_frames = experiment_config . num_frames padding_size = ( timing_config . camera_size_px [ 0 ] // 2 * 2 , timing_config . camera_size_px [ 1 ] // 2 * 2 ) resolution = tuple ([ sum ( x ) for x in zip ( experiment_config . orig_resolution , padding_size )]) reader = DummyReader ( num_frames , resolution , colored = True ) if motor_controller is None : motor_controller = SineMotorController ( timing_config ) self . _motor_controller = motor_controller self . _view = ViewController ( frame_reader = reader , camera_size = timing_config . camera_size_px , micro_size = timing_config . micro_size_px , init_position = experiment_config . init_position , ) @property def view ( self ) -> ViewController : \"\"\" Get the view controller. Returns: ViewController: The view controller. \"\"\" return self . _view @property def position ( self ) -> tuple [ int , int ]: \"\"\" Get the current position. Returns: tuple[int, int]: The current position. \"\"\" return self . _view . position @property def cycle_number ( self ) -> int : \"\"\" Get the current cycle number. Returns: int: The current cycle number. \"\"\" return self . _view . index // self . timing_config . cycle_frame_num @property def frame_number ( self ) -> int : \"\"\" Get the current frame number. Returns: int: The current frame number. \"\"\" return self . _view . index @property def cycle_step ( self ) -> int : \"\"\" Get the current cycle step. Returns: int: The current cycle step. \"\"\" return self . _view . index % self . timing_config . cycle_frame_num def camera_view ( self ) -> np . ndarray : \"\"\" Get the view that the camera sees. Returns: np.ndarray: The camera view. \"\"\" return self . _view . camera_view () def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () def _reset ( self ): \"\"\" Reset the simulator. \"\"\" self . view . reset () self . view . set_position ( * self . experiment_config . init_position ) def run ( self , visualize : bool = False , wait_key : bool = False ): \"\"\" Run the simulation. Args: visualize (bool, optional): Whether to visualize the simulation. wait_key (bool, optional): Whether to wait for a key press to advance the simulation during visualization. \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config . cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress (): if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num : self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num : self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num : self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num : dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize : self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close () class SimController ( abc . ABC ): \"\"\" Abstract base class for simulator controllers. Attributes: timing_config (TimingConfig): The timing configuration for the simulator. \"\"\" def __init__ ( self , timing_config : TimingConfig ): self . timing_config = timing_config def on_sim_start ( self , sim : Simulator ): \"\"\" Called when the simulation starts. \"\"\" pass def on_sim_end ( self , sim : Simulator ): \"\"\" Called when the simulation ends. \"\"\" pass def on_cycle_start ( self , sim : Simulator ): \"\"\" Called when a new cycle starts. \"\"\" pass def on_cycle_end ( self , sim : Simulator ): \"\"\" Called when a cycle ends. \"\"\" pass def on_camera_frame ( self , sim : Simulator ): \"\"\" Called when a camera frame is captured. Happens every frame. \"\"\" pass def on_imaging_start ( self , sim : Simulator ): \"\"\" Called when imaging phase starts. \"\"\" pass def on_micro_frame ( self , sim : Simulator ): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass def on_imaging_end ( self , sim : Simulator ): \"\"\" Called when imaging phase ends. \"\"\" pass def on_movement_start ( self , sim : Simulator ): \"\"\" Called when movement phase starts. \"\"\" pass def on_movement_end ( self , sim : Simulator ): \"\"\" Called when movement phase ends. \"\"\" pass @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () @abc . abstractmethod def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : \"\"\" Returns a list of bbox predictions of the worm, for each frame of the current cycle. If a prediction is not available, return None for that frame. Used internally for logging. \"\"\" raise NotImplementedError ()","title":"Module wtracker.sim.simulator"},{"location":"reference/wtracker/sim/simulator/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/simulator/#simcontroller","text":"class SimController ( timing_config : 'TimingConfig' ) Abstract base class for simulator controllers.","title":"SimController"},{"location":"reference/wtracker/sim/simulator/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class SimController ( abc . ABC ) : \"\"\" Abstract base class for simulator controllers. Attributes: timing_config (TimingConfig): The timing configuration for the simulator. \"\"\" def __init__ ( self , timing_config : TimingConfig ) : self . timing_config = timing_config def on_sim_start ( self , sim : Simulator ) : \"\"\" Called when the simulation starts. \"\"\" pass def on_sim_end ( self , sim : Simulator ) : \"\"\" Called when the simulation ends. \"\"\" pass def on_cycle_start ( self , sim : Simulator ) : \"\"\" Called when a new cycle starts. \"\"\" pass def on_cycle_end ( self , sim : Simulator ) : \"\"\" Called when a cycle ends. \"\"\" pass def on_camera_frame ( self , sim : Simulator ) : \"\"\" Called when a camera frame is captured. Happens every frame. \"\"\" pass def on_imaging_start ( self , sim : Simulator ) : \"\"\" Called when imaging phase starts. \"\"\" pass def on_micro_frame ( self , sim : Simulator ) : \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass def on_imaging_end ( self , sim : Simulator ) : \"\"\" Called when imaging phase ends. \"\"\" pass def on_movement_start ( self , sim : Simulator ) : \"\"\" Called when movement phase starts. \"\"\" pass def on_movement_end ( self , sim : Simulator ) : \"\"\" Called when movement phase ends. \"\"\" pass @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () @abc . abstractmethod def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : \"\"\" Returns a list of bbox predictions of the worm, for each frame of the current cycle. If a prediction is not available, return None for that frame. Used internally for logging. \"\"\" raise NotImplementedError ()","title":"Attributes"},{"location":"reference/wtracker/sim/simulator/#ancestors-in-mro","text":"abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/simulator/#descendants","text":"wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.sim_controllers.logging_controller.LoggingController wtracker.sim.sim_controllers.yolo_controller.YoloController","title":"Descendants"},{"location":"reference/wtracker/sim/simulator/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/simulator/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : 'Simulator' ) -> 'None' Called when the movement prediction begins. View Source @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError ()","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/simulator/#on_camera_frame","text":"def on_camera_frame ( self , sim : 'Simulator' ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : \"\" \" Called when a camera frame is captured. Happens every frame. \" \"\" pass","title":"on_camera_frame"},{"location":"reference/wtracker/sim/simulator/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/simulator/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/simulator/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/simulator/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/simulator/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/simulator/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/simulator/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/simulator/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/simulator/#on_sim_start","text":"def on_sim_start ( self , sim : 'Simulator' ) Called when the simulation starts. View Source def on_sim_start(self, sim: Simulator): \"\"\" Called when the simulation starts. \"\"\" pass","title":"on_sim_start"},{"location":"reference/wtracker/sim/simulator/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : 'Simulator' ) -> 'tuple[int, int]' Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError ()","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/simulator/#simulator","text":"class Simulator ( timing_config : 'TimingConfig' , experiment_config : 'ExperimentConfig' , sim_controller : 'SimController' , reader : 'FrameReader' = None , motor_controller : 'MotorController' = None ) A class representing a simulator for a biological experiment.","title":"Simulator"},{"location":"reference/wtracker/sim/simulator/#attributes_1","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the experiment. None experiment_config ExperimentConfig The experiment configuration. None sim_controller SimController The simulation controller. None reader FrameReader The frame reader. None motor_controller MotorController The motor controller. None timing_config TimingConfig The timing configuration for the experiment. None experiment_config ExperimentConfig The experiment configuration. None View Source class Simulator : \"\"\" A class representing a simulator for a biological experiment. Args: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. sim_controller (SimController): The simulation controller. reader (FrameReader, optional): The frame reader. motor_controller (MotorController, optional): The motor controller. Attributes: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig , experiment_config : ExperimentConfig , sim_controller : SimController , reader : FrameReader = None , motor_controller : MotorController = None , ) -> None : self . timing_config = timing_config self . experiment_config = experiment_config self . _sim_controller = sim_controller if reader is None : num_frames = experiment_config . num_frames padding_size = ( timing_config . camera_size_px [ 0 ] // 2 * 2 , timing_config . camera_size_px [ 1 ] // 2 * 2 ) resolution = tuple ( [ sum(x) for x in zip(experiment_config.orig_resolution, padding_size) ] ) reader = DummyReader ( num_frames , resolution , colored = True ) if motor_controller is None : motor_controller = SineMotorController ( timing_config ) self . _motor_controller = motor_controller self . _view = ViewController ( frame_reader = reader , camera_size = timing_config . camera_size_px , micro_size = timing_config . micro_size_px , init_position = experiment_config . init_position , ) @property def view ( self ) -> ViewController : \"\"\" Get the view controller. Returns: ViewController: The view controller. \"\"\" return self . _view @property def position ( self ) -> tuple [ int, int ] : \"\"\" Get the current position. Returns: tuple[int, int]: The current position. \"\"\" return self . _view . position @property def cycle_number ( self ) -> int : \"\"\" Get the current cycle number. Returns: int: The current cycle number. \"\"\" return self . _view . index // self . timing_config . cycle_frame_num @property def frame_number ( self ) -> int : \"\"\" Get the current frame number. Returns: int: The current frame number. \"\"\" return self . _view . index @property def cycle_step ( self ) -> int : \"\"\" Get the current cycle step. Returns: int: The current cycle step. \"\"\" return self . _view . index % self . timing_config . cycle_frame_num def camera_view ( self ) -> np . ndarray : \"\"\" Get the view that the camera sees. Returns: np.ndarray: The camera view. \"\"\" return self . _view . camera_view () def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () def _reset ( self ) : \"\"\" Reset the simulator. \"\"\" self . view . reset () self . view . set_position ( * self . experiment_config . init_position ) def run ( self , visualize : bool = False , wait_key : bool = False ) : \"\"\" Run the simulation. Args: visualize (bool, optional): Whether to visualize the simulation. wait_key (bool, optional): Whether to wait for a key press to advance the simulation during visualization. \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config . cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress () : if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num : self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num : self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num : self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num : dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize : self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close ()","title":"Attributes"},{"location":"reference/wtracker/sim/simulator/#instance-variables","text":"cycle_number Get the current cycle number. cycle_step Get the current cycle step. frame_number Get the current frame number. position Get the current position. view Get the view controller.","title":"Instance variables"},{"location":"reference/wtracker/sim/simulator/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/simulator/#camera_view","text":"def camera_view ( self ) -> 'np.ndarray' Get the view that the camera sees. Returns: Type Description np.ndarray The camera view. View Source def camera_view ( self ) - > np . ndarray : \"\" \" Get the view that the camera sees. Returns: np.ndarray: The camera view. \" \"\" return self . _view . camera_view ()","title":"camera_view"},{"location":"reference/wtracker/sim/simulator/#micro_view","text":"def micro_view ( self ) -> 'np.ndarray' Get the view that the microscope sees. Returns: Type Description np.ndarray The micro view. View Source def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view ()","title":"micro_view"},{"location":"reference/wtracker/sim/simulator/#run","text":"def run ( self , visualize : 'bool' = False , wait_key : 'bool' = False ) Run the simulation. Parameters: Name Type Description Default visualize bool Whether to visualize the simulation. None wait_key bool Whether to wait for a key press to advance the simulation during visualization. None View Source def run ( self , visualize: bool = False , wait_key: bool = False ) : \"\"\" Run the simulation . Args: visualize ( bool , optional ) : Whether to visualize the simulation . wait_key ( bool , optional ) : Whether to wait for a key press to advance the simulation during visualization . \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config.cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress () : if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num: self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num: self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num: self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num: dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize: self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close ()","title":"run"},{"location":"reference/wtracker/sim/view_controller/","text":"Module wtracker.sim.view_controller View Source import cv2 as cv import numpy as np from wtracker.utils.frame_reader import FrameReader , FrameStream class ViewController ( FrameStream ): \"\"\" A class representing a view controller for a frame stream. This class allows for easy manipulation of the camera and microscope positions, and provides their corresponding views. Args: frame_reader (FrameReader): The frame reader object. camera_size (tuple[int, int], optional): The size of the camera frame. micro_size (tuple[int, int], optional): The size of the micro frame. init_position (tuple[int, int], optional): The initial position of the view. Attributes: frame_reader (FrameReader): The frame reader object. camera_size (tuple[int, int]): The size of the camera view (w, h). micro_size (tuple[int, int]): The size of the micro view (w, h). position (tuple[int, int]): The current position of the center of the view (x, y). \"\"\" def __init__ ( self , frame_reader : FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ), ): super () . __init__ ( frame_reader ) assert camera_size [ 0 ] >= micro_size [ 0 ] assert camera_size [ 1 ] >= micro_size [ 1 ] self . _padding_size : tuple [ int , int ] = ( camera_size [ 0 ] // 2 , camera_size [ 1 ] // 2 ) self . _camera_size = camera_size self . _micro_size = micro_size self . _position = init_position self . set_position ( * init_position ) def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding. Returns: np.ndarray: The padded frame. \"\"\" frame = super () . read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame @property def position ( self ) -> tuple [ int , int ]: \"\"\" Get the current position of the view controller. Returns: tuple[int, int]: The current position (x, y). \"\"\" return self . _position @property def camera_size ( self ) -> tuple [ int , int ]: \"\"\" Get the size of the camera view. Returns: tuple[int, int]: The size of the camera view (w, h). \"\"\" return self . _camera_size @property def micro_size ( self ) -> tuple [ int , int ]: \"\"\" Get the size of the micro view. Returns: tuple[int, int]: The size of the micro view (w, h). \"\"\" return self . _micro_size @property def camera_position ( self ) -> tuple [ int , int , int , int ]: \"\"\" Get the position of the camera view. Returns: tuple[int, int, int, int]: The position of the camera view (x, y, w, h). \"\"\" w , h = self . camera_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h @property def micro_position ( self ) -> tuple [ int , int , int , int ]: \"\"\" Get the position of the micro view. Returns: tuple[int, int, int, int]: The position of the micro view (x, y, w, h). \"\"\" w , h = self . micro_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h def set_position ( self , x : int , y : int ): \"\"\" Set the position of the view controller. Note, that the position is clamped to the frame size. Args: x (int): The x-coordinate of the position. y (int): The y-coordinate of the position. \"\"\" x = np . clip ( x , 0 , self . _frame_reader . frame_shape [ 1 ] - 1 ) y = np . clip ( y , 0 , self . _frame_reader . frame_shape [ 0 ] - 1 ) self . _position = ( x , y ) def move_position ( self , dx : int , dy : int ): \"\"\" Move the position of the view controller by dx and dy. Args: dx (int): The amount to move in the x-direction. dy (int): The amount to move in the y-direction. \"\"\" self . set_position ( self . _position [ 0 ] + dx , self . _position [ 1 ] + dy ) def _calc_view_bbox ( self , w : int , h : int ) -> tuple [ int , int , int , int ]: \"\"\" Calculate the bbox of the view, while taking padding into account. Args: w (int): The width of the view. h (int): The height of the view. Returns: tuple[int, int, int, int]: The bounding box of the view (x, y, w, h). \"\"\" x = self . _position [ 0 ] + self . _padding_size [ 0 ] - w // 2 y = self . _position [ 1 ] + self . _padding_size [ 1 ] - h // 2 return x , y , w , h def _custom_view ( self , w : int , h : int ) -> np . ndarray : \"\"\" Get a custom view of the frame. Args: w (int): The width of the view. h (int): The height of the view. Returns: np.ndarray: The custom view of the frame. \"\"\" x , y , w , h = self . _calc_view_bbox ( w , h ) frame = self . read () slice = frame [ y : y + w , x : x + h ] return slice def camera_view ( self ) -> np . ndarray : \"\"\" Get the camera view. Returns: np.ndarray: The camera view. \"\"\" return self . _custom_view ( * self . camera_size ) def micro_view ( self ) -> np . ndarray : \"\"\" Get the micro view. Returns: np.ndarray: The micro view. \"\"\" return self . _custom_view ( * self . micro_size ) def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ): \"\"\" Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Args: line_width (int): The width of the bounding box lines. \"\"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout ) Classes ViewController class ViewController ( frame_reader : wtracker . utils . frame_reader . FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ) ) A class representing a view controller for a frame stream. This class allows for easy manipulation of the camera and microscope positions, and provides their corresponding views. Attributes Name Type Description Default frame_reader FrameReader The frame reader object. None camera_size tuple[int, int] The size of the camera frame. None micro_size tuple[int, int] The size of the micro frame. None init_position tuple[int, int] The initial position of the view. None frame_reader FrameReader The frame reader object. None camera_size tuple[int, int] The size of the camera view (w, h). None micro_size tuple[int, int] The size of the micro view (w, h). None position tuple[int, int] The current position of the center of the view (x, y). None View Source class ViewController ( FrameStream ) : \"\"\" A class representing a view controller for a frame stream . This class allows for easy manipulation of the camera and microscope positions , and provides their corresponding views . Args : frame_reader ( FrameReader ) : The frame reader object . camera_size ( tuple [ int , int ], optional ) : The size of the camera frame . micro_size ( tuple [ int , int ], optional ) : The size of the micro frame . init_position ( tuple [ int , int ], optional ) : The initial position of the view . Attributes : frame_reader ( FrameReader ) : The frame reader object . camera_size ( tuple [ int , int ]) : The size of the camera view ( w , h ). micro_size ( tuple [ int , int ]) : The size of the micro view ( w , h ). position ( tuple [ int , int ]) : The current position of the center of the view ( x , y ). \"\"\" def __init__ ( self , frame_reader : FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ), ) : super (). __init__ ( frame_reader ) assert camera_size [ 0 ] >= micro_size [ 0 ] assert camera_size [ 1 ] >= micro_size [ 1 ] self . _padding_size : tuple [ int , int ] = ( camera_size [ 0 ] // 2, camera_size[1] // 2) self . _camera_size = camera_size self . _micro_size = micro_size self . _position = init_position self . set_position ( * init_position ) def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding . Returns : np . ndarray : The padded frame . \"\"\" frame = super (). read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame @property def position ( self ) -> tuple [ int , int ] : \"\"\" Get the current position of the view controller . Returns : tuple [ int , int ] : The current position ( x , y ). \"\"\" return self . _position @property def camera_size ( self ) -> tuple [ int , int ] : \"\"\" Get the size of the camera view . Returns : tuple [ int , int ] : The size of the camera view ( w , h ). \"\"\" return self . _camera_size @property def micro_size ( self ) -> tuple [ int , int ] : \"\"\" Get the size of the micro view . Returns : tuple [ int , int ] : The size of the micro view ( w , h ). \"\"\" return self . _micro_size @property def camera_position ( self ) -> tuple [ int , int , int , int ] : \"\"\" Get the position of the camera view . Returns : tuple [ int , int , int , int ] : The position of the camera view ( x , y , w , h ). \"\"\" w , h = self . camera_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h @property def micro_position ( self ) -> tuple [ int , int , int , int ] : \"\"\" Get the position of the micro view . Returns : tuple [ int , int , int , int ] : The position of the micro view ( x , y , w , h ). \"\"\" w , h = self . micro_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h def set_position ( self , x : int , y : int ) : \"\"\" Set the position of the view controller . Note , that the position is clamped to the frame size . Args : x ( int ) : The x - coordinate of the position . y ( int ) : The y - coordinate of the position . \"\"\" x = np . clip ( x , 0 , self . _frame_reader . frame_shape [ 1 ] - 1 ) y = np . clip ( y , 0 , self . _frame_reader . frame_shape [ 0 ] - 1 ) self . _position = ( x , y ) def move_position ( self , dx : int , dy : int ) : \"\"\" Move the position of the view controller by dx and dy . Args : dx ( int ) : The amount to move in the x - direction . dy ( int ) : The amount to move in the y - direction . \"\"\" self . set_position ( self . _position [ 0 ] + dx , self . _position [ 1 ] + dy ) def _calc_view_bbox ( self , w : int , h : int ) -> tuple [ int , int , int , int ] : \"\"\" Calculate the bbox of the view , while taking padding into account . Args : w ( int ) : The width of the view . h ( int ) : The height of the view . Returns : tuple [ int , int , int , int ] : The bounding box of the view ( x , y , w , h ). \"\"\" x = self . _position [ 0 ] + self . _padding_size [ 0 ] - w // 2 y = self . _position [ 1 ] + self . _padding_size [ 1 ] - h // 2 return x , y , w , h def _custom_view ( self , w : int , h : int ) -> np . ndarray : \"\"\" Get a custom view of the frame . Args : w ( int ) : The width of the view . h ( int ) : The height of the view . Returns : np . ndarray : The custom view of the frame . \"\"\" x , y , w , h = self . _calc_view_bbox ( w , h ) frame = self . read () slice = frame [ y : y + w , x : x + h ] return slice def camera_view ( self ) -> np . ndarray : \"\"\" Get the camera view . Returns : np . ndarray : The camera view . \"\"\" return self . _custom_view ( * self . camera_size ) def micro_view ( self ) -> np . ndarray : \"\"\" Get the micro view . Returns : np . ndarray : The micro view . \"\"\" return self . _custom_view ( * self . micro_size ) def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) : \"\"\" Visualize the world view with bounding boxes . Both the camera and micro views are visualized , along with the center point . Args : line_width ( int ) : The width of the bounding box lines . \"\"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout ) Ancestors (in MRO) wtracker.utils.frame_reader.FrameStream Instance variables camera_position Get the position of the camera view. camera_size Get the size of the camera view. index The index of the current frame. micro_position Get the position of the micro view. micro_size Get the size of the micro view. position Get the current position of the view controller. Methods camera_view def camera_view ( self ) -> numpy . ndarray Get the camera view. Returns: Type Description np.ndarray The camera view. View Source def camera_view ( self ) - > np . ndarray : \"\" \" Get the camera view. Returns: np.ndarray: The camera view. \" \"\" return self . _custom_view ( * self . camera_size ) can_read def can_read ( self ) -> 'bool' View Source def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) micro_view def micro_view ( self ) -> numpy . ndarray Get the micro view. Returns: Type Description np.ndarray The micro view. View Source def micro_view(self) -> np.ndarray: \"\"\" Get the micro view. Returns: np.ndarray: The micro view. \"\"\" return self._custom_view(*self.micro_size) move_position def move_position ( self , dx : int , dy : int ) Move the position of the view controller by dx and dy. Parameters: Name Type Description Default dx int The amount to move in the x-direction. None dy int The amount to move in the y-direction. None View Source def move_position(self, dx: int, dy: int): \"\"\" Move the position of the view controller by dx and dy. Args: dx (int): The amount to move in the x-direction. dy (int): The amount to move in the y-direction. \"\"\" self.set_position(self._position[0] + dx, self._position[1] + dy) progress def progress ( self , n : 'int' = 1 ) -> 'bool' Moves the current index forward by the specified number of steps. Parameters: Name Type Description Default n int The number of steps to move forward. None Returns: Type Description bool True if the index was successfully moved forward, False otherwise. View Source def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) read def read ( self ) -> numpy . ndarray Read a frame from the frame reader and apply padding. Returns: Type Description np.ndarray The padded frame. View Source def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding. Returns: np.ndarray: The padded frame. \"\"\" frame = super (). read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame reset def reset ( self ) Resets the frame reader to the beginning of the steam. View Source def reset(self): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self.seek(-1) seek def seek ( self , idx : 'int' ) -> 'bool' Move the index to the specified position. Parameters: Name Type Description Default idx int The index to seek to. None Returns: Type Description bool True if the index is within the valid range, False otherwise. View Source def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () set_position def set_position ( self , x : int , y : int ) Set the position of the view controller. Note, that the position is clamped to the frame size. Parameters: Name Type Description Default x int The x-coordinate of the position. None y int The y-coordinate of the position. None View Source def set_position(self, x: int, y: int): \"\"\" Set the position of the view controller. Note, that the position is clamped to the frame size. Args: x (int): The x-coordinate of the position. y (int): The y-coordinate of the position. \"\"\" x = np.clip(x, 0, self._frame_reader.frame_shape[1] - 1) y = np.clip(y, 0, self._frame_reader.frame_shape[0] - 1) self._position = (x, y) visualize_world def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Parameters: Name Type Description Default line_width int The width of the bounding box lines. None View Source def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) : \"\" \" Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Args: line_width (int): The width of the bounding box lines. \" \"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout )","title":"View Controller"},{"location":"reference/wtracker/sim/view_controller/#module-wtrackersimview_controller","text":"View Source import cv2 as cv import numpy as np from wtracker.utils.frame_reader import FrameReader , FrameStream class ViewController ( FrameStream ): \"\"\" A class representing a view controller for a frame stream. This class allows for easy manipulation of the camera and microscope positions, and provides their corresponding views. Args: frame_reader (FrameReader): The frame reader object. camera_size (tuple[int, int], optional): The size of the camera frame. micro_size (tuple[int, int], optional): The size of the micro frame. init_position (tuple[int, int], optional): The initial position of the view. Attributes: frame_reader (FrameReader): The frame reader object. camera_size (tuple[int, int]): The size of the camera view (w, h). micro_size (tuple[int, int]): The size of the micro view (w, h). position (tuple[int, int]): The current position of the center of the view (x, y). \"\"\" def __init__ ( self , frame_reader : FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ), ): super () . __init__ ( frame_reader ) assert camera_size [ 0 ] >= micro_size [ 0 ] assert camera_size [ 1 ] >= micro_size [ 1 ] self . _padding_size : tuple [ int , int ] = ( camera_size [ 0 ] // 2 , camera_size [ 1 ] // 2 ) self . _camera_size = camera_size self . _micro_size = micro_size self . _position = init_position self . set_position ( * init_position ) def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding. Returns: np.ndarray: The padded frame. \"\"\" frame = super () . read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame @property def position ( self ) -> tuple [ int , int ]: \"\"\" Get the current position of the view controller. Returns: tuple[int, int]: The current position (x, y). \"\"\" return self . _position @property def camera_size ( self ) -> tuple [ int , int ]: \"\"\" Get the size of the camera view. Returns: tuple[int, int]: The size of the camera view (w, h). \"\"\" return self . _camera_size @property def micro_size ( self ) -> tuple [ int , int ]: \"\"\" Get the size of the micro view. Returns: tuple[int, int]: The size of the micro view (w, h). \"\"\" return self . _micro_size @property def camera_position ( self ) -> tuple [ int , int , int , int ]: \"\"\" Get the position of the camera view. Returns: tuple[int, int, int, int]: The position of the camera view (x, y, w, h). \"\"\" w , h = self . camera_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h @property def micro_position ( self ) -> tuple [ int , int , int , int ]: \"\"\" Get the position of the micro view. Returns: tuple[int, int, int, int]: The position of the micro view (x, y, w, h). \"\"\" w , h = self . micro_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h def set_position ( self , x : int , y : int ): \"\"\" Set the position of the view controller. Note, that the position is clamped to the frame size. Args: x (int): The x-coordinate of the position. y (int): The y-coordinate of the position. \"\"\" x = np . clip ( x , 0 , self . _frame_reader . frame_shape [ 1 ] - 1 ) y = np . clip ( y , 0 , self . _frame_reader . frame_shape [ 0 ] - 1 ) self . _position = ( x , y ) def move_position ( self , dx : int , dy : int ): \"\"\" Move the position of the view controller by dx and dy. Args: dx (int): The amount to move in the x-direction. dy (int): The amount to move in the y-direction. \"\"\" self . set_position ( self . _position [ 0 ] + dx , self . _position [ 1 ] + dy ) def _calc_view_bbox ( self , w : int , h : int ) -> tuple [ int , int , int , int ]: \"\"\" Calculate the bbox of the view, while taking padding into account. Args: w (int): The width of the view. h (int): The height of the view. Returns: tuple[int, int, int, int]: The bounding box of the view (x, y, w, h). \"\"\" x = self . _position [ 0 ] + self . _padding_size [ 0 ] - w // 2 y = self . _position [ 1 ] + self . _padding_size [ 1 ] - h // 2 return x , y , w , h def _custom_view ( self , w : int , h : int ) -> np . ndarray : \"\"\" Get a custom view of the frame. Args: w (int): The width of the view. h (int): The height of the view. Returns: np.ndarray: The custom view of the frame. \"\"\" x , y , w , h = self . _calc_view_bbox ( w , h ) frame = self . read () slice = frame [ y : y + w , x : x + h ] return slice def camera_view ( self ) -> np . ndarray : \"\"\" Get the camera view. Returns: np.ndarray: The camera view. \"\"\" return self . _custom_view ( * self . camera_size ) def micro_view ( self ) -> np . ndarray : \"\"\" Get the micro view. Returns: np.ndarray: The micro view. \"\"\" return self . _custom_view ( * self . micro_size ) def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ): \"\"\" Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Args: line_width (int): The width of the bounding box lines. \"\"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout )","title":"Module wtracker.sim.view_controller"},{"location":"reference/wtracker/sim/view_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/view_controller/#viewcontroller","text":"class ViewController ( frame_reader : wtracker . utils . frame_reader . FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ) ) A class representing a view controller for a frame stream. This class allows for easy manipulation of the camera and microscope positions, and provides their corresponding views.","title":"ViewController"},{"location":"reference/wtracker/sim/view_controller/#attributes","text":"Name Type Description Default frame_reader FrameReader The frame reader object. None camera_size tuple[int, int] The size of the camera frame. None micro_size tuple[int, int] The size of the micro frame. None init_position tuple[int, int] The initial position of the view. None frame_reader FrameReader The frame reader object. None camera_size tuple[int, int] The size of the camera view (w, h). None micro_size tuple[int, int] The size of the micro view (w, h). None position tuple[int, int] The current position of the center of the view (x, y). None View Source class ViewController ( FrameStream ) : \"\"\" A class representing a view controller for a frame stream . This class allows for easy manipulation of the camera and microscope positions , and provides their corresponding views . Args : frame_reader ( FrameReader ) : The frame reader object . camera_size ( tuple [ int , int ], optional ) : The size of the camera frame . micro_size ( tuple [ int , int ], optional ) : The size of the micro frame . init_position ( tuple [ int , int ], optional ) : The initial position of the view . Attributes : frame_reader ( FrameReader ) : The frame reader object . camera_size ( tuple [ int , int ]) : The size of the camera view ( w , h ). micro_size ( tuple [ int , int ]) : The size of the micro view ( w , h ). position ( tuple [ int , int ]) : The current position of the center of the view ( x , y ). \"\"\" def __init__ ( self , frame_reader : FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ), ) : super (). __init__ ( frame_reader ) assert camera_size [ 0 ] >= micro_size [ 0 ] assert camera_size [ 1 ] >= micro_size [ 1 ] self . _padding_size : tuple [ int , int ] = ( camera_size [ 0 ] // 2, camera_size[1] // 2) self . _camera_size = camera_size self . _micro_size = micro_size self . _position = init_position self . set_position ( * init_position ) def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding . Returns : np . ndarray : The padded frame . \"\"\" frame = super (). read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame @property def position ( self ) -> tuple [ int , int ] : \"\"\" Get the current position of the view controller . Returns : tuple [ int , int ] : The current position ( x , y ). \"\"\" return self . _position @property def camera_size ( self ) -> tuple [ int , int ] : \"\"\" Get the size of the camera view . Returns : tuple [ int , int ] : The size of the camera view ( w , h ). \"\"\" return self . _camera_size @property def micro_size ( self ) -> tuple [ int , int ] : \"\"\" Get the size of the micro view . Returns : tuple [ int , int ] : The size of the micro view ( w , h ). \"\"\" return self . _micro_size @property def camera_position ( self ) -> tuple [ int , int , int , int ] : \"\"\" Get the position of the camera view . Returns : tuple [ int , int , int , int ] : The position of the camera view ( x , y , w , h ). \"\"\" w , h = self . camera_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h @property def micro_position ( self ) -> tuple [ int , int , int , int ] : \"\"\" Get the position of the micro view . Returns : tuple [ int , int , int , int ] : The position of the micro view ( x , y , w , h ). \"\"\" w , h = self . micro_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h def set_position ( self , x : int , y : int ) : \"\"\" Set the position of the view controller . Note , that the position is clamped to the frame size . Args : x ( int ) : The x - coordinate of the position . y ( int ) : The y - coordinate of the position . \"\"\" x = np . clip ( x , 0 , self . _frame_reader . frame_shape [ 1 ] - 1 ) y = np . clip ( y , 0 , self . _frame_reader . frame_shape [ 0 ] - 1 ) self . _position = ( x , y ) def move_position ( self , dx : int , dy : int ) : \"\"\" Move the position of the view controller by dx and dy . Args : dx ( int ) : The amount to move in the x - direction . dy ( int ) : The amount to move in the y - direction . \"\"\" self . set_position ( self . _position [ 0 ] + dx , self . _position [ 1 ] + dy ) def _calc_view_bbox ( self , w : int , h : int ) -> tuple [ int , int , int , int ] : \"\"\" Calculate the bbox of the view , while taking padding into account . Args : w ( int ) : The width of the view . h ( int ) : The height of the view . Returns : tuple [ int , int , int , int ] : The bounding box of the view ( x , y , w , h ). \"\"\" x = self . _position [ 0 ] + self . _padding_size [ 0 ] - w // 2 y = self . _position [ 1 ] + self . _padding_size [ 1 ] - h // 2 return x , y , w , h def _custom_view ( self , w : int , h : int ) -> np . ndarray : \"\"\" Get a custom view of the frame . Args : w ( int ) : The width of the view . h ( int ) : The height of the view . Returns : np . ndarray : The custom view of the frame . \"\"\" x , y , w , h = self . _calc_view_bbox ( w , h ) frame = self . read () slice = frame [ y : y + w , x : x + h ] return slice def camera_view ( self ) -> np . ndarray : \"\"\" Get the camera view . Returns : np . ndarray : The camera view . \"\"\" return self . _custom_view ( * self . camera_size ) def micro_view ( self ) -> np . ndarray : \"\"\" Get the micro view . Returns : np . ndarray : The micro view . \"\"\" return self . _custom_view ( * self . micro_size ) def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) : \"\"\" Visualize the world view with bounding boxes . Both the camera and micro views are visualized , along with the center point . Args : line_width ( int ) : The width of the bounding box lines . \"\"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout )","title":"Attributes"},{"location":"reference/wtracker/sim/view_controller/#ancestors-in-mro","text":"wtracker.utils.frame_reader.FrameStream","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/view_controller/#instance-variables","text":"camera_position Get the position of the camera view. camera_size Get the size of the camera view. index The index of the current frame. micro_position Get the position of the micro view. micro_size Get the size of the micro view. position Get the current position of the view controller.","title":"Instance variables"},{"location":"reference/wtracker/sim/view_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/view_controller/#camera_view","text":"def camera_view ( self ) -> numpy . ndarray Get the camera view. Returns: Type Description np.ndarray The camera view. View Source def camera_view ( self ) - > np . ndarray : \"\" \" Get the camera view. Returns: np.ndarray: The camera view. \" \"\" return self . _custom_view ( * self . camera_size )","title":"camera_view"},{"location":"reference/wtracker/sim/view_controller/#can_read","text":"def can_read ( self ) -> 'bool' View Source def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader )","title":"can_read"},{"location":"reference/wtracker/sim/view_controller/#micro_view","text":"def micro_view ( self ) -> numpy . ndarray Get the micro view. Returns: Type Description np.ndarray The micro view. View Source def micro_view(self) -> np.ndarray: \"\"\" Get the micro view. Returns: np.ndarray: The micro view. \"\"\" return self._custom_view(*self.micro_size)","title":"micro_view"},{"location":"reference/wtracker/sim/view_controller/#move_position","text":"def move_position ( self , dx : int , dy : int ) Move the position of the view controller by dx and dy. Parameters: Name Type Description Default dx int The amount to move in the x-direction. None dy int The amount to move in the y-direction. None View Source def move_position(self, dx: int, dy: int): \"\"\" Move the position of the view controller by dx and dy. Args: dx (int): The amount to move in the x-direction. dy (int): The amount to move in the y-direction. \"\"\" self.set_position(self._position[0] + dx, self._position[1] + dy)","title":"move_position"},{"location":"reference/wtracker/sim/view_controller/#progress","text":"def progress ( self , n : 'int' = 1 ) -> 'bool' Moves the current index forward by the specified number of steps. Parameters: Name Type Description Default n int The number of steps to move forward. None Returns: Type Description bool True if the index was successfully moved forward, False otherwise. View Source def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n )","title":"progress"},{"location":"reference/wtracker/sim/view_controller/#read","text":"def read ( self ) -> numpy . ndarray Read a frame from the frame reader and apply padding. Returns: Type Description np.ndarray The padded frame. View Source def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding. Returns: np.ndarray: The padded frame. \"\"\" frame = super (). read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame","title":"read"},{"location":"reference/wtracker/sim/view_controller/#reset","text":"def reset ( self ) Resets the frame reader to the beginning of the steam. View Source def reset(self): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self.seek(-1)","title":"reset"},{"location":"reference/wtracker/sim/view_controller/#seek","text":"def seek ( self , idx : 'int' ) -> 'bool' Move the index to the specified position. Parameters: Name Type Description Default idx int The index to seek to. None Returns: Type Description bool True if the index is within the valid range, False otherwise. View Source def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read ()","title":"seek"},{"location":"reference/wtracker/sim/view_controller/#set_position","text":"def set_position ( self , x : int , y : int ) Set the position of the view controller. Note, that the position is clamped to the frame size. Parameters: Name Type Description Default x int The x-coordinate of the position. None y int The y-coordinate of the position. None View Source def set_position(self, x: int, y: int): \"\"\" Set the position of the view controller. Note, that the position is clamped to the frame size. Args: x (int): The x-coordinate of the position. y (int): The y-coordinate of the position. \"\"\" x = np.clip(x, 0, self._frame_reader.frame_shape[1] - 1) y = np.clip(y, 0, self._frame_reader.frame_shape[0] - 1) self._position = (x, y)","title":"set_position"},{"location":"reference/wtracker/sim/view_controller/#visualize_world","text":"def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Parameters: Name Type Description Default line_width int The width of the bounding box lines. None View Source def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) : \"\" \" Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Args: line_width (int): The width of the bounding box lines. \" \"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout )","title":"visualize_world"},{"location":"reference/wtracker/sim/sim_controllers/","text":"Module wtracker.sim.sim_controllers View Source from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.sim.sim_controllers.mlp_controllers import MLPController from wtracker.sim.sim_controllers.logging_controller import LogConfig , LoggingController from wtracker.sim.sim_controllers.optimal_controller import OptimalController from wtracker.sim.sim_controllers.polyfit_controller import PolyfitConfig , PolyfitController from wtracker.sim.sim_controllers.yolo_controller import YoloConfig , YoloController Sub-modules wtracker.sim.sim_controllers.csv_controller wtracker.sim.sim_controllers.logging_controller wtracker.sim.sim_controllers.mlp_controllers wtracker.sim.sim_controllers.optimal_controller wtracker.sim.sim_controllers.polyfit_controller wtracker.sim.sim_controllers.yolo_controller","title":"Index"},{"location":"reference/wtracker/sim/sim_controllers/#module-wtrackersimsim_controllers","text":"View Source from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.sim.sim_controllers.mlp_controllers import MLPController from wtracker.sim.sim_controllers.logging_controller import LogConfig , LoggingController from wtracker.sim.sim_controllers.optimal_controller import OptimalController from wtracker.sim.sim_controllers.polyfit_controller import PolyfitConfig , PolyfitController from wtracker.sim.sim_controllers.yolo_controller import YoloConfig , YoloController","title":"Module wtracker.sim.sim_controllers"},{"location":"reference/wtracker/sim/sim_controllers/#sub-modules","text":"wtracker.sim.sim_controllers.csv_controller wtracker.sim.sim_controllers.logging_controller wtracker.sim.sim_controllers.mlp_controllers wtracker.sim.sim_controllers.optimal_controller wtracker.sim.sim_controllers.polyfit_controller wtracker.sim.sim_controllers.yolo_controller","title":"Sub-modules"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/","text":"Module wtracker.sim.sim_controllers.csv_controller View Source from collections import deque from typing import Collection import pandas as pd import numpy as np from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import SimController , Simulator from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat class CsvController ( SimController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super () . __init__ ( timing_config ) self . csv_path = csv_path self . _csv_data = pd . read_csv ( self . csv_path , usecols = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]) . to_numpy ( dtype = float ) self . _camera_bboxes = deque ( maxlen = timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ): self . _camera_bboxes . clear () def on_camera_frame ( self , sim : Simulator ): self . _camera_bboxes . append ( sim . view . camera_position ) def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ]) worm_bboxes = np . full (( frame_nums . shape [ 0 ], 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums [ valid_mask ], :] if not relative : return worm_bboxes # TODO: if relative == True then it works only if frame number if within the last cycle. # maybe fix that. cam_bboxes = [ self . _camera_bboxes [ n % self . timing_config . cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [:, 0 ] -= cam_bboxes [:, 0 ] worm_bboxes [:, 1 ] -= cam_bboxes [:, 1 ] return worm_bboxes def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: bbox = self . predict ([ sim . frame_number - self . timing_config . pred_frame_num ]) bbox = bbox [ 0 , :] if not np . isfinite ( bbox ) . all (): return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ]) dy = round ( center [ 1 ] - cam_center [ 1 ]) return dx , dy def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : start = ( sim . cycle_number - 1 ) * self . timing_config . cycle_frame_num end = start + self . timing_config . cycle_frame_num end = min ( end , len ( self . _csv_data )) return self . predict ( np . arange ( start , end )) Classes CsvController class CsvController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class CsvController ( SimController ) : def __init__ ( self , timing_config : TimingConfig , csv_path : str ) : super (). __init__ ( timing_config ) self . csv_path = csv_path self . _csv_data = pd . read_csv ( self . csv_path , usecols =[ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ). to_numpy ( dtype = float ) self . _camera_bboxes = deque ( maxlen = timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : bbox = self . predict ( [ sim.frame_number - self.timing_config.pred_frame_num ] ) bbox = bbox [ 0, : ] if not np . isfinite ( bbox ). all () : return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ] ) dy = round ( center [ 1 ] - cam_center [ 1 ] ) return dx , dy def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : start = ( sim . cycle_number - 1 ) * self . timing_config . cycle_frame_num end = start + self . timing_config . cycle_frame_num end = min ( end , len ( self . _csv_data )) return self . predict ( np . arange ( start , end )) Ancestors (in MRO) wtracker.sim.simulator.SimController abc.ABC Descendants wtracker.sim.sim_controllers.mlp_controllers.MLPController wtracker.sim.sim_controllers.optimal_controller.OptimalController wtracker.sim.sim_controllers.polyfit_controller.PolyfitController Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () predict def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) - > tuple [ int , int ] : bbox = self . predict ([ sim . frame_number - self . timing_config . pred_frame_num ]) bbox = bbox [ 0 , : ] if not np . isfinite ( bbox ) . all () : return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ]) dy = round ( center [ 1 ] - cam_center [ 1 ]) return dx , dy","title":"Csv Controller"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#module-wtrackersimsim_controllerscsv_controller","text":"View Source from collections import deque from typing import Collection import pandas as pd import numpy as np from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import SimController , Simulator from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat class CsvController ( SimController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super () . __init__ ( timing_config ) self . csv_path = csv_path self . _csv_data = pd . read_csv ( self . csv_path , usecols = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]) . to_numpy ( dtype = float ) self . _camera_bboxes = deque ( maxlen = timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ): self . _camera_bboxes . clear () def on_camera_frame ( self , sim : Simulator ): self . _camera_bboxes . append ( sim . view . camera_position ) def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ]) worm_bboxes = np . full (( frame_nums . shape [ 0 ], 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums [ valid_mask ], :] if not relative : return worm_bboxes # TODO: if relative == True then it works only if frame number if within the last cycle. # maybe fix that. cam_bboxes = [ self . _camera_bboxes [ n % self . timing_config . cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [:, 0 ] -= cam_bboxes [:, 0 ] worm_bboxes [:, 1 ] -= cam_bboxes [:, 1 ] return worm_bboxes def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: bbox = self . predict ([ sim . frame_number - self . timing_config . pred_frame_num ]) bbox = bbox [ 0 , :] if not np . isfinite ( bbox ) . all (): return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ]) dy = round ( center [ 1 ] - cam_center [ 1 ]) return dx , dy def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : start = ( sim . cycle_number - 1 ) * self . timing_config . cycle_frame_num end = start + self . timing_config . cycle_frame_num end = min ( end , len ( self . _csv_data )) return self . predict ( np . arange ( start , end ))","title":"Module wtracker.sim.sim_controllers.csv_controller"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#csvcontroller","text":"class CsvController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str ) Abstract base class for simulator controllers.","title":"CsvController"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class CsvController ( SimController ) : def __init__ ( self , timing_config : TimingConfig , csv_path : str ) : super (). __init__ ( timing_config ) self . csv_path = csv_path self . _csv_data = pd . read_csv ( self . csv_path , usecols =[ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ). to_numpy ( dtype = float ) self . _camera_bboxes = deque ( maxlen = timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : bbox = self . predict ( [ sim.frame_number - self.timing_config.pred_frame_num ] ) bbox = bbox [ 0, : ] if not np . isfinite ( bbox ). all () : return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ] ) dy = round ( center [ 1 ] - cam_center [ 1 ] ) return dx , dy def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : start = ( sim . cycle_number - 1 ) * self . timing_config . cycle_frame_num end = start + self . timing_config . cycle_frame_num end = min ( end , len ( self . _csv_data )) return self . predict ( np . arange ( start , end ))","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#ancestors-in-mro","text":"wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#descendants","text":"wtracker.sim.sim_controllers.mlp_controllers.MLPController wtracker.sim.sim_controllers.optimal_controller.OptimalController wtracker.sim.sim_controllers.polyfit_controller.PolyfitController","title":"Descendants"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#predict","text":"def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) - > tuple [ int , int ] : bbox = self . predict ([ sim . frame_number - self . timing_config . pred_frame_num ]) bbox = bbox [ 0 , : ] if not np . isfinite ( bbox ) . all () : return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ]) dy = round ( center [ 1 ] - cam_center [ 1 ]) return dx , dy","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/","text":"Module wtracker.sim.sim_controllers.logging_controller View Source from collections import deque import numpy as np from dataclasses import dataclass , field from copy import deepcopy from wtracker.sim.simulator import Simulator , SimController from wtracker.utils.io_utils import ImageSaver , FrameSaver from wtracker.utils.log_utils import CSVLogger from wtracker.utils.config_base import ConfigBase from wtracker.utils.path_utils import join_paths , create_parent_directory from wtracker.utils.bbox_utils import BoxUtils , BoxFormat @dataclass class LogConfig ( ConfigBase ): root_folder : str \"\"\"The directory where the logs will be saved into.\"\"\" save_mic_view : bool = False \"\"\"Whether to save the microscope view of each frame.\"\"\" save_cam_view : bool = False \"\"\"Whether to save the camera view of each frame.\"\"\" save_err_view : bool = True \"\"\"Whether to camera view of frames in which no prediction was made.\"\"\" save_wrm_view : bool = False \"\"\"whether to save the detected worm head of each frame.\"\"\" mic_folder_name : str = \"micro\" cam_folder_name : str = \"camera\" err_folder_name : str = \"errors\" wrm_folder_name : str = \"worms\" # TODO: WHY DO WE SAVE IN PNG FORMAT AND NOT BMP? bbox_file_name : str = \"bboxes.csv\" mic_file_name : str = \"mic_ {:09d} .png\" cam_file_name : str = \"cam_ {:09d} .png\" wrm_file_name : str = \"wrm_ {:09d} .png\" mic_file_path : str = field ( init = False ) cam_file_path : str = field ( init = False ) err_file_path : str = field ( init = False ) wrm_file_path : str = field ( init = False ) bbox_file_path : str = field ( init = False ) def __post_init__ ( self ): self . mic_file_path = join_paths ( self . root_folder , self . mic_folder_name , self . mic_file_name ) self . cam_file_path = join_paths ( self . root_folder , self . cam_folder_name , self . cam_file_name ) self . err_file_path = join_paths ( self . root_folder , self . err_folder_name , self . cam_file_name ) self . wrm_file_path = join_paths ( self . root_folder , self . wrm_folder_name , self . wrm_file_name ) self . bbox_file_path = join_paths ( self . root_folder , self . bbox_file_name ) def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path ) class LoggingController ( SimController ): def __init__ ( self , sim_controller : SimController , log_config : LogConfig , ): super () . __init__ ( sim_controller . timing_config ) self . sim_controller = sim_controller self . log_config = log_config self . _camera_frames = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _platform_positions = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _camera_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _micro_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ): self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names = [ \"frame\" , \"cycle\" , \"phase\" , \"plt_x\" , \"plt_y\" , \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" , \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" , \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" , ], ) def on_cycle_start ( self , sim : Simulator ): self . sim_controller . on_cycle_start ( sim ) def on_camera_frame ( self , sim : Simulator ): self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) def _log_cycle ( self , sim : Simulator ): cycle_number = sim . cycle_number - 1 frame_offset = cycle_number * self . timing_config . cycle_frame_num worm_bboxes = self . sim_controller . _cycle_predict_all ( sim ) cam_bboxes = np . asanyarray ( list ( self . _camera_bboxes )) # make worm bboxes coordinate absolute worm_bboxes [:, 0 ] += cam_bboxes [:, 0 ] worm_bboxes [:, 1 ] += cam_bboxes [:, 1 ] # calc the crop dims to get the worm view from the original frame ( H , W ) = sim . experiment_config . orig_resolution crop_dims , is_crop_legal = BoxUtils . discretize ( worm_bboxes , ( H , W ), BoxFormat . XYWH ) for i , worm_bbox in enumerate ( worm_bboxes ): frame_number = frame_offset + i # if no prediction and we're saving error frames if not np . isfinite ( worm_bbox ) . all () and self . log_config . save_err_view : err_view = self . _camera_frames [ i ] path = self . log_config . err_file_path . format ( frame_number ) self . _image_saver . schedule_save ( img = err_view , img_name = path ) # save cropped worm view if crop is legal if self . log_config . save_wrm_view and is_crop_legal [ i ]: crop_dim = crop_dims [ i ] path = self . log_config . wrm_file_path . format ( frame_number ) self . _frame_saver . schedule_save ( img_index = frame_number , crop_dims = crop_dim , img_name = path ) csv_row = {} csv_row [ \"plt_x\" ], csv_row [ \"plt_y\" ] = self . _platform_positions [ i ] csv_row [ \"cam_x\" ], csv_row [ \"cam_y\" ], csv_row [ \"cam_w\" ], csv_row [ \"cam_h\" ] = self . _camera_bboxes [ i ] csv_row [ \"mic_x\" ], csv_row [ \"mic_y\" ], csv_row [ \"mic_w\" ], csv_row [ \"mic_h\" ] = self . _micro_bboxes [ i ] csv_row [ \"cycle\" ] = cycle_number csv_row [ \"frame\" ] = frame_number csv_row [ \"phase\" ] = \"imaging\" if i < self . timing_config . imaging_frame_num else \"moving\" csv_row [ \"wrm_x\" ], csv_row [ \"wrm_y\" ], csv_row [ \"wrm_w\" ], csv_row [ \"wrm_h\" ] = worm_bbox self . _bbox_logger . write ( csv_row ) self . _bbox_logger . flush () def on_cycle_end ( self , sim : Simulator ): self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () def on_sim_end ( self , sim : Simulator ): self . sim_controller . on_sim_end ( sim ) self . _image_saver . close () self . _frame_saver . close () self . _bbox_logger . close () def on_imaging_start ( self , sim : Simulator ): self . sim_controller . on_imaging_start ( sim ) def on_micro_frame ( self , sim : Simulator ): self . sim_controller . on_micro_frame ( sim ) def on_imaging_end ( self , sim : Simulator ): self . sim_controller . on_imaging_end ( sim ) def on_movement_start ( self , sim : Simulator ): self . sim_controller . on_movement_start ( sim ) def on_movement_end ( self , sim : Simulator ): self . sim_controller . on_movement_end ( sim ) def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: return self . sim_controller . provide_movement_vector ( sim ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . sim_controller . _cycle_predict_all ( sim ) Classes LogConfig class LogConfig ( root_folder : str , save_mic_view : bool = False , save_cam_view : bool = False , save_err_view : bool = True , save_wrm_view : bool = False , mic_folder_name : str = 'micro' , cam_folder_name : str = 'camera' , err_folder_name : str = 'errors' , wrm_folder_name : str = 'worms' , bbox_file_name : str = 'bboxes.csv' , mic_file_name : str = 'mic_ {:09d} .png' , cam_file_name : str = 'cam_ {:09d} .png' , wrm_file_name : str = 'wrm_ {:09d} .png' ) LogConfig(root_folder: str, save_mic_view: bool = False, save_cam_view: bool = False, save_err_view: bool = True, save_wrm_view: bool = False, mic_folder_name: str = 'micro', cam_folder_name: str = 'camera', err_folder_name: str = 'errors', wrm_folder_name: str = 'worms', bbox_file_name: str = 'bboxes.csv', mic_file_name: str = 'mic_{:09d}.png', cam_file_name: str = 'cam_{:09d}.png', wrm_file_name: str = 'wrm_{:09d}.png') View Source @dataclass class LogConfig ( ConfigBase ) : root_folder : str \"\"\"The directory where the logs will be saved into.\"\"\" save_mic_view : bool = False \"\"\"Whether to save the microscope view of each frame.\"\"\" save_cam_view : bool = False \"\"\"Whether to save the camera view of each frame.\"\"\" save_err_view : bool = True \"\"\"Whether to camera view of frames in which no prediction was made.\"\"\" save_wrm_view : bool = False \"\"\"whether to save the detected worm head of each frame.\"\"\" mic_folder_name : str = \"micro\" cam_folder_name : str = \"camera\" err_folder_name : str = \"errors\" wrm_folder_name : str = \"worms\" # TODO : WHY DO WE SAVE IN PNG FORMAT AND NOT BMP ? bbox_file_name : str = \"bboxes.csv\" mic_file_name : str = \"mic_{:09d}.png\" cam_file_name : str = \"cam_{:09d}.png\" wrm_file_name : str = \"wrm_{:09d}.png\" mic_file_path : str = field ( init = False ) cam_file_path : str = field ( init = False ) err_file_path : str = field ( init = False ) wrm_file_path : str = field ( init = False ) bbox_file_path : str = field ( init = False ) def __post_init__ ( self ) : self . mic_file_path = join_paths ( self . root_folder , self . mic_folder_name , self . mic_file_name ) self . cam_file_path = join_paths ( self . root_folder , self . cam_folder_name , self . cam_file_name ) self . err_file_path = join_paths ( self . root_folder , self . err_folder_name , self . cam_file_name ) self . wrm_file_path = join_paths ( self . root_folder , self . wrm_folder_name , self . wrm_file_name ) self . bbox_file_path = join_paths ( self . root_folder , self . bbox_file_name ) def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables bbox_file_name cam_file_name cam_folder_name err_folder_name mic_file_name mic_folder_name save_cam_view save_err_view save_mic_view save_wrm_view wrm_file_name wrm_folder_name Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods create_dirs def create_dirs ( self ) -> None View Source def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path ) save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) LoggingController class LoggingController ( sim_controller : wtracker . sim . simulator . SimController , log_config : wtracker . sim . sim_controllers . logging_controller . LogConfig ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class LoggingController ( SimController ) : def __init__ ( self , sim_controller : SimController , log_config : LogConfig , ) : super (). __init__ ( sim_controller . timing_config ) self . sim_controller = sim_controller self . log_config = log_config self . _camera_frames = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _platform_positions = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _camera_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _micro_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ) : self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names =[ \"frame\", \"cycle\", \"phase\", \"plt_x\", \"plt_y\", \"cam_x\", \"cam_y\", \"cam_w\", \"cam_h\", \"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\", \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\", ] , ) def on_cycle_start ( self , sim : Simulator ) : self . sim_controller . on_cycle_start ( sim ) def on_camera_frame ( self , sim : Simulator ) : self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) def _log_cycle ( self , sim : Simulator ) : cycle_number = sim . cycle_number - 1 frame_offset = cycle_number * self . timing_config . cycle_frame_num worm_bboxes = self . sim_controller . _cycle_predict_all ( sim ) cam_bboxes = np . asanyarray ( list ( self . _camera_bboxes )) # make worm bboxes coordinate absolute worm_bboxes [ :, 0 ] += cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] += cam_bboxes [ :, 1 ] # calc the crop dims to get the worm view from the original frame ( H , W ) = sim . experiment_config . orig_resolution crop_dims , is_crop_legal = BoxUtils . discretize ( worm_bboxes , ( H , W ), BoxFormat . XYWH ) for i , worm_bbox in enumerate ( worm_bboxes ) : frame_number = frame_offset + i # if no prediction and we ' re saving error frames if not np . isfinite ( worm_bbox ). all () and self . log_config . save_err_view : err_view = self . _camera_frames [ i ] path = self . log_config . err_file_path . format ( frame_number ) self . _image_saver . schedule_save ( img = err_view , img_name = path ) # save cropped worm view if crop is legal if self . log_config . save_wrm_view and is_crop_legal [ i ] : crop_dim = crop_dims [ i ] path = self . log_config . wrm_file_path . format ( frame_number ) self . _frame_saver . schedule_save ( img_index = frame_number , crop_dims = crop_dim , img_name = path ) csv_row = {} csv_row [ \"plt_x\" ] , csv_row [ \"plt_y\" ] = self . _platform_positions [ i ] csv_row [ \"cam_x\" ] , csv_row [ \"cam_y\" ] , csv_row [ \"cam_w\" ] , csv_row [ \"cam_h\" ] = self . _camera_bboxes [ i ] csv_row [ \"mic_x\" ] , csv_row [ \"mic_y\" ] , csv_row [ \"mic_w\" ] , csv_row [ \"mic_h\" ] = self . _micro_bboxes [ i ] csv_row [ \"cycle\" ] = cycle_number csv_row [ \"frame\" ] = frame_number csv_row [ \"phase\" ] = \"imaging\" if i < self . timing_config . imaging_frame_num else \"moving\" csv_row [ \"wrm_x\" ] , csv_row [ \"wrm_y\" ] , csv_row [ \"wrm_w\" ] , csv_row [ \"wrm_h\" ] = worm_bbox self . _bbox_logger . write ( csv_row ) self . _bbox_logger . flush () def on_cycle_end ( self , sim : Simulator ) : self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () def on_sim_end ( self , sim : Simulator ) : self . sim_controller . on_sim_end ( sim ) self . _image_saver . close () self . _frame_saver . close () self . _bbox_logger . close () def on_imaging_start ( self , sim : Simulator ) : self . sim_controller . on_imaging_start ( sim ) def on_micro_frame ( self , sim : Simulator ) : self . sim_controller . on_micro_frame ( sim ) def on_imaging_end ( self , sim : Simulator ) : self . sim_controller . on_imaging_end ( sim ) def on_movement_start ( self , sim : Simulator ) : self . sim_controller . on_movement_start ( sim ) def on_movement_end ( self , sim : Simulator ) : self . sim_controller . on_movement_end ( sim ) def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : return self . sim_controller . provide_movement_vector ( sim ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . sim_controller . _cycle_predict_all ( sim ) Ancestors (in MRO) wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) on_cycle_end def on_cycle_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when a cycle ends. View Source def on_cycle_end ( self , sim : Simulator ) : self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () on_cycle_start def on_cycle_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): self.sim_controller.on_cycle_start(sim) on_imaging_end def on_imaging_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): self.sim_controller.on_imaging_end(sim) on_imaging_start def on_imaging_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): self.sim_controller.on_imaging_start(sim) on_micro_frame def on_micro_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): self.sim_controller.on_micro_frame(sim) on_movement_end def on_movement_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): self.sim_controller.on_movement_end(sim) on_movement_start def on_movement_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): self.sim_controller.on_movement_start(sim) on_sim_end def on_sim_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): self.sim_controller.on_sim_end(sim) self._image_saver.close() self._frame_saver.close() self._bbox_logger.close() on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names = [ \"frame\" , \"cycle\" , \"phase\" , \"plt_x\" , \"plt_y\" , \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" , \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" , \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" , ], ) provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ] : return self . sim_controller . provide_movement_vector ( sim )","title":"Logging Controller"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#module-wtrackersimsim_controllerslogging_controller","text":"View Source from collections import deque import numpy as np from dataclasses import dataclass , field from copy import deepcopy from wtracker.sim.simulator import Simulator , SimController from wtracker.utils.io_utils import ImageSaver , FrameSaver from wtracker.utils.log_utils import CSVLogger from wtracker.utils.config_base import ConfigBase from wtracker.utils.path_utils import join_paths , create_parent_directory from wtracker.utils.bbox_utils import BoxUtils , BoxFormat @dataclass class LogConfig ( ConfigBase ): root_folder : str \"\"\"The directory where the logs will be saved into.\"\"\" save_mic_view : bool = False \"\"\"Whether to save the microscope view of each frame.\"\"\" save_cam_view : bool = False \"\"\"Whether to save the camera view of each frame.\"\"\" save_err_view : bool = True \"\"\"Whether to camera view of frames in which no prediction was made.\"\"\" save_wrm_view : bool = False \"\"\"whether to save the detected worm head of each frame.\"\"\" mic_folder_name : str = \"micro\" cam_folder_name : str = \"camera\" err_folder_name : str = \"errors\" wrm_folder_name : str = \"worms\" # TODO: WHY DO WE SAVE IN PNG FORMAT AND NOT BMP? bbox_file_name : str = \"bboxes.csv\" mic_file_name : str = \"mic_ {:09d} .png\" cam_file_name : str = \"cam_ {:09d} .png\" wrm_file_name : str = \"wrm_ {:09d} .png\" mic_file_path : str = field ( init = False ) cam_file_path : str = field ( init = False ) err_file_path : str = field ( init = False ) wrm_file_path : str = field ( init = False ) bbox_file_path : str = field ( init = False ) def __post_init__ ( self ): self . mic_file_path = join_paths ( self . root_folder , self . mic_folder_name , self . mic_file_name ) self . cam_file_path = join_paths ( self . root_folder , self . cam_folder_name , self . cam_file_name ) self . err_file_path = join_paths ( self . root_folder , self . err_folder_name , self . cam_file_name ) self . wrm_file_path = join_paths ( self . root_folder , self . wrm_folder_name , self . wrm_file_name ) self . bbox_file_path = join_paths ( self . root_folder , self . bbox_file_name ) def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path ) class LoggingController ( SimController ): def __init__ ( self , sim_controller : SimController , log_config : LogConfig , ): super () . __init__ ( sim_controller . timing_config ) self . sim_controller = sim_controller self . log_config = log_config self . _camera_frames = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _platform_positions = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _camera_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _micro_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ): self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names = [ \"frame\" , \"cycle\" , \"phase\" , \"plt_x\" , \"plt_y\" , \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" , \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" , \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" , ], ) def on_cycle_start ( self , sim : Simulator ): self . sim_controller . on_cycle_start ( sim ) def on_camera_frame ( self , sim : Simulator ): self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) def _log_cycle ( self , sim : Simulator ): cycle_number = sim . cycle_number - 1 frame_offset = cycle_number * self . timing_config . cycle_frame_num worm_bboxes = self . sim_controller . _cycle_predict_all ( sim ) cam_bboxes = np . asanyarray ( list ( self . _camera_bboxes )) # make worm bboxes coordinate absolute worm_bboxes [:, 0 ] += cam_bboxes [:, 0 ] worm_bboxes [:, 1 ] += cam_bboxes [:, 1 ] # calc the crop dims to get the worm view from the original frame ( H , W ) = sim . experiment_config . orig_resolution crop_dims , is_crop_legal = BoxUtils . discretize ( worm_bboxes , ( H , W ), BoxFormat . XYWH ) for i , worm_bbox in enumerate ( worm_bboxes ): frame_number = frame_offset + i # if no prediction and we're saving error frames if not np . isfinite ( worm_bbox ) . all () and self . log_config . save_err_view : err_view = self . _camera_frames [ i ] path = self . log_config . err_file_path . format ( frame_number ) self . _image_saver . schedule_save ( img = err_view , img_name = path ) # save cropped worm view if crop is legal if self . log_config . save_wrm_view and is_crop_legal [ i ]: crop_dim = crop_dims [ i ] path = self . log_config . wrm_file_path . format ( frame_number ) self . _frame_saver . schedule_save ( img_index = frame_number , crop_dims = crop_dim , img_name = path ) csv_row = {} csv_row [ \"plt_x\" ], csv_row [ \"plt_y\" ] = self . _platform_positions [ i ] csv_row [ \"cam_x\" ], csv_row [ \"cam_y\" ], csv_row [ \"cam_w\" ], csv_row [ \"cam_h\" ] = self . _camera_bboxes [ i ] csv_row [ \"mic_x\" ], csv_row [ \"mic_y\" ], csv_row [ \"mic_w\" ], csv_row [ \"mic_h\" ] = self . _micro_bboxes [ i ] csv_row [ \"cycle\" ] = cycle_number csv_row [ \"frame\" ] = frame_number csv_row [ \"phase\" ] = \"imaging\" if i < self . timing_config . imaging_frame_num else \"moving\" csv_row [ \"wrm_x\" ], csv_row [ \"wrm_y\" ], csv_row [ \"wrm_w\" ], csv_row [ \"wrm_h\" ] = worm_bbox self . _bbox_logger . write ( csv_row ) self . _bbox_logger . flush () def on_cycle_end ( self , sim : Simulator ): self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () def on_sim_end ( self , sim : Simulator ): self . sim_controller . on_sim_end ( sim ) self . _image_saver . close () self . _frame_saver . close () self . _bbox_logger . close () def on_imaging_start ( self , sim : Simulator ): self . sim_controller . on_imaging_start ( sim ) def on_micro_frame ( self , sim : Simulator ): self . sim_controller . on_micro_frame ( sim ) def on_imaging_end ( self , sim : Simulator ): self . sim_controller . on_imaging_end ( sim ) def on_movement_start ( self , sim : Simulator ): self . sim_controller . on_movement_start ( sim ) def on_movement_end ( self , sim : Simulator ): self . sim_controller . on_movement_end ( sim ) def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: return self . sim_controller . provide_movement_vector ( sim ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . sim_controller . _cycle_predict_all ( sim )","title":"Module wtracker.sim.sim_controllers.logging_controller"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#logconfig","text":"class LogConfig ( root_folder : str , save_mic_view : bool = False , save_cam_view : bool = False , save_err_view : bool = True , save_wrm_view : bool = False , mic_folder_name : str = 'micro' , cam_folder_name : str = 'camera' , err_folder_name : str = 'errors' , wrm_folder_name : str = 'worms' , bbox_file_name : str = 'bboxes.csv' , mic_file_name : str = 'mic_ {:09d} .png' , cam_file_name : str = 'cam_ {:09d} .png' , wrm_file_name : str = 'wrm_ {:09d} .png' ) LogConfig(root_folder: str, save_mic_view: bool = False, save_cam_view: bool = False, save_err_view: bool = True, save_wrm_view: bool = False, mic_folder_name: str = 'micro', cam_folder_name: str = 'camera', err_folder_name: str = 'errors', wrm_folder_name: str = 'worms', bbox_file_name: str = 'bboxes.csv', mic_file_name: str = 'mic_{:09d}.png', cam_file_name: str = 'cam_{:09d}.png', wrm_file_name: str = 'wrm_{:09d}.png') View Source @dataclass class LogConfig ( ConfigBase ) : root_folder : str \"\"\"The directory where the logs will be saved into.\"\"\" save_mic_view : bool = False \"\"\"Whether to save the microscope view of each frame.\"\"\" save_cam_view : bool = False \"\"\"Whether to save the camera view of each frame.\"\"\" save_err_view : bool = True \"\"\"Whether to camera view of frames in which no prediction was made.\"\"\" save_wrm_view : bool = False \"\"\"whether to save the detected worm head of each frame.\"\"\" mic_folder_name : str = \"micro\" cam_folder_name : str = \"camera\" err_folder_name : str = \"errors\" wrm_folder_name : str = \"worms\" # TODO : WHY DO WE SAVE IN PNG FORMAT AND NOT BMP ? bbox_file_name : str = \"bboxes.csv\" mic_file_name : str = \"mic_{:09d}.png\" cam_file_name : str = \"cam_{:09d}.png\" wrm_file_name : str = \"wrm_{:09d}.png\" mic_file_path : str = field ( init = False ) cam_file_path : str = field ( init = False ) err_file_path : str = field ( init = False ) wrm_file_path : str = field ( init = False ) bbox_file_path : str = field ( init = False ) def __post_init__ ( self ) : self . mic_file_path = join_paths ( self . root_folder , self . mic_folder_name , self . mic_file_name ) self . cam_file_path = join_paths ( self . root_folder , self . cam_folder_name , self . cam_file_name ) self . err_file_path = join_paths ( self . root_folder , self . err_folder_name , self . cam_file_name ) self . wrm_file_path = join_paths ( self . root_folder , self . wrm_folder_name , self . wrm_file_name ) self . bbox_file_path = join_paths ( self . root_folder , self . bbox_file_name ) def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path )","title":"LogConfig"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#class-variables","text":"bbox_file_name cam_file_name cam_folder_name err_folder_name mic_file_name mic_folder_name save_cam_view save_err_view save_mic_view save_wrm_view wrm_file_name wrm_folder_name","title":"Class variables"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#create_dirs","text":"def create_dirs ( self ) -> None View Source def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path )","title":"create_dirs"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#loggingcontroller","text":"class LoggingController ( sim_controller : wtracker . sim . simulator . SimController , log_config : wtracker . sim . sim_controllers . logging_controller . LogConfig ) Abstract base class for simulator controllers.","title":"LoggingController"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class LoggingController ( SimController ) : def __init__ ( self , sim_controller : SimController , log_config : LogConfig , ) : super (). __init__ ( sim_controller . timing_config ) self . sim_controller = sim_controller self . log_config = log_config self . _camera_frames = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _platform_positions = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _camera_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _micro_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ) : self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names =[ \"frame\", \"cycle\", \"phase\", \"plt_x\", \"plt_y\", \"cam_x\", \"cam_y\", \"cam_w\", \"cam_h\", \"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\", \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\", ] , ) def on_cycle_start ( self , sim : Simulator ) : self . sim_controller . on_cycle_start ( sim ) def on_camera_frame ( self , sim : Simulator ) : self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) def _log_cycle ( self , sim : Simulator ) : cycle_number = sim . cycle_number - 1 frame_offset = cycle_number * self . timing_config . cycle_frame_num worm_bboxes = self . sim_controller . _cycle_predict_all ( sim ) cam_bboxes = np . asanyarray ( list ( self . _camera_bboxes )) # make worm bboxes coordinate absolute worm_bboxes [ :, 0 ] += cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] += cam_bboxes [ :, 1 ] # calc the crop dims to get the worm view from the original frame ( H , W ) = sim . experiment_config . orig_resolution crop_dims , is_crop_legal = BoxUtils . discretize ( worm_bboxes , ( H , W ), BoxFormat . XYWH ) for i , worm_bbox in enumerate ( worm_bboxes ) : frame_number = frame_offset + i # if no prediction and we ' re saving error frames if not np . isfinite ( worm_bbox ). all () and self . log_config . save_err_view : err_view = self . _camera_frames [ i ] path = self . log_config . err_file_path . format ( frame_number ) self . _image_saver . schedule_save ( img = err_view , img_name = path ) # save cropped worm view if crop is legal if self . log_config . save_wrm_view and is_crop_legal [ i ] : crop_dim = crop_dims [ i ] path = self . log_config . wrm_file_path . format ( frame_number ) self . _frame_saver . schedule_save ( img_index = frame_number , crop_dims = crop_dim , img_name = path ) csv_row = {} csv_row [ \"plt_x\" ] , csv_row [ \"plt_y\" ] = self . _platform_positions [ i ] csv_row [ \"cam_x\" ] , csv_row [ \"cam_y\" ] , csv_row [ \"cam_w\" ] , csv_row [ \"cam_h\" ] = self . _camera_bboxes [ i ] csv_row [ \"mic_x\" ] , csv_row [ \"mic_y\" ] , csv_row [ \"mic_w\" ] , csv_row [ \"mic_h\" ] = self . _micro_bboxes [ i ] csv_row [ \"cycle\" ] = cycle_number csv_row [ \"frame\" ] = frame_number csv_row [ \"phase\" ] = \"imaging\" if i < self . timing_config . imaging_frame_num else \"moving\" csv_row [ \"wrm_x\" ] , csv_row [ \"wrm_y\" ] , csv_row [ \"wrm_w\" ] , csv_row [ \"wrm_h\" ] = worm_bbox self . _bbox_logger . write ( csv_row ) self . _bbox_logger . flush () def on_cycle_end ( self , sim : Simulator ) : self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () def on_sim_end ( self , sim : Simulator ) : self . sim_controller . on_sim_end ( sim ) self . _image_saver . close () self . _frame_saver . close () self . _bbox_logger . close () def on_imaging_start ( self , sim : Simulator ) : self . sim_controller . on_imaging_start ( sim ) def on_micro_frame ( self , sim : Simulator ) : self . sim_controller . on_micro_frame ( sim ) def on_imaging_end ( self , sim : Simulator ) : self . sim_controller . on_imaging_end ( sim ) def on_movement_start ( self , sim : Simulator ) : self . sim_controller . on_movement_start ( sim ) def on_movement_end ( self , sim : Simulator ) : self . sim_controller . on_movement_end ( sim ) def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : return self . sim_controller . provide_movement_vector ( sim ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . sim_controller . _cycle_predict_all ( sim )","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#ancestors-in-mro_1","text":"wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim )","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when a cycle ends. View Source def on_cycle_end ( self , sim : Simulator ) : self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear ()","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): self.sim_controller.on_cycle_start(sim)","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): self.sim_controller.on_imaging_end(sim)","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): self.sim_controller.on_imaging_start(sim)","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): self.sim_controller.on_micro_frame(sim)","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_movement_end","text":"def on_movement_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): self.sim_controller.on_movement_end(sim)","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_movement_start","text":"def on_movement_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): self.sim_controller.on_movement_start(sim)","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_sim_end","text":"def on_sim_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): self.sim_controller.on_sim_end(sim) self._image_saver.close() self._frame_saver.close() self._bbox_logger.close()","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names = [ \"frame\" , \"cycle\" , \"phase\" , \"plt_x\" , \"plt_y\" , \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" , \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" , \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" , ], )","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ] : return self . sim_controller . provide_movement_vector ( sim )","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/","text":"Module wtracker.sim.sim_controllers.mlp_controllers View Source from typing import Collection import numpy as np from collections import deque from torch import Tensor from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat from wtracker.neural.mlp import WormPredictor from wtracker.neural.config import IOConfig class MLPController ( CsvController ): \"\"\" MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Args: timing_config (TimingConfig): The timing configuration for the simulation. csv_path (str): The path to the CSV file containing the simulation data. model (WormPredictor): The WormPredictor model used for predicting worm movement. max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped. \"\"\" def __init__ ( self , timing_config : TimingConfig , csv_path : str , model : WormPredictor , max_speed : float = 0.9 ): super () . __init__ ( timing_config , csv_path ) self . model : WormPredictor = model self . io_config : IOConfig = model . io_config self . model . eval () px_per_mm = self . timing_config . px_per_mm fps = self . timing_config . frames_per_sec max_speed_px_frame = max_speed * ( px_per_mm / fps ) self . max_dist_per_pred = max_speed_px_frame * ( self . io_config . pred_frames [ 0 ]) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction (input to the model) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ) . reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ) . all (): return 0 , 0 # relative position of the worm to the camera center, we use the worm x,y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )) . flatten () . detach () . numpy () # make sure the prediction is within the limits and apply post-proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ] . item () + rel_x ) dy = round ( pred [ 1 ] . item () + rel_y ) # dx = np.clip(dx, -self.max_dist_per_pred, self.max_dist_per_pred) # dy = np.clip(dy, -self.max_dist_per_pred, self.max_dist_per_pred) return ( dx , dy ) def print_model ( self ): print ( self . model ) Classes MLPController class MLPController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str , model : wtracker . neural . mlp . WormPredictor , max_speed : float = 0.9 ) MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulation. None csv_path str The path to the CSV file containing the simulation data. None model WormPredictor The WormPredictor model used for predicting worm movement. None max_speed float max speed of the worm in mm/s, predictions above this will be clipped. None View Source class MLPController ( CsvController ): \"\"\" MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Args: timing_config (TimingConfig): The timing configuration for the simulation. csv_path (str): The path to the CSV file containing the simulation data. model (WormPredictor): The WormPredictor model used for predicting worm movement. max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped. \"\"\" def __init__ ( self , timing_config : TimingConfig , csv_path : str , model : WormPredictor , max_speed : float = 0.9 ): super (). __init__ ( timing_config , csv_path ) self . model : WormPredictor = model self . io_config : IOConfig = model . io_config self . model . eval () px_per_mm = self . timing_config . px_per_mm fps = self . timing_config . frames_per_sec max_speed_px_frame = max_speed * ( px_per_mm / fps ) self . max_dist_per_pred = max_speed_px_frame * ( self . io_config . pred_frames [ 0 ]) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction ( input to the model ) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ). reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ). all (): return 0 , 0 # relative position of the worm to the camera center , we use the worm x , y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )). flatten (). detach (). numpy () # make sure the prediction is within the limits and apply post - proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ]. item () + rel_x ) dy = round ( pred [ 1 ]. item () + rel_y ) # dx = np . clip ( dx , - self . max_dist_per_pred , self . max_dist_per_pred ) # dy = np . clip ( dy , - self . max_dist_per_pred , self . max_dist_per_pred ) return ( dx , dy ) def print_model ( self ): print ( self . model ) Ancestors (in MRO) wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () predict def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes print_model def print_model ( self ) View Source def print_model(self): print(self.model) provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction ( input to the model ) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ). reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ). all (): return 0 , 0 # relative position of the worm to the camera center , we use the worm x , y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )). flatten (). detach (). numpy () # make sure the prediction is within the limits and apply post - proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ]. item () + rel_x ) dy = round ( pred [ 1 ]. item () + rel_y ) # dx = np . clip ( dx , - self . max_dist_per_pred , self . max_dist_per_pred ) # dy = np . clip ( dy , - self . max_dist_per_pred , self . max_dist_per_pred ) return ( dx , dy )","title":"Mlp Controllers"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#module-wtrackersimsim_controllersmlp_controllers","text":"View Source from typing import Collection import numpy as np from collections import deque from torch import Tensor from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat from wtracker.neural.mlp import WormPredictor from wtracker.neural.config import IOConfig class MLPController ( CsvController ): \"\"\" MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Args: timing_config (TimingConfig): The timing configuration for the simulation. csv_path (str): The path to the CSV file containing the simulation data. model (WormPredictor): The WormPredictor model used for predicting worm movement. max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped. \"\"\" def __init__ ( self , timing_config : TimingConfig , csv_path : str , model : WormPredictor , max_speed : float = 0.9 ): super () . __init__ ( timing_config , csv_path ) self . model : WormPredictor = model self . io_config : IOConfig = model . io_config self . model . eval () px_per_mm = self . timing_config . px_per_mm fps = self . timing_config . frames_per_sec max_speed_px_frame = max_speed * ( px_per_mm / fps ) self . max_dist_per_pred = max_speed_px_frame * ( self . io_config . pred_frames [ 0 ]) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction (input to the model) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ) . reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ) . all (): return 0 , 0 # relative position of the worm to the camera center, we use the worm x,y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )) . flatten () . detach () . numpy () # make sure the prediction is within the limits and apply post-proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ] . item () + rel_x ) dy = round ( pred [ 1 ] . item () + rel_y ) # dx = np.clip(dx, -self.max_dist_per_pred, self.max_dist_per_pred) # dy = np.clip(dy, -self.max_dist_per_pred, self.max_dist_per_pred) return ( dx , dy ) def print_model ( self ): print ( self . model )","title":"Module wtracker.sim.sim_controllers.mlp_controllers"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#mlpcontroller","text":"class MLPController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str , model : wtracker . neural . mlp . WormPredictor , max_speed : float = 0.9 ) MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation.","title":"MLPController"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulation. None csv_path str The path to the CSV file containing the simulation data. None model WormPredictor The WormPredictor model used for predicting worm movement. None max_speed float max speed of the worm in mm/s, predictions above this will be clipped. None View Source class MLPController ( CsvController ): \"\"\" MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Args: timing_config (TimingConfig): The timing configuration for the simulation. csv_path (str): The path to the CSV file containing the simulation data. model (WormPredictor): The WormPredictor model used for predicting worm movement. max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped. \"\"\" def __init__ ( self , timing_config : TimingConfig , csv_path : str , model : WormPredictor , max_speed : float = 0.9 ): super (). __init__ ( timing_config , csv_path ) self . model : WormPredictor = model self . io_config : IOConfig = model . io_config self . model . eval () px_per_mm = self . timing_config . px_per_mm fps = self . timing_config . frames_per_sec max_speed_px_frame = max_speed * ( px_per_mm / fps ) self . max_dist_per_pred = max_speed_px_frame * ( self . io_config . pred_frames [ 0 ]) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction ( input to the model ) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ). reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ). all (): return 0 , 0 # relative position of the worm to the camera center , we use the worm x , y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )). flatten (). detach (). numpy () # make sure the prediction is within the limits and apply post - proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ]. item () + rel_x ) dy = round ( pred [ 1 ]. item () + rel_y ) # dx = np . clip ( dx , - self . max_dist_per_pred , self . max_dist_per_pred ) # dy = np . clip ( dy , - self . max_dist_per_pred , self . max_dist_per_pred ) return ( dx , dy ) def print_model ( self ): print ( self . model )","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#ancestors-in-mro","text":"wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#predict","text":"def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#print_model","text":"def print_model ( self ) View Source def print_model(self): print(self.model)","title":"print_model"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction ( input to the model ) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ). reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ). all (): return 0 , 0 # relative position of the worm to the camera center , we use the worm x , y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )). flatten (). detach (). numpy () # make sure the prediction is within the limits and apply post - proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ]. item () + rel_x ) dy = round ( pred [ 1 ]. item () + rel_y ) # dx = np . clip ( dx , - self . max_dist_per_pred , self . max_dist_per_pred ) # dy = np . clip ( dy , - self . max_dist_per_pred , self . max_dist_per_pred ) return ( dx , dy )","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/","text":"Module wtracker.sim.sim_controllers.optimal_controller View Source import numpy as np from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController class OptimalController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super () . __init__ ( timing_config , csv_path ) self . _csv_centers = np . empty (( len ( self . _csv_data ), 2 ), dtype = self . _csv_data . dtype ) self . _csv_centers [:, 0 ] = self . _csv_data [:, 0 ] + self . _csv_data [:, 2 ] / 2 self . _csv_centers [:, 1 ] = self . _csv_data [:, 1 ] + self . _csv_data [:, 3 ] / 2 def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , :] next_imaging = next_imaging [ np . isfinite ( next_imaging ) . all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ]) Classes OptimalController class OptimalController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class OptimalController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super (). __init__ ( timing_config , csv_path ) self . _csv_centers = np . empty (( len ( self . _csv_data ), 2 ), dtype = self . _csv_data . dtype ) self . _csv_centers [:, 0 ] = self . _csv_data [:, 0 ] + self . _csv_data [:, 2 ] / 2 self . _csv_centers [:, 1 ] = self . _csv_data [:, 1 ] + self . _csv_data [:, 3 ] / 2 def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , :] next_imaging = next_imaging [ np . isfinite ( next_imaging ). all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ]) Ancestors (in MRO) wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () predict def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) - > tuple [ int , int ] : # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , : ] next_imaging = next_imaging [ np . isfinite ( next_imaging ) . all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ])","title":"Optimal Controller"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#module-wtrackersimsim_controllersoptimal_controller","text":"View Source import numpy as np from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController class OptimalController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super () . __init__ ( timing_config , csv_path ) self . _csv_centers = np . empty (( len ( self . _csv_data ), 2 ), dtype = self . _csv_data . dtype ) self . _csv_centers [:, 0 ] = self . _csv_data [:, 0 ] + self . _csv_data [:, 2 ] / 2 self . _csv_centers [:, 1 ] = self . _csv_data [:, 1 ] + self . _csv_data [:, 3 ] / 2 def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , :] next_imaging = next_imaging [ np . isfinite ( next_imaging ) . all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ])","title":"Module wtracker.sim.sim_controllers.optimal_controller"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#optimalcontroller","text":"class OptimalController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str ) Abstract base class for simulator controllers.","title":"OptimalController"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class OptimalController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super (). __init__ ( timing_config , csv_path ) self . _csv_centers = np . empty (( len ( self . _csv_data ), 2 ), dtype = self . _csv_data . dtype ) self . _csv_centers [:, 0 ] = self . _csv_data [:, 0 ] + self . _csv_data [:, 2 ] / 2 self . _csv_centers [:, 1 ] = self . _csv_data [:, 1 ] + self . _csv_data [:, 3 ] / 2 def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , :] next_imaging = next_imaging [ np . isfinite ( next_imaging ). all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ])","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#ancestors-in-mro","text":"wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#predict","text":"def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) - > tuple [ int , int ] : # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , : ] next_imaging = next_imaging [ np . isfinite ( next_imaging ) . all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ])","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/","text":"Module wtracker.sim.sim_controllers.polyfit_controller View Source import numpy as np import pandas as pd from dataclasses import dataclass import numpy.polynomial.polynomial as poly from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat from wtracker.utils.config_base import ConfigBase @dataclass class PolyfitConfig ( ConfigBase ): degree : int \"\"\"The degree of the polynomial, which will be fitted to the worm movement.\"\"\" sample_times : list [ int ] \"\"\"Times at which the worm position is be sampled for the polynomial fit. Time 0 denotes the beginning of the current cycle. Negative values are allowed.\"\"\" weights : list [ float ] = None \"\"\"Weights for each position sample for the polynomial fit. If None, all weights are set to 1.0. If the weights are not uniform, weighted polynomial fit is performed, where the residuals of samples with higher weights are more important for the fitting.\"\"\" def __post_init__ ( self ): self . sample_times = sorted ( self . sample_times ) if self . weights is None : self . weights = [ 1.0 for _ in self . sample_times ] assert len ( self . sample_times ) == len ( self . weights ) class PolyfitController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , polyfit_config : PolyfitConfig , csv_path : str , ) -> None : \"\"\" Args: timing_config (TimingConfig): The timing configuration of the simulation. csv_path (str): The path to the csv file with the worm data. polyfit_config (PolyfitConfig): The configuration for the polynomial fit. \"\"\" super () . __init__ ( timing_config , csv_path ) self . polyfit_config = polyfit_config self . _sample_times = np . asanyarray ( polyfit_config . sample_times , dtype = int ) self . _weights = np . asanyarray ( polyfit_config . weights , dtype = float ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [:, 0 ] -= camera_bbox [ 0 ] bboxes [:, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ) . all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ]) dy = round ( y_pred - camera_mid [ 1 ]) return dx , dy class WeightEvaluator : \"\"\" Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Args: csv_paths (list[str]): The paths to the csv files with the worm data. timing_config (TimingConfig): The timing configuration of the simulation. input_time_offsets (np.ndarray): The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. pred_time_offset (int): The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. min_speed (float, optional): The minimum speed of the worm for a cycle to be considered. max_speed (float, optional): The maximum speed of the worm for a cycle to be considered. \"\"\" def __init__ ( self , csv_paths : list [ str ], timing_config : TimingConfig , input_time_offsets : np . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = np . inf , ): self . csv_paths = csv_paths self . timing_config = timing_config self . pred_time_offset = pred_time_offset self . min_speed = min_speed self . max_speed = max_speed self . input_time_offsets = np . sort ( input_time_offsets ) self . _construct_dataset () def _construct_dataset ( self ) -> None : input_positions = [] target_positions = [] for i , path in enumerate ( self . csv_paths ): bboxes = pd . read_csv ( path , usecols = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]) . to_numpy ( dtype = float ) input_pos , target_pos = self . _extract_positions ( bboxes , self . timing_config . cycle_frame_num ) input_positions . append ( input_pos ) target_positions . append ( target_pos ) # print stats init_num_cycles = len ( bboxes ) // self . timing_config . cycle_frame_num final_num_cycles = len ( target_pos ) // 2 removed_percent = round (( init_num_cycles - final_num_cycles ) / init_num_cycles * 100 , 1 ) print ( f \"Log { i } :: Number of evaluation cycles: { final_num_cycles } \" ) print ( f \"Log { i } :: Number of cycles removed: { init_num_cycles - final_num_cycles } ( { removed_percent } %)\" ) self . y_input = np . concatenate ( input_positions , axis = 1 ) self . x_input = self . input_time_offsets . reshape ( - 1 ) self . y_target = np . concatenate ( target_positions , axis = 0 ) self . x_target = np . full_like ( self . y_target , self . pred_time_offset ) def _extract_positions ( self , raw_bboxes : pd . DataFrame , cycle_length : int ) -> tuple [ np . ndarray , np . ndarray ]: N = self . input_time_offsets . shape [ 0 ] cycle_starts = np . arange ( 0 , raw_bboxes . shape [ 0 ], cycle_length , dtype = int ) centers = BoxUtils . center ( raw_bboxes ) # x are times, y are positions # create input and target arrays for the times x_input = np . repeat ( cycle_starts , repeats = N ) + np . tile ( self . input_time_offsets , reps = cycle_starts . shape [ 0 ]) x_input = x_input . reshape ( - 1 , N ) x_target = cycle_starts + self . pred_time_offset # remove input and target cycles with invalid time # i.e. when input time is negative or target time is out of bounds mask = ( x_input >= 0 ) . all ( axis = 1 ) & ( x_target < len ( centers )) x_input = x_input [ mask , :] x_target = x_target [ mask ] # get input and target positions for each cycle y_input = centers [ x_input . flatten (), :] y_input = y_input . reshape ( - 1 , N , 2 ) y_target = centers [ x_target . flatten (), :] y_target = y_target . reshape ( - 1 , 2 ) # remove all cycles with invalid positions input_mask = np . isfinite ( y_input ) . all ( axis = ( 1 , 2 )) target_mask = np . isfinite ( y_target ) . all ( axis = 1 ) mask = input_mask & target_mask y_input = y_input [ mask , :, :] y_target = y_target [ mask , :] # remove cycles with average speed below threshold # dist = np.sqrt((y_target[:, 1] - y_input[:, 0, 1]) ** 2 + (y_target[:, 0] - y_input[:, 0, 0]) ** 2) dist = np . linalg . norm ( y_target - y_input [:, 0 , :], axis = 1 ) time = self . pred_time_offset - self . input_time_offsets [ 0 ] speed = dist / time speed_mask = ( speed >= self . min_speed ) & ( speed <= self . max_speed ) y_input = y_input [ speed_mask , :, :] y_target = y_target [ speed_mask , :] # reshape target arrays y_input = y_input . swapaxes ( 0 , 1 ) . reshape ( N , - 1 ) y_target = y_target . reshape ( - 1 ) return y_input , y_target def _polyval ( self , coeffs : np . ndarray , x : np . ndarray ) -> np . ndarray : \"\"\" Evaluate a polynomial at given values. This implementation is way faster than np.polyval for multiple polynomials. Args: coeffs (np.ndarray): Coefficients of the polynomial. Coefficients at increasing order. Should have shape [deg+1, N]. x (np.ndarray): Values at which to evaluate the polynomial. Should have shape [N]. Returns: np.ndarray: The result of evaluating the polynomial at the given values. Shape is [N]. \"\"\" coeffs = coeffs . swapaxes ( 0 , 1 ) van = np . vander ( x , N = coeffs . shape [ 1 ], increasing = True ) return np . sum ( van * coeffs , axis =- 1 ) def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae Classes PolyfitConfig class PolyfitConfig ( degree : int , sample_times : list [ int ], weights : list [ float ] = None ) PolyfitConfig(degree: int, sample_times: list[int], weights: list[float] = None) View Source @dataclass class PolyfitConfig ( ConfigBase ) : degree : int \"\"\"The degree of the polynomial, which will be fitted to the worm movement.\"\"\" sample_times : list [ int ] \"\"\"Times at which the worm position is be sampled for the polynomial fit. Time 0 denotes the beginning of the current cycle. Negative values are allowed.\"\"\" weights : list [ float ] = None \"\"\"Weights for each position sample for the polynomial fit. If None, all weights are set to 1.0. If the weights are not uniform, weighted polynomial fit is performed, where the residuals of samples with higher weights are more important for the fitting.\"\"\" def __post_init__ ( self ) : self . sample_times = sorted ( self . sample_times ) if self . weights is None : self . weights = [ 1.0 for _ in self.sample_times ] assert len ( self . sample_times ) == len ( self . weights ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables weights Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) PolyfitController class PolyfitController ( timing_config : wtracker . sim . config . TimingConfig , polyfit_config : wtracker . sim . sim_controllers . polyfit_controller . PolyfitConfig , csv_path : str ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class PolyfitController ( CsvController ) : def __init__ ( self , timing_config : TimingConfig , polyfit_config : PolyfitConfig , csv_path : str , ) -> None : \"\"\" Args: timing_config (TimingConfig): The timing configuration of the simulation. csv_path (str): The path to the csv file with the worm data. polyfit_config (PolyfitConfig): The configuration for the polynomial fit. \"\"\" super (). __init__ ( timing_config , csv_path ) self . polyfit_config = polyfit_config self . _sample_times = np . asanyarray ( polyfit_config . sample_times , dtype = int ) self . _weights = np . asanyarray ( polyfit_config . weights , dtype = float ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [ :, 0 ] -= camera_bbox [ 0 ] bboxes [ :, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ). all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ] ) dy = round ( y_pred - camera_mid [ 1 ] ) return dx , dy Ancestors (in MRO) wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () predict def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [ :, 0 ] -= camera_bbox [ 0 ] bboxes [ :, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ). all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ] ) dy = round ( y_pred - camera_mid [ 1 ] ) return dx , dy WeightEvaluator class WeightEvaluator ( csv_paths : list [ str ], timing_config : wtracker . sim . config . TimingConfig , input_time_offsets : numpy . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = inf ) Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Attributes Name Type Description Default csv_paths list[str] The paths to the csv files with the worm data. None timing_config TimingConfig The timing configuration of the simulation. None input_time_offsets np.ndarray The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. None pred_time_offset int The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. None min_speed float The minimum speed of the worm for a cycle to be considered. None max_speed float The maximum speed of the worm for a cycle to be considered. None View Source class WeightEvaluator : \"\"\" Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Args: csv_paths (list[str]): The paths to the csv files with the worm data. timing_config (TimingConfig): The timing configuration of the simulation. input_time_offsets (np.ndarray): The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. pred_time_offset (int): The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. min_speed (float, optional): The minimum speed of the worm for a cycle to be considered. max_speed (float, optional): The maximum speed of the worm for a cycle to be considered. \"\"\" def __init__ ( self , csv_paths : list [ str ] , timing_config : TimingConfig , input_time_offsets : np . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = np . inf , ) : self . csv_paths = csv_paths self . timing_config = timing_config self . pred_time_offset = pred_time_offset self . min_speed = min_speed self . max_speed = max_speed self . input_time_offsets = np . sort ( input_time_offsets ) self . _construct_dataset () def _construct_dataset ( self ) -> None : input_positions = [] target_positions = [] for i , path in enumerate ( self . csv_paths ) : bboxes = pd . read_csv ( path , usecols =[ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ). to_numpy ( dtype = float ) input_pos , target_pos = self . _extract_positions ( bboxes , self . timing_config . cycle_frame_num ) input_positions . append ( input_pos ) target_positions . append ( target_pos ) # print stats init_num_cycles = len ( bboxes ) // self . timing_config . cycle_frame_num final_num_cycles = len ( target_pos ) // 2 removed_percent = round (( init_num_cycles - final_num_cycles ) / init_num_cycles * 100 , 1 ) print ( f \"Log {i} :: Number of evaluation cycles: {final_num_cycles}\" ) print ( f \"Log {i} :: Number of cycles removed: {init_num_cycles - final_num_cycles} ({removed_percent} %)\" ) self . y_input = np . concatenate ( input_positions , axis = 1 ) self . x_input = self . input_time_offsets . reshape ( - 1 ) self . y_target = np . concatenate ( target_positions , axis = 0 ) self . x_target = np . full_like ( self . y_target , self . pred_time_offset ) def _extract_positions ( self , raw_bboxes : pd . DataFrame , cycle_length : int ) -> tuple [ np.ndarray, np.ndarray ] : N = self . input_time_offsets . shape [ 0 ] cycle_starts = np . arange ( 0 , raw_bboxes . shape [ 0 ] , cycle_length , dtype = int ) centers = BoxUtils . center ( raw_bboxes ) # x are times , y are positions # create input and target arrays for the times x_input = np . repeat ( cycle_starts , repeats = N ) + np . tile ( self . input_time_offsets , reps = cycle_starts . shape [ 0 ] ) x_input = x_input . reshape ( - 1 , N ) x_target = cycle_starts + self . pred_time_offset # remove input and target cycles with invalid time # i . e . when input time is negative or target time is out of bounds mask = ( x_input >= 0 ). all ( axis = 1 ) & ( x_target < len ( centers )) x_input = x_input [ mask, : ] x_target = x_target [ mask ] # get input and target positions for each cycle y_input = centers [ x_input.flatten(), : ] y_input = y_input . reshape ( - 1 , N , 2 ) y_target = centers [ x_target.flatten(), : ] y_target = y_target . reshape ( - 1 , 2 ) # remove all cycles with invalid positions input_mask = np . isfinite ( y_input ). all ( axis = ( 1 , 2 )) target_mask = np . isfinite ( y_target ). all ( axis = 1 ) mask = input_mask & target_mask y_input = y_input [ mask, :, : ] y_target = y_target [ mask, : ] # remove cycles with average speed below threshold # dist = np . sqrt (( y_target [ :, 1 ] - y_input [ :, 0, 1 ] ) ** 2 + ( y_target [ :, 0 ] - y_input [ :, 0, 0 ] ) ** 2 ) dist = np . linalg . norm ( y_target - y_input [ :, 0, : ] , axis = 1 ) time = self . pred_time_offset - self . input_time_offsets [ 0 ] speed = dist / time speed_mask = ( speed >= self . min_speed ) & ( speed <= self . max_speed ) y_input = y_input [ speed_mask, :, : ] y_target = y_target [ speed_mask, : ] # reshape target arrays y_input = y_input . swapaxes ( 0 , 1 ). reshape ( N , - 1 ) y_target = y_target . reshape ( - 1 ) return y_input , y_target def _polyval ( self , coeffs : np . ndarray , x : np . ndarray ) -> np . ndarray : \"\"\" Evaluate a polynomial at given values. This implementation is way faster than np.polyval for multiple polynomials. Args: coeffs (np.ndarray): Coefficients of the polynomial. Coefficients at increasing order. Should have shape [deg+1, N]. x (np.ndarray): Values at which to evaluate the polynomial. Should have shape [N]. Returns: np.ndarray: The result of evaluating the polynomial at the given values. Shape is [N]. \"\"\" coeffs = coeffs . swapaxes ( 0 , 1 ) van = np . vander ( x , N = coeffs . shape [ 1 ] , increasing = True ) return np . sum ( van * coeffs , axis =- 1 ) def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae Methods eval def eval ( self , weights : numpy . ndarray , deg : int = 2 ) -> float Evaluate the mean absolute error (MAE) of the polynomial fit. Parameters: Name Type Description Default weights np.ndarray The weights used for the polynomial fit. Should have shape [N]. None deg int The degree of the polynomial fit. None Returns: Type Description float The mean absolute error (MAE) of the polynomial fit. View Source def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae","title":"Polyfit Controller"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#module-wtrackersimsim_controllerspolyfit_controller","text":"View Source import numpy as np import pandas as pd from dataclasses import dataclass import numpy.polynomial.polynomial as poly from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat from wtracker.utils.config_base import ConfigBase @dataclass class PolyfitConfig ( ConfigBase ): degree : int \"\"\"The degree of the polynomial, which will be fitted to the worm movement.\"\"\" sample_times : list [ int ] \"\"\"Times at which the worm position is be sampled for the polynomial fit. Time 0 denotes the beginning of the current cycle. Negative values are allowed.\"\"\" weights : list [ float ] = None \"\"\"Weights for each position sample for the polynomial fit. If None, all weights are set to 1.0. If the weights are not uniform, weighted polynomial fit is performed, where the residuals of samples with higher weights are more important for the fitting.\"\"\" def __post_init__ ( self ): self . sample_times = sorted ( self . sample_times ) if self . weights is None : self . weights = [ 1.0 for _ in self . sample_times ] assert len ( self . sample_times ) == len ( self . weights ) class PolyfitController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , polyfit_config : PolyfitConfig , csv_path : str , ) -> None : \"\"\" Args: timing_config (TimingConfig): The timing configuration of the simulation. csv_path (str): The path to the csv file with the worm data. polyfit_config (PolyfitConfig): The configuration for the polynomial fit. \"\"\" super () . __init__ ( timing_config , csv_path ) self . polyfit_config = polyfit_config self . _sample_times = np . asanyarray ( polyfit_config . sample_times , dtype = int ) self . _weights = np . asanyarray ( polyfit_config . weights , dtype = float ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [:, 0 ] -= camera_bbox [ 0 ] bboxes [:, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ) . all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ]) dy = round ( y_pred - camera_mid [ 1 ]) return dx , dy class WeightEvaluator : \"\"\" Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Args: csv_paths (list[str]): The paths to the csv files with the worm data. timing_config (TimingConfig): The timing configuration of the simulation. input_time_offsets (np.ndarray): The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. pred_time_offset (int): The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. min_speed (float, optional): The minimum speed of the worm for a cycle to be considered. max_speed (float, optional): The maximum speed of the worm for a cycle to be considered. \"\"\" def __init__ ( self , csv_paths : list [ str ], timing_config : TimingConfig , input_time_offsets : np . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = np . inf , ): self . csv_paths = csv_paths self . timing_config = timing_config self . pred_time_offset = pred_time_offset self . min_speed = min_speed self . max_speed = max_speed self . input_time_offsets = np . sort ( input_time_offsets ) self . _construct_dataset () def _construct_dataset ( self ) -> None : input_positions = [] target_positions = [] for i , path in enumerate ( self . csv_paths ): bboxes = pd . read_csv ( path , usecols = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]) . to_numpy ( dtype = float ) input_pos , target_pos = self . _extract_positions ( bboxes , self . timing_config . cycle_frame_num ) input_positions . append ( input_pos ) target_positions . append ( target_pos ) # print stats init_num_cycles = len ( bboxes ) // self . timing_config . cycle_frame_num final_num_cycles = len ( target_pos ) // 2 removed_percent = round (( init_num_cycles - final_num_cycles ) / init_num_cycles * 100 , 1 ) print ( f \"Log { i } :: Number of evaluation cycles: { final_num_cycles } \" ) print ( f \"Log { i } :: Number of cycles removed: { init_num_cycles - final_num_cycles } ( { removed_percent } %)\" ) self . y_input = np . concatenate ( input_positions , axis = 1 ) self . x_input = self . input_time_offsets . reshape ( - 1 ) self . y_target = np . concatenate ( target_positions , axis = 0 ) self . x_target = np . full_like ( self . y_target , self . pred_time_offset ) def _extract_positions ( self , raw_bboxes : pd . DataFrame , cycle_length : int ) -> tuple [ np . ndarray , np . ndarray ]: N = self . input_time_offsets . shape [ 0 ] cycle_starts = np . arange ( 0 , raw_bboxes . shape [ 0 ], cycle_length , dtype = int ) centers = BoxUtils . center ( raw_bboxes ) # x are times, y are positions # create input and target arrays for the times x_input = np . repeat ( cycle_starts , repeats = N ) + np . tile ( self . input_time_offsets , reps = cycle_starts . shape [ 0 ]) x_input = x_input . reshape ( - 1 , N ) x_target = cycle_starts + self . pred_time_offset # remove input and target cycles with invalid time # i.e. when input time is negative or target time is out of bounds mask = ( x_input >= 0 ) . all ( axis = 1 ) & ( x_target < len ( centers )) x_input = x_input [ mask , :] x_target = x_target [ mask ] # get input and target positions for each cycle y_input = centers [ x_input . flatten (), :] y_input = y_input . reshape ( - 1 , N , 2 ) y_target = centers [ x_target . flatten (), :] y_target = y_target . reshape ( - 1 , 2 ) # remove all cycles with invalid positions input_mask = np . isfinite ( y_input ) . all ( axis = ( 1 , 2 )) target_mask = np . isfinite ( y_target ) . all ( axis = 1 ) mask = input_mask & target_mask y_input = y_input [ mask , :, :] y_target = y_target [ mask , :] # remove cycles with average speed below threshold # dist = np.sqrt((y_target[:, 1] - y_input[:, 0, 1]) ** 2 + (y_target[:, 0] - y_input[:, 0, 0]) ** 2) dist = np . linalg . norm ( y_target - y_input [:, 0 , :], axis = 1 ) time = self . pred_time_offset - self . input_time_offsets [ 0 ] speed = dist / time speed_mask = ( speed >= self . min_speed ) & ( speed <= self . max_speed ) y_input = y_input [ speed_mask , :, :] y_target = y_target [ speed_mask , :] # reshape target arrays y_input = y_input . swapaxes ( 0 , 1 ) . reshape ( N , - 1 ) y_target = y_target . reshape ( - 1 ) return y_input , y_target def _polyval ( self , coeffs : np . ndarray , x : np . ndarray ) -> np . ndarray : \"\"\" Evaluate a polynomial at given values. This implementation is way faster than np.polyval for multiple polynomials. Args: coeffs (np.ndarray): Coefficients of the polynomial. Coefficients at increasing order. Should have shape [deg+1, N]. x (np.ndarray): Values at which to evaluate the polynomial. Should have shape [N]. Returns: np.ndarray: The result of evaluating the polynomial at the given values. Shape is [N]. \"\"\" coeffs = coeffs . swapaxes ( 0 , 1 ) van = np . vander ( x , N = coeffs . shape [ 1 ], increasing = True ) return np . sum ( van * coeffs , axis =- 1 ) def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae","title":"Module wtracker.sim.sim_controllers.polyfit_controller"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#polyfitconfig","text":"class PolyfitConfig ( degree : int , sample_times : list [ int ], weights : list [ float ] = None ) PolyfitConfig(degree: int, sample_times: list[int], weights: list[float] = None) View Source @dataclass class PolyfitConfig ( ConfigBase ) : degree : int \"\"\"The degree of the polynomial, which will be fitted to the worm movement.\"\"\" sample_times : list [ int ] \"\"\"Times at which the worm position is be sampled for the polynomial fit. Time 0 denotes the beginning of the current cycle. Negative values are allowed.\"\"\" weights : list [ float ] = None \"\"\"Weights for each position sample for the polynomial fit. If None, all weights are set to 1.0. If the weights are not uniform, weighted polynomial fit is performed, where the residuals of samples with higher weights are more important for the fitting.\"\"\" def __post_init__ ( self ) : self . sample_times = sorted ( self . sample_times ) if self . weights is None : self . weights = [ 1.0 for _ in self.sample_times ] assert len ( self . sample_times ) == len ( self . weights )","title":"PolyfitConfig"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#class-variables","text":"weights","title":"Class variables"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#polyfitcontroller","text":"class PolyfitController ( timing_config : wtracker . sim . config . TimingConfig , polyfit_config : wtracker . sim . sim_controllers . polyfit_controller . PolyfitConfig , csv_path : str ) Abstract base class for simulator controllers.","title":"PolyfitController"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class PolyfitController ( CsvController ) : def __init__ ( self , timing_config : TimingConfig , polyfit_config : PolyfitConfig , csv_path : str , ) -> None : \"\"\" Args: timing_config (TimingConfig): The timing configuration of the simulation. csv_path (str): The path to the csv file with the worm data. polyfit_config (PolyfitConfig): The configuration for the polynomial fit. \"\"\" super (). __init__ ( timing_config , csv_path ) self . polyfit_config = polyfit_config self . _sample_times = np . asanyarray ( polyfit_config . sample_times , dtype = int ) self . _weights = np . asanyarray ( polyfit_config . weights , dtype = float ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [ :, 0 ] -= camera_bbox [ 0 ] bboxes [ :, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ). all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ] ) dy = round ( y_pred - camera_mid [ 1 ] ) return dx , dy","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#ancestors-in-mro_1","text":"wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#predict","text":"def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [ :, 0 ] -= camera_bbox [ 0 ] bboxes [ :, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ). all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ] ) dy = round ( y_pred - camera_mid [ 1 ] ) return dx , dy","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#weightevaluator","text":"class WeightEvaluator ( csv_paths : list [ str ], timing_config : wtracker . sim . config . TimingConfig , input_time_offsets : numpy . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = inf ) Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights.","title":"WeightEvaluator"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#attributes_1","text":"Name Type Description Default csv_paths list[str] The paths to the csv files with the worm data. None timing_config TimingConfig The timing configuration of the simulation. None input_time_offsets np.ndarray The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. None pred_time_offset int The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. None min_speed float The minimum speed of the worm for a cycle to be considered. None max_speed float The maximum speed of the worm for a cycle to be considered. None View Source class WeightEvaluator : \"\"\" Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Args: csv_paths (list[str]): The paths to the csv files with the worm data. timing_config (TimingConfig): The timing configuration of the simulation. input_time_offsets (np.ndarray): The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. pred_time_offset (int): The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. min_speed (float, optional): The minimum speed of the worm for a cycle to be considered. max_speed (float, optional): The maximum speed of the worm for a cycle to be considered. \"\"\" def __init__ ( self , csv_paths : list [ str ] , timing_config : TimingConfig , input_time_offsets : np . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = np . inf , ) : self . csv_paths = csv_paths self . timing_config = timing_config self . pred_time_offset = pred_time_offset self . min_speed = min_speed self . max_speed = max_speed self . input_time_offsets = np . sort ( input_time_offsets ) self . _construct_dataset () def _construct_dataset ( self ) -> None : input_positions = [] target_positions = [] for i , path in enumerate ( self . csv_paths ) : bboxes = pd . read_csv ( path , usecols =[ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ). to_numpy ( dtype = float ) input_pos , target_pos = self . _extract_positions ( bboxes , self . timing_config . cycle_frame_num ) input_positions . append ( input_pos ) target_positions . append ( target_pos ) # print stats init_num_cycles = len ( bboxes ) // self . timing_config . cycle_frame_num final_num_cycles = len ( target_pos ) // 2 removed_percent = round (( init_num_cycles - final_num_cycles ) / init_num_cycles * 100 , 1 ) print ( f \"Log {i} :: Number of evaluation cycles: {final_num_cycles}\" ) print ( f \"Log {i} :: Number of cycles removed: {init_num_cycles - final_num_cycles} ({removed_percent} %)\" ) self . y_input = np . concatenate ( input_positions , axis = 1 ) self . x_input = self . input_time_offsets . reshape ( - 1 ) self . y_target = np . concatenate ( target_positions , axis = 0 ) self . x_target = np . full_like ( self . y_target , self . pred_time_offset ) def _extract_positions ( self , raw_bboxes : pd . DataFrame , cycle_length : int ) -> tuple [ np.ndarray, np.ndarray ] : N = self . input_time_offsets . shape [ 0 ] cycle_starts = np . arange ( 0 , raw_bboxes . shape [ 0 ] , cycle_length , dtype = int ) centers = BoxUtils . center ( raw_bboxes ) # x are times , y are positions # create input and target arrays for the times x_input = np . repeat ( cycle_starts , repeats = N ) + np . tile ( self . input_time_offsets , reps = cycle_starts . shape [ 0 ] ) x_input = x_input . reshape ( - 1 , N ) x_target = cycle_starts + self . pred_time_offset # remove input and target cycles with invalid time # i . e . when input time is negative or target time is out of bounds mask = ( x_input >= 0 ). all ( axis = 1 ) & ( x_target < len ( centers )) x_input = x_input [ mask, : ] x_target = x_target [ mask ] # get input and target positions for each cycle y_input = centers [ x_input.flatten(), : ] y_input = y_input . reshape ( - 1 , N , 2 ) y_target = centers [ x_target.flatten(), : ] y_target = y_target . reshape ( - 1 , 2 ) # remove all cycles with invalid positions input_mask = np . isfinite ( y_input ). all ( axis = ( 1 , 2 )) target_mask = np . isfinite ( y_target ). all ( axis = 1 ) mask = input_mask & target_mask y_input = y_input [ mask, :, : ] y_target = y_target [ mask, : ] # remove cycles with average speed below threshold # dist = np . sqrt (( y_target [ :, 1 ] - y_input [ :, 0, 1 ] ) ** 2 + ( y_target [ :, 0 ] - y_input [ :, 0, 0 ] ) ** 2 ) dist = np . linalg . norm ( y_target - y_input [ :, 0, : ] , axis = 1 ) time = self . pred_time_offset - self . input_time_offsets [ 0 ] speed = dist / time speed_mask = ( speed >= self . min_speed ) & ( speed <= self . max_speed ) y_input = y_input [ speed_mask, :, : ] y_target = y_target [ speed_mask, : ] # reshape target arrays y_input = y_input . swapaxes ( 0 , 1 ). reshape ( N , - 1 ) y_target = y_target . reshape ( - 1 ) return y_input , y_target def _polyval ( self , coeffs : np . ndarray , x : np . ndarray ) -> np . ndarray : \"\"\" Evaluate a polynomial at given values. This implementation is way faster than np.polyval for multiple polynomials. Args: coeffs (np.ndarray): Coefficients of the polynomial. Coefficients at increasing order. Should have shape [deg+1, N]. x (np.ndarray): Values at which to evaluate the polynomial. Should have shape [N]. Returns: np.ndarray: The result of evaluating the polynomial at the given values. Shape is [N]. \"\"\" coeffs = coeffs . swapaxes ( 0 , 1 ) van = np . vander ( x , N = coeffs . shape [ 1 ] , increasing = True ) return np . sum ( van * coeffs , axis =- 1 ) def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#eval","text":"def eval ( self , weights : numpy . ndarray , deg : int = 2 ) -> float Evaluate the mean absolute error (MAE) of the polynomial fit. Parameters: Name Type Description Default weights np.ndarray The weights used for the polynomial fit. Should have shape [N]. None deg int The degree of the polynomial fit. None Returns: Type Description float The mean absolute error (MAE) of the polynomial fit. View Source def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae","title":"eval"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/","text":"Module wtracker.sim.sim_controllers.yolo_controller View Source from typing import Collection , Any from dataclasses import dataclass , field import numpy as np import cv2 as cv from collections import deque from ultralytics import YOLO from wtracker.sim.simulator import Simulator , SimController from wtracker.sim.config import TimingConfig from wtracker.utils.config_base import ConfigBase from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat @dataclass class YoloConfig ( ConfigBase ): model_path : str \"\"\"The path to the pretrained YOLO weights file.\"\"\" device : str = \"cpu\" \"\"\"Inference device for YOLO. Can be either 'cpu' or 'cuda'.\"\"\" verbose : bool = False \"\"\"Whether to print verbose output during YOLO inference.\"\"\" pred_kwargs : dict = field ( default_factory = lambda : { \"imgsz\" : 384 , \"conf\" : 0.1 , } ) \"\"\"Additional keyword arguments for the YOLO prediction method.\"\"\" model : YOLO = field ( default = None , init = False , repr = False ) \"\"\"The YOLO model object.\"\"\" def __getstate__ ( self ) -> dict [ str , Any ]: state = self . __dict__ . copy () del state [ \"model\" ] # we dont want to serialize the model return state def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model class YoloController ( SimController ): def __init__ ( self , timing_config : TimingConfig , yolo_config : YoloConfig ): super () . __init__ ( timing_config ) self . yolo_config = yolo_config self . _camera_frames = deque ( maxlen = timing_config . cycle_frame_num ) self . _model = yolo_config . load_model () def on_sim_start ( self , sim : Simulator ): self . _camera_frames . clear () def on_camera_frame ( self , sim : Simulator ): self . _camera_frames . append ( sim . camera_view ()) def on_cycle_end ( self , sim : Simulator ): self . _camera_frames . clear () def predict ( self , frames : Collection [ np . ndarray ]) -> np . ndarray : assert len ( frames ) > 0 # convert grayscale images to BGR because YOLO expects 3-channel images if frames [ 0 ] . ndim == 2 : frames = [ cv . cvtColor ( frame , cv . COLOR_GRAY2BGR ) for frame in frames ] # predict bounding boxes and format results results = self . _model . predict ( source = frames , device = self . yolo_config . device , max_det = 1 , verbose = self . yolo_config . verbose , ** self . yolo_config . pred_kwargs , ) results = [ res . numpy () for res in results ] bboxes = [] for res in results : if len ( res . boxes . xyxy ) == 0 : bboxes . append ( np . full ([ 4 ], np . nan )) else : bbox = BoxConverter . to_xywh ( res . boxes . xyxy [ 0 ], BoxFormat . XYXY ) bboxes . append ( bbox ) return np . stack ( bboxes , axis = 0 ) def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: frame = self . _camera_frames [ - self . timing_config . pred_frame_num ] bbox = self . predict ([ frame ]) bbox = bbox [ 0 ] if not np . isfinite ( bbox ) . all (): return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ]), round ( bbox_mid [ 1 ] - camera_mid [ 1 ]) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . predict ( self . _camera_frames ) Classes YoloConfig class YoloConfig ( model_path : str , device : str = 'cpu' , verbose : bool = False , pred_kwargs : dict = < factory > ) YoloConfig(model_path: str, device: str = 'cpu', verbose: bool = False, pred_kwargs: dict = ) View Source @ dataclass class YoloConfig ( ConfigBase ): model_path : str \"\"\"The path to the pretrained YOLO weights file.\"\"\" device : str = \"cpu\" \"\"\"Inference device for YOLO. Can be either 'cpu' or 'cuda'.\"\"\" verbose : bool = False \"\"\"Whether to print verbose output during YOLO inference.\"\"\" pred_kwargs : dict = field ( default_factory = lambda : { \"imgsz\" : 384 , \"conf\" : 0.1 , } ) \"\"\"Additional keyword arguments for the YOLO prediction method.\"\"\" model : YOLO = field ( default = None , init = False , repr = False ) \"\"\"The YOLO model object.\"\"\" def __getstate__ ( self ) -> dict [ str , Any ]: state = self . __dict__ . copy () del state [ \"model\" ] # we dont want to serialize the model return state def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables device model verbose Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods load_model def load_model ( self ) -> ultralytics . models . yolo . model . YOLO View Source def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) YoloController class YoloController ( timing_config : wtracker . sim . config . TimingConfig , yolo_config : wtracker . sim . sim_controllers . yolo_controller . YoloConfig ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class YoloController ( SimController ) : def __init__ ( self , timing_config : TimingConfig , yolo_config : YoloConfig ) : super (). __init__ ( timing_config ) self . yolo_config = yolo_config self . _camera_frames = deque ( maxlen = timing_config . cycle_frame_num ) self . _model = yolo_config . load_model () def on_sim_start ( self , sim : Simulator ) : self . _camera_frames . clear () def on_camera_frame ( self , sim : Simulator ) : self . _camera_frames . append ( sim . camera_view ()) def on_cycle_end ( self , sim : Simulator ) : self . _camera_frames . clear () def predict ( self , frames : Collection [ np.ndarray ] ) -> np . ndarray : assert len ( frames ) > 0 # convert grayscale images to BGR because YOLO expects 3 - channel images if frames [ 0 ] . ndim == 2 : frames = [ cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames ] # predict bounding boxes and format results results = self . _model . predict ( source = frames , device = self . yolo_config . device , max_det = 1 , verbose = self . yolo_config . verbose , ** self . yolo_config . pred_kwargs , ) results = [ res.numpy() for res in results ] bboxes = [] for res in results : if len ( res . boxes . xyxy ) == 0 : bboxes . append ( np . full ( [ 4 ] , np . nan )) else : bbox = BoxConverter . to_xywh ( res . boxes . xyxy [ 0 ] , BoxFormat . XYXY ) bboxes . append ( bbox ) return np . stack ( bboxes , axis = 0 ) def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : frame = self . _camera_frames [ -self.timing_config.pred_frame_num ] bbox = self . predict ( [ frame ] ) bbox = bbox [ 0 ] if not np . isfinite ( bbox ). all () : return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ] ), round ( bbox_mid [ 1 ] - camera_mid [ 1 ] ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . predict ( self . _camera_frames ) Ancestors (in MRO) wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_frames . append ( sim . camera_view ()) on_cycle_end def on_cycle_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when a cycle ends. View Source def on_cycle_end ( self , sim : Simulator ) : self . _camera_frames . clear () on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_frames . clear () predict def predict ( self , frames : Collection [ numpy . ndarray ] ) -> numpy . ndarray View Source def predict(self, frames: Collection[np.ndarray]) -> np.ndarray: assert len(frames) > 0 # convert grayscale images to BGR because YOLO expects 3-channel images if frames[0].ndim == 2: frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames] # predict bounding boxes and format results results = self._model.predict( source=frames, device=self.yolo_config.device, max_det=1, verbose=self.yolo_config.verbose, **self.yolo_config.pred_kwargs, ) results = [res.numpy() for res in results] bboxes = [] for res in results: if len(res.boxes.xyxy) == 0: bboxes.append(np.full([4], np.nan)) else: bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY) bboxes.append(bbox) return np.stack(bboxes, axis=0) provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : frame = self . _camera_frames [ -self.timing_config.pred_frame_num ] bbox = self . predict ( [ frame ] ) bbox = bbox [ 0 ] if not np . isfinite ( bbox ). all () : return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ] ), round ( bbox_mid [ 1 ] - camera_mid [ 1 ] )","title":"Yolo Controller"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#module-wtrackersimsim_controllersyolo_controller","text":"View Source from typing import Collection , Any from dataclasses import dataclass , field import numpy as np import cv2 as cv from collections import deque from ultralytics import YOLO from wtracker.sim.simulator import Simulator , SimController from wtracker.sim.config import TimingConfig from wtracker.utils.config_base import ConfigBase from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat @dataclass class YoloConfig ( ConfigBase ): model_path : str \"\"\"The path to the pretrained YOLO weights file.\"\"\" device : str = \"cpu\" \"\"\"Inference device for YOLO. Can be either 'cpu' or 'cuda'.\"\"\" verbose : bool = False \"\"\"Whether to print verbose output during YOLO inference.\"\"\" pred_kwargs : dict = field ( default_factory = lambda : { \"imgsz\" : 384 , \"conf\" : 0.1 , } ) \"\"\"Additional keyword arguments for the YOLO prediction method.\"\"\" model : YOLO = field ( default = None , init = False , repr = False ) \"\"\"The YOLO model object.\"\"\" def __getstate__ ( self ) -> dict [ str , Any ]: state = self . __dict__ . copy () del state [ \"model\" ] # we dont want to serialize the model return state def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model class YoloController ( SimController ): def __init__ ( self , timing_config : TimingConfig , yolo_config : YoloConfig ): super () . __init__ ( timing_config ) self . yolo_config = yolo_config self . _camera_frames = deque ( maxlen = timing_config . cycle_frame_num ) self . _model = yolo_config . load_model () def on_sim_start ( self , sim : Simulator ): self . _camera_frames . clear () def on_camera_frame ( self , sim : Simulator ): self . _camera_frames . append ( sim . camera_view ()) def on_cycle_end ( self , sim : Simulator ): self . _camera_frames . clear () def predict ( self , frames : Collection [ np . ndarray ]) -> np . ndarray : assert len ( frames ) > 0 # convert grayscale images to BGR because YOLO expects 3-channel images if frames [ 0 ] . ndim == 2 : frames = [ cv . cvtColor ( frame , cv . COLOR_GRAY2BGR ) for frame in frames ] # predict bounding boxes and format results results = self . _model . predict ( source = frames , device = self . yolo_config . device , max_det = 1 , verbose = self . yolo_config . verbose , ** self . yolo_config . pred_kwargs , ) results = [ res . numpy () for res in results ] bboxes = [] for res in results : if len ( res . boxes . xyxy ) == 0 : bboxes . append ( np . full ([ 4 ], np . nan )) else : bbox = BoxConverter . to_xywh ( res . boxes . xyxy [ 0 ], BoxFormat . XYXY ) bboxes . append ( bbox ) return np . stack ( bboxes , axis = 0 ) def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: frame = self . _camera_frames [ - self . timing_config . pred_frame_num ] bbox = self . predict ([ frame ]) bbox = bbox [ 0 ] if not np . isfinite ( bbox ) . all (): return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ]), round ( bbox_mid [ 1 ] - camera_mid [ 1 ]) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . predict ( self . _camera_frames )","title":"Module wtracker.sim.sim_controllers.yolo_controller"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#yoloconfig","text":"class YoloConfig ( model_path : str , device : str = 'cpu' , verbose : bool = False , pred_kwargs : dict = < factory > ) YoloConfig(model_path: str, device: str = 'cpu', verbose: bool = False, pred_kwargs: dict = ) View Source @ dataclass class YoloConfig ( ConfigBase ): model_path : str \"\"\"The path to the pretrained YOLO weights file.\"\"\" device : str = \"cpu\" \"\"\"Inference device for YOLO. Can be either 'cpu' or 'cuda'.\"\"\" verbose : bool = False \"\"\"Whether to print verbose output during YOLO inference.\"\"\" pred_kwargs : dict = field ( default_factory = lambda : { \"imgsz\" : 384 , \"conf\" : 0.1 , } ) \"\"\"Additional keyword arguments for the YOLO prediction method.\"\"\" model : YOLO = field ( default = None , init = False , repr = False ) \"\"\"The YOLO model object.\"\"\" def __getstate__ ( self ) -> dict [ str , Any ]: state = self . __dict__ . copy () del state [ \"model\" ] # we dont want to serialize the model return state def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model","title":"YoloConfig"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#class-variables","text":"device model verbose","title":"Class variables"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#load_model","text":"def load_model ( self ) -> ultralytics . models . yolo . model . YOLO View Source def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model","title":"load_model"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#yolocontroller","text":"class YoloController ( timing_config : wtracker . sim . config . TimingConfig , yolo_config : wtracker . sim . sim_controllers . yolo_controller . YoloConfig ) Abstract base class for simulator controllers.","title":"YoloController"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class YoloController ( SimController ) : def __init__ ( self , timing_config : TimingConfig , yolo_config : YoloConfig ) : super (). __init__ ( timing_config ) self . yolo_config = yolo_config self . _camera_frames = deque ( maxlen = timing_config . cycle_frame_num ) self . _model = yolo_config . load_model () def on_sim_start ( self , sim : Simulator ) : self . _camera_frames . clear () def on_camera_frame ( self , sim : Simulator ) : self . _camera_frames . append ( sim . camera_view ()) def on_cycle_end ( self , sim : Simulator ) : self . _camera_frames . clear () def predict ( self , frames : Collection [ np.ndarray ] ) -> np . ndarray : assert len ( frames ) > 0 # convert grayscale images to BGR because YOLO expects 3 - channel images if frames [ 0 ] . ndim == 2 : frames = [ cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames ] # predict bounding boxes and format results results = self . _model . predict ( source = frames , device = self . yolo_config . device , max_det = 1 , verbose = self . yolo_config . verbose , ** self . yolo_config . pred_kwargs , ) results = [ res.numpy() for res in results ] bboxes = [] for res in results : if len ( res . boxes . xyxy ) == 0 : bboxes . append ( np . full ( [ 4 ] , np . nan )) else : bbox = BoxConverter . to_xywh ( res . boxes . xyxy [ 0 ] , BoxFormat . XYXY ) bboxes . append ( bbox ) return np . stack ( bboxes , axis = 0 ) def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : frame = self . _camera_frames [ -self.timing_config.pred_frame_num ] bbox = self . predict ( [ frame ] ) bbox = bbox [ 0 ] if not np . isfinite ( bbox ). all () : return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ] ), round ( bbox_mid [ 1 ] - camera_mid [ 1 ] ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . predict ( self . _camera_frames )","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#ancestors-in-mro_1","text":"wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_frames . append ( sim . camera_view ())","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when a cycle ends. View Source def on_cycle_end ( self , sim : Simulator ) : self . _camera_frames . clear ()","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_frames . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#predict","text":"def predict ( self , frames : Collection [ numpy . ndarray ] ) -> numpy . ndarray View Source def predict(self, frames: Collection[np.ndarray]) -> np.ndarray: assert len(frames) > 0 # convert grayscale images to BGR because YOLO expects 3-channel images if frames[0].ndim == 2: frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames] # predict bounding boxes and format results results = self._model.predict( source=frames, device=self.yolo_config.device, max_det=1, verbose=self.yolo_config.verbose, **self.yolo_config.pred_kwargs, ) results = [res.numpy() for res in results] bboxes = [] for res in results: if len(res.boxes.xyxy) == 0: bboxes.append(np.full([4], np.nan)) else: bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY) bboxes.append(bbox) return np.stack(bboxes, axis=0)","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : frame = self . _camera_frames [ -self.timing_config.pred_frame_num ] bbox = self . predict ( [ frame ] ) bbox = bbox [ 0 ] if not np . isfinite ( bbox ). all () : return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ] ), round ( bbox_mid [ 1 ] - camera_mid [ 1 ] )","title":"provide_movement_vector"},{"location":"reference/wtracker/utils/","text":"Namespace wtracker.utils Sub-modules wtracker.utils.bbox_utils wtracker.utils.config_base wtracker.utils.frame_reader wtracker.utils.gui_utils wtracker.utils.io_utils wtracker.utils.log_utils wtracker.utils.path_utils wtracker.utils.threading_utils","title":"Index"},{"location":"reference/wtracker/utils/#namespace-wtrackerutils","text":"","title":"Namespace wtracker.utils"},{"location":"reference/wtracker/utils/#sub-modules","text":"wtracker.utils.bbox_utils wtracker.utils.config_base wtracker.utils.frame_reader wtracker.utils.gui_utils wtracker.utils.io_utils wtracker.utils.log_utils wtracker.utils.path_utils wtracker.utils.threading_utils","title":"Sub-modules"},{"location":"reference/wtracker/utils/bbox_utils/","text":"Module wtracker.utils.bbox_utils View Source import numpy as np from enum import Enum class BoxFormat ( Enum ): \"\"\" Enumeration representing different box formats. Attributes: XYWH (int): Represents the box format as (x, y, width, height). XYXY (int): Represents the box format as (x1, y1, x2, y2). YOLO (int): Represents the box format as (center_x, center_y, width, height). \"\"\" XYWH = 0 XYXY = 1 YOLO = 2 class BoxUtils : \"\"\" A utility class for working with bounding boxes. \"\"\" @staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ - 1 ] == 4 @staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray ]: \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ - 1 ], axis =- 1 ) c1 = np . squeeze ( c1 , axis =- 1 ) c2 = np . squeeze ( c2 , axis =- 1 ) c3 = np . squeeze ( c3 , axis =- 1 ) c4 = np . squeeze ( c4 , axis =- 1 ) return c1 , c2 , c3 , c4 @staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis =- 1 ) c2 = np . expand_dims ( c2 , axis =- 1 ) c3 = np . expand_dims ( c3 , axis =- 1 ) c4 = np . expand_dims ( c4 , axis =- 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis =- 1 ) @staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ([ center_x , center_y ]) . T @staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ) . astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ) . astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ) . astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ) . astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) @staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non-finite bounding boxes is_legal = np . isfinite ( bboxes ) . all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal class BoxConverter : \"\"\" Utility class for converting bounding box coordinates between different formats. \"\"\" @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) Classes BoxConverter class BoxConverter ( / , * args , ** kwargs ) Utility class for converting bounding box coordinates between different formats. View Source class BoxConverter : \"\"\" Utility class for converting bounding box coordinates between different formats. \"\"\" @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) Static methods change_format def change_format ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat , dst_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates from one format to another. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None dst_format BoxFormat The destination format of the bounding box coordinates. None Returns: Type Description np.ndarray The converted bounding box coordinates. Raises: Type Description Exception If the conversion between the specified formats is not supported. View Source @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) to_xywh def to_xywh ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the XYWH format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the XYWH format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) to_xyxy def to_xyxy ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the XYXY format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the XYXY format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) to_yolo def to_yolo ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the YOLO format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the YOLO format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) BoxFormat class BoxFormat ( / , * args , ** kwargs ) Enumeration representing different box formats. Attributes Name Type Description Default XYWH int Represents the box format as (x, y, width, height). None XYXY int Represents the box format as (x1, y1, x2, y2). None YOLO int Represents the box format as (center_x, center_y, width, height). None View Source class BoxFormat ( Enum ): \"\"\" Enumeration representing different box formats. Attributes: XYWH (int): Represents the box format as (x, y, width, height). XYXY (int): Represents the box format as (x1, y1, x2, y2). YOLO (int): Represents the box format as (center_x, center_y, width, height). \"\"\" XYWH = 0 XYXY = 1 YOLO = 2 Ancestors (in MRO) enum.Enum Class variables XYWH XYXY YOLO name value BoxUtils class BoxUtils ( / , * args , ** kwargs ) A utility class for working with bounding boxes. View Source class BoxUtils : \"\"\" A utility class for working with bounding boxes. \"\"\" @ staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ - 1 ] == 4 @ staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray ]: \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ - 1 ], axis = - 1 ) c1 = np . squeeze ( c1 , axis = - 1 ) c2 = np . squeeze ( c2 , axis = - 1 ) c3 = np . squeeze ( c3 , axis = - 1 ) c4 = np . squeeze ( c4 , axis = - 1 ) return c1 , c2 , c3 , c4 @ staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis = - 1 ) c2 = np . expand_dims ( c2 , axis = - 1 ) c3 = np . expand_dims ( c3 , axis = - 1 ) c4 = np . expand_dims ( c4 , axis = - 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis = - 1 ) @ staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ([ center_x , center_y ]). T @ staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ). astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ). astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ). astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ). astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) @ staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non - finite bounding boxes is_legal = np . isfinite ( bboxes ). all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal Static methods center def center ( bboxes : numpy . ndarray , box_format : wtracker . utils . bbox_utils . BoxFormat = < BoxFormat . XYWH : 0 > ) -> numpy . ndarray Calculate the center of the bounding boxes. Parameters: Name Type Description Default bboxes np.ndarray The input bounding boxes. None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description np.ndarray The center of the bounding boxes, in the format (center_x, center_y). View Source @staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ( [ center_x, center_y ] ). T discretize def discretize ( bboxes : numpy . ndarray , bounds : tuple [ int , int ], box_format : wtracker . utils . bbox_utils . BoxFormat ) -> tuple [ numpy . ndarray , numpy . ndarray ] Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Parameters: Name Type Description Default bboxes np.ndarray The bounding box coordinates to convert. None bounds tuple[int, int] The bounds to clamp the bounding boxes to, in the format (h, w). None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description tuple[np.ndarray, np.ndarray] The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. View Source @ staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non - finite bounding boxes is_legal = np . isfinite ( bboxes ). all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal is_bbox def is_bbox ( array : numpy . ndarray ) -> bool Check if the given array is a valid bounding box. Parameters: Name Type Description Default array np.ndarray The array to check. None Returns: Type Description bool True if the array is a valid bounding box, False otherwise. View Source @staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ -1 ] == 4 pack def pack ( c1 : numpy . ndarray , c2 : numpy . ndarray , c3 : numpy . ndarray , c4 : numpy . ndarray ) -> numpy . ndarray Pack the given components into a single bounding box. Parameters: Name Type Description Default c1 np.ndarray The first component of the bounding box. None c2 np.ndarray The second component of the bounding box. None c3 np.ndarray The third component of the bounding box. None c4 np.ndarray The fourth component of the bounding box. None Returns: Type Description np.ndarray The packed bounding box. View Source @staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis =- 1 ) c2 = np . expand_dims ( c2 , axis =- 1 ) c3 = np . expand_dims ( c3 , axis =- 1 ) c4 = np . expand_dims ( c4 , axis =- 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis =- 1 ) round def round ( bboxes : numpy . ndarray , box_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Rounds the bounding box coordinates to integers. Parameters: Name Type Description Default bboxes np.ndarray The bounding box coordinates to convert. None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description np.ndarray The bounding box coordinates as integers. View Source @ staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ). astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ). astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ). astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ). astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) unpack def unpack ( bbox : numpy . ndarray ) -> tuple [ numpy . ndarray , numpy . ndarray , numpy . ndarray , numpy . ndarray ] Unpack the given bounding box into its individual components. Parameters: Name Type Description Default bbox np.ndarray The bounding box to unpack. None Returns: Type Description tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] The unpacked components of the bounding box. View Source @staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np.ndarray, np.ndarray, np.ndarray, np.ndarray ] : \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ -1 ] , axis =- 1 ) c1 = np . squeeze ( c1 , axis =- 1 ) c2 = np . squeeze ( c2 , axis =- 1 ) c3 = np . squeeze ( c3 , axis =- 1 ) c4 = np . squeeze ( c4 , axis =- 1 ) return c1 , c2 , c3 , c4","title":"Bbox Utils"},{"location":"reference/wtracker/utils/bbox_utils/#module-wtrackerutilsbbox_utils","text":"View Source import numpy as np from enum import Enum class BoxFormat ( Enum ): \"\"\" Enumeration representing different box formats. Attributes: XYWH (int): Represents the box format as (x, y, width, height). XYXY (int): Represents the box format as (x1, y1, x2, y2). YOLO (int): Represents the box format as (center_x, center_y, width, height). \"\"\" XYWH = 0 XYXY = 1 YOLO = 2 class BoxUtils : \"\"\" A utility class for working with bounding boxes. \"\"\" @staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ - 1 ] == 4 @staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray ]: \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ - 1 ], axis =- 1 ) c1 = np . squeeze ( c1 , axis =- 1 ) c2 = np . squeeze ( c2 , axis =- 1 ) c3 = np . squeeze ( c3 , axis =- 1 ) c4 = np . squeeze ( c4 , axis =- 1 ) return c1 , c2 , c3 , c4 @staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis =- 1 ) c2 = np . expand_dims ( c2 , axis =- 1 ) c3 = np . expand_dims ( c3 , axis =- 1 ) c4 = np . expand_dims ( c4 , axis =- 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis =- 1 ) @staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ([ center_x , center_y ]) . T @staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ) . astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ) . astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ) . astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ) . astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) @staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non-finite bounding boxes is_legal = np . isfinite ( bboxes ) . all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal class BoxConverter : \"\"\" Utility class for converting bounding box coordinates between different formats. \"\"\" @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"Module wtracker.utils.bbox_utils"},{"location":"reference/wtracker/utils/bbox_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/bbox_utils/#boxconverter","text":"class BoxConverter ( / , * args , ** kwargs ) Utility class for converting bounding box coordinates between different formats. View Source class BoxConverter : \"\"\" Utility class for converting bounding box coordinates between different formats. \"\"\" @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"BoxConverter"},{"location":"reference/wtracker/utils/bbox_utils/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/bbox_utils/#change_format","text":"def change_format ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat , dst_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates from one format to another. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None dst_format BoxFormat The destination format of the bounding box coordinates. None Returns: Type Description np.ndarray The converted bounding box coordinates. Raises: Type Description Exception If the conversion between the specified formats is not supported. View Source @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"change_format"},{"location":"reference/wtracker/utils/bbox_utils/#to_xywh","text":"def to_xywh ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the XYWH format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the XYWH format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"to_xywh"},{"location":"reference/wtracker/utils/bbox_utils/#to_xyxy","text":"def to_xyxy ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the XYXY format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the XYXY format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"to_xyxy"},{"location":"reference/wtracker/utils/bbox_utils/#to_yolo","text":"def to_yolo ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the YOLO format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the YOLO format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"to_yolo"},{"location":"reference/wtracker/utils/bbox_utils/#boxformat","text":"class BoxFormat ( / , * args , ** kwargs ) Enumeration representing different box formats.","title":"BoxFormat"},{"location":"reference/wtracker/utils/bbox_utils/#attributes","text":"Name Type Description Default XYWH int Represents the box format as (x, y, width, height). None XYXY int Represents the box format as (x1, y1, x2, y2). None YOLO int Represents the box format as (center_x, center_y, width, height). None View Source class BoxFormat ( Enum ): \"\"\" Enumeration representing different box formats. Attributes: XYWH (int): Represents the box format as (x, y, width, height). XYXY (int): Represents the box format as (x1, y1, x2, y2). YOLO (int): Represents the box format as (center_x, center_y, width, height). \"\"\" XYWH = 0 XYXY = 1 YOLO = 2","title":"Attributes"},{"location":"reference/wtracker/utils/bbox_utils/#ancestors-in-mro","text":"enum.Enum","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/bbox_utils/#class-variables","text":"XYWH XYXY YOLO name value","title":"Class variables"},{"location":"reference/wtracker/utils/bbox_utils/#boxutils","text":"class BoxUtils ( / , * args , ** kwargs ) A utility class for working with bounding boxes. View Source class BoxUtils : \"\"\" A utility class for working with bounding boxes. \"\"\" @ staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ - 1 ] == 4 @ staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray ]: \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ - 1 ], axis = - 1 ) c1 = np . squeeze ( c1 , axis = - 1 ) c2 = np . squeeze ( c2 , axis = - 1 ) c3 = np . squeeze ( c3 , axis = - 1 ) c4 = np . squeeze ( c4 , axis = - 1 ) return c1 , c2 , c3 , c4 @ staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis = - 1 ) c2 = np . expand_dims ( c2 , axis = - 1 ) c3 = np . expand_dims ( c3 , axis = - 1 ) c4 = np . expand_dims ( c4 , axis = - 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis = - 1 ) @ staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ([ center_x , center_y ]). T @ staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ). astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ). astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ). astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ). astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) @ staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non - finite bounding boxes is_legal = np . isfinite ( bboxes ). all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal","title":"BoxUtils"},{"location":"reference/wtracker/utils/bbox_utils/#static-methods_1","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/bbox_utils/#center","text":"def center ( bboxes : numpy . ndarray , box_format : wtracker . utils . bbox_utils . BoxFormat = < BoxFormat . XYWH : 0 > ) -> numpy . ndarray Calculate the center of the bounding boxes. Parameters: Name Type Description Default bboxes np.ndarray The input bounding boxes. None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description np.ndarray The center of the bounding boxes, in the format (center_x, center_y). View Source @staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ( [ center_x, center_y ] ). T","title":"center"},{"location":"reference/wtracker/utils/bbox_utils/#discretize","text":"def discretize ( bboxes : numpy . ndarray , bounds : tuple [ int , int ], box_format : wtracker . utils . bbox_utils . BoxFormat ) -> tuple [ numpy . ndarray , numpy . ndarray ] Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Parameters: Name Type Description Default bboxes np.ndarray The bounding box coordinates to convert. None bounds tuple[int, int] The bounds to clamp the bounding boxes to, in the format (h, w). None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description tuple[np.ndarray, np.ndarray] The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. View Source @ staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non - finite bounding boxes is_legal = np . isfinite ( bboxes ). all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal","title":"discretize"},{"location":"reference/wtracker/utils/bbox_utils/#is_bbox","text":"def is_bbox ( array : numpy . ndarray ) -> bool Check if the given array is a valid bounding box. Parameters: Name Type Description Default array np.ndarray The array to check. None Returns: Type Description bool True if the array is a valid bounding box, False otherwise. View Source @staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ -1 ] == 4","title":"is_bbox"},{"location":"reference/wtracker/utils/bbox_utils/#pack","text":"def pack ( c1 : numpy . ndarray , c2 : numpy . ndarray , c3 : numpy . ndarray , c4 : numpy . ndarray ) -> numpy . ndarray Pack the given components into a single bounding box. Parameters: Name Type Description Default c1 np.ndarray The first component of the bounding box. None c2 np.ndarray The second component of the bounding box. None c3 np.ndarray The third component of the bounding box. None c4 np.ndarray The fourth component of the bounding box. None Returns: Type Description np.ndarray The packed bounding box. View Source @staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis =- 1 ) c2 = np . expand_dims ( c2 , axis =- 1 ) c3 = np . expand_dims ( c3 , axis =- 1 ) c4 = np . expand_dims ( c4 , axis =- 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis =- 1 )","title":"pack"},{"location":"reference/wtracker/utils/bbox_utils/#round","text":"def round ( bboxes : numpy . ndarray , box_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Rounds the bounding box coordinates to integers. Parameters: Name Type Description Default bboxes np.ndarray The bounding box coordinates to convert. None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description np.ndarray The bounding box coordinates as integers. View Source @ staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ). astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ). astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ). astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ). astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format )","title":"round"},{"location":"reference/wtracker/utils/bbox_utils/#unpack","text":"def unpack ( bbox : numpy . ndarray ) -> tuple [ numpy . ndarray , numpy . ndarray , numpy . ndarray , numpy . ndarray ] Unpack the given bounding box into its individual components. Parameters: Name Type Description Default bbox np.ndarray The bounding box to unpack. None Returns: Type Description tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] The unpacked components of the bounding box. View Source @staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np.ndarray, np.ndarray, np.ndarray, np.ndarray ] : \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ -1 ] , axis =- 1 ) c1 = np . squeeze ( c1 , axis =- 1 ) c2 = np . squeeze ( c2 , axis =- 1 ) c3 = np . squeeze ( c3 , axis =- 1 ) c4 = np . squeeze ( c4 , axis =- 1 ) return c1 , c2 , c3 , c4","title":"unpack"},{"location":"reference/wtracker/utils/config_base/","text":"Module wtracker.utils.config_base View Source from __future__ import annotations from typing import Type , TypeVar from dataclasses import dataclass , fields , MISSING , is_dataclass import json from wtracker.utils.gui_utils import UserPrompt from wtracker.utils.io_utils import pickle_load_object , pickle_save_object T = TypeVar ( \"T\" , bound = \"ConfigBase\" ) @dataclass class ConfigBase : @classmethod def load_json ( cls : type [ T ], path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open { cls . __name__ } File\" , file_types = [( \"json\" , \".json\" )], ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save { type ( self ) . __name__ } As\" , file_types = [( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) @classmethod def load_pickle ( cls : type [ T ], path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open { cls . __name__ } File\" , file_types = [( \"pickle\" , \".pkl\" )], ) return pickle_load_object ( path ) def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save { type ( self ) . __name__ } As\" , file_types = [( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) def print_initialization ( cls , include_default : bool = True , init_fields_only : bool = True ) -> str : \"\"\" Print the initialization of a dataclass as a string \"\"\" if not is_dataclass ( cls ): print ( f \"ERROR:: { cls . __name__ } is not a dataclass\" ) return \"\" print ( f \" { cls . __name__ } (\" ) for field in fields ( cls ): if init_fields_only and field . init is False : continue is_default = not isinstance ( field . default , type ( MISSING )) val = None if include_default and is_default : val = field . default if type ( val ) is str : val = f 'f\" { val } \"' print ( f \" { field . name } = { val } , # { field . type } \" ) print ( \")\" ) Variables T Functions print_initialization def print_initialization ( cls , include_default : 'bool' = True , init_fields_only : 'bool' = True ) -> 'str' Print the initialization of a dataclass as a string View Source def print_initialization ( cls , include_default : bool = True , init_fields_only : bool = True ) -> str : \"\"\" Print the initialization of a dataclass as a string \"\"\" if not is_dataclass ( cls ): print ( f \"ERROR::{cls.__name__} is not a dataclass\" ) return \"\" print ( f \"{cls.__name__}(\" ) for field in fields ( cls ): if init_fields_only and field . init is False : continue is_default = not isinstance ( field . default , type ( MISSING )) val = None if include_default and is_default : val = field . default if type ( val ) is str : val = f ' f \"{val}\" ' print ( f \" {field.name} = {val}, # {field.type}\" ) print ( \")\" ) Classes ConfigBase class ConfigBase ( ) ConfigBase() View Source @dataclass class ConfigBase : @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj def save_json ( self , path : str = None ) : \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[ (\"json\", \".json\") ] , defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[ (\"pickle\", \".pkl\") ] , defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) Descendants wtracker.sim.config.TimingConfig wtracker.sim.config.ExperimentConfig wtracker.neural.config.DatasetConfig wtracker.neural.config.TrainConfig wtracker.neural.config.IOConfig wtracker.sim.sim_controllers.logging_controller.LogConfig wtracker.sim.sim_controllers.polyfit_controller.PolyfitConfig wtracker.sim.sim_controllers.yolo_controller.YoloConfig Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"Config Base"},{"location":"reference/wtracker/utils/config_base/#module-wtrackerutilsconfig_base","text":"View Source from __future__ import annotations from typing import Type , TypeVar from dataclasses import dataclass , fields , MISSING , is_dataclass import json from wtracker.utils.gui_utils import UserPrompt from wtracker.utils.io_utils import pickle_load_object , pickle_save_object T = TypeVar ( \"T\" , bound = \"ConfigBase\" ) @dataclass class ConfigBase : @classmethod def load_json ( cls : type [ T ], path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open { cls . __name__ } File\" , file_types = [( \"json\" , \".json\" )], ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save { type ( self ) . __name__ } As\" , file_types = [( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) @classmethod def load_pickle ( cls : type [ T ], path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open { cls . __name__ } File\" , file_types = [( \"pickle\" , \".pkl\" )], ) return pickle_load_object ( path ) def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save { type ( self ) . __name__ } As\" , file_types = [( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) def print_initialization ( cls , include_default : bool = True , init_fields_only : bool = True ) -> str : \"\"\" Print the initialization of a dataclass as a string \"\"\" if not is_dataclass ( cls ): print ( f \"ERROR:: { cls . __name__ } is not a dataclass\" ) return \"\" print ( f \" { cls . __name__ } (\" ) for field in fields ( cls ): if init_fields_only and field . init is False : continue is_default = not isinstance ( field . default , type ( MISSING )) val = None if include_default and is_default : val = field . default if type ( val ) is str : val = f 'f\" { val } \"' print ( f \" { field . name } = { val } , # { field . type } \" ) print ( \")\" )","title":"Module wtracker.utils.config_base"},{"location":"reference/wtracker/utils/config_base/#variables","text":"T","title":"Variables"},{"location":"reference/wtracker/utils/config_base/#functions","text":"","title":"Functions"},{"location":"reference/wtracker/utils/config_base/#print_initialization","text":"def print_initialization ( cls , include_default : 'bool' = True , init_fields_only : 'bool' = True ) -> 'str' Print the initialization of a dataclass as a string View Source def print_initialization ( cls , include_default : bool = True , init_fields_only : bool = True ) -> str : \"\"\" Print the initialization of a dataclass as a string \"\"\" if not is_dataclass ( cls ): print ( f \"ERROR::{cls.__name__} is not a dataclass\" ) return \"\" print ( f \"{cls.__name__}(\" ) for field in fields ( cls ): if init_fields_only and field . init is False : continue is_default = not isinstance ( field . default , type ( MISSING )) val = None if include_default and is_default : val = field . default if type ( val ) is str : val = f ' f \"{val}\" ' print ( f \" {field.name} = {val}, # {field.type}\" ) print ( \")\" )","title":"print_initialization"},{"location":"reference/wtracker/utils/config_base/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/config_base/#configbase","text":"class ConfigBase ( ) ConfigBase() View Source @dataclass class ConfigBase : @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj def save_json ( self , path : str = None ) : \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[ (\"json\", \".json\") ] , defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[ (\"pickle\", \".pkl\") ] , defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"ConfigBase"},{"location":"reference/wtracker/utils/config_base/#descendants","text":"wtracker.sim.config.TimingConfig wtracker.sim.config.ExperimentConfig wtracker.neural.config.DatasetConfig wtracker.neural.config.TrainConfig wtracker.neural.config.IOConfig wtracker.sim.sim_controllers.logging_controller.LogConfig wtracker.sim.sim_controllers.polyfit_controller.PolyfitConfig wtracker.sim.sim_controllers.yolo_controller.YoloConfig","title":"Descendants"},{"location":"reference/wtracker/utils/config_base/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/config_base/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/utils/config_base/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/utils/config_base/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/config_base/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/utils/config_base/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/utils/frame_reader/","text":"Module wtracker.utils.frame_reader View Source from __future__ import annotations import os import glob import numpy as np import cv2 as cv from wtracker.utils.path_utils import join_paths class FrameReader : \"\"\" An class for reading frames from a directory of frame files. Args: root_folder (str): The root folder path where the frame files are located. frame_files (list[str]): A list of frame file names. read_format (int, optional): The format in which the frames should be read. Attributes: root_folder (str): The root folder path where the frame files are located. frame_shape (tuple[int, ...]): The shape of the frame. frame_size (tuple[int, int]): The size of the frame. files (list[str]): The list of file paths. read_format (int): The read format of the frame reader. \"\"\" def __init__ ( self , root_folder : str , frame_files : list [ str ], read_format : int = cv . IMREAD_GRAYSCALE , ): assert os . path . exists ( root_folder ) assert len ( frame_files ) > 0 self . _root_folder = root_folder self . _files = frame_files self . _read_format = read_format self . _frame_shape = self . _extract_frame_shape () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: path = join_paths ( self . root_folder , self . files [ 0 ]) frame = cv . imread ( path , self . _read_format ) return frame . shape @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os . path . isfile ( join_paths ( root_folder , f ))] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os . path . isfile ( join_paths ( root_folder , f ))] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @property def root_folder ( self ) -> str : \"\"\" Returns the root folder path. Returns: str: The root folder path. \"\"\" return self . _root_folder @property def frame_shape ( self ) -> tuple [ int , ... ]: \"\"\" Returns the shape of the frame. Returns: tuple[int, ...]: The shape of the frame, in format (h, w, ...). \"\"\" return self . _frame_shape @property def frame_size ( self ) -> tuple [ int , int ]: \"\"\" Returns the size of the frame. Returns: tuple[int, int]: The shape of the frame, in format (h, w). \"\"\" return self . _frame_shape [: 2 ] @property def files ( self ) -> list [ str ]: \"\"\" Returns the list of files associated with the FrameReader object. Returns: list[str]: The list of file paths. \"\"\" return self . _files @property def read_format ( self ) -> int : \"\"\" Returns the read format of the frame reader. Returns: int: The read format. \"\"\" return self . _read_format def __len__ ( self ) -> int : return len ( self . _files ) def __getitem__ ( self , idx : int ) -> np . ndarray : if idx < 0 or idx >= len ( self . _files ): raise IndexError ( \"index out of bounds\" ) path = join_paths ( self . root_folder , self . files [ idx ]) frame = cv . imread ( path , self . _read_format ) return frame . astype ( np . uint8 , copy = False ) def __iter__ ( self ): return FrameStream ( self ) def make_stream ( self ): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream ( self ) class FrameStream : \"\"\" A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Args: frame_reader (FrameReader): The frame reader object. \"\"\" def __init__ ( self , frame_reader : FrameReader ): self . _frame_reader = frame_reader self . _idx = - 1 self . frame = None @property def index ( self ) -> int : \"\"\" The index of the current frame. \"\"\" return self . _idx def __len__ ( self ): return len ( self . _frame_reader ) def __iter__ ( self ): return self def __next__ ( self ) -> np . ndarray : self . progress () if not self . can_read (): raise StopIteration () frame = self . read () return frame def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read (): raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self . _idx ] return self . frame def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) def reset ( self ): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self . seek ( - 1 ) class DummyReader ( FrameReader ): \"\"\" A dummy frame reader that generates empty frames of a specified resolution. Args: num_frames (int): The number of frames to generate. resolution (tuple[int, int]): The resolution of the frames, in format (h, w). colored (bool, optional): Whether the frames are colored or grayscale. \"\"\" def __init__ ( self , num_frames : int , resolution : tuple [ int , int ], colored : bool = True ): self . colored = colored self . _resolution = resolution shape = ( * resolution , 3 ) if colored else resolution self . _frame = np . full ( shape , fill_value = 255 , dtype = np . uint8 ) frames = [ str ( i ) for i in range ( num_frames )] super () . __init__ ( \".\" , frame_files = frames ) def __getitem__ ( self , idx : int ) -> np . ndarray : return self . _frame . copy () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: if self . colored : return ( * self . _resolution , 3 ) return self . _resolution Classes DummyReader class DummyReader ( num_frames : 'int' , resolution : 'tuple[int, int]' , colored : 'bool' = True ) A dummy frame reader that generates empty frames of a specified resolution. Attributes Name Type Description Default num_frames int The number of frames to generate. None resolution tuple[int, int] The resolution of the frames, in format (h, w). None colored bool Whether the frames are colored or grayscale. None View Source class DummyReader ( FrameReader ): \"\"\" A dummy frame reader that generates empty frames of a specified resolution. Args: num_frames (int): The number of frames to generate. resolution (tuple[int, int]): The resolution of the frames, in format (h, w). colored (bool, optional): Whether the frames are colored or grayscale. \"\"\" def __init__ ( self , num_frames : int , resolution : tuple [ int , int ], colored : bool = True ): self . colored = colored self . _resolution = resolution shape = ( * resolution , 3 ) if colored else resolution self . _frame = np . full ( shape , fill_value = 255 , dtype = np . uint8 ) frames = [ str ( i ) for i in range ( num_frames )] super (). __init__ ( \".\" , frame_files = frames ) def __getitem__ ( self , idx : int ) -> np . ndarray : return self . _frame . copy () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: if self . colored : return ( * self . _resolution , 3 ) return self . _resolution Ancestors (in MRO) wtracker.utils.frame_reader.FrameReader Static methods create_from_directory def create_from_directory ( root_folder : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a directory. Parameters: Name Type Description Default root_folder str The root folder containing the frame files. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) create_from_template def create_from_template ( root_folder : 'str' , name_format : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a file name template. Parameters: Name Type Description Default root_folder str The root folder where the frame files are located. None name_format str The format of the frame file names. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) Instance variables files Returns the list of files associated with the FrameReader object. frame_shape Returns the shape of the frame. frame_size Returns the size of the frame. read_format Returns the read format of the frame reader. root_folder Returns the root folder path. Methods make_stream def make_stream ( self ) Creates and returns a FrameStream object using the current instance of FrameReader. Returns: Type Description FrameStream A FrameStream object. View Source def make_stream(self): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream(self) FrameReader class FrameReader ( root_folder : 'str' , frame_files : 'list[str]' , read_format : 'int' = 0 ) An class for reading frames from a directory of frame files. Attributes Name Type Description Default root_folder str The root folder path where the frame files are located. None frame_files list[str] A list of frame file names. None read_format int The format in which the frames should be read. None root_folder str The root folder path where the frame files are located. None frame_shape tuple[int, ...] The shape of the frame. None frame_size tuple[int, int] The size of the frame. None files list[str] The list of file paths. None read_format int The read format of the frame reader. None View Source class FrameReader : \"\"\" An class for reading frames from a directory of frame files. Args: root_folder (str): The root folder path where the frame files are located. frame_files (list[str]): A list of frame file names. read_format (int, optional): The format in which the frames should be read. Attributes: root_folder (str): The root folder path where the frame files are located. frame_shape (tuple[int, ...]): The shape of the frame. frame_size (tuple[int, int]): The size of the frame. files (list[str]): The list of file paths. read_format (int): The read format of the frame reader. \"\"\" def __init__ ( self , root_folder : str , frame_files : list [ str ] , read_format : int = cv . IMREAD_GRAYSCALE , ) : assert os . path . exists ( root_folder ) assert len ( frame_files ) > 0 self . _root_folder = root_folder self . _files = frame_files self . _read_format = read_format self . _frame_shape = self . _extract_frame_shape () def _extract_frame_shape ( self ) -> tuple [ int, ... ] : path = join_paths ( self . root_folder , self . files [ 0 ] ) frame = cv . imread ( path , self . _read_format ) return frame . shape @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @property def root_folder ( self ) -> str : \"\"\" Returns the root folder path. Returns: str: The root folder path. \"\"\" return self . _root_folder @property def frame_shape ( self ) -> tuple [ int, ... ] : \"\"\" Returns the shape of the frame. Returns: tuple[int, ...]: The shape of the frame, in format (h, w, ...). \"\"\" return self . _frame_shape @property def frame_size ( self ) -> tuple [ int, int ] : \"\"\" Returns the size of the frame. Returns: tuple[int, int]: The shape of the frame, in format (h, w). \"\"\" return self . _frame_shape [ :2 ] @property def files ( self ) -> list [ str ] : \"\"\" Returns the list of files associated with the FrameReader object. Returns: list[str]: The list of file paths. \"\"\" return self . _files @property def read_format ( self ) -> int : \"\"\" Returns the read format of the frame reader. Returns: int: The read format. \"\"\" return self . _read_format def __len__ ( self ) -> int : return len ( self . _files ) def __getitem__ ( self , idx : int ) -> np . ndarray : if idx < 0 or idx >= len ( self . _files ) : raise IndexError ( \"index out of bounds\" ) path = join_paths ( self . root_folder , self . files [ idx ] ) frame = cv . imread ( path , self . _read_format ) return frame . astype ( np . uint8 , copy = False ) def __iter__ ( self ) : return FrameStream ( self ) def make_stream ( self ) : \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream ( self ) Descendants wtracker.utils.frame_reader.DummyReader Static methods create_from_directory def create_from_directory ( root_folder : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a directory. Parameters: Name Type Description Default root_folder str The root folder containing the frame files. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) create_from_template def create_from_template ( root_folder : 'str' , name_format : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a file name template. Parameters: Name Type Description Default root_folder str The root folder where the frame files are located. None name_format str The format of the frame file names. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) Instance variables files Returns the list of files associated with the FrameReader object. frame_shape Returns the shape of the frame. frame_size Returns the size of the frame. read_format Returns the read format of the frame reader. root_folder Returns the root folder path. Methods make_stream def make_stream ( self ) Creates and returns a FrameStream object using the current instance of FrameReader. Returns: Type Description FrameStream A FrameStream object. View Source def make_stream(self): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream(self) FrameStream class FrameStream ( frame_reader : 'FrameReader' ) A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Attributes Name Type Description Default frame_reader FrameReader The frame reader object. None View Source class FrameStream : \"\"\" A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Args: frame_reader (FrameReader): The frame reader object. \"\"\" def __init__ ( self , frame_reader : FrameReader ) : self . _frame_reader = frame_reader self . _idx = - 1 self . frame = None @property def index ( self ) -> int : \"\"\" The index of the current frame. \"\"\" return self . _idx def __len__ ( self ) : return len ( self . _frame_reader ) def __iter__ ( self ) : return self def __next__ ( self ) -> np . ndarray : self . progress () if not self . can_read () : raise StopIteration () frame = self . read () return frame def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read () : raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self._idx ] return self . frame def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) def reset ( self ) : \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self . seek ( - 1 ) Descendants wtracker.sim.view_controller.ViewController Instance variables index The index of the current frame. Methods can_read def can_read ( self ) -> 'bool' View Source def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) progress def progress ( self , n : 'int' = 1 ) -> 'bool' Moves the current index forward by the specified number of steps. Parameters: Name Type Description Default n int The number of steps to move forward. None Returns: Type Description bool True if the index was successfully moved forward, False otherwise. View Source def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) read def read ( self ) -> 'np.ndarray' Read and return the frame at the current index. Returns: Type Description np.ndarray The frame at the current index. Raises: Type Description IndexError If the index is out of bounds. View Source def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read () : raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self . _idx ] return self . frame reset def reset ( self ) Resets the frame reader to the beginning of the steam. View Source def reset(self): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self.seek(-1) seek def seek ( self , idx : 'int' ) -> 'bool' Move the index to the specified position. Parameters: Name Type Description Default idx int The index to seek to. None Returns: Type Description bool True if the index is within the valid range, False otherwise. View Source def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read ()","title":"Frame Reader"},{"location":"reference/wtracker/utils/frame_reader/#module-wtrackerutilsframe_reader","text":"View Source from __future__ import annotations import os import glob import numpy as np import cv2 as cv from wtracker.utils.path_utils import join_paths class FrameReader : \"\"\" An class for reading frames from a directory of frame files. Args: root_folder (str): The root folder path where the frame files are located. frame_files (list[str]): A list of frame file names. read_format (int, optional): The format in which the frames should be read. Attributes: root_folder (str): The root folder path where the frame files are located. frame_shape (tuple[int, ...]): The shape of the frame. frame_size (tuple[int, int]): The size of the frame. files (list[str]): The list of file paths. read_format (int): The read format of the frame reader. \"\"\" def __init__ ( self , root_folder : str , frame_files : list [ str ], read_format : int = cv . IMREAD_GRAYSCALE , ): assert os . path . exists ( root_folder ) assert len ( frame_files ) > 0 self . _root_folder = root_folder self . _files = frame_files self . _read_format = read_format self . _frame_shape = self . _extract_frame_shape () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: path = join_paths ( self . root_folder , self . files [ 0 ]) frame = cv . imread ( path , self . _read_format ) return frame . shape @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os . path . isfile ( join_paths ( root_folder , f ))] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os . path . isfile ( join_paths ( root_folder , f ))] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @property def root_folder ( self ) -> str : \"\"\" Returns the root folder path. Returns: str: The root folder path. \"\"\" return self . _root_folder @property def frame_shape ( self ) -> tuple [ int , ... ]: \"\"\" Returns the shape of the frame. Returns: tuple[int, ...]: The shape of the frame, in format (h, w, ...). \"\"\" return self . _frame_shape @property def frame_size ( self ) -> tuple [ int , int ]: \"\"\" Returns the size of the frame. Returns: tuple[int, int]: The shape of the frame, in format (h, w). \"\"\" return self . _frame_shape [: 2 ] @property def files ( self ) -> list [ str ]: \"\"\" Returns the list of files associated with the FrameReader object. Returns: list[str]: The list of file paths. \"\"\" return self . _files @property def read_format ( self ) -> int : \"\"\" Returns the read format of the frame reader. Returns: int: The read format. \"\"\" return self . _read_format def __len__ ( self ) -> int : return len ( self . _files ) def __getitem__ ( self , idx : int ) -> np . ndarray : if idx < 0 or idx >= len ( self . _files ): raise IndexError ( \"index out of bounds\" ) path = join_paths ( self . root_folder , self . files [ idx ]) frame = cv . imread ( path , self . _read_format ) return frame . astype ( np . uint8 , copy = False ) def __iter__ ( self ): return FrameStream ( self ) def make_stream ( self ): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream ( self ) class FrameStream : \"\"\" A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Args: frame_reader (FrameReader): The frame reader object. \"\"\" def __init__ ( self , frame_reader : FrameReader ): self . _frame_reader = frame_reader self . _idx = - 1 self . frame = None @property def index ( self ) -> int : \"\"\" The index of the current frame. \"\"\" return self . _idx def __len__ ( self ): return len ( self . _frame_reader ) def __iter__ ( self ): return self def __next__ ( self ) -> np . ndarray : self . progress () if not self . can_read (): raise StopIteration () frame = self . read () return frame def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read (): raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self . _idx ] return self . frame def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) def reset ( self ): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self . seek ( - 1 ) class DummyReader ( FrameReader ): \"\"\" A dummy frame reader that generates empty frames of a specified resolution. Args: num_frames (int): The number of frames to generate. resolution (tuple[int, int]): The resolution of the frames, in format (h, w). colored (bool, optional): Whether the frames are colored or grayscale. \"\"\" def __init__ ( self , num_frames : int , resolution : tuple [ int , int ], colored : bool = True ): self . colored = colored self . _resolution = resolution shape = ( * resolution , 3 ) if colored else resolution self . _frame = np . full ( shape , fill_value = 255 , dtype = np . uint8 ) frames = [ str ( i ) for i in range ( num_frames )] super () . __init__ ( \".\" , frame_files = frames ) def __getitem__ ( self , idx : int ) -> np . ndarray : return self . _frame . copy () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: if self . colored : return ( * self . _resolution , 3 ) return self . _resolution","title":"Module wtracker.utils.frame_reader"},{"location":"reference/wtracker/utils/frame_reader/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/frame_reader/#dummyreader","text":"class DummyReader ( num_frames : 'int' , resolution : 'tuple[int, int]' , colored : 'bool' = True ) A dummy frame reader that generates empty frames of a specified resolution.","title":"DummyReader"},{"location":"reference/wtracker/utils/frame_reader/#attributes","text":"Name Type Description Default num_frames int The number of frames to generate. None resolution tuple[int, int] The resolution of the frames, in format (h, w). None colored bool Whether the frames are colored or grayscale. None View Source class DummyReader ( FrameReader ): \"\"\" A dummy frame reader that generates empty frames of a specified resolution. Args: num_frames (int): The number of frames to generate. resolution (tuple[int, int]): The resolution of the frames, in format (h, w). colored (bool, optional): Whether the frames are colored or grayscale. \"\"\" def __init__ ( self , num_frames : int , resolution : tuple [ int , int ], colored : bool = True ): self . colored = colored self . _resolution = resolution shape = ( * resolution , 3 ) if colored else resolution self . _frame = np . full ( shape , fill_value = 255 , dtype = np . uint8 ) frames = [ str ( i ) for i in range ( num_frames )] super (). __init__ ( \".\" , frame_files = frames ) def __getitem__ ( self , idx : int ) -> np . ndarray : return self . _frame . copy () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: if self . colored : return ( * self . _resolution , 3 ) return self . _resolution","title":"Attributes"},{"location":"reference/wtracker/utils/frame_reader/#ancestors-in-mro","text":"wtracker.utils.frame_reader.FrameReader","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/frame_reader/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/frame_reader/#create_from_directory","text":"def create_from_directory ( root_folder : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a directory. Parameters: Name Type Description Default root_folder str The root folder containing the frame files. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format )","title":"create_from_directory"},{"location":"reference/wtracker/utils/frame_reader/#create_from_template","text":"def create_from_template ( root_folder : 'str' , name_format : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a file name template. Parameters: Name Type Description Default root_folder str The root folder where the frame files are located. None name_format str The format of the frame file names. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format )","title":"create_from_template"},{"location":"reference/wtracker/utils/frame_reader/#instance-variables","text":"files Returns the list of files associated with the FrameReader object. frame_shape Returns the shape of the frame. frame_size Returns the size of the frame. read_format Returns the read format of the frame reader. root_folder Returns the root folder path.","title":"Instance variables"},{"location":"reference/wtracker/utils/frame_reader/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/frame_reader/#make_stream","text":"def make_stream ( self ) Creates and returns a FrameStream object using the current instance of FrameReader. Returns: Type Description FrameStream A FrameStream object. View Source def make_stream(self): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream(self)","title":"make_stream"},{"location":"reference/wtracker/utils/frame_reader/#framereader","text":"class FrameReader ( root_folder : 'str' , frame_files : 'list[str]' , read_format : 'int' = 0 ) An class for reading frames from a directory of frame files.","title":"FrameReader"},{"location":"reference/wtracker/utils/frame_reader/#attributes_1","text":"Name Type Description Default root_folder str The root folder path where the frame files are located. None frame_files list[str] A list of frame file names. None read_format int The format in which the frames should be read. None root_folder str The root folder path where the frame files are located. None frame_shape tuple[int, ...] The shape of the frame. None frame_size tuple[int, int] The size of the frame. None files list[str] The list of file paths. None read_format int The read format of the frame reader. None View Source class FrameReader : \"\"\" An class for reading frames from a directory of frame files. Args: root_folder (str): The root folder path where the frame files are located. frame_files (list[str]): A list of frame file names. read_format (int, optional): The format in which the frames should be read. Attributes: root_folder (str): The root folder path where the frame files are located. frame_shape (tuple[int, ...]): The shape of the frame. frame_size (tuple[int, int]): The size of the frame. files (list[str]): The list of file paths. read_format (int): The read format of the frame reader. \"\"\" def __init__ ( self , root_folder : str , frame_files : list [ str ] , read_format : int = cv . IMREAD_GRAYSCALE , ) : assert os . path . exists ( root_folder ) assert len ( frame_files ) > 0 self . _root_folder = root_folder self . _files = frame_files self . _read_format = read_format self . _frame_shape = self . _extract_frame_shape () def _extract_frame_shape ( self ) -> tuple [ int, ... ] : path = join_paths ( self . root_folder , self . files [ 0 ] ) frame = cv . imread ( path , self . _read_format ) return frame . shape @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @property def root_folder ( self ) -> str : \"\"\" Returns the root folder path. Returns: str: The root folder path. \"\"\" return self . _root_folder @property def frame_shape ( self ) -> tuple [ int, ... ] : \"\"\" Returns the shape of the frame. Returns: tuple[int, ...]: The shape of the frame, in format (h, w, ...). \"\"\" return self . _frame_shape @property def frame_size ( self ) -> tuple [ int, int ] : \"\"\" Returns the size of the frame. Returns: tuple[int, int]: The shape of the frame, in format (h, w). \"\"\" return self . _frame_shape [ :2 ] @property def files ( self ) -> list [ str ] : \"\"\" Returns the list of files associated with the FrameReader object. Returns: list[str]: The list of file paths. \"\"\" return self . _files @property def read_format ( self ) -> int : \"\"\" Returns the read format of the frame reader. Returns: int: The read format. \"\"\" return self . _read_format def __len__ ( self ) -> int : return len ( self . _files ) def __getitem__ ( self , idx : int ) -> np . ndarray : if idx < 0 or idx >= len ( self . _files ) : raise IndexError ( \"index out of bounds\" ) path = join_paths ( self . root_folder , self . files [ idx ] ) frame = cv . imread ( path , self . _read_format ) return frame . astype ( np . uint8 , copy = False ) def __iter__ ( self ) : return FrameStream ( self ) def make_stream ( self ) : \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream ( self )","title":"Attributes"},{"location":"reference/wtracker/utils/frame_reader/#descendants","text":"wtracker.utils.frame_reader.DummyReader","title":"Descendants"},{"location":"reference/wtracker/utils/frame_reader/#static-methods_1","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/frame_reader/#create_from_directory_1","text":"def create_from_directory ( root_folder : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a directory. Parameters: Name Type Description Default root_folder str The root folder containing the frame files. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format )","title":"create_from_directory"},{"location":"reference/wtracker/utils/frame_reader/#create_from_template_1","text":"def create_from_template ( root_folder : 'str' , name_format : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a file name template. Parameters: Name Type Description Default root_folder str The root folder where the frame files are located. None name_format str The format of the frame file names. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format )","title":"create_from_template"},{"location":"reference/wtracker/utils/frame_reader/#instance-variables_1","text":"files Returns the list of files associated with the FrameReader object. frame_shape Returns the shape of the frame. frame_size Returns the size of the frame. read_format Returns the read format of the frame reader. root_folder Returns the root folder path.","title":"Instance variables"},{"location":"reference/wtracker/utils/frame_reader/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/utils/frame_reader/#make_stream_1","text":"def make_stream ( self ) Creates and returns a FrameStream object using the current instance of FrameReader. Returns: Type Description FrameStream A FrameStream object. View Source def make_stream(self): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream(self)","title":"make_stream"},{"location":"reference/wtracker/utils/frame_reader/#framestream","text":"class FrameStream ( frame_reader : 'FrameReader' ) A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object.","title":"FrameStream"},{"location":"reference/wtracker/utils/frame_reader/#attributes_2","text":"Name Type Description Default frame_reader FrameReader The frame reader object. None View Source class FrameStream : \"\"\" A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Args: frame_reader (FrameReader): The frame reader object. \"\"\" def __init__ ( self , frame_reader : FrameReader ) : self . _frame_reader = frame_reader self . _idx = - 1 self . frame = None @property def index ( self ) -> int : \"\"\" The index of the current frame. \"\"\" return self . _idx def __len__ ( self ) : return len ( self . _frame_reader ) def __iter__ ( self ) : return self def __next__ ( self ) -> np . ndarray : self . progress () if not self . can_read () : raise StopIteration () frame = self . read () return frame def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read () : raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self._idx ] return self . frame def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) def reset ( self ) : \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self . seek ( - 1 )","title":"Attributes"},{"location":"reference/wtracker/utils/frame_reader/#descendants_1","text":"wtracker.sim.view_controller.ViewController","title":"Descendants"},{"location":"reference/wtracker/utils/frame_reader/#instance-variables_2","text":"index The index of the current frame.","title":"Instance variables"},{"location":"reference/wtracker/utils/frame_reader/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/utils/frame_reader/#can_read","text":"def can_read ( self ) -> 'bool' View Source def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader )","title":"can_read"},{"location":"reference/wtracker/utils/frame_reader/#progress","text":"def progress ( self , n : 'int' = 1 ) -> 'bool' Moves the current index forward by the specified number of steps. Parameters: Name Type Description Default n int The number of steps to move forward. None Returns: Type Description bool True if the index was successfully moved forward, False otherwise. View Source def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n )","title":"progress"},{"location":"reference/wtracker/utils/frame_reader/#read","text":"def read ( self ) -> 'np.ndarray' Read and return the frame at the current index. Returns: Type Description np.ndarray The frame at the current index. Raises: Type Description IndexError If the index is out of bounds. View Source def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read () : raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self . _idx ] return self . frame","title":"read"},{"location":"reference/wtracker/utils/frame_reader/#reset","text":"def reset ( self ) Resets the frame reader to the beginning of the steam. View Source def reset(self): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self.seek(-1)","title":"reset"},{"location":"reference/wtracker/utils/frame_reader/#seek","text":"def seek ( self , idx : 'int' ) -> 'bool' Move the index to the specified position. Parameters: Name Type Description Default idx int The index to seek to. None Returns: Type Description bool True if the index is within the valid range, False otherwise. View Source def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read ()","title":"seek"},{"location":"reference/wtracker/utils/gui_utils/","text":"Module wtracker.utils.gui_utils View Source import tkinter as tk from tkinter import filedialog class FocusedWindow : def __init__ ( self ): root = tk . Tk () self . root = root self . hide () def __enter__ ( self ) -> tk . Tk : return self . focus () def __exit__ ( self , exc_type , exc_val , exc_tb ): self . hide () def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root def close ( self ): self . root . destroy () def __del__ ( self ): self . close () class UserPrompt : \"\"\"Class for creating a user prompt dialogs.\"\"\" @staticmethod def open_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ]: \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) @staticmethod def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , ) @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , ) Classes FocusedWindow class FocusedWindow ( ) View Source class FocusedWindow : def __init__ ( self ): root = tk . Tk () self . root = root self . hide () def __enter__ ( self ) -> tk . Tk : return self . focus () def __exit__ ( self , exc_type , exc_val , exc_tb ): self . hide () def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root def close ( self ): self . root . destroy () def __del__ ( self ): self . close () Methods close def close ( self ) View Source def close(self): self.root.destroy() focus def focus ( self ) -> tkinter . Tk View Source def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root hide def hide ( self ) -> tkinter . Tk View Source def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root UserPrompt class UserPrompt ( / , * args , ** kwargs ) Class for creating a user prompt dialogs. View Source class UserPrompt : \"\"\"Class for creating a user prompt dialogs.\"\"\" @staticmethod def open_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ] : \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) @staticmethod def save_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , ) @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , ) Static methods open_directory def open_directory ( title : str = None , ** kwargs ) -> str Opens a dialog box to select a directory. Parameters: Name Type Description Default title str The title of the dialog box. None **kwargs None Additional keyword arguments to be passed to the filedialog.askdirectory function. None Returns: Type Description str The path of the selected directory. View Source @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , ) open_file def open_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , multiple : bool = False , ** kwargs ) -> str | list [ str ] Opens a file dialog to select one or multiple files. Parameters: Name Type Description Default title str The title of the file dialog window. None file_types list[tuple[str, str]] A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). None multiple bool Whether to allow multiple file selection. None **kwargs None Additional keyword arguments to be passed to the file dialog. None Returns: Type Description str list[str] View Source @staticmethod def open_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ] : \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) save_file def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str Opens a file dialog to save a file and returns the selected file path. Parameters: Name Type Description Default title str The title of the file dialog window. None file_types list[tuple[str, str]] A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). None **kwargs None Additional keyword arguments to be passed to the file dialog. None Returns: Type Description str The selected file path. View Source @ staticmethod def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , )","title":"Gui Utils"},{"location":"reference/wtracker/utils/gui_utils/#module-wtrackerutilsgui_utils","text":"View Source import tkinter as tk from tkinter import filedialog class FocusedWindow : def __init__ ( self ): root = tk . Tk () self . root = root self . hide () def __enter__ ( self ) -> tk . Tk : return self . focus () def __exit__ ( self , exc_type , exc_val , exc_tb ): self . hide () def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root def close ( self ): self . root . destroy () def __del__ ( self ): self . close () class UserPrompt : \"\"\"Class for creating a user prompt dialogs.\"\"\" @staticmethod def open_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ]: \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) @staticmethod def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , ) @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , )","title":"Module wtracker.utils.gui_utils"},{"location":"reference/wtracker/utils/gui_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/gui_utils/#focusedwindow","text":"class FocusedWindow ( ) View Source class FocusedWindow : def __init__ ( self ): root = tk . Tk () self . root = root self . hide () def __enter__ ( self ) -> tk . Tk : return self . focus () def __exit__ ( self , exc_type , exc_val , exc_tb ): self . hide () def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root def close ( self ): self . root . destroy () def __del__ ( self ): self . close ()","title":"FocusedWindow"},{"location":"reference/wtracker/utils/gui_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/gui_utils/#close","text":"def close ( self ) View Source def close(self): self.root.destroy()","title":"close"},{"location":"reference/wtracker/utils/gui_utils/#focus","text":"def focus ( self ) -> tkinter . Tk View Source def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root","title":"focus"},{"location":"reference/wtracker/utils/gui_utils/#hide","text":"def hide ( self ) -> tkinter . Tk View Source def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root","title":"hide"},{"location":"reference/wtracker/utils/gui_utils/#userprompt","text":"class UserPrompt ( / , * args , ** kwargs ) Class for creating a user prompt dialogs. View Source class UserPrompt : \"\"\"Class for creating a user prompt dialogs.\"\"\" @staticmethod def open_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ] : \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) @staticmethod def save_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , ) @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , )","title":"UserPrompt"},{"location":"reference/wtracker/utils/gui_utils/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/gui_utils/#open_directory","text":"def open_directory ( title : str = None , ** kwargs ) -> str Opens a dialog box to select a directory. Parameters: Name Type Description Default title str The title of the dialog box. None **kwargs None Additional keyword arguments to be passed to the filedialog.askdirectory function. None Returns: Type Description str The path of the selected directory. View Source @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , )","title":"open_directory"},{"location":"reference/wtracker/utils/gui_utils/#open_file","text":"def open_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , multiple : bool = False , ** kwargs ) -> str | list [ str ] Opens a file dialog to select one or multiple files. Parameters: Name Type Description Default title str The title of the file dialog window. None file_types list[tuple[str, str]] A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). None multiple bool Whether to allow multiple file selection. None **kwargs None Additional keyword arguments to be passed to the file dialog. None Returns: Type Description str list[str] View Source @staticmethod def open_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ] : \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , )","title":"open_file"},{"location":"reference/wtracker/utils/gui_utils/#save_file","text":"def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str Opens a file dialog to save a file and returns the selected file path. Parameters: Name Type Description Default title str The title of the file dialog window. None file_types list[tuple[str, str]] A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). None **kwargs None Additional keyword arguments to be passed to the file dialog. None Returns: Type Description str The selected file path. View Source @ staticmethod def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , )","title":"save_file"},{"location":"reference/wtracker/utils/io_utils/","text":"Module wtracker.utils.io_utils View Source import cv2 as cv import numpy as np import pickle import math from wtracker.utils.path_utils import join_paths , create_directory , create_parent_directory from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import TaskScheduler class FrameSaver ( TaskScheduler ): \"\"\" A class for saving images from a frame reader to a specified folder. This class utilizes a queue to save images in a separate thread, which allows for non-blocking image saving. Args: frame_reader (FrameReader): The frame reader object from which images will be saved. root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , frame_reader : FrameReader , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ): super () . __init__ ( self . _save_frame , maxsize , tqdm , ** tqdm_kwargs ) self . _frame_reader = frame_reader self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ): \"\"\" Adds an image to the queue for saving. Args: img_index (int): The index of the image in the frame reader. crop_dims (tuple[float, float, float, float]): The crop dimensions (x, y, w, h) for the image. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super () . schedule_save ( img_index , crop_dims , img_name ) def _save_frame ( self , params : tuple [ int , tuple [ float , float , float , float ], str ]): img_index , crop_dims , img_name = params save_path = join_paths ( self . _root_path , img_name ) img = self . _frame_reader [ img_index ] x , y , w , h = crop_dims img = img [ y : y + h , x : x + w ] success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image { save_path } \" ) class ImageSaver ( TaskScheduler ): \"\"\" A class for saving images asynchronously using a task scheduler. Args: root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ): super () . __init__ ( self . _save_image , maxsize , tqdm , ** tqdm_kwargs ) self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img : np . ndarray , img_path : str ): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super () . schedule_save ( img , img_path ) def _save_image ( self , params : tuple [ np . ndarray , str ]): img , img_name = params save_path = join_paths ( self . _root_path , img_name ) success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image { save_path } \" ) def pickle_load_object ( file_path : str ): \"\"\" Load an object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: The loaded object. Raises: FileNotFoundError: If the file does not exist. ValueError: If there is an error loading the object from the pickle file. \"\"\" try : with open ( file_path , \"rb\" ) as f : return pickle . load ( f ) except FileNotFoundError : raise FileNotFoundError ( f \"file does not exist: { file_path } \" ) except Exception as e : raise ValueError ( f \"error loading object from pickle file: { e } \" ) def pickle_save_object ( obj , file_path : str ): \"\"\" Save an object to a pickle file. Args: obj: The object to be saved. file_path (str): The path to the pickle file. Raises: ValueError: If there is an error saving the object to the pickle file. \"\"\" try : create_parent_directory ( file_path ) with open ( file_path , \"wb\" ) as f : pickle . dump ( obj , f , protocol = pickle . HIGHEST_PROTOCOL ) except Exception as e : raise ValueError ( f \"error saving object to pickle file: { e } \" ) Functions pickle_load_object def pickle_load_object ( file_path : str ) Load an object from a pickle file. Parameters: Name Type Description Default file_path str The path to the pickle file. None Returns: Type Description None The loaded object. Raises: Type Description FileNotFoundError If the file does not exist. ValueError If there is an error loading the object from the pickle file. View Source def pickle_load_object ( file_path : str ): \"\"\" Load an object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: The loaded object. Raises: FileNotFoundError: If the file does not exist. ValueError: If there is an error loading the object from the pickle file. \"\"\" try : with open ( file_path , \"rb\" ) as f : return pickle . load ( f ) except FileNotFoundError : raise FileNotFoundError ( f \"file does not exist: {file_path}\" ) except Exception as e : raise ValueError ( f \"error loading object from pickle file: {e}\" ) pickle_save_object def pickle_save_object ( obj , file_path : str ) Save an object to a pickle file. Parameters: Name Type Description Default obj None The object to be saved. None file_path str The path to the pickle file. None Raises: Type Description ValueError If there is an error saving the object to the pickle file. View Source def pickle_save_object(obj, file_path: str): \"\"\" Save an object to a pickle file. Args: obj: The object to be saved. file_path (str): The path to the pickle file. Raises: ValueError: If there is an error saving the object to the pickle file. \"\"\" try: create_parent_directory(file_path) with open(file_path, \"wb\") as f: pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) except Exception as e: raise ValueError(f\"error saving object to pickle file: {e}\") Classes FrameSaver class FrameSaver ( frame_reader : wtracker . utils . frame_reader . FrameReader , root_path : str = '' , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs ) A class for saving images from a frame reader to a specified folder. This class utilizes a queue to save images in a separate thread, which allows for non-blocking image saving. Attributes Name Type Description Default frame_reader FrameReader The frame reader object from which images will be saved. None root_path str The root folder path, relative to which all other paths are. None maxsize int The maximum size of the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments for tqdm. None View Source class FrameSaver ( TaskScheduler ) : \"\"\" A class for saving images from a frame reader to a specified folder . This class utilizes a queue to save images in a separate thread , which allows for non - blocking image saving . Args : frame_reader ( FrameReader ) : The frame reader object from which images will be saved . root_path ( str ) : The root folder path , relative to which all other paths are . maxsize ( int , optional ) : The maximum size of the queue . tqdm ( bool , optional ) : Whether to use tqdm for progress tracking . ** tqdm_kwargs : Additional keyword arguments for tqdm . \"\"\" def __init__ ( self , frame_reader : FrameReader , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ) : super (). __init__ ( self . _save_frame , maxsize , tqdm , ** tqdm_kwargs ) self . _frame_reader = frame_reader self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ) : \"\"\" Adds an image to the queue for saving . Args : img_index ( int ) : The index of the image in the frame reader . crop_dims ( tuple [ float , float , float , float ]) : The crop dimensions ( x , y , w , h ) for the image . img_name ( str ) : The name ( path ) of the image file relative to the root path . \"\"\" super (). schedule_save ( img_index , crop_dims , img_name ) def _save_frame ( self , params : tuple [ int , tuple [ float , float , float , float ], str ]) : img_index , crop_dims , img_name = params save_path = join_paths ( self . _root_path , img_name ) img = self . _frame_reader [ img_index ] x , y , w , h = crop_dims img = img [ y : y + h , x : x + w ] success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ) : raise ValueError ( f \"Failed to save image {save_path}\" ) Ancestors (in MRO) wtracker.utils.threading_utils.TaskScheduler Methods close def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join() schedule_save def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ) Adds an image to the queue for saving. Parameters: Name Type Description Default img_index int The index of the image in the frame reader. None crop_dims tuple[float, float, float, float] The crop dimensions (x, y, w, h) for the image. None img_name str The name (path) of the image file relative to the root path. None View Source def schedule_save(self, img_index: int, crop_dims: tuple[float, float, float, float], img_name: str): \"\"\" Adds an image to the queue for saving. Args: img_index (int): The index of the image in the frame reader. crop_dims (tuple[float, float, float, float]): The crop dimensions (x, y, w, h) for the image. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super().schedule_save(img_index, crop_dims, img_name) start def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start() ImageSaver class ImageSaver ( root_path : str = '' , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs ) A class for saving images asynchronously using a task scheduler. Attributes Name Type Description Default root_path str The root folder path, relative to which all other paths are. None maxsize int The maximum size of the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments for tqdm. None View Source class ImageSaver ( TaskScheduler ): \"\"\" A class for saving images asynchronously using a task scheduler. Args: root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , root_path: str = \"\" , maxsize: int = 100 , tqdm: bool = True , ** tqdm_kwargs , ): super (). __init__ ( self . _save_image , maxsize , tqdm , ** tqdm_kwargs ) self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img: np . ndarray , img_path: str ): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super (). schedule_save ( img , img_path ) def _save_image ( self , params: tuple [ np . ndarray , str ]): img , img_name = params save_path = join_paths ( self . _root_path , img_name ) success = cv . imwrite ( save_path , img ) if not success: create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image {save_path}\" ) Ancestors (in MRO) wtracker.utils.threading_utils.TaskScheduler Methods close def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join() schedule_save def schedule_save ( self , img : numpy . ndarray , img_path : str ) Adds an image to the queue for saving. Parameters: Name Type Description Default img np.ndarray The image to save. None img_name str The name (path) of the image file relative to the root path. None View Source def schedule_save(self, img: np.ndarray, img_path: str): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super().schedule_save(img, img_path) start def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start()","title":"Io Utils"},{"location":"reference/wtracker/utils/io_utils/#module-wtrackerutilsio_utils","text":"View Source import cv2 as cv import numpy as np import pickle import math from wtracker.utils.path_utils import join_paths , create_directory , create_parent_directory from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import TaskScheduler class FrameSaver ( TaskScheduler ): \"\"\" A class for saving images from a frame reader to a specified folder. This class utilizes a queue to save images in a separate thread, which allows for non-blocking image saving. Args: frame_reader (FrameReader): The frame reader object from which images will be saved. root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , frame_reader : FrameReader , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ): super () . __init__ ( self . _save_frame , maxsize , tqdm , ** tqdm_kwargs ) self . _frame_reader = frame_reader self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ): \"\"\" Adds an image to the queue for saving. Args: img_index (int): The index of the image in the frame reader. crop_dims (tuple[float, float, float, float]): The crop dimensions (x, y, w, h) for the image. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super () . schedule_save ( img_index , crop_dims , img_name ) def _save_frame ( self , params : tuple [ int , tuple [ float , float , float , float ], str ]): img_index , crop_dims , img_name = params save_path = join_paths ( self . _root_path , img_name ) img = self . _frame_reader [ img_index ] x , y , w , h = crop_dims img = img [ y : y + h , x : x + w ] success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image { save_path } \" ) class ImageSaver ( TaskScheduler ): \"\"\" A class for saving images asynchronously using a task scheduler. Args: root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ): super () . __init__ ( self . _save_image , maxsize , tqdm , ** tqdm_kwargs ) self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img : np . ndarray , img_path : str ): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super () . schedule_save ( img , img_path ) def _save_image ( self , params : tuple [ np . ndarray , str ]): img , img_name = params save_path = join_paths ( self . _root_path , img_name ) success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image { save_path } \" ) def pickle_load_object ( file_path : str ): \"\"\" Load an object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: The loaded object. Raises: FileNotFoundError: If the file does not exist. ValueError: If there is an error loading the object from the pickle file. \"\"\" try : with open ( file_path , \"rb\" ) as f : return pickle . load ( f ) except FileNotFoundError : raise FileNotFoundError ( f \"file does not exist: { file_path } \" ) except Exception as e : raise ValueError ( f \"error loading object from pickle file: { e } \" ) def pickle_save_object ( obj , file_path : str ): \"\"\" Save an object to a pickle file. Args: obj: The object to be saved. file_path (str): The path to the pickle file. Raises: ValueError: If there is an error saving the object to the pickle file. \"\"\" try : create_parent_directory ( file_path ) with open ( file_path , \"wb\" ) as f : pickle . dump ( obj , f , protocol = pickle . HIGHEST_PROTOCOL ) except Exception as e : raise ValueError ( f \"error saving object to pickle file: { e } \" )","title":"Module wtracker.utils.io_utils"},{"location":"reference/wtracker/utils/io_utils/#functions","text":"","title":"Functions"},{"location":"reference/wtracker/utils/io_utils/#pickle_load_object","text":"def pickle_load_object ( file_path : str ) Load an object from a pickle file. Parameters: Name Type Description Default file_path str The path to the pickle file. None Returns: Type Description None The loaded object. Raises: Type Description FileNotFoundError If the file does not exist. ValueError If there is an error loading the object from the pickle file. View Source def pickle_load_object ( file_path : str ): \"\"\" Load an object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: The loaded object. Raises: FileNotFoundError: If the file does not exist. ValueError: If there is an error loading the object from the pickle file. \"\"\" try : with open ( file_path , \"rb\" ) as f : return pickle . load ( f ) except FileNotFoundError : raise FileNotFoundError ( f \"file does not exist: {file_path}\" ) except Exception as e : raise ValueError ( f \"error loading object from pickle file: {e}\" )","title":"pickle_load_object"},{"location":"reference/wtracker/utils/io_utils/#pickle_save_object","text":"def pickle_save_object ( obj , file_path : str ) Save an object to a pickle file. Parameters: Name Type Description Default obj None The object to be saved. None file_path str The path to the pickle file. None Raises: Type Description ValueError If there is an error saving the object to the pickle file. View Source def pickle_save_object(obj, file_path: str): \"\"\" Save an object to a pickle file. Args: obj: The object to be saved. file_path (str): The path to the pickle file. Raises: ValueError: If there is an error saving the object to the pickle file. \"\"\" try: create_parent_directory(file_path) with open(file_path, \"wb\") as f: pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) except Exception as e: raise ValueError(f\"error saving object to pickle file: {e}\")","title":"pickle_save_object"},{"location":"reference/wtracker/utils/io_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/io_utils/#framesaver","text":"class FrameSaver ( frame_reader : wtracker . utils . frame_reader . FrameReader , root_path : str = '' , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs ) A class for saving images from a frame reader to a specified folder. This class utilizes a queue to save images in a separate thread, which allows for non-blocking image saving.","title":"FrameSaver"},{"location":"reference/wtracker/utils/io_utils/#attributes","text":"Name Type Description Default frame_reader FrameReader The frame reader object from which images will be saved. None root_path str The root folder path, relative to which all other paths are. None maxsize int The maximum size of the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments for tqdm. None View Source class FrameSaver ( TaskScheduler ) : \"\"\" A class for saving images from a frame reader to a specified folder . This class utilizes a queue to save images in a separate thread , which allows for non - blocking image saving . Args : frame_reader ( FrameReader ) : The frame reader object from which images will be saved . root_path ( str ) : The root folder path , relative to which all other paths are . maxsize ( int , optional ) : The maximum size of the queue . tqdm ( bool , optional ) : Whether to use tqdm for progress tracking . ** tqdm_kwargs : Additional keyword arguments for tqdm . \"\"\" def __init__ ( self , frame_reader : FrameReader , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ) : super (). __init__ ( self . _save_frame , maxsize , tqdm , ** tqdm_kwargs ) self . _frame_reader = frame_reader self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ) : \"\"\" Adds an image to the queue for saving . Args : img_index ( int ) : The index of the image in the frame reader . crop_dims ( tuple [ float , float , float , float ]) : The crop dimensions ( x , y , w , h ) for the image . img_name ( str ) : The name ( path ) of the image file relative to the root path . \"\"\" super (). schedule_save ( img_index , crop_dims , img_name ) def _save_frame ( self , params : tuple [ int , tuple [ float , float , float , float ], str ]) : img_index , crop_dims , img_name = params save_path = join_paths ( self . _root_path , img_name ) img = self . _frame_reader [ img_index ] x , y , w , h = crop_dims img = img [ y : y + h , x : x + w ] success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ) : raise ValueError ( f \"Failed to save image {save_path}\" )","title":"Attributes"},{"location":"reference/wtracker/utils/io_utils/#ancestors-in-mro","text":"wtracker.utils.threading_utils.TaskScheduler","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/io_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/io_utils/#close","text":"def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join()","title":"close"},{"location":"reference/wtracker/utils/io_utils/#schedule_save","text":"def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ) Adds an image to the queue for saving. Parameters: Name Type Description Default img_index int The index of the image in the frame reader. None crop_dims tuple[float, float, float, float] The crop dimensions (x, y, w, h) for the image. None img_name str The name (path) of the image file relative to the root path. None View Source def schedule_save(self, img_index: int, crop_dims: tuple[float, float, float, float], img_name: str): \"\"\" Adds an image to the queue for saving. Args: img_index (int): The index of the image in the frame reader. crop_dims (tuple[float, float, float, float]): The crop dimensions (x, y, w, h) for the image. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super().schedule_save(img_index, crop_dims, img_name)","title":"schedule_save"},{"location":"reference/wtracker/utils/io_utils/#start","text":"def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start()","title":"start"},{"location":"reference/wtracker/utils/io_utils/#imagesaver","text":"class ImageSaver ( root_path : str = '' , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs ) A class for saving images asynchronously using a task scheduler.","title":"ImageSaver"},{"location":"reference/wtracker/utils/io_utils/#attributes_1","text":"Name Type Description Default root_path str The root folder path, relative to which all other paths are. None maxsize int The maximum size of the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments for tqdm. None View Source class ImageSaver ( TaskScheduler ): \"\"\" A class for saving images asynchronously using a task scheduler. Args: root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , root_path: str = \"\" , maxsize: int = 100 , tqdm: bool = True , ** tqdm_kwargs , ): super (). __init__ ( self . _save_image , maxsize , tqdm , ** tqdm_kwargs ) self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img: np . ndarray , img_path: str ): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super (). schedule_save ( img , img_path ) def _save_image ( self , params: tuple [ np . ndarray , str ]): img , img_name = params save_path = join_paths ( self . _root_path , img_name ) success = cv . imwrite ( save_path , img ) if not success: create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image {save_path}\" )","title":"Attributes"},{"location":"reference/wtracker/utils/io_utils/#ancestors-in-mro_1","text":"wtracker.utils.threading_utils.TaskScheduler","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/io_utils/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/utils/io_utils/#close_1","text":"def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join()","title":"close"},{"location":"reference/wtracker/utils/io_utils/#schedule_save_1","text":"def schedule_save ( self , img : numpy . ndarray , img_path : str ) Adds an image to the queue for saving. Parameters: Name Type Description Default img np.ndarray The image to save. None img_name str The name (path) of the image file relative to the root path. None View Source def schedule_save(self, img: np.ndarray, img_path: str): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super().schedule_save(img, img_path)","title":"schedule_save"},{"location":"reference/wtracker/utils/io_utils/#start_1","text":"def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start()","title":"start"},{"location":"reference/wtracker/utils/log_utils/","text":"Module wtracker.utils.log_utils View Source import csv from typing import Iterable class CSVLogger : \"\"\" A class for logging data to a CSV file. Args: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. mode (str, optional): The file mode to open the CSV file in. Attributes: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. \"\"\" def __init__ ( self , path : str , col_names : list [ str ], mode : str = \"w+\" ): self . path = path self . col_names = col_names self . _file = open ( self . path , mode , newline = \"\" ) self . _writer = csv . DictWriter ( self . _file , self . col_names , escapechar = \",\" ) self . _writer . writeheader () self . flush () def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def close ( self ): \"\"\" Closes the CSV file. \"\"\" if not self . _file . closed : self . _file . flush () self . _file . close () def _to_dict ( self , items : Iterable ) -> dict : \"\"\" Converts an iterable of items to a dictionary using the column names as keys. Args: items (Iterable): The items to convert to a dictionary. Returns: dict: The dictionary with column names as keys and items as values. \"\"\" return { k : v for k , v in zip ( self . col_names , items )} def write ( self , row : dict | Iterable ): \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () if not isinstance ( row , dict ): row = self . _to_dict ( row ) self . _writer . writerow ( row ) def writerows ( self , rows : list [ dict ] | list [ Iterable ]): \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ], dict ): rows = [ self . _to_dict ( row ) for row in rows ] self . _writer . writerows ( rows ) def flush ( self ): \"\"\" Flushes any buffered data to the CSV file. \"\"\" self . _file . flush () Classes CSVLogger class CSVLogger ( path : str , col_names : list [ str ], mode : str = 'w+' ) A class for logging data to a CSV file. Attributes Name Type Description Default path str The path to the CSV file. None col_names list[str] The column names for the CSV file. None mode str The file mode to open the CSV file in. None path str The path to the CSV file. None col_names list[str] The column names for the CSV file. None View Source class CSVLogger : \"\"\" A class for logging data to a CSV file. Args: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. mode (str, optional): The file mode to open the CSV file in. Attributes: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. \"\"\" def __init__ ( self , path : str , col_names : list [ str ] , mode : str = \"w+\" ) : self . path = path self . col_names = col_names self . _file = open ( self . path , mode , newline = \"\" ) self . _writer = csv . DictWriter ( self . _file , self . col_names , escapechar = \",\" ) self . _writer . writeheader () self . flush () def __enter__ ( self ) : return self def __exit__ ( self , exc_type , exc_value , traceback ) : self . close () def close ( self ) : \"\"\" Closes the CSV file. \"\"\" if not self . _file . closed : self . _file . flush () self . _file . close () def _to_dict ( self , items : Iterable ) -> dict : \"\"\" Converts an iterable of items to a dictionary using the column names as keys. Args: items (Iterable): The items to convert to a dictionary. Returns: dict: The dictionary with column names as keys and items as values. \"\"\" return { k : v for k , v in zip ( self . col_names , items ) } def write ( self , row : dict | Iterable ) : \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () if not isinstance ( row , dict ) : row = self . _to_dict ( row ) self . _writer . writerow ( row ) def writerows ( self , rows : list [ dict ] | list [ Iterable ] ) : \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ] , dict ) : rows = [ self._to_dict(row) for row in rows ] self . _writer . writerows ( rows ) def flush ( self ) : \"\"\" Flushes any buffered data to the CSV file. \"\"\" self . _file . flush () Methods close def close ( self ) Closes the CSV file. View Source def close(self): \"\"\" Closes the CSV file. \"\"\" if not self._file.closed: self._file.flush() self._file.close() flush def flush ( self ) Flushes any buffered data to the CSV file. View Source def flush(self): \"\"\" Flushes any buffered data to the CSV file. \"\"\" self._file.flush() write def write ( self , row : Union [ dict , Iterable ] ) Writes a single row of data to the CSV file. Parameters: Name Type Description Default row dict Iterable The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. View Source def write(self, row: dict | Iterable): \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self._file.writable() if not isinstance(row, dict): row = self._to_dict(row) self._writer.writerow(row) writerows def writerows ( self , rows : list [ dict ] | list [ typing . Iterable ] ) Writes multiple rows of data to the CSV file. Parameters: Name Type Description Default rows list[dict] list[Iterable] The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. View Source def writerows ( self , rows : list [ dict ] | list [ Iterable ] ) : \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ] , dict ) : rows = [ self._to_dict(row) for row in rows ] self . _writer . writerows ( rows )","title":"Log Utils"},{"location":"reference/wtracker/utils/log_utils/#module-wtrackerutilslog_utils","text":"View Source import csv from typing import Iterable class CSVLogger : \"\"\" A class for logging data to a CSV file. Args: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. mode (str, optional): The file mode to open the CSV file in. Attributes: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. \"\"\" def __init__ ( self , path : str , col_names : list [ str ], mode : str = \"w+\" ): self . path = path self . col_names = col_names self . _file = open ( self . path , mode , newline = \"\" ) self . _writer = csv . DictWriter ( self . _file , self . col_names , escapechar = \",\" ) self . _writer . writeheader () self . flush () def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def close ( self ): \"\"\" Closes the CSV file. \"\"\" if not self . _file . closed : self . _file . flush () self . _file . close () def _to_dict ( self , items : Iterable ) -> dict : \"\"\" Converts an iterable of items to a dictionary using the column names as keys. Args: items (Iterable): The items to convert to a dictionary. Returns: dict: The dictionary with column names as keys and items as values. \"\"\" return { k : v for k , v in zip ( self . col_names , items )} def write ( self , row : dict | Iterable ): \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () if not isinstance ( row , dict ): row = self . _to_dict ( row ) self . _writer . writerow ( row ) def writerows ( self , rows : list [ dict ] | list [ Iterable ]): \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ], dict ): rows = [ self . _to_dict ( row ) for row in rows ] self . _writer . writerows ( rows ) def flush ( self ): \"\"\" Flushes any buffered data to the CSV file. \"\"\" self . _file . flush ()","title":"Module wtracker.utils.log_utils"},{"location":"reference/wtracker/utils/log_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/log_utils/#csvlogger","text":"class CSVLogger ( path : str , col_names : list [ str ], mode : str = 'w+' ) A class for logging data to a CSV file.","title":"CSVLogger"},{"location":"reference/wtracker/utils/log_utils/#attributes","text":"Name Type Description Default path str The path to the CSV file. None col_names list[str] The column names for the CSV file. None mode str The file mode to open the CSV file in. None path str The path to the CSV file. None col_names list[str] The column names for the CSV file. None View Source class CSVLogger : \"\"\" A class for logging data to a CSV file. Args: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. mode (str, optional): The file mode to open the CSV file in. Attributes: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. \"\"\" def __init__ ( self , path : str , col_names : list [ str ] , mode : str = \"w+\" ) : self . path = path self . col_names = col_names self . _file = open ( self . path , mode , newline = \"\" ) self . _writer = csv . DictWriter ( self . _file , self . col_names , escapechar = \",\" ) self . _writer . writeheader () self . flush () def __enter__ ( self ) : return self def __exit__ ( self , exc_type , exc_value , traceback ) : self . close () def close ( self ) : \"\"\" Closes the CSV file. \"\"\" if not self . _file . closed : self . _file . flush () self . _file . close () def _to_dict ( self , items : Iterable ) -> dict : \"\"\" Converts an iterable of items to a dictionary using the column names as keys. Args: items (Iterable): The items to convert to a dictionary. Returns: dict: The dictionary with column names as keys and items as values. \"\"\" return { k : v for k , v in zip ( self . col_names , items ) } def write ( self , row : dict | Iterable ) : \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () if not isinstance ( row , dict ) : row = self . _to_dict ( row ) self . _writer . writerow ( row ) def writerows ( self , rows : list [ dict ] | list [ Iterable ] ) : \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ] , dict ) : rows = [ self._to_dict(row) for row in rows ] self . _writer . writerows ( rows ) def flush ( self ) : \"\"\" Flushes any buffered data to the CSV file. \"\"\" self . _file . flush ()","title":"Attributes"},{"location":"reference/wtracker/utils/log_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/log_utils/#close","text":"def close ( self ) Closes the CSV file. View Source def close(self): \"\"\" Closes the CSV file. \"\"\" if not self._file.closed: self._file.flush() self._file.close()","title":"close"},{"location":"reference/wtracker/utils/log_utils/#flush","text":"def flush ( self ) Flushes any buffered data to the CSV file. View Source def flush(self): \"\"\" Flushes any buffered data to the CSV file. \"\"\" self._file.flush()","title":"flush"},{"location":"reference/wtracker/utils/log_utils/#write","text":"def write ( self , row : Union [ dict , Iterable ] ) Writes a single row of data to the CSV file. Parameters: Name Type Description Default row dict Iterable The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. View Source def write(self, row: dict | Iterable): \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self._file.writable() if not isinstance(row, dict): row = self._to_dict(row) self._writer.writerow(row)","title":"write"},{"location":"reference/wtracker/utils/log_utils/#writerows","text":"def writerows ( self , rows : list [ dict ] | list [ typing . Iterable ] ) Writes multiple rows of data to the CSV file. Parameters: Name Type Description Default rows list[dict] list[Iterable] The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. View Source def writerows ( self , rows : list [ dict ] | list [ Iterable ] ) : \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ] , dict ) : rows = [ self._to_dict(row) for row in rows ] self . _writer . writerows ( rows )","title":"writerows"},{"location":"reference/wtracker/utils/path_utils/","text":"Module wtracker.utils.path_utils View Source from __future__ import annotations import os from pathlib import Path , PurePath from typing import Callable , Union import shutil def absolute_path ( file_path : str ) -> str : \"\"\" Get the absolute path of a file. Args: file_path (str): The path of the file. Returns: str: The absolute path of the file. \"\"\" return Path ( file_path ) . resolve () . as_posix () def join_paths ( * path_segments : str ): \"\"\" Join multiple path segments into a single path. Args: *path_segments: Variable number of path segments to be joined. Returns: str: The joined path as a string. Example: >>> join_paths('home', 'yashlat', 'source', 'Bio-Proj', 'data') 'home/yashlat/source/Bio-Proj/data' \"\"\" return PurePath ( * path_segments ) . as_posix () def create_parent_directory ( file_path : str ): \"\"\" Create the parent directory for the given file path if it doesn't exist. Args: file_path (str): The path of the file. Returns: None \"\"\" save_folder = Path ( file_path ) . parent save_folder . mkdir ( parents = True , exist_ok = True ) def create_directory ( dir_path : str ): \"\"\" Create a directory at the specified path if it doesn't already exist. Args: dir_path (str): The path of the directory to be created. Returns: None \"\"\" Path ( dir_path ) . mkdir ( parents = True , exist_ok = True ) def bulk_rename ( dir_path : str , rename_fn : Callable [[ str ], str ]): \"\"\" Rename all files in a directory using the provided renaming function. Args: dir_path (str): The path of the directory containing the files to be renamed. rename_fn (Callable[[str], str]): The function to be used for renaming the files. Returns: None \"\"\" path : Path = Path ( dir_path ) for file_name in path . iterdir (): if file_name . is_dir (): continue new_name = path / rename_fn ( file_name . name ) file_name . rename ( new_name ) class Files : \"\"\" A utility class for working with files in a directory. Args: directory (str): The directory path to scan for files. extension (str, optional): The file extension to filter the files. scan_dirs (bool, optional): Whether to include directories in the results. return_full_path (bool, optional): Whether to return the full path of the files. sorting_key (Callable[[str], Union[int, str]], optional): A function to determine the sorting order of the files. \"\"\" def __init__ ( self , directory : str , extension : str = \"\" , scan_dirs : bool = False , return_full_path : bool = True , sorting_key : Callable [[ str ], Union [ int , str ]] = lambda name : name , ) -> None : self . root = directory self . extension = extension . lower () self . scan_dirs : bool = scan_dirs self . return_full_path = return_full_path self . results : list [ os . DirEntry ] = [] self . sorting_func = sorting_key self . _pos = - 1 self . _scan () def _scan ( self ): self . results = [] self . _pos = - 1 for result in os . scandir ( self . root ): if self . scan_dirs and result . is_dir (): self . results . append ( result ) else : if result . name . lower () . endswith ( self . extension ): self . results . append ( result ) self . results = sorted ( self . results , key = lambda f : self . sorting_func ( f . name )) def __getitem__ ( self , index : int ) -> os . DirEntry : \"\"\" Returns the file at the specified index. Args: index (int): The index of the file. Returns: os.DirEntry: The file at the specified index. \"\"\" return self . results [ index ] def __iter__ ( self ) -> Files : \"\"\" Returns an iterator object. Returns: Files: The iterator object. \"\"\" self . _pos = - 1 return self def __next__ ( self ) -> str : \"\"\" Returns the next file name or path in the iteration. Returns: str: The next file name or path. Raises: StopIteration: If there are no more files in the iteration. \"\"\" self . _pos += 1 if self . _pos >= self . __len__ (): raise StopIteration result = self . results [ self . _pos ] if self . return_full_path : return result . path return result . name def __len__ ( self ) -> int : \"\"\" Returns the number of files in the results list. Returns: int: The number of files. \"\"\" return len ( self . results ) def __contains__ ( self , key : str ) -> bool : \"\"\" Checks if a file with the specified name exists in the results list. Args: key (str): The file name to check. Returns: bool: True if the file exists, False otherwise. \"\"\" for res in self . results : if key == res . name : return True return False def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self . _pos ] . name def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self . _pos ] . path def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ () def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root ) Functions absolute_path def absolute_path ( file_path : 'str' ) -> 'str' Get the absolute path of a file. Parameters: Name Type Description Default file_path str The path of the file. None Returns: Type Description str The absolute path of the file. View Source def absolute_path ( file_path : str ) -> str : \"\"\" Get the absolute path of a file. Args: file_path (str): The path of the file. Returns: str: The absolute path of the file. \"\"\" return Path ( file_path ). resolve (). as_posix () bulk_rename def bulk_rename ( dir_path : 'str' , rename_fn : 'Callable[[str], str]' ) Rename all files in a directory using the provided renaming function. Parameters: Name Type Description Default dir_path str The path of the directory containing the files to be renamed. None rename_fn Callable[[str], str] The function to be used for renaming the files. None Returns: Type Description None None View Source def bulk_rename ( dir_path : str , rename_fn : Callable [ [str ] , str ] ) : \"\"\" Rename all files in a directory using the provided renaming function. Args: dir_path (str): The path of the directory containing the files to be renamed. rename_fn (Callable[[str], str]): The function to be used for renaming the files. Returns: None \"\"\" path : Path = Path ( dir_path ) for file_name in path . iterdir () : if file_name . is_dir () : continue new_name = path / rename_fn ( file_name . name ) file_name . rename ( new_name ) create_directory def create_directory ( dir_path : 'str' ) Create a directory at the specified path if it doesn't already exist. Parameters: Name Type Description Default dir_path str The path of the directory to be created. None Returns: Type Description None None View Source def create_directory(dir_path: str): \"\"\" Create a directory at the specified path if it doesn't already exist. Args: dir_path (str): The path of the directory to be created. Returns: None \"\"\" Path(dir_path).mkdir(parents=True, exist_ok=True) create_parent_directory def create_parent_directory ( file_path : 'str' ) Create the parent directory for the given file path if it doesn't exist. Parameters: Name Type Description Default file_path str The path of the file. None Returns: Type Description None None View Source def create_parent_directory(file_path: str): \"\"\" Create the parent directory for the given file path if it doesn't exist. Args: file_path (str): The path of the file. Returns: None \"\"\" save_folder = Path(file_path).parent save_folder.mkdir(parents=True, exist_ok=True) join_paths def join_paths ( * path_segments : 'str' ) Join multiple path segments into a single path. Parameters: Name Type Description Default *path_segments None Variable number of path segments to be joined. None Returns: Type Description str The joined path as a string. View Source def join_paths(*path_segments: str): \"\"\" Join multiple path segments into a single path. Args: *path_segments: Variable number of path segments to be joined. Returns: str: The joined path as a string. Example: >>> join_paths('home', 'yashlat', 'source', 'Bio-Proj', 'data') 'home/yashlat/source/Bio-Proj/data' \"\"\" return PurePath(*path_segments).as_posix() Classes Files class Files ( directory : 'str' , extension : 'str' = '' , scan_dirs : 'bool' = False , return_full_path : 'bool' = True , sorting_key : 'Callable[[str], Union[int, str]]' = < function Files .< lambda > at 0x7f894dc4e290 > ) A utility class for working with files in a directory. Attributes Name Type Description Default directory str The directory path to scan for files. None extension str The file extension to filter the files. None scan_dirs bool Whether to include directories in the results. None return_full_path bool Whether to return the full path of the files. None sorting_key Callable[[str], Union[int, str]] A function to determine the sorting order of the files. None View Source class Files : \"\"\" A utility class for working with files in a directory. Args: directory (str): The directory path to scan for files. extension (str, optional): The file extension to filter the files. scan_dirs (bool, optional): Whether to include directories in the results. return_full_path (bool, optional): Whether to return the full path of the files. sorting_key (Callable[[str], Union[int, str]], optional): A function to determine the sorting order of the files. \"\"\" def __init__ ( self , directory : str , extension : str = \"\" , scan_dirs : bool = False , return_full_path : bool = True , sorting_key : Callable [ [str ] , Union [ int, str ] ] = lambda name : name , ) -> None : self . root = directory self . extension = extension . lower () self . scan_dirs : bool = scan_dirs self . return_full_path = return_full_path self . results : list [ os.DirEntry ] = [] self . sorting_func = sorting_key self . _pos = - 1 self . _scan () def _scan ( self ) : self . results = [] self . _pos = - 1 for result in os . scandir ( self . root ) : if self . scan_dirs and result . is_dir () : self . results . append ( result ) else : if result . name . lower (). endswith ( self . extension ) : self . results . append ( result ) self . results = sorted ( self . results , key = lambda f : self . sorting_func ( f . name )) def __getitem__ ( self , index : int ) -> os . DirEntry : \"\"\" Returns the file at the specified index. Args: index (int): The index of the file. Returns: os.DirEntry: The file at the specified index. \"\"\" return self . results [ index ] def __iter__ ( self ) -> Files : \"\"\" Returns an iterator object. Returns: Files: The iterator object. \"\"\" self . _pos = - 1 return self def __next__ ( self ) -> str : \"\"\" Returns the next file name or path in the iteration. Returns: str: The next file name or path. Raises: StopIteration: If there are no more files in the iteration. \"\"\" self . _pos += 1 if self . _pos >= self . __len__ () : raise StopIteration result = self . results [ self._pos ] if self . return_full_path : return result . path return result . name def __len__ ( self ) -> int : \"\"\" Returns the number of files in the results list. Returns: int: The number of files. \"\"\" return len ( self . results ) def __contains__ ( self , key : str ) -> bool : \"\"\" Checks if a file with the specified name exists in the results list. Args: key (str): The file name to check. Returns: bool: True if the file exists, False otherwise. \"\"\" for res in self . results : if key == res . name : return True return False def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self._pos ] . name def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self._pos ] . path def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ () def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root ) Methods copy def copy ( self , dst_root : 'str' ) -> 'None' Copies the current file to the specified destination directory. Parameters: Name Type Description Default dst_root str The destination directory path. None View Source def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root ) get_filename def get_filename ( self ) -> 'str' Returns the name of the current file. Returns: Type Description str The name of the current file. View Source def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self . _pos ]. name get_path def get_path ( self ) -> 'str' Returns the path of the current file. Returns: Type Description str The path of the current file. View Source def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self . _pos ]. path seek def seek ( self , pos : 'int' ) -> 'str' Moves the iterator to the specified position and returns the file name or path. Parameters: Name Type Description Default pos int The position to seek to. None Returns: Type Description str The file name or path at the specified position. Raises: Type Description AssertionError If the specified position is invalid. View Source def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ ()","title":"Path Utils"},{"location":"reference/wtracker/utils/path_utils/#module-wtrackerutilspath_utils","text":"View Source from __future__ import annotations import os from pathlib import Path , PurePath from typing import Callable , Union import shutil def absolute_path ( file_path : str ) -> str : \"\"\" Get the absolute path of a file. Args: file_path (str): The path of the file. Returns: str: The absolute path of the file. \"\"\" return Path ( file_path ) . resolve () . as_posix () def join_paths ( * path_segments : str ): \"\"\" Join multiple path segments into a single path. Args: *path_segments: Variable number of path segments to be joined. Returns: str: The joined path as a string. Example: >>> join_paths('home', 'yashlat', 'source', 'Bio-Proj', 'data') 'home/yashlat/source/Bio-Proj/data' \"\"\" return PurePath ( * path_segments ) . as_posix () def create_parent_directory ( file_path : str ): \"\"\" Create the parent directory for the given file path if it doesn't exist. Args: file_path (str): The path of the file. Returns: None \"\"\" save_folder = Path ( file_path ) . parent save_folder . mkdir ( parents = True , exist_ok = True ) def create_directory ( dir_path : str ): \"\"\" Create a directory at the specified path if it doesn't already exist. Args: dir_path (str): The path of the directory to be created. Returns: None \"\"\" Path ( dir_path ) . mkdir ( parents = True , exist_ok = True ) def bulk_rename ( dir_path : str , rename_fn : Callable [[ str ], str ]): \"\"\" Rename all files in a directory using the provided renaming function. Args: dir_path (str): The path of the directory containing the files to be renamed. rename_fn (Callable[[str], str]): The function to be used for renaming the files. Returns: None \"\"\" path : Path = Path ( dir_path ) for file_name in path . iterdir (): if file_name . is_dir (): continue new_name = path / rename_fn ( file_name . name ) file_name . rename ( new_name ) class Files : \"\"\" A utility class for working with files in a directory. Args: directory (str): The directory path to scan for files. extension (str, optional): The file extension to filter the files. scan_dirs (bool, optional): Whether to include directories in the results. return_full_path (bool, optional): Whether to return the full path of the files. sorting_key (Callable[[str], Union[int, str]], optional): A function to determine the sorting order of the files. \"\"\" def __init__ ( self , directory : str , extension : str = \"\" , scan_dirs : bool = False , return_full_path : bool = True , sorting_key : Callable [[ str ], Union [ int , str ]] = lambda name : name , ) -> None : self . root = directory self . extension = extension . lower () self . scan_dirs : bool = scan_dirs self . return_full_path = return_full_path self . results : list [ os . DirEntry ] = [] self . sorting_func = sorting_key self . _pos = - 1 self . _scan () def _scan ( self ): self . results = [] self . _pos = - 1 for result in os . scandir ( self . root ): if self . scan_dirs and result . is_dir (): self . results . append ( result ) else : if result . name . lower () . endswith ( self . extension ): self . results . append ( result ) self . results = sorted ( self . results , key = lambda f : self . sorting_func ( f . name )) def __getitem__ ( self , index : int ) -> os . DirEntry : \"\"\" Returns the file at the specified index. Args: index (int): The index of the file. Returns: os.DirEntry: The file at the specified index. \"\"\" return self . results [ index ] def __iter__ ( self ) -> Files : \"\"\" Returns an iterator object. Returns: Files: The iterator object. \"\"\" self . _pos = - 1 return self def __next__ ( self ) -> str : \"\"\" Returns the next file name or path in the iteration. Returns: str: The next file name or path. Raises: StopIteration: If there are no more files in the iteration. \"\"\" self . _pos += 1 if self . _pos >= self . __len__ (): raise StopIteration result = self . results [ self . _pos ] if self . return_full_path : return result . path return result . name def __len__ ( self ) -> int : \"\"\" Returns the number of files in the results list. Returns: int: The number of files. \"\"\" return len ( self . results ) def __contains__ ( self , key : str ) -> bool : \"\"\" Checks if a file with the specified name exists in the results list. Args: key (str): The file name to check. Returns: bool: True if the file exists, False otherwise. \"\"\" for res in self . results : if key == res . name : return True return False def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self . _pos ] . name def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self . _pos ] . path def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ () def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root )","title":"Module wtracker.utils.path_utils"},{"location":"reference/wtracker/utils/path_utils/#functions","text":"","title":"Functions"},{"location":"reference/wtracker/utils/path_utils/#absolute_path","text":"def absolute_path ( file_path : 'str' ) -> 'str' Get the absolute path of a file. Parameters: Name Type Description Default file_path str The path of the file. None Returns: Type Description str The absolute path of the file. View Source def absolute_path ( file_path : str ) -> str : \"\"\" Get the absolute path of a file. Args: file_path (str): The path of the file. Returns: str: The absolute path of the file. \"\"\" return Path ( file_path ). resolve (). as_posix ()","title":"absolute_path"},{"location":"reference/wtracker/utils/path_utils/#bulk_rename","text":"def bulk_rename ( dir_path : 'str' , rename_fn : 'Callable[[str], str]' ) Rename all files in a directory using the provided renaming function. Parameters: Name Type Description Default dir_path str The path of the directory containing the files to be renamed. None rename_fn Callable[[str], str] The function to be used for renaming the files. None Returns: Type Description None None View Source def bulk_rename ( dir_path : str , rename_fn : Callable [ [str ] , str ] ) : \"\"\" Rename all files in a directory using the provided renaming function. Args: dir_path (str): The path of the directory containing the files to be renamed. rename_fn (Callable[[str], str]): The function to be used for renaming the files. Returns: None \"\"\" path : Path = Path ( dir_path ) for file_name in path . iterdir () : if file_name . is_dir () : continue new_name = path / rename_fn ( file_name . name ) file_name . rename ( new_name )","title":"bulk_rename"},{"location":"reference/wtracker/utils/path_utils/#create_directory","text":"def create_directory ( dir_path : 'str' ) Create a directory at the specified path if it doesn't already exist. Parameters: Name Type Description Default dir_path str The path of the directory to be created. None Returns: Type Description None None View Source def create_directory(dir_path: str): \"\"\" Create a directory at the specified path if it doesn't already exist. Args: dir_path (str): The path of the directory to be created. Returns: None \"\"\" Path(dir_path).mkdir(parents=True, exist_ok=True)","title":"create_directory"},{"location":"reference/wtracker/utils/path_utils/#create_parent_directory","text":"def create_parent_directory ( file_path : 'str' ) Create the parent directory for the given file path if it doesn't exist. Parameters: Name Type Description Default file_path str The path of the file. None Returns: Type Description None None View Source def create_parent_directory(file_path: str): \"\"\" Create the parent directory for the given file path if it doesn't exist. Args: file_path (str): The path of the file. Returns: None \"\"\" save_folder = Path(file_path).parent save_folder.mkdir(parents=True, exist_ok=True)","title":"create_parent_directory"},{"location":"reference/wtracker/utils/path_utils/#join_paths","text":"def join_paths ( * path_segments : 'str' ) Join multiple path segments into a single path. Parameters: Name Type Description Default *path_segments None Variable number of path segments to be joined. None Returns: Type Description str The joined path as a string. View Source def join_paths(*path_segments: str): \"\"\" Join multiple path segments into a single path. Args: *path_segments: Variable number of path segments to be joined. Returns: str: The joined path as a string. Example: >>> join_paths('home', 'yashlat', 'source', 'Bio-Proj', 'data') 'home/yashlat/source/Bio-Proj/data' \"\"\" return PurePath(*path_segments).as_posix()","title":"join_paths"},{"location":"reference/wtracker/utils/path_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/path_utils/#files","text":"class Files ( directory : 'str' , extension : 'str' = '' , scan_dirs : 'bool' = False , return_full_path : 'bool' = True , sorting_key : 'Callable[[str], Union[int, str]]' = < function Files .< lambda > at 0x7f894dc4e290 > ) A utility class for working with files in a directory.","title":"Files"},{"location":"reference/wtracker/utils/path_utils/#attributes","text":"Name Type Description Default directory str The directory path to scan for files. None extension str The file extension to filter the files. None scan_dirs bool Whether to include directories in the results. None return_full_path bool Whether to return the full path of the files. None sorting_key Callable[[str], Union[int, str]] A function to determine the sorting order of the files. None View Source class Files : \"\"\" A utility class for working with files in a directory. Args: directory (str): The directory path to scan for files. extension (str, optional): The file extension to filter the files. scan_dirs (bool, optional): Whether to include directories in the results. return_full_path (bool, optional): Whether to return the full path of the files. sorting_key (Callable[[str], Union[int, str]], optional): A function to determine the sorting order of the files. \"\"\" def __init__ ( self , directory : str , extension : str = \"\" , scan_dirs : bool = False , return_full_path : bool = True , sorting_key : Callable [ [str ] , Union [ int, str ] ] = lambda name : name , ) -> None : self . root = directory self . extension = extension . lower () self . scan_dirs : bool = scan_dirs self . return_full_path = return_full_path self . results : list [ os.DirEntry ] = [] self . sorting_func = sorting_key self . _pos = - 1 self . _scan () def _scan ( self ) : self . results = [] self . _pos = - 1 for result in os . scandir ( self . root ) : if self . scan_dirs and result . is_dir () : self . results . append ( result ) else : if result . name . lower (). endswith ( self . extension ) : self . results . append ( result ) self . results = sorted ( self . results , key = lambda f : self . sorting_func ( f . name )) def __getitem__ ( self , index : int ) -> os . DirEntry : \"\"\" Returns the file at the specified index. Args: index (int): The index of the file. Returns: os.DirEntry: The file at the specified index. \"\"\" return self . results [ index ] def __iter__ ( self ) -> Files : \"\"\" Returns an iterator object. Returns: Files: The iterator object. \"\"\" self . _pos = - 1 return self def __next__ ( self ) -> str : \"\"\" Returns the next file name or path in the iteration. Returns: str: The next file name or path. Raises: StopIteration: If there are no more files in the iteration. \"\"\" self . _pos += 1 if self . _pos >= self . __len__ () : raise StopIteration result = self . results [ self._pos ] if self . return_full_path : return result . path return result . name def __len__ ( self ) -> int : \"\"\" Returns the number of files in the results list. Returns: int: The number of files. \"\"\" return len ( self . results ) def __contains__ ( self , key : str ) -> bool : \"\"\" Checks if a file with the specified name exists in the results list. Args: key (str): The file name to check. Returns: bool: True if the file exists, False otherwise. \"\"\" for res in self . results : if key == res . name : return True return False def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self._pos ] . name def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self._pos ] . path def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ () def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root )","title":"Attributes"},{"location":"reference/wtracker/utils/path_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/path_utils/#copy","text":"def copy ( self , dst_root : 'str' ) -> 'None' Copies the current file to the specified destination directory. Parameters: Name Type Description Default dst_root str The destination directory path. None View Source def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root )","title":"copy"},{"location":"reference/wtracker/utils/path_utils/#get_filename","text":"def get_filename ( self ) -> 'str' Returns the name of the current file. Returns: Type Description str The name of the current file. View Source def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self . _pos ]. name","title":"get_filename"},{"location":"reference/wtracker/utils/path_utils/#get_path","text":"def get_path ( self ) -> 'str' Returns the path of the current file. Returns: Type Description str The path of the current file. View Source def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self . _pos ]. path","title":"get_path"},{"location":"reference/wtracker/utils/path_utils/#seek","text":"def seek ( self , pos : 'int' ) -> 'str' Moves the iterator to the specified position and returns the file name or path. Parameters: Name Type Description Default pos int The position to seek to. None Returns: Type Description str The file name or path at the specified position. Raises: Type Description AssertionError If the specified position is invalid. View Source def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ ()","title":"seek"},{"location":"reference/wtracker/utils/threading_utils/","text":"Module wtracker.utils.threading_utils View Source import queue import threading import multiprocessing from typing import Callable from tqdm.auto import tqdm def adjust_num_workers ( num_tasks : int , chunk_size : int , num_workers : int = None ) -> int : \"\"\" Adjust the number of workers based on the number of tasks and chunk size. Args: num_tasks (int): The number of tasks to be processed. chunk_size (int): The size of each processing chunk. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. \"\"\" if num_workers is None : # if None then choose automatically num_workers = min ( multiprocessing . cpu_count () / 2 , num_tasks / ( 2 * chunk_size )) num_workers = round ( num_workers ) use_multiprocessing = num_workers > 0 num_workers = min ( num_workers , num_tasks // chunk_size ) # no point having workers without tasks num_workers = min ( num_workers , multiprocessing . cpu_count ()) # no point having more workers than cpus if num_workers < 0 : # make sure value is valid num_workers = 0 if use_multiprocessing : num_workers = max ( num_workers , 1 ) elif not use_multiprocessing and num_workers == 1 : num_workers = 0 return num_workers class TqdmQueue ( queue . Queue ): \"\"\" A subclass of `queue.Queue` that provides progress tracking using `tqdm`. Args: maxsize (int): The maximum size of the queue (default: 0). **kwargs: Additional keyword arguments to be passed to the tqdm progress bar. Attributes: pbar (tqdm.tqdm): The progress bar object. total (int): The total number of items processed. Example: queue = ProgressQueue(maxsize=10) queue.put(item) queue.task_done() queue.join() \"\"\" def __init__ ( self , maxsize : int = 0 , ** kwargs ): super () . __init__ ( maxsize = maxsize ) self . pbar = tqdm ( total = 1 , ** kwargs ) self . total = 0 # Keep our own total tracker so we can update the Progressbar def task_done ( self ): \"\"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\"\" super () . task_done () self . pbar . update () self . pbar . refresh () # Redraw the progressbar def _put ( self , item ): super () . _put ( item ) self . total += 1 processed = self . pbar . n # Get current progress to re-apply self . pbar . reset ( self . total ) # Reset and update total self . pbar . update ( processed ) # Re-apply progress self . pbar . refresh () # Redraw the progressbar def join ( self ): \"\"\" Blocks until all items in the Queue have been gotten and processed. \"\"\" super () . join () self . pbar . close () class TaskScheduler : \"\"\" This class is used to schedule tasks to be executed by a worker thread. Args: task_func (Callable): The function to be executed by the worker thread. maxsize (int, optional): The maximum number of items that can be in the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments to be passed to the TqdmQueue constructor. \"\"\" def __init__ ( self , task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs , ): self . _queue = TqdmQueue ( maxsize , ** tqdm_kwargs ) if tqdm else queue . Queue ( maxsize ) self . _worker_thread = threading . Thread ( target = self . _worker , args = ( self . _queue ,)) self . _task_func = task_func def start ( self ): \"\"\" Starts the worker thread. \"\"\" self . _worker_thread . start () def __enter__ ( self ): self . start () return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def schedule_save ( self , * params ): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self . _queue . put ( item = params , block = True ) def _worker ( self , q : queue . Queue ): while True : params = q . get ( block = True ) # exit if signaled if params is None : break self . _task_func ( params ) q . task_done () def close ( self ): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self . _queue . join () self . _queue . put ( None ) self . _worker_thread . join () Functions adjust_num_workers def adjust_num_workers ( num_tasks : int , chunk_size : int , num_workers : int = None ) -> int Adjust the number of workers based on the number of tasks and chunk size. Parameters: Name Type Description Default num_tasks int The number of tasks to be processed. None chunk_size int The size of each processing chunk. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None View Source def adjust_num_workers(num_tasks: int, chunk_size: int, num_workers: int = None) -> int: \"\"\" Adjust the number of workers based on the number of tasks and chunk size. Args: num_tasks (int): The number of tasks to be processed. chunk_size (int): The size of each processing chunk. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. \"\"\" if num_workers is None: # if None then choose automatically num_workers = min(multiprocessing.cpu_count() / 2, num_tasks / (2 * chunk_size)) num_workers = round(num_workers) use_multiprocessing = num_workers > 0 num_workers = min(num_workers, num_tasks // chunk_size) # no point having workers without tasks num_workers = min(num_workers, multiprocessing.cpu_count()) # no point having more workers than cpus if num_workers < 0: # make sure value is valid num_workers = 0 if use_multiprocessing: num_workers = max(num_workers, 1) elif not use_multiprocessing and num_workers == 1: num_workers = 0 return num_workers Classes TaskScheduler class TaskScheduler ( task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs ) This class is used to schedule tasks to be executed by a worker thread. Attributes Name Type Description Default task_func Callable The function to be executed by the worker thread. None maxsize int The maximum number of items that can be in the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments to be passed to the TqdmQueue constructor. None View Source class TaskScheduler : \"\"\" This class is used to schedule tasks to be executed by a worker thread. Args: task_func (Callable): The function to be executed by the worker thread. maxsize (int, optional): The maximum number of items that can be in the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments to be passed to the TqdmQueue constructor. \"\"\" def __init__ ( self , task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs , ): self . _queue = TqdmQueue ( maxsize , ** tqdm_kwargs ) if tqdm else queue . Queue ( maxsize ) self . _worker_thread = threading . Thread ( target = self . _worker , args = ( self . _queue ,)) self . _task_func = task_func def start ( self ): \"\"\" Starts the worker thread. \"\"\" self . _worker_thread . start () def __enter__ ( self ): self . start () return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def schedule_save ( self , * params ): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self . _queue . put ( item = params , block = True ) def _worker ( self , q : queue . Queue ): while True : params = q . get ( block = True ) # exit if signaled if params is None : break self . _task_func ( params ) q . task_done () def close ( self ): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self . _queue . join () self . _queue . put ( None ) self . _worker_thread . join () Descendants wtracker.utils.io_utils.FrameSaver wtracker.utils.io_utils.ImageSaver Methods close def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join() schedule_save def schedule_save ( self , * params ) Schedules a task by putting task parameters into the queue. Parameters: Name Type Description Default *params None The parameters to be passed to the task function. None View Source def schedule_save(self, *params): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self._queue.put(item=params, block=True) start def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start() TqdmQueue class TqdmQueue ( maxsize : int = 0 , ** kwargs ) A subclass of queue.Queue that provides progress tracking using tqdm . Attributes Name Type Description Default maxsize int The maximum size of the queue (default: 0). None **kwargs None Additional keyword arguments to be passed to the tqdm progress bar. None pbar tqdm.tqdm The progress bar object. None total int The total number of items processed. None View Source class TqdmQueue ( queue . Queue ) : \" \"\" A subclass of `queue.Queue` that provides progress tracking using `tqdm`. Args: maxsize (int): The maximum size of the queue (default: 0). **kwargs: Additional keyword arguments to be passed to the tqdm progress bar. Attributes: pbar (tqdm.tqdm): The progress bar object. total (int): The total number of items processed. Example: queue = ProgressQueue(maxsize=10) queue.put(item) queue.task_done() queue.join() \"\" \" def __init__ ( self , maxsize : int = 0 , ** kwargs ) : super (). __init__ ( maxsize = maxsize ) self . pbar = tqdm ( total = 1 , ** kwargs ) self . total = 0 # Keep our own total tracker so we can update the Progressbar def task_done ( self ) : \" \"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\" \" super (). task_done () self . pbar . update () self . pbar . refresh () # Redraw the progressbar def _put ( self , item ) : super (). _put ( item ) self . total += 1 processed = self . pbar . n # Get current progress to re-apply self . pbar . reset ( self . total ) # Reset and update total self . pbar . update ( processed ) # Re-apply progress self . pbar . refresh () # Redraw the progressbar def join ( self ) : \" \"\" Blocks until all items in the Queue have been gotten and processed. \"\" \" super (). join () self . pbar . close () Ancestors (in MRO) queue.Queue Methods empty def empty ( self ) Return True if the queue is empty, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() == 0 as a direct substitute, but be aware that either approach risks a race condition where a queue can grow before the result of empty() or qsize() can be used. To create code that needs to wait for all queued tasks to be completed, the preferred technique is to use the join() method. View Source def empty(self): '''Return True if the queue is empty, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() == 0 as a direct substitute, but be aware that either approach risks a race condition where a queue can grow before the result of empty() or qsize() can be used. To create code that needs to wait for all queued tasks to be completed, the preferred technique is to use the join() method. ''' with self.mutex: return not self._qsize() full def full ( self ) Return True if the queue is full, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() >= n as a direct substitute, but be aware that either approach risks a race condition where a queue can shrink before the result of full() or qsize() can be used. View Source def full(self): '''Return True if the queue is full, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() >= n as a direct substitute, but be aware that either approach risks a race condition where a queue can shrink before the result of full() or qsize() can be used. ''' with self.mutex: return 0 < self.maxsize <= self._qsize() get def get ( self , block = True , timeout = None ) Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). View Source def get(self, block=True, timeout=None): '''Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). ''' with self.not_empty: if not block: if not self._qsize(): raise Empty elif timeout is None: while not self._qsize(): self.not_empty.wait() elif timeout < 0: raise ValueError(\"'timeout' must be a non-negative number\") else: endtime = time() + timeout while not self._qsize(): remaining = endtime - time() if remaining <= 0.0: raise Empty self.not_empty.wait(remaining) item = self._get() self.not_full.notify() return item get_nowait def get_nowait ( self ) Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise raise the Empty exception. View Source def get_nowait(self): '''Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise raise the Empty exception. ''' return self.get(block=False) join def join ( self ) Blocks until all items in the Queue have been gotten and processed. View Source def join ( self ) : \"\" \" Blocks until all items in the Queue have been gotten and processed. \"\" \" super().join() self.pbar.close() put def put ( self , item , block = True , timeout = None ) Put an item into the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until a free slot is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Full exception if no free slot was available within that time. Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). View Source def put(self, item, block=True, timeout=None): '''Put an item into the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until a free slot is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Full exception if no free slot was available within that time. Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). ''' with self.not_full: if self.maxsize > 0: if not block: if self._qsize() >= self.maxsize: raise Full elif timeout is None: while self._qsize() >= self.maxsize: self.not_full.wait() elif timeout < 0: raise ValueError(\"'timeout' must be a non-negative number\") else: endtime = time() + timeout while self._qsize() >= self.maxsize: remaining = endtime - time() if remaining <= 0.0: raise Full self.not_full.wait(remaining) self._put(item) self.unfinished_tasks += 1 self.not_empty.notify() put_nowait def put_nowait ( self , item ) Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. Otherwise raise the Full exception. View Source def put_nowait(self, item): '''Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. Otherwise raise the Full exception. ''' return self.put(item, block=False) qsize def qsize ( self ) Return the approximate size of the queue (not reliable!). View Source def qsize(self): '''Return the approximate size of the queue (not reliable!).''' with self.mutex: return self._qsize() task_done def task_done ( self ) Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. View Source def task_done(self): \"\"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\"\" super().task_done() self.pbar.update() self.pbar.refresh() # Redraw the progressbar","title":"Threading Utils"},{"location":"reference/wtracker/utils/threading_utils/#module-wtrackerutilsthreading_utils","text":"View Source import queue import threading import multiprocessing from typing import Callable from tqdm.auto import tqdm def adjust_num_workers ( num_tasks : int , chunk_size : int , num_workers : int = None ) -> int : \"\"\" Adjust the number of workers based on the number of tasks and chunk size. Args: num_tasks (int): The number of tasks to be processed. chunk_size (int): The size of each processing chunk. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. \"\"\" if num_workers is None : # if None then choose automatically num_workers = min ( multiprocessing . cpu_count () / 2 , num_tasks / ( 2 * chunk_size )) num_workers = round ( num_workers ) use_multiprocessing = num_workers > 0 num_workers = min ( num_workers , num_tasks // chunk_size ) # no point having workers without tasks num_workers = min ( num_workers , multiprocessing . cpu_count ()) # no point having more workers than cpus if num_workers < 0 : # make sure value is valid num_workers = 0 if use_multiprocessing : num_workers = max ( num_workers , 1 ) elif not use_multiprocessing and num_workers == 1 : num_workers = 0 return num_workers class TqdmQueue ( queue . Queue ): \"\"\" A subclass of `queue.Queue` that provides progress tracking using `tqdm`. Args: maxsize (int): The maximum size of the queue (default: 0). **kwargs: Additional keyword arguments to be passed to the tqdm progress bar. Attributes: pbar (tqdm.tqdm): The progress bar object. total (int): The total number of items processed. Example: queue = ProgressQueue(maxsize=10) queue.put(item) queue.task_done() queue.join() \"\"\" def __init__ ( self , maxsize : int = 0 , ** kwargs ): super () . __init__ ( maxsize = maxsize ) self . pbar = tqdm ( total = 1 , ** kwargs ) self . total = 0 # Keep our own total tracker so we can update the Progressbar def task_done ( self ): \"\"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\"\" super () . task_done () self . pbar . update () self . pbar . refresh () # Redraw the progressbar def _put ( self , item ): super () . _put ( item ) self . total += 1 processed = self . pbar . n # Get current progress to re-apply self . pbar . reset ( self . total ) # Reset and update total self . pbar . update ( processed ) # Re-apply progress self . pbar . refresh () # Redraw the progressbar def join ( self ): \"\"\" Blocks until all items in the Queue have been gotten and processed. \"\"\" super () . join () self . pbar . close () class TaskScheduler : \"\"\" This class is used to schedule tasks to be executed by a worker thread. Args: task_func (Callable): The function to be executed by the worker thread. maxsize (int, optional): The maximum number of items that can be in the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments to be passed to the TqdmQueue constructor. \"\"\" def __init__ ( self , task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs , ): self . _queue = TqdmQueue ( maxsize , ** tqdm_kwargs ) if tqdm else queue . Queue ( maxsize ) self . _worker_thread = threading . Thread ( target = self . _worker , args = ( self . _queue ,)) self . _task_func = task_func def start ( self ): \"\"\" Starts the worker thread. \"\"\" self . _worker_thread . start () def __enter__ ( self ): self . start () return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def schedule_save ( self , * params ): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self . _queue . put ( item = params , block = True ) def _worker ( self , q : queue . Queue ): while True : params = q . get ( block = True ) # exit if signaled if params is None : break self . _task_func ( params ) q . task_done () def close ( self ): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self . _queue . join () self . _queue . put ( None ) self . _worker_thread . join ()","title":"Module wtracker.utils.threading_utils"},{"location":"reference/wtracker/utils/threading_utils/#functions","text":"","title":"Functions"},{"location":"reference/wtracker/utils/threading_utils/#adjust_num_workers","text":"def adjust_num_workers ( num_tasks : int , chunk_size : int , num_workers : int = None ) -> int Adjust the number of workers based on the number of tasks and chunk size. Parameters: Name Type Description Default num_tasks int The number of tasks to be processed. None chunk_size int The size of each processing chunk. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None View Source def adjust_num_workers(num_tasks: int, chunk_size: int, num_workers: int = None) -> int: \"\"\" Adjust the number of workers based on the number of tasks and chunk size. Args: num_tasks (int): The number of tasks to be processed. chunk_size (int): The size of each processing chunk. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. \"\"\" if num_workers is None: # if None then choose automatically num_workers = min(multiprocessing.cpu_count() / 2, num_tasks / (2 * chunk_size)) num_workers = round(num_workers) use_multiprocessing = num_workers > 0 num_workers = min(num_workers, num_tasks // chunk_size) # no point having workers without tasks num_workers = min(num_workers, multiprocessing.cpu_count()) # no point having more workers than cpus if num_workers < 0: # make sure value is valid num_workers = 0 if use_multiprocessing: num_workers = max(num_workers, 1) elif not use_multiprocessing and num_workers == 1: num_workers = 0 return num_workers","title":"adjust_num_workers"},{"location":"reference/wtracker/utils/threading_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/threading_utils/#taskscheduler","text":"class TaskScheduler ( task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs ) This class is used to schedule tasks to be executed by a worker thread.","title":"TaskScheduler"},{"location":"reference/wtracker/utils/threading_utils/#attributes","text":"Name Type Description Default task_func Callable The function to be executed by the worker thread. None maxsize int The maximum number of items that can be in the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments to be passed to the TqdmQueue constructor. None View Source class TaskScheduler : \"\"\" This class is used to schedule tasks to be executed by a worker thread. Args: task_func (Callable): The function to be executed by the worker thread. maxsize (int, optional): The maximum number of items that can be in the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments to be passed to the TqdmQueue constructor. \"\"\" def __init__ ( self , task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs , ): self . _queue = TqdmQueue ( maxsize , ** tqdm_kwargs ) if tqdm else queue . Queue ( maxsize ) self . _worker_thread = threading . Thread ( target = self . _worker , args = ( self . _queue ,)) self . _task_func = task_func def start ( self ): \"\"\" Starts the worker thread. \"\"\" self . _worker_thread . start () def __enter__ ( self ): self . start () return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def schedule_save ( self , * params ): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self . _queue . put ( item = params , block = True ) def _worker ( self , q : queue . Queue ): while True : params = q . get ( block = True ) # exit if signaled if params is None : break self . _task_func ( params ) q . task_done () def close ( self ): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self . _queue . join () self . _queue . put ( None ) self . _worker_thread . join ()","title":"Attributes"},{"location":"reference/wtracker/utils/threading_utils/#descendants","text":"wtracker.utils.io_utils.FrameSaver wtracker.utils.io_utils.ImageSaver","title":"Descendants"},{"location":"reference/wtracker/utils/threading_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/threading_utils/#close","text":"def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join()","title":"close"},{"location":"reference/wtracker/utils/threading_utils/#schedule_save","text":"def schedule_save ( self , * params ) Schedules a task by putting task parameters into the queue. Parameters: Name Type Description Default *params None The parameters to be passed to the task function. None View Source def schedule_save(self, *params): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self._queue.put(item=params, block=True)","title":"schedule_save"},{"location":"reference/wtracker/utils/threading_utils/#start","text":"def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start()","title":"start"},{"location":"reference/wtracker/utils/threading_utils/#tqdmqueue","text":"class TqdmQueue ( maxsize : int = 0 , ** kwargs ) A subclass of queue.Queue that provides progress tracking using tqdm .","title":"TqdmQueue"},{"location":"reference/wtracker/utils/threading_utils/#attributes_1","text":"Name Type Description Default maxsize int The maximum size of the queue (default: 0). None **kwargs None Additional keyword arguments to be passed to the tqdm progress bar. None pbar tqdm.tqdm The progress bar object. None total int The total number of items processed. None View Source class TqdmQueue ( queue . Queue ) : \" \"\" A subclass of `queue.Queue` that provides progress tracking using `tqdm`. Args: maxsize (int): The maximum size of the queue (default: 0). **kwargs: Additional keyword arguments to be passed to the tqdm progress bar. Attributes: pbar (tqdm.tqdm): The progress bar object. total (int): The total number of items processed. Example: queue = ProgressQueue(maxsize=10) queue.put(item) queue.task_done() queue.join() \"\" \" def __init__ ( self , maxsize : int = 0 , ** kwargs ) : super (). __init__ ( maxsize = maxsize ) self . pbar = tqdm ( total = 1 , ** kwargs ) self . total = 0 # Keep our own total tracker so we can update the Progressbar def task_done ( self ) : \" \"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\" \" super (). task_done () self . pbar . update () self . pbar . refresh () # Redraw the progressbar def _put ( self , item ) : super (). _put ( item ) self . total += 1 processed = self . pbar . n # Get current progress to re-apply self . pbar . reset ( self . total ) # Reset and update total self . pbar . update ( processed ) # Re-apply progress self . pbar . refresh () # Redraw the progressbar def join ( self ) : \" \"\" Blocks until all items in the Queue have been gotten and processed. \"\" \" super (). join () self . pbar . close ()","title":"Attributes"},{"location":"reference/wtracker/utils/threading_utils/#ancestors-in-mro","text":"queue.Queue","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/threading_utils/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/utils/threading_utils/#empty","text":"def empty ( self ) Return True if the queue is empty, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() == 0 as a direct substitute, but be aware that either approach risks a race condition where a queue can grow before the result of empty() or qsize() can be used. To create code that needs to wait for all queued tasks to be completed, the preferred technique is to use the join() method. View Source def empty(self): '''Return True if the queue is empty, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() == 0 as a direct substitute, but be aware that either approach risks a race condition where a queue can grow before the result of empty() or qsize() can be used. To create code that needs to wait for all queued tasks to be completed, the preferred technique is to use the join() method. ''' with self.mutex: return not self._qsize()","title":"empty"},{"location":"reference/wtracker/utils/threading_utils/#full","text":"def full ( self ) Return True if the queue is full, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() >= n as a direct substitute, but be aware that either approach risks a race condition where a queue can shrink before the result of full() or qsize() can be used. View Source def full(self): '''Return True if the queue is full, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() >= n as a direct substitute, but be aware that either approach risks a race condition where a queue can shrink before the result of full() or qsize() can be used. ''' with self.mutex: return 0 < self.maxsize <= self._qsize()","title":"full"},{"location":"reference/wtracker/utils/threading_utils/#get","text":"def get ( self , block = True , timeout = None ) Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). View Source def get(self, block=True, timeout=None): '''Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). ''' with self.not_empty: if not block: if not self._qsize(): raise Empty elif timeout is None: while not self._qsize(): self.not_empty.wait() elif timeout < 0: raise ValueError(\"'timeout' must be a non-negative number\") else: endtime = time() + timeout while not self._qsize(): remaining = endtime - time() if remaining <= 0.0: raise Empty self.not_empty.wait(remaining) item = self._get() self.not_full.notify() return item","title":"get"},{"location":"reference/wtracker/utils/threading_utils/#get_nowait","text":"def get_nowait ( self ) Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise raise the Empty exception. View Source def get_nowait(self): '''Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise raise the Empty exception. ''' return self.get(block=False)","title":"get_nowait"},{"location":"reference/wtracker/utils/threading_utils/#join","text":"def join ( self ) Blocks until all items in the Queue have been gotten and processed. View Source def join ( self ) : \"\" \" Blocks until all items in the Queue have been gotten and processed. \"\" \" super().join() self.pbar.close()","title":"join"},{"location":"reference/wtracker/utils/threading_utils/#put","text":"def put ( self , item , block = True , timeout = None ) Put an item into the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until a free slot is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Full exception if no free slot was available within that time. Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). View Source def put(self, item, block=True, timeout=None): '''Put an item into the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until a free slot is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Full exception if no free slot was available within that time. Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). ''' with self.not_full: if self.maxsize > 0: if not block: if self._qsize() >= self.maxsize: raise Full elif timeout is None: while self._qsize() >= self.maxsize: self.not_full.wait() elif timeout < 0: raise ValueError(\"'timeout' must be a non-negative number\") else: endtime = time() + timeout while self._qsize() >= self.maxsize: remaining = endtime - time() if remaining <= 0.0: raise Full self.not_full.wait(remaining) self._put(item) self.unfinished_tasks += 1 self.not_empty.notify()","title":"put"},{"location":"reference/wtracker/utils/threading_utils/#put_nowait","text":"def put_nowait ( self , item ) Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. Otherwise raise the Full exception. View Source def put_nowait(self, item): '''Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. Otherwise raise the Full exception. ''' return self.put(item, block=False)","title":"put_nowait"},{"location":"reference/wtracker/utils/threading_utils/#qsize","text":"def qsize ( self ) Return the approximate size of the queue (not reliable!). View Source def qsize(self): '''Return the approximate size of the queue (not reliable!).''' with self.mutex: return self._qsize()","title":"qsize"},{"location":"reference/wtracker/utils/threading_utils/#task_done","text":"def task_done ( self ) Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. View Source def task_done(self): \"\"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\"\" super().task_done() self.pbar.update() self.pbar.refresh() # Redraw the progressbar","title":"task_done"}]} \ No newline at end of file +{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"WTracker Description This library provides tools for worm detection and movement prediction, training predictors, and analyzing the results. It includes support for YOLO-based prediction and various simulation controllers. Features Real-time Worm detection and movement prediction Logging and analysis tools CSV, logging, and YOLO controllers Documentation There is an Official Documentation website availabe of the entire API. The library is fully documented within the code base. Workflow files have elaborate documentation for usage. Installation Download the Repository Download the project repository (by clicking on code -> download zip) and extract the files in the desired location. Environment Installation Step 1 - Install mamba: Install 'Miniforge' from this link , make sure to download the right version (that match the OS and CPU of the computer). If asked during installation: add to PATH. * if unsure, use this link to download mamba. Step 2 - verify that mamba is installed correctly: 1. Navigate to the folder into which the library was download. 2. Open terminal/command prompt. 3. Enter - 'mamba -h'. if no error is encountered then mamba is installed correctly. Step 3 - create a new environment: 1. Enter the following command - \"mamba create -n bio-proj python=3.12\". 2. Enter the command - 'mamba init'. * You can choose another name (not only 'bio-proj'). If you do , you will need to change the name field in the 'requirements.yaml' file as well. Step 4 - Activate the environment: 1. Enter the command - 'mamba activate bio-proj'. * If you used another name for the environment, replace 'bio-proj' with the name you have chosen. Step 5 - Installing Pytorch: 1. Head to the pytorch website here , there you will find a table to select your configuration, select the following: 1. PyTorch Build = stable 2. OS - the operating system of the computer [Windows/Linux] 3. Package - the package manager [conda] 4. Language - the programming language [Python] 5. Compute Platform - if the computer has GPU select the newest version of CUDA [12.1 was tested], otherwise select CPU. 2. Copy the command below the table and enter it in the terminal/command prompt 3. Wait till the installation is complate. That might take a while. Step 6 - Install the rest of the libraries: Enter the command - 'mamba env update -f requirements.yaml -n bio-proj' Install the Development Environment To run the project we recommend 'Visual Studio Code' (also referred as VS Code), a free IDE. Basic usage videos and documentation can be found here . You can download and install VS Code from here . To set up VS Code for the project you need to install several extensions. Follow this link to learn how to install extensions. The extensions needed are: - Jupyter - Python - Pylance * Some extensions may be already installed by default. Usage Refer to the variouse '.ipynb' files for usage for each workflow. License The code is licensed under the GPL v3.0 License . TL;DR Here's what the license entails: 1. Anyone can copy, modify and distribute this software. 2. You have to include the license and copyright notice with each and every distribution. 3. You can use this software privately. 4. You can use this software for commercial purposes. 5. If you dare build your business solely from this code, you risk open-sourcing the whole code base. 6. If you modify it, you have to indicate changes made to the code. 7. Any modifications of this code base MUST be distributed with the same license, GPLv3. 8. This software is provided without warranty. 9. The software author or license can not be held liable for any damages inflicted by the software. For more details see the license file. Contact Please open an issue in the GitHub repository if you have any questions or feedback.","title":"Home"},{"location":"#wtracker","text":"","title":"WTracker"},{"location":"#description","text":"This library provides tools for worm detection and movement prediction, training predictors, and analyzing the results. It includes support for YOLO-based prediction and various simulation controllers.","title":"Description"},{"location":"#features","text":"Real-time Worm detection and movement prediction Logging and analysis tools CSV, logging, and YOLO controllers","title":"Features"},{"location":"#documentation","text":"There is an Official Documentation website availabe of the entire API. The library is fully documented within the code base. Workflow files have elaborate documentation for usage.","title":"Documentation"},{"location":"#installation","text":"","title":"Installation"},{"location":"#download-the-repository","text":"Download the project repository (by clicking on code -> download zip) and extract the files in the desired location.","title":"Download the Repository"},{"location":"#environment-installation","text":"Step 1 - Install mamba: Install 'Miniforge' from this link , make sure to download the right version (that match the OS and CPU of the computer). If asked during installation: add to PATH. * if unsure, use this link to download mamba. Step 2 - verify that mamba is installed correctly: 1. Navigate to the folder into which the library was download. 2. Open terminal/command prompt. 3. Enter - 'mamba -h'. if no error is encountered then mamba is installed correctly. Step 3 - create a new environment: 1. Enter the following command - \"mamba create -n bio-proj python=3.12\". 2. Enter the command - 'mamba init'. * You can choose another name (not only 'bio-proj'). If you do , you will need to change the name field in the 'requirements.yaml' file as well. Step 4 - Activate the environment: 1. Enter the command - 'mamba activate bio-proj'. * If you used another name for the environment, replace 'bio-proj' with the name you have chosen. Step 5 - Installing Pytorch: 1. Head to the pytorch website here , there you will find a table to select your configuration, select the following: 1. PyTorch Build = stable 2. OS - the operating system of the computer [Windows/Linux] 3. Package - the package manager [conda] 4. Language - the programming language [Python] 5. Compute Platform - if the computer has GPU select the newest version of CUDA [12.1 was tested], otherwise select CPU. 2. Copy the command below the table and enter it in the terminal/command prompt 3. Wait till the installation is complate. That might take a while. Step 6 - Install the rest of the libraries: Enter the command - 'mamba env update -f requirements.yaml -n bio-proj'","title":"Environment Installation"},{"location":"#install-the-development-environment","text":"To run the project we recommend 'Visual Studio Code' (also referred as VS Code), a free IDE. Basic usage videos and documentation can be found here . You can download and install VS Code from here . To set up VS Code for the project you need to install several extensions. Follow this link to learn how to install extensions. The extensions needed are: - Jupyter - Python - Pylance * Some extensions may be already installed by default.","title":"Install the Development Environment"},{"location":"#usage","text":"Refer to the variouse '.ipynb' files for usage for each workflow.","title":"Usage"},{"location":"#license","text":"The code is licensed under the GPL v3.0 License . TL;DR Here's what the license entails: 1. Anyone can copy, modify and distribute this software. 2. You have to include the license and copyright notice with each and every distribution. 3. You can use this software privately. 4. You can use this software for commercial purposes. 5. If you dare build your business solely from this code, you risk open-sourcing the whole code base. 6. If you modify it, you have to indicate changes made to the code. 7. Any modifications of this code base MUST be distributed with the same license, GPLv3. 8. This software is provided without warranty. 9. The software author or license can not be held liable for any damages inflicted by the software. For more details see the license file.","title":"License"},{"location":"#contact","text":"Please open an issue in the GitHub repository if you have any questions or feedback.","title":"Contact"},{"location":"docs/workflows/","text":"General workflows Here we will go over the steps to do some of the main tasks, from training a YOLO model on custom data to running simulations with different configurations. All of the main Workflows have a dedicated, interactive notebook (.ipynb file) ready to use with explanations for each step. All of the workflow notebooks are located in a dedicated folder called \"workflows\". Workflow Files Descriptions create_yolo_images.ipynb - Prepares raw frames of some experiment for the process of training YOLO model on them. This step entails detecting the worm in selected frames and cropping a region of pre-defined size around the worms. yolo_training.ipynb - Used to train a YOLO model on a given dataset. The training dataset was prepared by annotating 3 the images which were extracted using the notebook create_yolo_images. The annotation process can be done with RoboFlow, which is an online dataset creation and annotation tool. initialize_experiment.ipynb - In order to run system simulations on a new experiment, first it\u2019s essential to initialize the experiment. The initialization step runs the YOLO predictor on the raw experiment, detects worm\u2019s head position in each frame and saves the detection results into a log. That log would be later used for simulating different control algorithms on the experiment. In addition, the background image and worm images are extracted from the raw frames. These can be used later during analysis, to calculate the segmentation based error. This log is useful since in the future the simulator can simply read worm head positions from the log, instead of using YOLO to predict worm\u2019s head position in every frame of interest (which is much slower, especially on computers without a dedicated graphics card). simulate.ipynb - Run a full system simulation on some previously initialized experiment. The simulation is ran by reading an experiment log produced by the initialization process - in each frame, worm\u2019s head position is retrieved from the log. In this notebook it is possible to simulate the system with any controller and any configuration parameters, not only the ones of used for the initial experiment log. Similar to the initialization process, the simulation produces a log, which would be later used to analyze system\u2019s performance and its behavior. analysis.ipynb - This notebook is used to analyze the performance of a control algorithm (controller). A log which was produced by running simulate is read and analyzed, and different plots and statistics are presented. In addition, there is an option to calculate segmentation evaluation-error, by counting how many pixels of the worm are outside of the microscope view. To this end, we use the background and worm images which were extracted during the run of intialize_experiment notebook for this experiment. visualize.ipynb - Given a system log which was produced by simulate, this notebook is able to visually recreate the simulator\u2019s behavior. At each frame, the position of worm\u2019s head is drawn, the position of the microscope FOV, and also the camera FOV. This notebook is used to visually assess the performance and the behavior of the simulator, and to visually investigate what causes the system to misbehave. predictor_training - Used to train a specific simulation control algorithm. The MLPController *is an algorithm that uses a neural network (NN) to predict worm\u2019s future position. Since this algorithm is NN based, it requires training. That script is responsible to train that NN from experiment log files, which were produced by either running initialize or simulate (doesn\u2019t matter). polyfit_optimizer.ipynb - This notebook is used to tune the parameters of a specific simulation control algorithm. The PolyfitController is an algorithm that uses polynomial-fitting to predict worm\u2019s future position. A polynomial is fitted from past observations at previous time stamps, and afterwards sampled in the future time to predict worm\u2019s position. This notebook is used to determine the optimal degree of the fitted polynomial, and to find the optimal weight of each past sample for the fitting process. Workflow Files Dependency Graph Workflow outline and the dependencies between each notebook files. Blue color (rectangles) denotes an interactive notebook file, green color (diamond) denotes intermediate outputs between different files, and the global input is in yellow color (circle). Dotted line denote optional dependencies. Complete Workflows Conducting an Experiment Here we explain how to properly capture the footage of an experiment for the simulator. Decide on the frame rate (FPS) of the experiment. There are two distinct scenarios: If the sole reason for the experiment footage is to be used for YOLO training, a very low FPS can be used (can be as low as 1 FPS or even lower if possible). Ideally, a single frame would be captured every few seconds. If a simulation to be run on the experiment footage, a high FPS should be used, preferably at least 60 FPS. Note, that the the chosen frame rate should be the same frame rate on which the platform control algorithms were calibrated. The camera should be completely stationary from which the entire arena is visible. The footage should be captured as distinct images for each frame, not as a continious video. We recommend to use \"bmp\" image format, which is lossless, and is able to save single channel if the image is grayscale. Make sure that the distinct frames are saved as images with the frame number appearing in their name, such that it's possible to read them in-order. If you want to run a system simulation on the experiment, follow the steps in the initialize_experiment notebook. YOLO Model Training Below is the workflow to train a YOLO model to detect worm's head position: Conduct a new experiment and capture the footage, as explained in the previous section. Determine the image size for the YOLO model, this size should match the desired input image size during a simulation run. At the time of writing, it is possible to pass images of different sizes to YOLO models, but they are scaled to the closes multiple of 32. This means you should use and train YOLO models on images with sizes that are a multiple of 32 when possible. A YOLO model should be trained on images with the same size, it is not expected to work well on images of different sizes without special attention. Be careful of a Distribution Shift, this means that the training data is different (not representative) of the real world. For example: In the training data, are the worms always in a similar position? Is the background lighting consistent with the one on the system? Is the size of each pixel the same as in the system? 4.Are the worms in the dataset representative of the worm population? Create a set of images for annotation - Follow the instructions in the create_yolo_images python notebook. Make sure to provide the correct image size, which was determined in the previous step. Annotate the data - The annotation process is the process of marking the actual worm head position in the extracted images. To do so, we recommend using the website RoboFlow , which provides easy-to-use tools to annotate the data, and create a dataset for YOLO models. Create a YOLO dataset - If you used Roboflow to create the dataset - on the dataset page you can click on 'export dataset' and then 'show download code'. You can copy the code snippet to the appropriate place in the notebook of step 6 to download the dataset to the computer and use it for training. Follow the instructions in the yolo_training notebook and train the YOLO model. There are two approaches to tackle the challenges of distribution shift, mentioned earlier. The first approach is to carefully train the YOLO model on very similar conditions as of the final system. The resulting model will function well, but if conditions change then models performance will likely degrade. The other approach is to train the model on wide variety of settings (e.g. different lighting conditions or different magnification levels), leading to a more robust model. The benefit of this approach is that the model is more robust to changes, but a disadvantage is that such models usually require more data, and may perform slightly worse than models carefully trained on some very specific conditions. Perform System Simulation Below is the workflow of performing a full system simulation on some experiment, and analyzing the results. If the experiment was not initialized yet, make sure to follow the instructions in the initialize_experiment notebook. Decide on the platform control algorithm to be used. If MLPController algorithm is chosen: the MLP controller works by a neural network that predicts future positions of the worm. If that network needs training then first run predictor_training notebook. Note, that the neural network should be trained only once. Once the network is trained, there is no need to perform this step anymore. If PolyfitController algorithm is chosen, and the hyper-parameters of the controller should be tuned then first run the polyfit_optimizer notebook. Note, that the hyper-parameters of this controller should be tuned only once. Once they were tuned, there is no need to perform this step anymore. Follow the steps in the simulate notebook. The result of running this notebook is a log file containing the full simulation log. To visualize the simulation run visualize notebook, and to analyze the performance of the control algorithm, and general statistics of the conducted experiment run the analyze notebook. Both of these notebooks analyze the log produced by simulate .","title":"Workflows"},{"location":"docs/workflows/#general-workflows","text":"Here we will go over the steps to do some of the main tasks, from training a YOLO model on custom data to running simulations with different configurations. All of the main Workflows have a dedicated, interactive notebook (.ipynb file) ready to use with explanations for each step. All of the workflow notebooks are located in a dedicated folder called \"workflows\".","title":"General workflows"},{"location":"docs/workflows/#workflow-files-descriptions","text":"create_yolo_images.ipynb - Prepares raw frames of some experiment for the process of training YOLO model on them. This step entails detecting the worm in selected frames and cropping a region of pre-defined size around the worms. yolo_training.ipynb - Used to train a YOLO model on a given dataset. The training dataset was prepared by annotating 3 the images which were extracted using the notebook create_yolo_images. The annotation process can be done with RoboFlow, which is an online dataset creation and annotation tool. initialize_experiment.ipynb - In order to run system simulations on a new experiment, first it\u2019s essential to initialize the experiment. The initialization step runs the YOLO predictor on the raw experiment, detects worm\u2019s head position in each frame and saves the detection results into a log. That log would be later used for simulating different control algorithms on the experiment. In addition, the background image and worm images are extracted from the raw frames. These can be used later during analysis, to calculate the segmentation based error. This log is useful since in the future the simulator can simply read worm head positions from the log, instead of using YOLO to predict worm\u2019s head position in every frame of interest (which is much slower, especially on computers without a dedicated graphics card). simulate.ipynb - Run a full system simulation on some previously initialized experiment. The simulation is ran by reading an experiment log produced by the initialization process - in each frame, worm\u2019s head position is retrieved from the log. In this notebook it is possible to simulate the system with any controller and any configuration parameters, not only the ones of used for the initial experiment log. Similar to the initialization process, the simulation produces a log, which would be later used to analyze system\u2019s performance and its behavior. analysis.ipynb - This notebook is used to analyze the performance of a control algorithm (controller). A log which was produced by running simulate is read and analyzed, and different plots and statistics are presented. In addition, there is an option to calculate segmentation evaluation-error, by counting how many pixels of the worm are outside of the microscope view. To this end, we use the background and worm images which were extracted during the run of intialize_experiment notebook for this experiment. visualize.ipynb - Given a system log which was produced by simulate, this notebook is able to visually recreate the simulator\u2019s behavior. At each frame, the position of worm\u2019s head is drawn, the position of the microscope FOV, and also the camera FOV. This notebook is used to visually assess the performance and the behavior of the simulator, and to visually investigate what causes the system to misbehave. predictor_training - Used to train a specific simulation control algorithm. The MLPController *is an algorithm that uses a neural network (NN) to predict worm\u2019s future position. Since this algorithm is NN based, it requires training. That script is responsible to train that NN from experiment log files, which were produced by either running initialize or simulate (doesn\u2019t matter). polyfit_optimizer.ipynb - This notebook is used to tune the parameters of a specific simulation control algorithm. The PolyfitController is an algorithm that uses polynomial-fitting to predict worm\u2019s future position. A polynomial is fitted from past observations at previous time stamps, and afterwards sampled in the future time to predict worm\u2019s position. This notebook is used to determine the optimal degree of the fitted polynomial, and to find the optimal weight of each past sample for the fitting process.","title":"Workflow Files Descriptions"},{"location":"docs/workflows/#workflow-files-dependency-graph","text":"Workflow outline and the dependencies between each notebook files. Blue color (rectangles) denotes an interactive notebook file, green color (diamond) denotes intermediate outputs between different files, and the global input is in yellow color (circle). Dotted line denote optional dependencies.","title":"Workflow Files Dependency Graph"},{"location":"docs/workflows/#complete-workflows","text":"","title":"Complete Workflows"},{"location":"docs/workflows/#conducting-an-experiment","text":"Here we explain how to properly capture the footage of an experiment for the simulator. Decide on the frame rate (FPS) of the experiment. There are two distinct scenarios: If the sole reason for the experiment footage is to be used for YOLO training, a very low FPS can be used (can be as low as 1 FPS or even lower if possible). Ideally, a single frame would be captured every few seconds. If a simulation to be run on the experiment footage, a high FPS should be used, preferably at least 60 FPS. Note, that the the chosen frame rate should be the same frame rate on which the platform control algorithms were calibrated. The camera should be completely stationary from which the entire arena is visible. The footage should be captured as distinct images for each frame, not as a continious video. We recommend to use \"bmp\" image format, which is lossless, and is able to save single channel if the image is grayscale. Make sure that the distinct frames are saved as images with the frame number appearing in their name, such that it's possible to read them in-order. If you want to run a system simulation on the experiment, follow the steps in the initialize_experiment notebook.","title":"Conducting an Experiment"},{"location":"docs/workflows/#yolo-model-training","text":"Below is the workflow to train a YOLO model to detect worm's head position: Conduct a new experiment and capture the footage, as explained in the previous section. Determine the image size for the YOLO model, this size should match the desired input image size during a simulation run. At the time of writing, it is possible to pass images of different sizes to YOLO models, but they are scaled to the closes multiple of 32. This means you should use and train YOLO models on images with sizes that are a multiple of 32 when possible. A YOLO model should be trained on images with the same size, it is not expected to work well on images of different sizes without special attention. Be careful of a Distribution Shift, this means that the training data is different (not representative) of the real world. For example: In the training data, are the worms always in a similar position? Is the background lighting consistent with the one on the system? Is the size of each pixel the same as in the system? 4.Are the worms in the dataset representative of the worm population? Create a set of images for annotation - Follow the instructions in the create_yolo_images python notebook. Make sure to provide the correct image size, which was determined in the previous step. Annotate the data - The annotation process is the process of marking the actual worm head position in the extracted images. To do so, we recommend using the website RoboFlow , which provides easy-to-use tools to annotate the data, and create a dataset for YOLO models. Create a YOLO dataset - If you used Roboflow to create the dataset - on the dataset page you can click on 'export dataset' and then 'show download code'. You can copy the code snippet to the appropriate place in the notebook of step 6 to download the dataset to the computer and use it for training. Follow the instructions in the yolo_training notebook and train the YOLO model. There are two approaches to tackle the challenges of distribution shift, mentioned earlier. The first approach is to carefully train the YOLO model on very similar conditions as of the final system. The resulting model will function well, but if conditions change then models performance will likely degrade. The other approach is to train the model on wide variety of settings (e.g. different lighting conditions or different magnification levels), leading to a more robust model. The benefit of this approach is that the model is more robust to changes, but a disadvantage is that such models usually require more data, and may perform slightly worse than models carefully trained on some very specific conditions.","title":"YOLO Model Training"},{"location":"docs/workflows/#perform-system-simulation","text":"Below is the workflow of performing a full system simulation on some experiment, and analyzing the results. If the experiment was not initialized yet, make sure to follow the instructions in the initialize_experiment notebook. Decide on the platform control algorithm to be used. If MLPController algorithm is chosen: the MLP controller works by a neural network that predicts future positions of the worm. If that network needs training then first run predictor_training notebook. Note, that the neural network should be trained only once. Once the network is trained, there is no need to perform this step anymore. If PolyfitController algorithm is chosen, and the hyper-parameters of the controller should be tuned then first run the polyfit_optimizer notebook. Note, that the hyper-parameters of this controller should be tuned only once. Once they were tuned, there is no need to perform this step anymore. Follow the steps in the simulate notebook. The result of running this notebook is a log file containing the full simulation log. To visualize the simulation run visualize notebook, and to analyze the performance of the control algorithm, and general statistics of the conducted experiment run the analyze notebook. Both of these notebooks analyze the log produced by simulate .","title":"Perform System Simulation"},{"location":"reference/wtracker/dataset/","text":"Module wtracker.dataset View Source from wtracker.dataset.sample_extractor import SampleExtractor from wtracker.dataset.box_calculator import BoxCalculator from wtracker.dataset.bg_extractor import BGExtractor Sub-modules wtracker.dataset.bg_extractor wtracker.dataset.box_calculator wtracker.dataset.sample_extractor","title":"Index"},{"location":"reference/wtracker/dataset/#module-wtrackerdataset","text":"View Source from wtracker.dataset.sample_extractor import SampleExtractor from wtracker.dataset.box_calculator import BoxCalculator from wtracker.dataset.bg_extractor import BGExtractor","title":"Module wtracker.dataset"},{"location":"reference/wtracker/dataset/#sub-modules","text":"wtracker.dataset.bg_extractor wtracker.dataset.box_calculator wtracker.dataset.sample_extractor","title":"Sub-modules"},{"location":"reference/wtracker/dataset/bg_extractor/","text":"Module wtracker.dataset.bg_extractor View Source import numpy as np from tqdm.auto import tqdm from wtracker.utils.frame_reader import FrameReader class BGExtractor : \"\"\" A class for extracting the background from a given sequence of frames, provided by a FrameReader. Args: reader (FrameReader): The FrameReader object holding the frames to extract the background from. \"\"\" def __init__ ( self , reader : FrameReader ): self . reader = reader def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \"random\" or \"uniform\". \"uniform\" will select frames uniformly spaced from the FrameReader. \"random\" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \"median\" or \"mean\". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\" , \"uniform\" ] assert method in [ \"median\" , \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg def _calc_background_mean ( self , frame_ids : np . ndarray ) -> np . ndarray : sum = np . zeros ( self . reader . frame_shape , dtype = np . float64 ) # read frames for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ): frame = self . reader [ frame_id ] sum += frame mean = sum / len ( frame_ids ) return mean . astype ( np . uint8 , copy = False ) def _calc_background_median ( self , frame_ids : np . ndarray ) -> np . ndarray : # get frames extracted_list = [] for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ): frame = self . reader [ frame_id ] extracted_list . append ( frame ) # calculate the median along the time axis extracted = np . stack ( extracted_list , axis = 0 ) median = np . median ( extracted , axis = 0 ) . astype ( np . uint8 , copy = False ) return median Classes BGExtractor class BGExtractor ( reader : wtracker . utils . frame_reader . FrameReader ) A class for extracting the background from a given sequence of frames, provided by a FrameReader. Attributes Name Type Description Default reader FrameReader The FrameReader object holding the frames to extract the background from. None View Source class BGExtractor : \"\"\" A class for extracting the background from a given sequence of frames, provided by a FrameReader. Args: reader (FrameReader): The FrameReader object holding the frames to extract the background from. \"\"\" def __init__ ( self , reader : FrameReader ) : self . reader = reader def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \" random \" or \" uniform \". \" uniform \" will select frames uniformly spaced from the FrameReader. \" random \" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \" median \" or \" mean \". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\", \"uniform\" ] assert method in [ \"median\", \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg def _calc_background_mean ( self , frame_ids : np . ndarray ) -> np . ndarray : sum = np . zeros ( self . reader . frame_shape , dtype = np . float64 ) # read frames for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ) : frame = self . reader [ frame_id ] sum += frame mean = sum / len ( frame_ids ) return mean . astype ( np . uint8 , copy = False ) def _calc_background_median ( self , frame_ids : np . ndarray ) -> np . ndarray : # get frames extracted_list = [] for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ) : frame = self . reader [ frame_id ] extracted_list . append ( frame ) # calculate the median along the time axis extracted = np . stack ( extracted_list , axis = 0 ) median = np . median ( extracted , axis = 0 ). astype ( np . uint8 , copy = False ) return median Methods calc_background def calc_background ( self , num_probes : int , sampling : str = 'uniform' , method : str = 'median' ) -> numpy . ndarray Calculate the background of the dataset. Parameters: Name Type Description Default num_probes int The number of probes to sample for background calculation. None sampling str The sampling method for selecting probes. Can be \"random\" or \"uniform\". \"uniform\" will select frames uniformly spaced from the FrameReader. \"random\" will select frames randomly from the FrameReader. None method str The method for calculating the background. Can be \"median\" or \"mean\". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. None Returns: Type Description np.ndarray The calculated background as a numpy array. View Source def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \" random \" or \" uniform \". \" uniform \" will select frames uniformly spaced from the FrameReader. \" random \" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \" median \" or \" mean \". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\" , \"uniform\" ] assert method in [ \"median\" , \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg","title":"Bg Extractor"},{"location":"reference/wtracker/dataset/bg_extractor/#module-wtrackerdatasetbg_extractor","text":"View Source import numpy as np from tqdm.auto import tqdm from wtracker.utils.frame_reader import FrameReader class BGExtractor : \"\"\" A class for extracting the background from a given sequence of frames, provided by a FrameReader. Args: reader (FrameReader): The FrameReader object holding the frames to extract the background from. \"\"\" def __init__ ( self , reader : FrameReader ): self . reader = reader def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \"random\" or \"uniform\". \"uniform\" will select frames uniformly spaced from the FrameReader. \"random\" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \"median\" or \"mean\". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\" , \"uniform\" ] assert method in [ \"median\" , \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg def _calc_background_mean ( self , frame_ids : np . ndarray ) -> np . ndarray : sum = np . zeros ( self . reader . frame_shape , dtype = np . float64 ) # read frames for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ): frame = self . reader [ frame_id ] sum += frame mean = sum / len ( frame_ids ) return mean . astype ( np . uint8 , copy = False ) def _calc_background_median ( self , frame_ids : np . ndarray ) -> np . ndarray : # get frames extracted_list = [] for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ): frame = self . reader [ frame_id ] extracted_list . append ( frame ) # calculate the median along the time axis extracted = np . stack ( extracted_list , axis = 0 ) median = np . median ( extracted , axis = 0 ) . astype ( np . uint8 , copy = False ) return median","title":"Module wtracker.dataset.bg_extractor"},{"location":"reference/wtracker/dataset/bg_extractor/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/dataset/bg_extractor/#bgextractor","text":"class BGExtractor ( reader : wtracker . utils . frame_reader . FrameReader ) A class for extracting the background from a given sequence of frames, provided by a FrameReader.","title":"BGExtractor"},{"location":"reference/wtracker/dataset/bg_extractor/#attributes","text":"Name Type Description Default reader FrameReader The FrameReader object holding the frames to extract the background from. None View Source class BGExtractor : \"\"\" A class for extracting the background from a given sequence of frames, provided by a FrameReader. Args: reader (FrameReader): The FrameReader object holding the frames to extract the background from. \"\"\" def __init__ ( self , reader : FrameReader ) : self . reader = reader def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \" random \" or \" uniform \". \" uniform \" will select frames uniformly spaced from the FrameReader. \" random \" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \" median \" or \" mean \". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\", \"uniform\" ] assert method in [ \"median\", \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg def _calc_background_mean ( self , frame_ids : np . ndarray ) -> np . ndarray : sum = np . zeros ( self . reader . frame_shape , dtype = np . float64 ) # read frames for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ) : frame = self . reader [ frame_id ] sum += frame mean = sum / len ( frame_ids ) return mean . astype ( np . uint8 , copy = False ) def _calc_background_median ( self , frame_ids : np . ndarray ) -> np . ndarray : # get frames extracted_list = [] for frame_id in tqdm ( frame_ids , desc = \"Extracting background frames\" , unit = \"fr\" ) : frame = self . reader [ frame_id ] extracted_list . append ( frame ) # calculate the median along the time axis extracted = np . stack ( extracted_list , axis = 0 ) median = np . median ( extracted , axis = 0 ). astype ( np . uint8 , copy = False ) return median","title":"Attributes"},{"location":"reference/wtracker/dataset/bg_extractor/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/dataset/bg_extractor/#calc_background","text":"def calc_background ( self , num_probes : int , sampling : str = 'uniform' , method : str = 'median' ) -> numpy . ndarray Calculate the background of the dataset. Parameters: Name Type Description Default num_probes int The number of probes to sample for background calculation. None sampling str The sampling method for selecting probes. Can be \"random\" or \"uniform\". \"uniform\" will select frames uniformly spaced from the FrameReader. \"random\" will select frames randomly from the FrameReader. None method str The method for calculating the background. Can be \"median\" or \"mean\". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. None Returns: Type Description np.ndarray The calculated background as a numpy array. View Source def calc_background ( self , num_probes : int , sampling : str = \"uniform\" , method : str = \"median\" ) -> np . ndarray : \"\"\" Calculate the background of the dataset. Args: num_probes (int): The number of probes to sample for background calculation. sampling (str, optional): The sampling method for selecting probes. Can be \" random \" or \" uniform \". \" uniform \" will select frames uniformly spaced from the FrameReader. \" random \" will select frames randomly from the FrameReader. method (str, optional): The method for calculating the background. Can be \" median \" or \" mean \". The background is calculated by either taking the median or mean of the sampled frames. Calculating the mean is substantially faster, but produces worse results. Returns: np.ndarray: The calculated background as a numpy array. \"\"\" assert sampling in [ \"random\" , \"uniform\" ] assert method in [ \"median\" , \"mean\" ] length = len ( self . reader ) size = min ( num_probes , length ) if sampling == \"random\" : frame_ids = np . random . choice ( length , size = size , replace = False ) elif sampling == \"uniform\" : frame_ids = np . linspace ( 0 , length - 1 , num = size ) frame_ids = np . unique ( frame_ids . astype ( int , copy = False )) if method == \"median\" : bg = self . _calc_background_median ( frame_ids ) elif method == \"mean\" : bg = self . _calc_background_mean ( frame_ids ) return bg","title":"calc_background"},{"location":"reference/wtracker/dataset/box_calculator/","text":"Module wtracker.dataset.box_calculator View Source import cv2 as cv import numpy as np from typing import Collection from tqdm.auto import tqdm from tqdm.contrib import concurrent from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import adjust_num_workers class BoxCalculator : \"\"\" A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Args: frame_reader (FrameReader): The frame reader object holing the relevant frames. background (np.ndarray): The background image of the frames in the `frame_reader` argument. diff_thresh (int, optional): Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. \"\"\" def __init__ ( self , frame_reader : FrameReader , background : np . ndarray , diff_thresh : int = 20 , ) -> None : assert diff_thresh > 0 , \"Difference threshold must be greater than 0.\" assert frame_reader . frame_shape == background . shape , \"Background shape must match frame shape.\" # convert background to grayscale if needed if background . ndim == 3 and background . shape [ 2 ] == 3 : background = cv . cvtColor ( background , cv . COLOR_BGR2GRAY ) if background . ndim != 2 : raise ValueError ( \"background must be either a gray or a color image.\" ) self . _frame_reader = frame_reader self . _background = background self . _diff_thresh = diff_thresh self . _all_bboxes = np . full (( len ( frame_reader ), 4 ), - 1 , dtype = int ) def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn't calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox def _calc_bounding_box ( self , frame_idx : int ) -> np . ndarray : # get mask according to the threshold value frame = self . _frame_reader [ frame_idx ] # convert to grayscale if needed if frame . ndim == 3 and frame . shape [ 2 ] == 3 : frame = cv . cvtColor ( frame , cv . COLOR_BGR2GRAY ) diff = cv . absdiff ( frame , self . _background ) _ , mask = cv . threshold ( diff , self . _diff_thresh , 255 , cv . THRESH_BINARY ) # apply morphological ops to the mask mask = cv . morphologyEx ( mask , cv . MORPH_OPEN , np . ones (( 5 , 5 ), np . uint8 )) mask = cv . dilate ( mask , np . ones (( 11 , 11 ), np . uint8 )) # extract contours and bbox contours , _ = cv . findContours ( mask , cv . RETR_EXTERNAL , cv . CHAIN_APPROX_NONE ) if not contours : zero_bbox = np . array ([ 0 , 0 , 0 , 0 ]) self . _all_bboxes [ frame_idx ] = zero_bbox return zero_bbox largest_contour = max ( contours , key = cv . contourArea ) largest_bbox = cv . boundingRect ( largest_contour ) largest_bbox = np . asanyarray ( largest_bbox , dtype = int ) return largest_bbox def calc_specified_boxes ( self , frame_indices : Collection [ int ], num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ): self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ): self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices , :] return bboxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size ) Classes BoxCalculator class BoxCalculator ( frame_reader : wtracker . utils . frame_reader . FrameReader , background : numpy . ndarray , diff_thresh : int = 20 ) A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Attributes Name Type Description Default frame_reader FrameReader The frame reader object holing the relevant frames. None background np.ndarray The background image of the frames in the frame_reader argument. None diff_thresh int Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. None View Source class BoxCalculator : \"\"\" A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Args: frame_reader (FrameReader): The frame reader object holing the relevant frames. background (np.ndarray): The background image of the frames in the `frame_reader` argument. diff_thresh (int, optional): Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. \"\"\" def __init__ ( self , frame_reader : FrameReader , background : np . ndarray , diff_thresh : int = 20 , ) -> None : assert diff_thresh > 0 , \"Difference threshold must be greater than 0.\" assert frame_reader . frame_shape == background . shape , \"Background shape must match frame shape.\" # convert background to grayscale if needed if background . ndim == 3 and background . shape [ 2 ] == 3 : background = cv . cvtColor ( background , cv . COLOR_BGR2GRAY ) if background . ndim != 2 : raise ValueError ( \"background must be either a gray or a color image.\" ) self . _frame_reader = frame_reader self . _background = background self . _diff_thresh = diff_thresh self . _all_bboxes = np . full (( len ( frame_reader ), 4 ), - 1 , dtype = int ) def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn ' t calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox def _calc_bounding_box ( self , frame_idx : int ) -> np . ndarray : # get mask according to the threshold value frame = self . _frame_reader [ frame_idx ] # convert to grayscale if needed if frame . ndim == 3 and frame . shape [ 2 ] == 3 : frame = cv . cvtColor ( frame , cv . COLOR_BGR2GRAY ) diff = cv . absdiff ( frame , self . _background ) _ , mask = cv . threshold ( diff , self . _diff_thresh , 255 , cv . THRESH_BINARY ) # apply morphological ops to the mask mask = cv . morphologyEx ( mask , cv . MORPH_OPEN , np . ones (( 5 , 5 ), np . uint8 )) mask = cv . dilate ( mask , np . ones (( 11 , 11 ), np . uint8 )) # extract contours and bbox contours , _ = cv . findContours ( mask , cv . RETR_EXTERNAL , cv . CHAIN_APPROX_NONE ) if not contours : zero_bbox = np . array ( [ 0, 0, 0, 0 ] ) self . _all_bboxes [ frame_idx ] = zero_bbox return zero_bbox largest_contour = max ( contours , key = cv . contourArea ) largest_bbox = cv . boundingRect ( largest_contour ) largest_bbox = np . asanyarray ( largest_bbox , dtype = int ) return largest_bbox def calc_specified_boxes ( self , frame_indices : Collection [ int ] , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ) : self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ) : self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices, : ] return bboxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size ) Methods all_bboxes def all_bboxes ( self ) -> numpy . ndarray Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: Type Description np.ndarray Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). View Source def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes calc_all_boxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 ) -> numpy . ndarray Calculate bounding boxes for all frames. Parameters: Name Type Description Default num_workers int Number of workers for parallel processing. If None is provided then number of workers is determined automatically. None chunk_size int Size of each chunk for parallel processing. None Returns: Type Description np.ndarray Array of bounding boxes for all frames. View Source def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size ) calc_specified_boxes def calc_specified_boxes ( self , frame_indices : Collection [ int ], num_workers : int = None , chunk_size : int = 50 ) -> numpy . ndarray Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. View Source def calc_specified_boxes ( self , frame_indices : Collection [ int ] , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ) : self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ) : self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices, : ] return bboxes get_bbox def get_bbox ( self , frame_idx : int ) -> numpy . ndarray Returns the bounding box for a given frame index. Parameters: Name Type Description Default frame_idx int The index of the frame from which to extract the bounding box. None Returns: Type Description np.ndarray The bounding box coordinates as a numpy array, in format (x, y, w, h). View Source def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn ' t calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox","title":"Box Calculator"},{"location":"reference/wtracker/dataset/box_calculator/#module-wtrackerdatasetbox_calculator","text":"View Source import cv2 as cv import numpy as np from typing import Collection from tqdm.auto import tqdm from tqdm.contrib import concurrent from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import adjust_num_workers class BoxCalculator : \"\"\" A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Args: frame_reader (FrameReader): The frame reader object holing the relevant frames. background (np.ndarray): The background image of the frames in the `frame_reader` argument. diff_thresh (int, optional): Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. \"\"\" def __init__ ( self , frame_reader : FrameReader , background : np . ndarray , diff_thresh : int = 20 , ) -> None : assert diff_thresh > 0 , \"Difference threshold must be greater than 0.\" assert frame_reader . frame_shape == background . shape , \"Background shape must match frame shape.\" # convert background to grayscale if needed if background . ndim == 3 and background . shape [ 2 ] == 3 : background = cv . cvtColor ( background , cv . COLOR_BGR2GRAY ) if background . ndim != 2 : raise ValueError ( \"background must be either a gray or a color image.\" ) self . _frame_reader = frame_reader self . _background = background self . _diff_thresh = diff_thresh self . _all_bboxes = np . full (( len ( frame_reader ), 4 ), - 1 , dtype = int ) def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn't calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox def _calc_bounding_box ( self , frame_idx : int ) -> np . ndarray : # get mask according to the threshold value frame = self . _frame_reader [ frame_idx ] # convert to grayscale if needed if frame . ndim == 3 and frame . shape [ 2 ] == 3 : frame = cv . cvtColor ( frame , cv . COLOR_BGR2GRAY ) diff = cv . absdiff ( frame , self . _background ) _ , mask = cv . threshold ( diff , self . _diff_thresh , 255 , cv . THRESH_BINARY ) # apply morphological ops to the mask mask = cv . morphologyEx ( mask , cv . MORPH_OPEN , np . ones (( 5 , 5 ), np . uint8 )) mask = cv . dilate ( mask , np . ones (( 11 , 11 ), np . uint8 )) # extract contours and bbox contours , _ = cv . findContours ( mask , cv . RETR_EXTERNAL , cv . CHAIN_APPROX_NONE ) if not contours : zero_bbox = np . array ([ 0 , 0 , 0 , 0 ]) self . _all_bboxes [ frame_idx ] = zero_bbox return zero_bbox largest_contour = max ( contours , key = cv . contourArea ) largest_bbox = cv . boundingRect ( largest_contour ) largest_bbox = np . asanyarray ( largest_bbox , dtype = int ) return largest_bbox def calc_specified_boxes ( self , frame_indices : Collection [ int ], num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ): self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ): self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices , :] return bboxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size )","title":"Module wtracker.dataset.box_calculator"},{"location":"reference/wtracker/dataset/box_calculator/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/dataset/box_calculator/#boxcalculator","text":"class BoxCalculator ( frame_reader : wtracker . utils . frame_reader . FrameReader , background : numpy . ndarray , diff_thresh : int = 20 ) A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box.","title":"BoxCalculator"},{"location":"reference/wtracker/dataset/box_calculator/#attributes","text":"Name Type Description Default frame_reader FrameReader The frame reader object holing the relevant frames. None background np.ndarray The background image of the frames in the frame_reader argument. None diff_thresh int Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. None View Source class BoxCalculator : \"\"\" A class for calculating bounding boxes around an object for a sequence of frames. The bounding boxes are calculated by comparing the frames to a background image. The largest contour in the difference image between the frame and the background is used to calculate the bounding box. Args: frame_reader (FrameReader): The frame reader object holing the relevant frames. background (np.ndarray): The background image of the frames in the `frame_reader` argument. diff_thresh (int, optional): Threshold value for the detecting foreground objects. Pixels with difference value greater than this threshold are considered as foreground. \"\"\" def __init__ ( self , frame_reader : FrameReader , background : np . ndarray , diff_thresh : int = 20 , ) -> None : assert diff_thresh > 0 , \"Difference threshold must be greater than 0.\" assert frame_reader . frame_shape == background . shape , \"Background shape must match frame shape.\" # convert background to grayscale if needed if background . ndim == 3 and background . shape [ 2 ] == 3 : background = cv . cvtColor ( background , cv . COLOR_BGR2GRAY ) if background . ndim != 2 : raise ValueError ( \"background must be either a gray or a color image.\" ) self . _frame_reader = frame_reader self . _background = background self . _diff_thresh = diff_thresh self . _all_bboxes = np . full (( len ( frame_reader ), 4 ), - 1 , dtype = int ) def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn ' t calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox def _calc_bounding_box ( self , frame_idx : int ) -> np . ndarray : # get mask according to the threshold value frame = self . _frame_reader [ frame_idx ] # convert to grayscale if needed if frame . ndim == 3 and frame . shape [ 2 ] == 3 : frame = cv . cvtColor ( frame , cv . COLOR_BGR2GRAY ) diff = cv . absdiff ( frame , self . _background ) _ , mask = cv . threshold ( diff , self . _diff_thresh , 255 , cv . THRESH_BINARY ) # apply morphological ops to the mask mask = cv . morphologyEx ( mask , cv . MORPH_OPEN , np . ones (( 5 , 5 ), np . uint8 )) mask = cv . dilate ( mask , np . ones (( 11 , 11 ), np . uint8 )) # extract contours and bbox contours , _ = cv . findContours ( mask , cv . RETR_EXTERNAL , cv . CHAIN_APPROX_NONE ) if not contours : zero_bbox = np . array ( [ 0, 0, 0, 0 ] ) self . _all_bboxes [ frame_idx ] = zero_bbox return zero_bbox largest_contour = max ( contours , key = cv . contourArea ) largest_bbox = cv . boundingRect ( largest_contour ) largest_bbox = np . asanyarray ( largest_bbox , dtype = int ) return largest_bbox def calc_specified_boxes ( self , frame_indices : Collection [ int ] , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ) : self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ) : self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices, : ] return bboxes def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size )","title":"Attributes"},{"location":"reference/wtracker/dataset/box_calculator/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/dataset/box_calculator/#all_bboxes","text":"def all_bboxes ( self ) -> numpy . ndarray Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: Type Description np.ndarray Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). View Source def all_bboxes ( self ) -> np . ndarray : \"\"\" Returns all bounding boxes for all the frames. Note that if a bounding box has not been calculated for some frame, then the matching entry will be (-1, -1, -1, -1). Returns: np.ndarray: Array of bounding boxes, in shape (N, 4), where N is the number of frames. The bounding boxes are stored in the format (x, y, w, h). \"\"\" return self . _all_bboxes","title":"all_bboxes"},{"location":"reference/wtracker/dataset/box_calculator/#calc_all_boxes","text":"def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 ) -> numpy . ndarray Calculate bounding boxes for all frames. Parameters: Name Type Description Default num_workers int Number of workers for parallel processing. If None is provided then number of workers is determined automatically. None chunk_size int Size of each chunk for parallel processing. None Returns: Type Description np.ndarray Array of bounding boxes for all frames. View Source def calc_all_boxes ( self , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for all frames. Args: num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: Array of bounding boxes for all frames. \"\"\" indices = range ( len ( self . _frame_reader )) return self . calc_specified_boxes ( indices , num_workers , chunk_size )","title":"calc_all_boxes"},{"location":"reference/wtracker/dataset/box_calculator/#calc_specified_boxes","text":"def calc_specified_boxes ( self , frame_indices : Collection [ int ], num_workers : int = None , chunk_size : int = 50 ) -> numpy . ndarray Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. View Source def calc_specified_boxes ( self , frame_indices : Collection [ int ] , num_workers : int = None , chunk_size : int = 50 , ) -> np . ndarray : \"\"\" Calculate bounding boxes for the specified frame indices. Args: frame_indices (Iterable[int]): The indices of the frames for which to calculate the bboxes. num_workers (int, optional): Number of workers for parallel processing. If None is provided then number of workers is determined automatically. chunk_size (int, optional): Size of each chunk for parallel processing. Returns: np.ndarray: The calculated boxes for the specified frames. \"\"\" num_workers = adjust_num_workers ( len ( frame_indices ), chunk_size , num_workers ) if num_workers > 0 : bbox_list = concurrent . process_map ( self . get_bbox , frame_indices , max_workers = num_workers , chunksize = chunk_size , desc = \"Extracting bboxes\" , unit = \"fr\" , ) for idx , bbox in zip ( frame_indices , bbox_list ) : self . _all_bboxes [ idx ] = bbox else : for idx in tqdm ( frame_indices , desc = \"Extracting bboxes\" , unit = \"fr\" ) : self . get_bbox ( idx ) bboxes = self . _all_bboxes [ frame_indices, : ] return bboxes","title":"calc_specified_boxes"},{"location":"reference/wtracker/dataset/box_calculator/#get_bbox","text":"def get_bbox ( self , frame_idx : int ) -> numpy . ndarray Returns the bounding box for a given frame index. Parameters: Name Type Description Default frame_idx int The index of the frame from which to extract the bounding box. None Returns: Type Description np.ndarray The bounding box coordinates as a numpy array, in format (x, y, w, h). View Source def get_bbox ( self , frame_idx : int ) -> np . ndarray : \"\"\" Returns the bounding box for a given frame index. Args: frame_idx (int): The index of the frame from which to extract the bounding box. Returns: np.ndarray: The bounding box coordinates as a numpy array, in format (x, y, w, h). \"\"\" bbox = self . _all_bboxes [ frame_idx ] if bbox [ 0 ] == - 1 : # calculate bbox since it wasn ' t calculated before bbox = self . _calc_bounding_box ( frame_idx ) self . _all_bboxes [ frame_idx ] = bbox return bbox","title":"get_bbox"},{"location":"reference/wtracker/dataset/sample_extractor/","text":"Module wtracker.dataset.sample_extractor View Source import numpy as np from typing import Collection from wtracker.dataset.box_calculator import BoxCalculator from wtracker.utils.bbox_utils import BoxUtils from wtracker.utils.io_utils import FrameSaver class SampleExtractor : \"\"\" A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Args: bbox_calculator (BoxCalculator): An instance of the BoxCalculator class. \"\"\" def __init__ ( self , bbox_calculator : BoxCalculator ): self . _bbox_calculator = bbox_calculator self . _frame_reader = bbox_calculator . _frame_reader def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int , int ]) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ): raise ValueError () if np . any ( x + w > frame_size [ 0 ]) or np . any ( y + h > frame_size [ 1 ]): raise ValueError () return BoxUtils . pack ( x , y , w , h ) def create_specified_samples ( self , frame_indices : Collection [ int ], target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ]) h = np . full_like ( x , target_size [ 1 ]) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # (h, w) -> (w, h) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ): saver . schedule_save ( i , bbox , name_format . format ( i )) def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) def create_all_samples ( self , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range ( 0 , len ( self . _frame_reader )) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) Classes SampleExtractor class SampleExtractor ( bbox_calculator : wtracker . dataset . box_calculator . BoxCalculator ) A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Attributes Name Type Description Default bbox_calculator BoxCalculator An instance of the BoxCalculator class. None View Source class SampleExtractor : \"\"\" A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Args: bbox_calculator (BoxCalculator): An instance of the BoxCalculator class. \"\"\" def __init__ ( self , bbox_calculator : BoxCalculator ) : self . _bbox_calculator = bbox_calculator self . _frame_reader = bbox_calculator . _frame_reader def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int, int ] ) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ) : raise ValueError () if np . any ( x + w > frame_size [ 0 ] ) or np . any ( y + h > frame_size [ 1 ] ) : raise ValueError () return BoxUtils . pack ( x , y , w , h ) def create_specified_samples ( self , frame_indices : Collection [ int ] , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ] ) h = np . full_like ( x , target_size [ 1 ] ) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # ( h , w ) -> ( w , h ) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ) : saver . schedule_save ( i , bbox , name_format . format ( i )) def create_samples ( self , count : int , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) def create_all_samples ( self , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range ( 0 , len ( self . _frame_reader )) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) Methods create_all_samples def create_all_samples ( self , target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates samples for all frames. Parameters: Name Type Description Default target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def create_all_samples( self, target_size: tuple[int, int], save_folder: str, name_format: str = \"img_{:09d}.png\", num_workers: int = None, chunk_size: int = 50, ): \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range(0, len(self._frame_reader)) self.create_specified_samples(frame_indices, target_size, save_folder, name_format, num_workers, chunk_size) create_samples def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates random samples based on a specified count. Parameters: Name Type Description Default count int The number of samples to create. None target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk sent to each worker. None View Source def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) create_specified_samples def create_specified_samples ( self , frame_indices : Collection [ int ], target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates specified samples based on the given frame indices. Parameters: Name Type Description Default frame_indices Collection[int] The indices of the frames to extract samples from. None target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def create_specified_samples ( self , frame_indices : Collection [ int ] , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ] ) h = np . full_like ( x , target_size [ 1 ] ) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # ( h , w ) -> ( w , h ) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ) : saver . schedule_save ( i , bbox , name_format . format ( i )) move_bboxes_into_bounds def move_bboxes_into_bounds ( self , bboxes : numpy . ndarray , frame_size : tuple [ int , int ] ) -> numpy . ndarray Moves the bounding boxes into the bounds of the frame. Parameters: Name Type Description Default bboxes np.ndarray The bounding boxes to be moved. None frame_size tuple[int, int] The size of the frame in the format (w, h). None Returns: Type Description np.ndarray The updated bounding boxes. Raises: Type Description ValueError If exists a bounding box which cannot be moved into the provided bounds without resizing it. View Source def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int, int ] ) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ) : raise ValueError () if np . any ( x + w > frame_size [ 0 ] ) or np . any ( y + h > frame_size [ 1 ] ) : raise ValueError () return BoxUtils . pack ( x , y , w , h )","title":"Sample Extractor"},{"location":"reference/wtracker/dataset/sample_extractor/#module-wtrackerdatasetsample_extractor","text":"View Source import numpy as np from typing import Collection from wtracker.dataset.box_calculator import BoxCalculator from wtracker.utils.bbox_utils import BoxUtils from wtracker.utils.io_utils import FrameSaver class SampleExtractor : \"\"\" A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Args: bbox_calculator (BoxCalculator): An instance of the BoxCalculator class. \"\"\" def __init__ ( self , bbox_calculator : BoxCalculator ): self . _bbox_calculator = bbox_calculator self . _frame_reader = bbox_calculator . _frame_reader def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int , int ]) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ): raise ValueError () if np . any ( x + w > frame_size [ 0 ]) or np . any ( y + h > frame_size [ 1 ]): raise ValueError () return BoxUtils . pack ( x , y , w , h ) def create_specified_samples ( self , frame_indices : Collection [ int ], target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ]) h = np . full_like ( x , target_size [ 1 ]) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # (h, w) -> (w, h) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ): saver . schedule_save ( i , bbox , name_format . format ( i )) def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) def create_all_samples ( self , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_ {:09d} .png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range ( 0 , len ( self . _frame_reader )) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size )","title":"Module wtracker.dataset.sample_extractor"},{"location":"reference/wtracker/dataset/sample_extractor/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/dataset/sample_extractor/#sampleextractor","text":"class SampleExtractor ( bbox_calculator : wtracker . dataset . box_calculator . BoxCalculator ) A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models.","title":"SampleExtractor"},{"location":"reference/wtracker/dataset/sample_extractor/#attributes","text":"Name Type Description Default bbox_calculator BoxCalculator An instance of the BoxCalculator class. None View Source class SampleExtractor : \"\"\" A class that extracts samples from frames based on specified parameters. Each sample is a cropped image around a bounding box which was detected in the frame. The bounding boxes are calculated using the BoxCalculator class. This class is used to create image datasets for training object detection models. Args: bbox_calculator (BoxCalculator): An instance of the BoxCalculator class. \"\"\" def __init__ ( self , bbox_calculator : BoxCalculator ) : self . _bbox_calculator = bbox_calculator self . _frame_reader = bbox_calculator . _frame_reader def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int, int ] ) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ) : raise ValueError () if np . any ( x + w > frame_size [ 0 ] ) or np . any ( y + h > frame_size [ 1 ] ) : raise ValueError () return BoxUtils . pack ( x , y , w , h ) def create_specified_samples ( self , frame_indices : Collection [ int ] , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ] ) h = np . full_like ( x , target_size [ 1 ] ) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # ( h , w ) -> ( w , h ) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ) : saver . schedule_save ( i , bbox , name_format . format ( i )) def create_samples ( self , count : int , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size ) def create_all_samples ( self , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range ( 0 , len ( self . _frame_reader )) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size )","title":"Attributes"},{"location":"reference/wtracker/dataset/sample_extractor/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/dataset/sample_extractor/#create_all_samples","text":"def create_all_samples ( self , target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates samples for all frames. Parameters: Name Type Description Default target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def create_all_samples( self, target_size: tuple[int, int], save_folder: str, name_format: str = \"img_{:09d}.png\", num_workers: int = None, chunk_size: int = 50, ): \"\"\" Creates samples for all frames. Args: target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frame_indices = range(0, len(self._frame_reader)) self.create_specified_samples(frame_indices, target_size, save_folder, name_format, num_workers, chunk_size)","title":"create_all_samples"},{"location":"reference/wtracker/dataset/sample_extractor/#create_samples","text":"def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates random samples based on a specified count. Parameters: Name Type Description Default count int The number of samples to create. None target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk sent to each worker. None View Source def create_samples ( self , count : int , target_size : tuple [ int , int ], save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ): \"\"\" Creates random samples based on a specified count. Args: count (int): The number of samples to create. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk sent to each worker. \"\"\" length = len ( self . _frame_reader ) count = min ( length , count ) frame_indices = np . random . choice ( length , size = count , replace = False ) self . create_specified_samples ( frame_indices , target_size , save_folder , name_format , num_workers , chunk_size )","title":"create_samples"},{"location":"reference/wtracker/dataset/sample_extractor/#create_specified_samples","text":"def create_specified_samples ( self , frame_indices : Collection [ int ], target_size : tuple [ int , int ], save_folder : str , name_format : str = 'img_ {:09d} .png' , num_workers : int = None , chunk_size : int = 50 ) Creates specified samples based on the given frame indices. Parameters: Name Type Description Default frame_indices Collection[int] The indices of the frames to extract samples from. None target_size tuple[int, int] The target size of the samples in the format (w, h). None save_folder str The folder path to save the samples. None name_format str The format of the sample names. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def create_specified_samples ( self , frame_indices : Collection [ int ] , target_size : tuple [ int, int ] , save_folder : str , name_format : str = \"img_{:09d}.png\" , num_workers : int = None , chunk_size : int = 50 , ) : \"\"\" Creates specified samples based on the given frame indices. Args: frame_indices (Collection[int]): The indices of the frames to extract samples from. target_size (tuple[int, int]): The target size of the samples in the format (w, h). save_folder (str): The folder path to save the samples. name_format (str, optional): The format of the sample names. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" bboxes = self . _bbox_calculator . calc_specified_boxes ( frame_indices = frame_indices , num_workers = num_workers , chunk_size = chunk_size , ) x , y , w , h = BoxUtils . unpack ( bboxes ) x -= np . random . randint ( 0 , target_size [ 0 ] - w + 1 ) y -= np . random . randint ( 0 , target_size [ 1 ] - h + 1 ) w = np . full_like ( x , target_size [ 0 ] ) h = np . full_like ( x , target_size [ 1 ] ) bboxes = BoxUtils . pack ( x , y , w , h ) frame_size = tuple ( reversed ( self . _frame_reader . frame_size )) # ( h , w ) -> ( w , h ) bboxes = self . move_bboxes_into_bounds ( bboxes , frame_size ) with FrameSaver ( self . _frame_reader , root_path = save_folder , desc = \"Saving samples\" , unit = \"fr\" ) as saver : for i , bbox in enumerate ( bboxes ) : saver . schedule_save ( i , bbox , name_format . format ( i ))","title":"create_specified_samples"},{"location":"reference/wtracker/dataset/sample_extractor/#move_bboxes_into_bounds","text":"def move_bboxes_into_bounds ( self , bboxes : numpy . ndarray , frame_size : tuple [ int , int ] ) -> numpy . ndarray Moves the bounding boxes into the bounds of the frame. Parameters: Name Type Description Default bboxes np.ndarray The bounding boxes to be moved. None frame_size tuple[int, int] The size of the frame in the format (w, h). None Returns: Type Description np.ndarray The updated bounding boxes. Raises: Type Description ValueError If exists a bounding box which cannot be moved into the provided bounds without resizing it. View Source def move_bboxes_into_bounds ( self , bboxes : np . ndarray , frame_size : tuple [ int, int ] ) -> np . ndarray : \"\"\" Moves the bounding boxes into the bounds of the frame. Args: bboxes (np.ndarray): The bounding boxes to be moved. frame_size (tuple[int, int]): The size of the frame in the format (w, h). Returns: np.ndarray: The updated bounding boxes. Raises: ValueError: If exists a bounding box which cannot be moved into the provided bounds without resizing it. \"\"\" max_w , max_h = frame_size x , y , w , h = BoxUtils . unpack ( bboxes ) x [ x < 0 ] = 0 mask = ( x + w ) > max_w x [ mask ] = max_w - w [ mask ] y [ y < 0 ] = 0 mask = ( y + h ) > max_h y [ mask ] = max_h - h [ mask ] if np . any ( x < 0 ) or np . any ( y < 0 ) : raise ValueError () if np . any ( x + w > frame_size [ 0 ] ) or np . any ( y + h > frame_size [ 1 ] ) : raise ValueError () return BoxUtils . pack ( x , y , w , h )","title":"move_bboxes_into_bounds"},{"location":"reference/wtracker/eval/","text":"Module wtracker.eval View Source from wtracker.eval.plotter import Plotter from wtracker.eval.data_analyzer import DataAnalyzer from wtracker.eval.error_calculator import ErrorCalculator from wtracker.eval.vlc import VLC , StreamViewer , HotKey Sub-modules wtracker.eval.data_analyzer wtracker.eval.error_calculator wtracker.eval.plotter wtracker.eval.vlc","title":"Index"},{"location":"reference/wtracker/eval/#module-wtrackereval","text":"View Source from wtracker.eval.plotter import Plotter from wtracker.eval.data_analyzer import DataAnalyzer from wtracker.eval.error_calculator import ErrorCalculator from wtracker.eval.vlc import VLC , StreamViewer , HotKey","title":"Module wtracker.eval"},{"location":"reference/wtracker/eval/#sub-modules","text":"wtracker.eval.data_analyzer wtracker.eval.error_calculator wtracker.eval.plotter wtracker.eval.vlc","title":"Sub-modules"},{"location":"reference/wtracker/eval/data_analyzer/","text":"Module wtracker.eval.data_analyzer View Source from __future__ import annotations import pandas as pd import numpy as np import tqdm.contrib.concurrent as concurrent from wtracker.sim.config import TimingConfig from wtracker.eval.error_calculator import ErrorCalculator from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import adjust_num_workers class DataAnalyzer : \"\"\" A class for analyzing simulation log. Args: time_config (TimingConfig): The timing configuration. log_path (pd.DataFrame): Dataframe containing the simulation log data. \"\"\" def __init__ ( self , time_config : TimingConfig , log_data : pd . DataFrame , ): self . time_config = time_config self . data = log_data . copy () self . _orig_data = log_data self . _unit = \"frame\" @property def unit ( self ) -> str : return self . _unit def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False ) @staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) def initialize ( self , period : int = 10 ): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self . _orig_data data [ \"time\" ] = data [ \"frame\" ] data [ \"cycle_step\" ] = data [ \"frame\" ] % self . time_config . cycle_frame_num data = DataAnalyzer . _calc_centers ( data ) data = DataAnalyzer . _calc_speed ( data , period ) data = DataAnalyzer . _calc_worm_deviation ( data ) data = DataAnalyzer . _calc_errors ( data ) data = data . round ( 5 ) self . _orig_data = data self . data = self . _orig_data . copy () @staticmethod def _calc_centers ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"wrm_center_x\" ] = data [ \"wrm_x\" ] + data [ \"wrm_w\" ] / 2 data [ \"wrm_center_y\" ] = data [ \"wrm_y\" ] + data [ \"wrm_h\" ] / 2 data [ \"mic_center_x\" ] = data [ \"mic_x\" ] + data [ \"mic_w\" ] / 2 data [ \"mic_center_y\" ] = data [ \"mic_y\" ] + data [ \"mic_h\" ] / 2 return data @staticmethod def _calc_speed ( data : pd . DataFrame , n : int ) -> pd . DataFrame : diff = data [ \"time\" ] . diff ( n ) . to_numpy () data [ \"wrm_speed_x\" ] = data [ \"wrm_center_x\" ] . diff ( n ) / diff data [ \"wrm_speed_y\" ] = data [ \"wrm_center_y\" ] . diff ( n ) / diff data [ \"wrm_speed\" ] = np . sqrt ( data [ \"wrm_speed_x\" ] ** 2 + data [ \"wrm_speed_y\" ] ** 2 ) return data @staticmethod def _calc_worm_deviation ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"worm_deviation_x\" ] = data [ \"wrm_center_x\" ] - data [ \"mic_center_x\" ] data [ \"worm_deviation_y\" ] = data [ \"wrm_center_y\" ] - data [ \"mic_center_y\" ] data [ \"worm_deviation\" ] = np . sqrt ( data [ \"worm_deviation_x\" ] ** 2 + data [ \"worm_deviation_y\" ] ** 2 ) return data @staticmethod def _calc_errors ( data : pd . DataFrame ) -> pd . DataFrame : wrm_bboxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () bbox_error = ErrorCalculator . calculate_bbox_error ( wrm_bboxes , mic_bboxes ) data [ \"bbox_error\" ] = bbox_error data [ \"precise_error\" ] = np . nan return data def remove_cycle ( self , cycles : int | list [ int ]): \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ): cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~ mask ] def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float , float , float , float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ]) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ]) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ]) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ]) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ]) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ]) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ]) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ]) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data def reset_changes ( self ): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self . data = self . _orig_data . copy () self . _unit = \"frame\" def column_names ( self ) -> list [ str ]: \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () def change_unit ( self , unit : str ): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [ \"frame\" , \"sec\" ] if self . _unit == unit : return data = self . data if unit == \"sec\" : # frame -> sec dist_factor = self . time_config . mm_per_px * 1000 time_factor = self . time_config . ms_per_frame / 1000 if unit == \"frame\" : # sec -> frame dist_factor = self . time_config . px_per_mm / 1000 time_factor = self . time_config . frames_per_sec data [ \"time\" ] *= time_factor data [[ \"plt_x\" , \"plt_y\" ]] *= dist_factor data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] *= dist_factor data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] *= dist_factor data [[ \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" ]] *= dist_factor data [[ \"wrm_center_x\" , \"wrm_center_y\" ]] *= dist_factor data [[ \"mic_center_x\" , \"mic_center_y\" ]] *= dist_factor data [[ \"worm_deviation_x\" , \"worm_deviation_y\" , \"worm_deviation\" ]] *= dist_factor data [[ \"wrm_speed_x\" , \"wrm_speed_y\" , \"wrm_speed\" ]] *= dist_factor / time_factor self . _unit = unit self . data = data # TODO: TEST # TODO: MAYBE REMOVE, THE non-multithreaded version works very fast for me for some reason # perhaps SSD is required for fast analysis. def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( int , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ) . all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO: add non-multithreaded case whenever num_workers=0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ], mic_bboxes = mic_bboxes_list [ idx ], frame_nums = frames_list [ idx ], worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~ mask ] return anomalies def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 : - 1 ] return self . data [ columns ] . describe ( percentiles ) def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: { num_removed } ( { round ( 100 * num_removed / len ( self . _orig_data . index ), 3 ) } %)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: { no_preds } ( { round ( 100 * no_preds / len ( self . data . index ), 3 ) } %)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: { num_cycles } \" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: { round ( 100 * non_perfect , 3 ) } %\" ) Classes DataAnalyzer class DataAnalyzer ( time_config : 'TimingConfig' , log_data : 'pd.DataFrame' ) A class for analyzing simulation log. Attributes Name Type Description Default time_config TimingConfig The timing configuration. None log_path pd.DataFrame Dataframe containing the simulation log data. None View Source class DataAnalyzer : \"\"\" A class for analyzing simulation log. Args: time_config (TimingConfig): The timing configuration. log_path (pd.DataFrame): Dataframe containing the simulation log data. \"\"\" def __init__ ( self , time_config : TimingConfig , log_data : pd . DataFrame , ): self . time_config = time_config self . data = log_data . copy () self . _orig_data = log_data self . _unit = \"frame\" @ property def unit ( self ) -> str : return self . _unit def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False ) @ staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) def initialize ( self , period : int = 10 ): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self . _orig_data data [ \"time\" ] = data [ \"frame\" ] data [ \"cycle_step\" ] = data [ \"frame\" ] % self . time_config . cycle_frame_num data = DataAnalyzer . _calc_centers ( data ) data = DataAnalyzer . _calc_speed ( data , period ) data = DataAnalyzer . _calc_worm_deviation ( data ) data = DataAnalyzer . _calc_errors ( data ) data = data . round ( 5 ) self . _orig_data = data self . data = self . _orig_data . copy () @ staticmethod def _calc_centers ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"wrm_center_x\" ] = data [ \"wrm_x\" ] + data [ \"wrm_w\" ] / 2 data [ \"wrm_center_y\" ] = data [ \"wrm_y\" ] + data [ \"wrm_h\" ] / 2 data [ \"mic_center_x\" ] = data [ \"mic_x\" ] + data [ \"mic_w\" ] / 2 data [ \"mic_center_y\" ] = data [ \"mic_y\" ] + data [ \"mic_h\" ] / 2 return data @ staticmethod def _calc_speed ( data : pd . DataFrame , n : int ) -> pd . DataFrame : diff = data [ \"time\" ] . diff ( n ) . to_numpy () data [ \"wrm_speed_x\" ] = data [ \"wrm_center_x\" ] . diff ( n ) / diff data [ \"wrm_speed_y\" ] = data [ \"wrm_center_y\" ] . diff ( n ) / diff data [ \"wrm_speed\" ] = np . sqrt ( data [ \"wrm_speed_x\" ] ** 2 + data [ \"wrm_speed_y\" ] ** 2 ) return data @ staticmethod def _calc_worm_deviation ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"worm_deviation_x\" ] = data [ \"wrm_center_x\" ] - data [ \"mic_center_x\" ] data [ \"worm_deviation_y\" ] = data [ \"wrm_center_y\" ] - data [ \"mic_center_y\" ] data [ \"worm_deviation\" ] = np . sqrt ( data [ \"worm_deviation_x\" ] ** 2 + data [ \"worm_deviation_y\" ] ** 2 ) return data @ staticmethod def _calc_errors ( data : pd . DataFrame ) -> pd . DataFrame : wrm_bboxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () bbox_error = ErrorCalculator . calculate_bbox_error ( wrm_bboxes , mic_bboxes ) data [ \"bbox_error\" ] = bbox_error data [ \"precise_error\" ] = np . nan return data def remove_cycle ( self , cycles : int | list [ int ]): \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ): cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~ mask ] def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float , float , float , float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ]) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ]) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ]) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ]) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ]) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ]) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ]) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ]) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data def reset_changes ( self ): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self . data = self . _orig_data . copy () self . _unit = \"frame\" def column_names ( self ) -> list [ str ]: \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () def change_unit ( self , unit : str ): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [ \"frame\" , \"sec\" ] if self . _unit == unit : return data = self . data if unit == \"sec\" : # frame -> sec dist_factor = self . time_config . mm_per_px * 1000 time_factor = self . time_config . ms_per_frame / 1000 if unit == \"frame\" : # sec -> frame dist_factor = self . time_config . px_per_mm / 1000 time_factor = self . time_config . frames_per_sec data [ \"time\" ] *= time_factor data [[ \"plt_x\" , \"plt_y\" ]] *= dist_factor data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] *= dist_factor data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] *= dist_factor data [[ \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" ]] *= dist_factor data [[ \"wrm_center_x\" , \"wrm_center_y\" ]] *= dist_factor data [[ \"mic_center_x\" , \"mic_center_y\" ]] *= dist_factor data [[ \"worm_deviation_x\" , \"worm_deviation_y\" , \"worm_deviation\" ]] *= dist_factor data [[ \"wrm_speed_x\" , \"wrm_speed_y\" , \"wrm_speed\" ]] *= dist_factor / time_factor self . _unit = unit self . data = data # TODO: TEST # TODO: MAYBE REMOVE, THE non-multithreaded version works very fast for me for some reason # perhaps SSD is required for fast analysis. def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( int , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ) . all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO: add non-multithreaded case whenever num_workers=0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ], mic_bboxes = mic_bboxes_list [ idx ], frame_nums = frames_list [ idx ], worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~ mask ] return anomalies def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 : - 1 ] return self . data [ columns ] . describe ( percentiles ) def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: {num_removed} ({round(100 * num_removed / len(self._orig_data.index), 3)}%)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: {no_preds} ({round(100 * no_preds / len(self.data.index), 3)}%)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: {num_cycles}\" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: {round(100 * non_perfect, 3)}%\" ) Static methods load def load ( time_config : 'TimingConfig' , csv_path : 'str' ) -> 'DataAnalyzer' Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Parameters: Name Type Description Default time_config TimingConfig The timing configuration. None csv_path str Path to the csv file containing the experiment data. None View Source @ staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) Instance variables unit Methods calc_anomalies def calc_anomalies ( self , no_preds : 'bool' = True , min_bbox_error : 'float' = inf , min_dist_error : 'float' = inf , min_speed : 'float' = inf , min_size : 'float' = inf , remove_anomalies : 'bool' = False ) -> 'pd.DataFrame' Calculate anomalies in the data based on specified criteria. Parameters: Name Type Description Default no_preds bool Flag indicating whether to consider instances with missing predictions. None min_bbox_error float Minimum bounding box error threshold to consider as anomaly. None min_dist_error float Minimum distance error threshold to consider as anomaly. None min_speed float Minimum speed threshold to consider as anomaly. None min_size float Minimum size threshold to consider as anomaly. None remove_anomalies bool Flag indicating whether to remove the anomalies from the data. None Returns: Type Description pd.DataFrame DataFrame containing the anomalies found in the data. View Source def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ()). all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~mask ] return anomalies calc_precise_error def calc_precise_error ( self , worm_reader : 'FrameReader' , background : 'np.ndarray' , diff_thresh = 20 ) -> 'None' Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Parameters: Name Type Description Default worm_reader FrameReader Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. None background np.ndarray The background image of the entire experiment. None diff_thresh int Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None View Source def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy (). astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy () mic_bboxes = self . _orig_data [ [\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\" ] ] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] calc_precise_error_experimental def calc_precise_error_experimental ( self , worm_reader : 'FrameReader' , background : 'np.ndarray' , diff_thresh = 20 , num_workers : 'int' = None , chunk_size : 'int' = 2000 ) -> 'None' Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Parameters: Name Type Description Default worm_reader FrameReader Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. None background np.ndarray The background image of the entire experiment. None diff_thresh int Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy (). astype ( int , copy = False ) wrm_bboxes = self . _orig_data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy () mic_bboxes = self . _orig_data [ [\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\" ] ] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ). all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO : add non - multithreaded case whenever num_workers = 0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ] , mic_bboxes = mic_bboxes_list [ idx ] , frame_nums = frames_list [ idx ] , worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] change_unit def change_unit ( self , unit : 'str' ) Changes the unit of time and distance in the data. Parameters: Name Type Description Default unit str The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. None View Source def change_unit(self, unit: str): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [\"frame\", \"sec\"] if self._unit == unit: return data = self.data if unit == \"sec\": # frame -> sec dist_factor = self.time_config.mm_per_px * 1000 time_factor = self.time_config.ms_per_frame / 1000 if unit == \"frame\": # sec -> frame dist_factor = self.time_config.px_per_mm / 1000 time_factor = self.time_config.frames_per_sec data[\"time\"] * = time_factor data[[\"plt_x\", \"plt_y\"]] *= dist_factor data[[\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\"]] * = dist_factor data[[\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\"]] *= dist_factor data[[\"cam_x\", \"cam_y\", \"cam_w\", \"cam_h\"]] * = dist_factor data[[\"wrm_center_x\", \"wrm_center_y\"]] *= dist_factor data[[\"mic_center_x\", \"mic_center_y\"]] * = dist_factor data[[\"worm_deviation_x\", \"worm_deviation_y\", \"worm_deviation\"]] *= dist_factor data[[\"wrm_speed_x\", \"wrm_speed_y\", \"wrm_speed\"]] * = dist_factor / time_factor self._unit = unit self.data = data clean def clean ( self , trim_cycles : 'bool' = False , imaging_only : 'bool' = False , bounds : 'tuple[float, float, float, float]' = None ) -> 'None' Clean the data by the provided parameters. Parameters: Name Type Description Default trim_cycles bool whether to remove the first and the last cycles from the data. None imaging_only bool Flag indicating whether to include only imaging phases in the analysis. None legal_bounds tuple[float, float, float, float] The legal bounds for worm movement. None View Source def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float, float, float, float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ()). all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ] ) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ] ) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ] ) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ] ) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ] ) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ] ) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ] ) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ] ) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data column_names def column_names ( self ) -> 'list[str]' Returns a list of all column names in the analyzed data. Returns: Type Description list[str] A list of column names. View Source def column_names ( self ) -> list [ str ] : \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () describe def describe ( self , columns : 'list[str]' = None , num : 'int' = 3 , percentiles : 'list[float]' = None ) -> 'pd.DataFrame' Generate descriptive statistics of the specified columns in the table containing the data. Parameters: Name Type Description Default columns list[str] List of column names to include in the analysis. If None, all columns will be included. None num int Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. None percentiles list[float] List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. None Returns: Type Description pd.DataFrame A DataFrame containing the descriptive statistics of the specified columns. View Source def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 :- 1 ] return self . data [ columns ]. describe ( percentiles ) initialize def initialize ( self , period : 'int' = 10 ) Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Parameters: Name Type Description Default period int The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. None View Source def initialize(self, period: int = 10): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self._orig_data data[\"time\"] = data[\"frame\"] data[\"cycle_step\"] = data[\"frame\"] % self.time_config.cycle_frame_num data = DataAnalyzer._calc_centers(data) data = DataAnalyzer._calc_speed(data, period) data = DataAnalyzer._calc_worm_deviation(data) data = DataAnalyzer._calc_errors(data) data = data.round(5) self._orig_data = data self.data = self._orig_data.copy() print_stats def print_stats ( self ) -> 'None' Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. View Source def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: {num_removed} ({round(100 * num_removed / len(self._orig_data.index), 3)}%)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: {no_preds} ({round(100 * no_preds / len(self.data.index), 3)}%)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: {num_cycles}\" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: {round(100 * non_perfect, 3)}%\" ) remove_cycle def remove_cycle ( self , cycles : 'int | list[int]' ) Remove the specified cycles from the data. Parameters: Name Type Description Default cycles int list[int] The cycle(s) to remove from the data. View Source def remove_cycle ( self , cycles : int | list [ int ] ) : \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ) : cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~mask ] reset_changes def reset_changes ( self ) Reset the data to its original state. Note, that this method will not reset the unit of time and distance. View Source def reset_changes(self): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self.data = self._orig_data.copy() self._unit = \"frame\" save def save ( self , path : 'str' ) -> 'None' Save the full analyzed data to a csv file. View Source def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False )","title":"Data Analyzer"},{"location":"reference/wtracker/eval/data_analyzer/#module-wtrackerevaldata_analyzer","text":"View Source from __future__ import annotations import pandas as pd import numpy as np import tqdm.contrib.concurrent as concurrent from wtracker.sim.config import TimingConfig from wtracker.eval.error_calculator import ErrorCalculator from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import adjust_num_workers class DataAnalyzer : \"\"\" A class for analyzing simulation log. Args: time_config (TimingConfig): The timing configuration. log_path (pd.DataFrame): Dataframe containing the simulation log data. \"\"\" def __init__ ( self , time_config : TimingConfig , log_data : pd . DataFrame , ): self . time_config = time_config self . data = log_data . copy () self . _orig_data = log_data self . _unit = \"frame\" @property def unit ( self ) -> str : return self . _unit def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False ) @staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) def initialize ( self , period : int = 10 ): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self . _orig_data data [ \"time\" ] = data [ \"frame\" ] data [ \"cycle_step\" ] = data [ \"frame\" ] % self . time_config . cycle_frame_num data = DataAnalyzer . _calc_centers ( data ) data = DataAnalyzer . _calc_speed ( data , period ) data = DataAnalyzer . _calc_worm_deviation ( data ) data = DataAnalyzer . _calc_errors ( data ) data = data . round ( 5 ) self . _orig_data = data self . data = self . _orig_data . copy () @staticmethod def _calc_centers ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"wrm_center_x\" ] = data [ \"wrm_x\" ] + data [ \"wrm_w\" ] / 2 data [ \"wrm_center_y\" ] = data [ \"wrm_y\" ] + data [ \"wrm_h\" ] / 2 data [ \"mic_center_x\" ] = data [ \"mic_x\" ] + data [ \"mic_w\" ] / 2 data [ \"mic_center_y\" ] = data [ \"mic_y\" ] + data [ \"mic_h\" ] / 2 return data @staticmethod def _calc_speed ( data : pd . DataFrame , n : int ) -> pd . DataFrame : diff = data [ \"time\" ] . diff ( n ) . to_numpy () data [ \"wrm_speed_x\" ] = data [ \"wrm_center_x\" ] . diff ( n ) / diff data [ \"wrm_speed_y\" ] = data [ \"wrm_center_y\" ] . diff ( n ) / diff data [ \"wrm_speed\" ] = np . sqrt ( data [ \"wrm_speed_x\" ] ** 2 + data [ \"wrm_speed_y\" ] ** 2 ) return data @staticmethod def _calc_worm_deviation ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"worm_deviation_x\" ] = data [ \"wrm_center_x\" ] - data [ \"mic_center_x\" ] data [ \"worm_deviation_y\" ] = data [ \"wrm_center_y\" ] - data [ \"mic_center_y\" ] data [ \"worm_deviation\" ] = np . sqrt ( data [ \"worm_deviation_x\" ] ** 2 + data [ \"worm_deviation_y\" ] ** 2 ) return data @staticmethod def _calc_errors ( data : pd . DataFrame ) -> pd . DataFrame : wrm_bboxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () bbox_error = ErrorCalculator . calculate_bbox_error ( wrm_bboxes , mic_bboxes ) data [ \"bbox_error\" ] = bbox_error data [ \"precise_error\" ] = np . nan return data def remove_cycle ( self , cycles : int | list [ int ]): \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ): cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~ mask ] def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float , float , float , float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ]) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ]) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ]) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ]) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ]) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ]) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ]) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ]) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data def reset_changes ( self ): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self . data = self . _orig_data . copy () self . _unit = \"frame\" def column_names ( self ) -> list [ str ]: \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () def change_unit ( self , unit : str ): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [ \"frame\" , \"sec\" ] if self . _unit == unit : return data = self . data if unit == \"sec\" : # frame -> sec dist_factor = self . time_config . mm_per_px * 1000 time_factor = self . time_config . ms_per_frame / 1000 if unit == \"frame\" : # sec -> frame dist_factor = self . time_config . px_per_mm / 1000 time_factor = self . time_config . frames_per_sec data [ \"time\" ] *= time_factor data [[ \"plt_x\" , \"plt_y\" ]] *= dist_factor data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] *= dist_factor data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] *= dist_factor data [[ \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" ]] *= dist_factor data [[ \"wrm_center_x\" , \"wrm_center_y\" ]] *= dist_factor data [[ \"mic_center_x\" , \"mic_center_y\" ]] *= dist_factor data [[ \"worm_deviation_x\" , \"worm_deviation_y\" , \"worm_deviation\" ]] *= dist_factor data [[ \"wrm_speed_x\" , \"wrm_speed_y\" , \"wrm_speed\" ]] *= dist_factor / time_factor self . _unit = unit self . data = data # TODO: TEST # TODO: MAYBE REMOVE, THE non-multithreaded version works very fast for me for some reason # perhaps SSD is required for fast analysis. def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( int , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ) . all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO: add non-multithreaded case whenever num_workers=0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ], mic_bboxes = mic_bboxes_list [ idx ], frame_nums = frames_list [ idx ], worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~ mask ] return anomalies def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 : - 1 ] return self . data [ columns ] . describe ( percentiles ) def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: { num_removed } ( { round ( 100 * num_removed / len ( self . _orig_data . index ), 3 ) } %)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: { no_preds } ( { round ( 100 * no_preds / len ( self . data . index ), 3 ) } %)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: { num_cycles } \" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: { round ( 100 * non_perfect , 3 ) } %\" )","title":"Module wtracker.eval.data_analyzer"},{"location":"reference/wtracker/eval/data_analyzer/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/eval/data_analyzer/#dataanalyzer","text":"class DataAnalyzer ( time_config : 'TimingConfig' , log_data : 'pd.DataFrame' ) A class for analyzing simulation log.","title":"DataAnalyzer"},{"location":"reference/wtracker/eval/data_analyzer/#attributes","text":"Name Type Description Default time_config TimingConfig The timing configuration. None log_path pd.DataFrame Dataframe containing the simulation log data. None View Source class DataAnalyzer : \"\"\" A class for analyzing simulation log. Args: time_config (TimingConfig): The timing configuration. log_path (pd.DataFrame): Dataframe containing the simulation log data. \"\"\" def __init__ ( self , time_config : TimingConfig , log_data : pd . DataFrame , ): self . time_config = time_config self . data = log_data . copy () self . _orig_data = log_data self . _unit = \"frame\" @ property def unit ( self ) -> str : return self . _unit def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False ) @ staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data ) def initialize ( self , period : int = 10 ): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self . _orig_data data [ \"time\" ] = data [ \"frame\" ] data [ \"cycle_step\" ] = data [ \"frame\" ] % self . time_config . cycle_frame_num data = DataAnalyzer . _calc_centers ( data ) data = DataAnalyzer . _calc_speed ( data , period ) data = DataAnalyzer . _calc_worm_deviation ( data ) data = DataAnalyzer . _calc_errors ( data ) data = data . round ( 5 ) self . _orig_data = data self . data = self . _orig_data . copy () @ staticmethod def _calc_centers ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"wrm_center_x\" ] = data [ \"wrm_x\" ] + data [ \"wrm_w\" ] / 2 data [ \"wrm_center_y\" ] = data [ \"wrm_y\" ] + data [ \"wrm_h\" ] / 2 data [ \"mic_center_x\" ] = data [ \"mic_x\" ] + data [ \"mic_w\" ] / 2 data [ \"mic_center_y\" ] = data [ \"mic_y\" ] + data [ \"mic_h\" ] / 2 return data @ staticmethod def _calc_speed ( data : pd . DataFrame , n : int ) -> pd . DataFrame : diff = data [ \"time\" ] . diff ( n ) . to_numpy () data [ \"wrm_speed_x\" ] = data [ \"wrm_center_x\" ] . diff ( n ) / diff data [ \"wrm_speed_y\" ] = data [ \"wrm_center_y\" ] . diff ( n ) / diff data [ \"wrm_speed\" ] = np . sqrt ( data [ \"wrm_speed_x\" ] ** 2 + data [ \"wrm_speed_y\" ] ** 2 ) return data @ staticmethod def _calc_worm_deviation ( data : pd . DataFrame ) -> pd . DataFrame : data [ \"worm_deviation_x\" ] = data [ \"wrm_center_x\" ] - data [ \"mic_center_x\" ] data [ \"worm_deviation_y\" ] = data [ \"wrm_center_y\" ] - data [ \"mic_center_y\" ] data [ \"worm_deviation\" ] = np . sqrt ( data [ \"worm_deviation_x\" ] ** 2 + data [ \"worm_deviation_y\" ] ** 2 ) return data @ staticmethod def _calc_errors ( data : pd . DataFrame ) -> pd . DataFrame : wrm_bboxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () bbox_error = ErrorCalculator . calculate_bbox_error ( wrm_bboxes , mic_bboxes ) data [ \"bbox_error\" ] = bbox_error data [ \"precise_error\" ] = np . nan return data def remove_cycle ( self , cycles : int | list [ int ]): \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ): cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~ mask ] def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float , float , float , float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ]) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ]) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ]) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ]) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ]) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ]) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ]) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ]) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data def reset_changes ( self ): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self . data = self . _orig_data . copy () self . _unit = \"frame\" def column_names ( self ) -> list [ str ]: \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list () def change_unit ( self , unit : str ): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [ \"frame\" , \"sec\" ] if self . _unit == unit : return data = self . data if unit == \"sec\" : # frame -> sec dist_factor = self . time_config . mm_per_px * 1000 time_factor = self . time_config . ms_per_frame / 1000 if unit == \"frame\" : # sec -> frame dist_factor = self . time_config . px_per_mm / 1000 time_factor = self . time_config . frames_per_sec data [ \"time\" ] *= time_factor data [[ \"plt_x\" , \"plt_y\" ]] *= dist_factor data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] *= dist_factor data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] *= dist_factor data [[ \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" ]] *= dist_factor data [[ \"wrm_center_x\" , \"wrm_center_y\" ]] *= dist_factor data [[ \"mic_center_x\" , \"mic_center_y\" ]] *= dist_factor data [[ \"worm_deviation_x\" , \"worm_deviation_y\" , \"worm_deviation\" ]] *= dist_factor data [[ \"wrm_speed_x\" , \"wrm_speed_y\" , \"wrm_speed\" ]] *= dist_factor / time_factor self . _unit = unit self . data = data # TODO: TEST # TODO: MAYBE REMOVE, THE non-multithreaded version works very fast for me for some reason # perhaps SSD is required for fast analysis. def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( int , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ) . all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO: add non-multithreaded case whenever num_workers=0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ], mic_bboxes = mic_bboxes_list [ idx ], frame_nums = frames_list [ idx ], worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy () . astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy () mic_bboxes = self . _orig_data [[ \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" ]] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ] def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ()) . all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~ mask ] return anomalies def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 : - 1 ] return self . data [ columns ] . describe ( percentiles ) def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: {num_removed} ({round(100 * num_removed / len(self._orig_data.index), 3)}%)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: {no_preds} ({round(100 * no_preds / len(self.data.index), 3)}%)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: {num_cycles}\" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: {round(100 * non_perfect, 3)}%\" )","title":"Attributes"},{"location":"reference/wtracker/eval/data_analyzer/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/eval/data_analyzer/#load","text":"def load ( time_config : 'TimingConfig' , csv_path : 'str' ) -> 'DataAnalyzer' Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Parameters: Name Type Description Default time_config TimingConfig The timing configuration. None csv_path str Path to the csv file containing the experiment data. None View Source @ staticmethod def load ( time_config : TimingConfig , csv_path : str ) -> DataAnalyzer : \"\"\" Create a DataAnalyzer object from a csv file containing experiment data, regardless whether if it's analyzed or not. Args: time_config (TimingConfig): The timing configuration. csv_path (str): Path to the csv file containing the experiment data. \"\"\" data = pd . read_csv ( csv_path ) return DataAnalyzer ( time_config , data )","title":"load"},{"location":"reference/wtracker/eval/data_analyzer/#instance-variables","text":"unit","title":"Instance variables"},{"location":"reference/wtracker/eval/data_analyzer/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/eval/data_analyzer/#calc_anomalies","text":"def calc_anomalies ( self , no_preds : 'bool' = True , min_bbox_error : 'float' = inf , min_dist_error : 'float' = inf , min_speed : 'float' = inf , min_size : 'float' = inf , remove_anomalies : 'bool' = False ) -> 'pd.DataFrame' Calculate anomalies in the data based on specified criteria. Parameters: Name Type Description Default no_preds bool Flag indicating whether to consider instances with missing predictions. None min_bbox_error float Minimum bounding box error threshold to consider as anomaly. None min_dist_error float Minimum distance error threshold to consider as anomaly. None min_speed float Minimum speed threshold to consider as anomaly. None min_size float Minimum size threshold to consider as anomaly. None remove_anomalies bool Flag indicating whether to remove the anomalies from the data. None Returns: Type Description pd.DataFrame DataFrame containing the anomalies found in the data. View Source def calc_anomalies ( self , no_preds : bool = True , min_bbox_error : float = np . inf , min_dist_error : float = np . inf , min_speed : float = np . inf , min_size : float = np . inf , remove_anomalies : bool = False , ) -> pd . DataFrame : \"\"\" Calculate anomalies in the data based on specified criteria. Args: no_preds (bool, optional): Flag indicating whether to consider instances with missing predictions. min_bbox_error (float, optional): Minimum bounding box error threshold to consider as anomaly. min_dist_error (float, optional): Minimum distance error threshold to consider as anomaly. min_speed (float, optional): Minimum speed threshold to consider as anomaly. min_size (float, optional): Minimum size threshold to consider as anomaly. remove_anomalies (bool, optional): Flag indicating whether to remove the anomalies from the data. Returns: pd.DataFrame: DataFrame containing the anomalies found in the data. \"\"\" data = self . data mask_speed = data [ \"wrm_speed\" ] >= min_speed mask_bbox_error = data [ \"bbox_error\" ] >= min_bbox_error mask_dist_error = data [ \"worm_deviation\" ] >= min_dist_error mask_worm_width = data [ \"wrm_w\" ] >= min_size mask_worm_height = data [ \"wrm_h\" ] >= min_size mask_no_preds = np . isfinite ( data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ()). all ( axis = 1 ) == False mask_no_preds = no_preds & mask_no_preds mask = mask_speed | mask_bbox_error | mask_dist_error | mask_worm_width | mask_worm_height | mask_no_preds anomalies = data [ mask ] . copy () anomalies [ \"speed_anomaly\" ] = mask_speed [ mask ] anomalies [ \"bbox_error_anomaly\" ] = mask_bbox_error [ mask ] anomalies [ \"dist_error_anomaly\" ] = mask_dist_error [ mask ] anomalies [ \"width_anomaly\" ] = mask_worm_width [ mask ] anomalies [ \"height_anomaly\" ] = mask_worm_height [ mask ] anomalies [ \"no_pred_anomaly\" ] = mask_no_preds [ mask ] if remove_anomalies : self . data = self . data [ ~mask ] return anomalies","title":"calc_anomalies"},{"location":"reference/wtracker/eval/data_analyzer/#calc_precise_error","text":"def calc_precise_error ( self , worm_reader : 'FrameReader' , background : 'np.ndarray' , diff_thresh = 20 ) -> 'None' Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Parameters: Name Type Description Default worm_reader FrameReader Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. None background np.ndarray The background image of the entire experiment. None diff_thresh int Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None View Source def calc_precise_error ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy (). astype ( np . int32 , copy = False ) wrm_bboxes = self . _orig_data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy () mic_bboxes = self . _orig_data [ [\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\" ] ] . to_numpy () errors = ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes , mic_bboxes = mic_bboxes , frame_nums = frames , worm_reader = worm_reader , diff_thresh = diff_thresh , ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ]","title":"calc_precise_error"},{"location":"reference/wtracker/eval/data_analyzer/#calc_precise_error_experimental","text":"def calc_precise_error_experimental ( self , worm_reader : 'FrameReader' , background : 'np.ndarray' , diff_thresh = 20 , num_workers : 'int' = None , chunk_size : 'int' = 2000 ) -> 'None' Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Parameters: Name Type Description Default worm_reader FrameReader Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. None background np.ndarray The background image of the entire experiment. None diff_thresh int Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None chunk_size int The size of each processing chunk. None View Source def calc_precise_error_experimental ( self , worm_reader : FrameReader , background : np . ndarray , diff_thresh = 20 , num_workers : int = None , chunk_size : int = 2000 , ) -> None : \"\"\" Calculate the precise error between the worm and the microscope view. This error is segmentation based, and measures the proportion of worm's head that is outside of the view of the microscope. Note that this calculation might take a while. Args: worm_reader (FrameReader): Images of the worm at each frame, cropped to the size of the bounding box which was detected around the worm. background (np.ndarray): The background image of the entire experiment. diff_thresh (int): Difference threshold to differentiate between the background and foreground. A foreground object is detected if the pixel value difference with the background is greater than this threshold. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. chunk_size (int, optional): The size of each processing chunk. \"\"\" frames = self . _orig_data [ \"frame\" ] . to_numpy (). astype ( int , copy = False ) wrm_bboxes = self . _orig_data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy () mic_bboxes = self . _orig_data [ [\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\" ] ] . to_numpy () errors = np . ones_like ( frames , dtype = float ) mask = np . isfinite ( wrm_bboxes ). all ( axis = 1 ) wrm_bboxes = wrm_bboxes [ mask ] mic_bboxes = mic_bboxes [ mask ] frames = frames [ mask ] num_sections = len ( frames ) // chunk_size wrm_bboxes_list = np . array_split ( wrm_bboxes , num_sections , axis = 0 ) mic_bboxes_list = np . array_split ( mic_bboxes , num_sections , axis = 0 ) frames_list = np . array_split ( frames , num_sections ) # TODO : add non - multithreaded case whenever num_workers = 0 num_workers = adjust_num_workers ( len ( frames ), chunk_size , num_workers ) def calc_error ( idx : int ) -> np . ndarray : return ErrorCalculator . calculate_precise ( background = background , worm_bboxes = wrm_bboxes_list [ idx ] , mic_bboxes = mic_bboxes_list [ idx ] , frame_nums = frames_list [ idx ] , worm_reader = worm_reader , diff_thresh = diff_thresh , ) results = concurrent . thread_map ( calc_error , list ( range ( len ( wrm_bboxes_list ))), max_workers = num_workers , chunksize = 1 , desc = \"Extracting bboxes\" , unit = \"fr\" , leave = False , ) # set the error in the original data errors [ mask ] = np . concatenate ( results ) self . _orig_data [ \"precise_error\" ] = errors # copy relevant error entries into the work data idx = self . data [ \"frame\" ] . to_numpy ( dtype = int , copy = False ) self . data [ \"precise_error\" ] = errors [ idx ]","title":"calc_precise_error_experimental"},{"location":"reference/wtracker/eval/data_analyzer/#change_unit","text":"def change_unit ( self , unit : 'str' ) Changes the unit of time and distance in the data. Parameters: Name Type Description Default unit str The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. None View Source def change_unit(self, unit: str): \"\"\" Changes the unit of time and distance in the data. Args: unit (str, optional): The new unit of time to convert into. Can be \"frame\" or \"sec\". If \"sec\" is chosen, the time will be converted to seconds, and the distance metric is micrometer. If \"frame\" is chosen, the time will be in frames, and the distance metric is pixels. \"\"\" assert unit in [\"frame\", \"sec\"] if self._unit == unit: return data = self.data if unit == \"sec\": # frame -> sec dist_factor = self.time_config.mm_per_px * 1000 time_factor = self.time_config.ms_per_frame / 1000 if unit == \"frame\": # sec -> frame dist_factor = self.time_config.px_per_mm / 1000 time_factor = self.time_config.frames_per_sec data[\"time\"] * = time_factor data[[\"plt_x\", \"plt_y\"]] *= dist_factor data[[\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\"]] * = dist_factor data[[\"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\"]] *= dist_factor data[[\"cam_x\", \"cam_y\", \"cam_w\", \"cam_h\"]] * = dist_factor data[[\"wrm_center_x\", \"wrm_center_y\"]] *= dist_factor data[[\"mic_center_x\", \"mic_center_y\"]] * = dist_factor data[[\"worm_deviation_x\", \"worm_deviation_y\", \"worm_deviation\"]] *= dist_factor data[[\"wrm_speed_x\", \"wrm_speed_y\", \"wrm_speed\"]] * = dist_factor / time_factor self._unit = unit self.data = data","title":"change_unit"},{"location":"reference/wtracker/eval/data_analyzer/#clean","text":"def clean ( self , trim_cycles : 'bool' = False , imaging_only : 'bool' = False , bounds : 'tuple[float, float, float, float]' = None ) -> 'None' Clean the data by the provided parameters. Parameters: Name Type Description Default trim_cycles bool whether to remove the first and the last cycles from the data. None imaging_only bool Flag indicating whether to include only imaging phases in the analysis. None legal_bounds tuple[float, float, float, float] The legal bounds for worm movement. None View Source def clean ( self , trim_cycles : bool = False , imaging_only : bool = False , bounds : tuple [ float, float, float, float ] = None , ) -> None : \"\"\" Clean the data by the provided parameters. Args: trim_cycles (bool): whether to remove the first and the last cycles from the data. imaging_only (bool): Flag indicating whether to include only imaging phases in the analysis. legal_bounds (tuple[float, float, float, float]): The legal bounds for worm movement. \"\"\" data = self . data if imaging_only : mask = data [ \"phase\" ] == \"imaging\" data = data [ mask ] if bounds is not None : has_pred = np . isfinite ( data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ()). all ( axis = 1 ) mask_wrm = has_pred # if there is a prediction for a frame then look at worm bbox mask_wrm &= ( data [ \"wrm_x\" ] >= bounds [ 0 ] ) & ( data [ \"wrm_x\" ] + data [ \"wrm_w\" ] <= bounds [ 2 ] ) mask_wrm &= ( data [ \"wrm_y\" ] >= bounds [ 1 ] ) & ( data [ \"wrm_y\" ] + data [ \"wrm_h\" ] <= bounds [ 3 ] ) mask_mic = ~ has_pred # if there is no prediction for a frame then look at micro bbox mask_mic &= ( data [ \"mic_x\" ] >= bounds [ 0 ] ) & ( data [ \"mic_x\" ] + data [ \"mic_w\" ] <= bounds [ 2 ] ) mask_mic &= ( data [ \"mic_y\" ] >= bounds [ 1 ] ) & ( data [ \"mic_y\" ] + data [ \"mic_h\" ] <= bounds [ 3 ] ) data = data [ mask_wrm | mask_mic ] if trim_cycles : mask = data [ \"cycle\" ] != 0 mask &= data [ \"cycle\" ] != data [ \"cycle\" ] . max () data = data [ mask ] self . data = data","title":"clean"},{"location":"reference/wtracker/eval/data_analyzer/#column_names","text":"def column_names ( self ) -> 'list[str]' Returns a list of all column names in the analyzed data. Returns: Type Description list[str] A list of column names. View Source def column_names ( self ) -> list [ str ] : \"\"\" Returns a list of all column names in the analyzed data. Returns: list[str]: A list of column names. \"\"\" return self . data . columns . to_list ()","title":"column_names"},{"location":"reference/wtracker/eval/data_analyzer/#describe","text":"def describe ( self , columns : 'list[str]' = None , num : 'int' = 3 , percentiles : 'list[float]' = None ) -> 'pd.DataFrame' Generate descriptive statistics of the specified columns in the table containing the data. Parameters: Name Type Description Default columns list[str] List of column names to include in the analysis. If None, all columns will be included. None num int Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. None percentiles list[float] List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. None Returns: Type Description pd.DataFrame A DataFrame containing the descriptive statistics of the specified columns. View Source def describe ( self , columns : list [ str ] = None , num : int = 3 , percentiles : list [ float ] = None ) -> pd . DataFrame : \"\"\" Generate descriptive statistics of the specified columns in the table containing the data. Args: columns (list[str], optional): List of column names to include in the analysis. If None, all columns will be included. num (int, optional): Number of evenly spaced percentiles to include in the analysis. If percentiles is not None, this parameter is ignored. percentiles (list[float], optional): List of specific percentiles to include in the analysis. If None, evenly spaced percentiles will be used. Returns: pd.DataFrame: A DataFrame containing the descriptive statistics of the specified columns. \"\"\" if columns is None : columns = self . column_names () if percentiles is None : percentiles = np . linspace ( start = 0 , stop = 1.0 , num = num + 2 )[ 1 :- 1 ] return self . data [ columns ]. describe ( percentiles )","title":"describe"},{"location":"reference/wtracker/eval/data_analyzer/#initialize","text":"def initialize ( self , period : 'int' = 10 ) Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Parameters: Name Type Description Default period int The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. None View Source def initialize(self, period: int = 10): \"\"\" Initializes the data analyzer. It's essential to call this function if the class was created from a non-analyzed log data. Args: period (int): The period for calculating speed in frames. The speed is calculated by measuring the distance between current frame and period frames before. \"\"\" data = self._orig_data data[\"time\"] = data[\"frame\"] data[\"cycle_step\"] = data[\"frame\"] % self.time_config.cycle_frame_num data = DataAnalyzer._calc_centers(data) data = DataAnalyzer._calc_speed(data, period) data = DataAnalyzer._calc_worm_deviation(data) data = DataAnalyzer._calc_errors(data) data = data.round(5) self._orig_data = data self.data = self._orig_data.copy()","title":"initialize"},{"location":"reference/wtracker/eval/data_analyzer/#print_stats","text":"def print_stats ( self ) -> 'None' Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. View Source def print_stats ( self ) -> None : \"\"\" Prints various statistics related to the data. This method calculates and prints the following statistics: - Count of Removed Frames: The number of frames that were removed from the original data. - Total Count of No Pred Frames: The number of frames where the predictions are missing. - Total Num of Cycles: The number of unique cycles in the data. - Non Perfect Predictions: The percentage of predictions that are not perfect. \"\"\" num_removed = len ( self . _orig_data . index ) - len ( self . data . index ) print ( f \"Count of Removed Frames: {num_removed} ({round(100 * num_removed / len(self._orig_data.index), 3)}%)\" ) no_preds = self . data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . isna () . any ( axis = 1 ) . sum () print ( f \"Count of No-Pred Frames: {no_preds} ({round(100 * no_preds / len(self.data.index), 3)}%)\" ) num_cycles = self . data [ \"cycle\" ] . nunique () print ( f \"Total Num of Cycles: {num_cycles}\" ) non_perfect = ( self . data [ \"bbox_error\" ] > 1e-7 ) . sum () / len ( self . data . index ) print ( f \"Non Perfect Predictions: {round(100 * non_perfect, 3)}%\" )","title":"print_stats"},{"location":"reference/wtracker/eval/data_analyzer/#remove_cycle","text":"def remove_cycle ( self , cycles : 'int | list[int]' ) Remove the specified cycles from the data. Parameters: Name Type Description Default cycles int list[int] The cycle(s) to remove from the data. View Source def remove_cycle ( self , cycles : int | list [ int ] ) : \"\"\" Remove the specified cycles from the data. Args: cycles (int | list[int]): The cycle(s) to remove from the data. \"\"\" if isinstance ( cycles , int ) : cycles = [ cycles ] mask = self . data [ \"cycle\" ] . isin ( cycles ) self . data = self . data [ ~mask ]","title":"remove_cycle"},{"location":"reference/wtracker/eval/data_analyzer/#reset_changes","text":"def reset_changes ( self ) Reset the data to its original state. Note, that this method will not reset the unit of time and distance. View Source def reset_changes(self): \"\"\" Reset the data to its original state. Note, that this method will not reset the unit of time and distance. \"\"\" self.data = self._orig_data.copy() self._unit = \"frame\"","title":"reset_changes"},{"location":"reference/wtracker/eval/data_analyzer/#save","text":"def save ( self , path : 'str' ) -> 'None' Save the full analyzed data to a csv file. View Source def save ( self , path : str ) -> None : \"\"\" Save the full analyzed data to a csv file. \"\"\" self . _orig_data . to_csv ( path , index = False )","title":"save"},{"location":"reference/wtracker/eval/error_calculator/","text":"Module wtracker.eval.error_calculator View Source from typing import Collection import numpy as np import cv2 as cv from tqdm.auto import tqdm from typing import Callable from wtracker.utils.frame_reader import FrameReader from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat class ErrorCalculator : \"\"\" The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view. \"\"\" # TODO: Kinda a weird solution, but it works for now. Maybe find a better way to do this. probe_hook : Callable [[ np . ndarray , np . ndarray ], None ] = None # takes mask and view for testing @staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image. Args: bbox (np.ndarray): The bounding box of the image, in the format (x, y, w, h). image (np.ndarray): The image to calculate segmentation from. background (np.ndarray): The background image. diff_thresh (float): The difference threshold to distinguish foreground and background objects from. Returns: np.ndarray: The segmentation mask. Raises: ValueError: If the image is not grayscale or color. \"\"\" x , y , w , h = bbox assert image . shape [: 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )) . astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm # TODO: VERY FAST FOR ME, INVESTIGATE WHY IT'S SLOW IN THE LAB # TODO: swap the FrameReader to another type. The only requirement is that accessing frame index returns the correct frame. # we should probably use something like ImageLoader, which is implemented in the analysis_experimental. @staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Args: background (np.ndarray): The background image. worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). frame_nums (np.ndarray): An array of frame numbers to calculate the error for. worm_reader (FrameReader): A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. diff_thresh (float, optional): The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. Returns: np.ndarray: Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: AssertionError: If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [: 2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes, indicting no prediction or bad prediction. errors [ ~ is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ): wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox [ 1 ] : int_bbox [ 1 ] + int_bbox [ 3 ], int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ]] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ) . sum () error = 1.0 - intersection / total errors [ i ] = error return errors @staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors @staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors Classes ErrorCalculator class ErrorCalculator ( / , * args , ** kwargs ) The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view. View Source class ErrorCalculator : \"\"\" The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view . \"\"\" # TODO: Kinda a weird solution, but it works for now. Maybe find a better way to do this. probe_hook : Callable [[ np . ndarray , np . ndarray ], None ] = None # takes mask and view for testing @ staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image . Args : bbox ( np . ndarray ) : The bounding box of the image , in the format ( x , y , w , h ). image ( np . ndarray ) : The image to calculate segmentation from . background ( np . ndarray ) : The background image . diff_thresh ( float ) : The difference threshold to distinguish foreground and background objects from . Returns : np . ndarray : The segmentation mask . Raises : ValueError : If the image is not grayscale or color . \"\"\" x , y , w , h = bbox assert image . shape [ : 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )). astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm # TODO: VERY FAST FOR ME, INVESTIGATE WHY IT'S SLOW IN THE LAB # TODO: swap the FrameReader to another type. The only requirement is that accessing frame index returns the correct frame. # we should probably use something like ImageLoader, which is implemented in the analysis_experimental. @ staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence . This error is based on precise segmentation of the worm object from the frame , and determining the exact proportion of worm ' s body outside the microscope view . Args : background ( np . ndarray ) : The background image . worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). frame_nums ( np . ndarray ) : An array of frame numbers to calculate the error for . worm_reader ( FrameReader ) : A frame reader containing segmented worm images for each frame . These worm images should match the shape of the worm bounding boxes . Frames passed in frame_nums are read from this reader by index . diff_thresh ( float , optional ) : The difference threshold to distinguish foreground and background objects from . A foreground object is detected if the pixel value difference with the background is greater than this threshold . Returns : np . ndarray : Array of errors of shape ( N ,) representing the precise segmentation error for each frame . Raises : AssertionError : If the length of frame_nums , worm_bboxes , and mic_bboxes do not match . \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [ : 2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes, indicting no prediction or bad prediction. errors [ ~ is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ) : wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox [ 1 ] : int_bbox [ 1 ] + int_bbox [ 3 ], int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ]] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ). sum () error = 1.0 - intersection / total errors [ i ] = error return errors @ staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes . This error calculates the proportion of the worm bounding box that is outside the microscope bounding box . Args : worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). Returns : np . ndarray : Array of errors of shape ( N ,) representing the bounding box error for each pair of worm and microscope bounding boxes . \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors @ staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error ( MSE ) error between the centers of worm bounding boxes and microscope bounding boxes . Args : worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). Returns : np . ndarray : Array of errors of shape ( N ,) representing the MSE error for each pair of worm and microscope bounding boxes . \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors Class variables probe_hook Static methods calculate_bbox_error def calculate_bbox_error ( worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray ) -> numpy . ndarray Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Parameters: Name Type Description Default worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. View Source @staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors calculate_mse_error def calculate_mse_error ( worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray ) -> numpy . ndarray Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Parameters: Name Type Description Default worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. View Source @staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors calculate_precise def calculate_precise ( background : numpy . ndarray , worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray , frame_nums : numpy . ndarray , worm_reader : wtracker . utils . frame_reader . FrameReader , diff_thresh : float = 10 ) -> numpy . ndarray Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Parameters: Name Type Description Default background np.ndarray The background image. None worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None frame_nums np.ndarray An array of frame numbers to calculate the error for. None worm_reader FrameReader A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. None diff_thresh float The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: Type Description AssertionError If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. View Source @staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Args: background (np.ndarray): The background image. worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). frame_nums (np.ndarray): An array of frame numbers to calculate the error for. worm_reader (FrameReader): A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. diff_thresh (float, optional): The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. Returns: np.ndarray: Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: AssertionError: If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [ :2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes , indicting no prediction or bad prediction . errors [ ~is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ) : wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox[1 ] : int_bbox [ 1 ] + int_bbox [ 3 ] , int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ] ] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ). sum () error = 1.0 - intersection / total errors [ i ] = error return errors calculate_segmentation def calculate_segmentation ( bbox : numpy . ndarray , image : numpy . ndarray , background : numpy . ndarray , diff_thresh : float ) -> numpy . ndarray Calculates the segmentation error between a view and background image. Parameters: Name Type Description Default bbox np.ndarray The bounding box of the image, in the format (x, y, w, h). None image np.ndarray The image to calculate segmentation from. None background np.ndarray The background image. None diff_thresh float The difference threshold to distinguish foreground and background objects from. None Returns: Type Description np.ndarray The segmentation mask. Raises: Type Description ValueError If the image is not grayscale or color. View Source @ staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image . Args : bbox ( np . ndarray ) : The bounding box of the image , in the format ( x , y , w , h ). image ( np . ndarray ) : The image to calculate segmentation from . background ( np . ndarray ) : The background image . diff_thresh ( float ) : The difference threshold to distinguish foreground and background objects from . Returns : np . ndarray : The segmentation mask . Raises : ValueError : If the image is not grayscale or color . \"\"\" x , y , w , h = bbox assert image . shape [ : 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )). astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm","title":"Error Calculator"},{"location":"reference/wtracker/eval/error_calculator/#module-wtrackerevalerror_calculator","text":"View Source from typing import Collection import numpy as np import cv2 as cv from tqdm.auto import tqdm from typing import Callable from wtracker.utils.frame_reader import FrameReader from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat class ErrorCalculator : \"\"\" The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view. \"\"\" # TODO: Kinda a weird solution, but it works for now. Maybe find a better way to do this. probe_hook : Callable [[ np . ndarray , np . ndarray ], None ] = None # takes mask and view for testing @staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image. Args: bbox (np.ndarray): The bounding box of the image, in the format (x, y, w, h). image (np.ndarray): The image to calculate segmentation from. background (np.ndarray): The background image. diff_thresh (float): The difference threshold to distinguish foreground and background objects from. Returns: np.ndarray: The segmentation mask. Raises: ValueError: If the image is not grayscale or color. \"\"\" x , y , w , h = bbox assert image . shape [: 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )) . astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm # TODO: VERY FAST FOR ME, INVESTIGATE WHY IT'S SLOW IN THE LAB # TODO: swap the FrameReader to another type. The only requirement is that accessing frame index returns the correct frame. # we should probably use something like ImageLoader, which is implemented in the analysis_experimental. @staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Args: background (np.ndarray): The background image. worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). frame_nums (np.ndarray): An array of frame numbers to calculate the error for. worm_reader (FrameReader): A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. diff_thresh (float, optional): The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. Returns: np.ndarray: Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: AssertionError: If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [: 2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes, indicting no prediction or bad prediction. errors [ ~ is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ): wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox [ 1 ] : int_bbox [ 1 ] + int_bbox [ 3 ], int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ]] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ) . sum () error = 1.0 - intersection / total errors [ i ] = error return errors @staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors @staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors","title":"Module wtracker.eval.error_calculator"},{"location":"reference/wtracker/eval/error_calculator/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/eval/error_calculator/#errorcalculator","text":"class ErrorCalculator ( / , * args , ** kwargs ) The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view. View Source class ErrorCalculator : \"\"\" The ErrorCalculator class provides methods to calculate different types of errors based on worm position and the microscope view . \"\"\" # TODO: Kinda a weird solution, but it works for now. Maybe find a better way to do this. probe_hook : Callable [[ np . ndarray , np . ndarray ], None ] = None # takes mask and view for testing @ staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image . Args : bbox ( np . ndarray ) : The bounding box of the image , in the format ( x , y , w , h ). image ( np . ndarray ) : The image to calculate segmentation from . background ( np . ndarray ) : The background image . diff_thresh ( float ) : The difference threshold to distinguish foreground and background objects from . Returns : np . ndarray : The segmentation mask . Raises : ValueError : If the image is not grayscale or color . \"\"\" x , y , w , h = bbox assert image . shape [ : 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )). astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm # TODO: VERY FAST FOR ME, INVESTIGATE WHY IT'S SLOW IN THE LAB # TODO: swap the FrameReader to another type. The only requirement is that accessing frame index returns the correct frame. # we should probably use something like ImageLoader, which is implemented in the analysis_experimental. @ staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence . This error is based on precise segmentation of the worm object from the frame , and determining the exact proportion of worm ' s body outside the microscope view . Args : background ( np . ndarray ) : The background image . worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). frame_nums ( np . ndarray ) : An array of frame numbers to calculate the error for . worm_reader ( FrameReader ) : A frame reader containing segmented worm images for each frame . These worm images should match the shape of the worm bounding boxes . Frames passed in frame_nums are read from this reader by index . diff_thresh ( float , optional ) : The difference threshold to distinguish foreground and background objects from . A foreground object is detected if the pixel value difference with the background is greater than this threshold . Returns : np . ndarray : Array of errors of shape ( N ,) representing the precise segmentation error for each frame . Raises : AssertionError : If the length of frame_nums , worm_bboxes , and mic_bboxes do not match . \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [ : 2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes, indicting no prediction or bad prediction. errors [ ~ is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ) : wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox [ 1 ] : int_bbox [ 1 ] + int_bbox [ 3 ], int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ]] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ). sum () error = 1.0 - intersection / total errors [ i ] = error return errors @ staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes . This error calculates the proportion of the worm bounding box that is outside the microscope bounding box . Args : worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). Returns : np . ndarray : Array of errors of shape ( N ,) representing the bounding box error for each pair of worm and microscope bounding boxes . \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors @ staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error ( MSE ) error between the centers of worm bounding boxes and microscope bounding boxes . Args : worm_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of worms . The bounding boxes should be in the format ( x , y , w , h ). mic_bboxes : A numpy array of shape ( N , 4 ) representing the bounding boxes of the microscope . The bounding boxes should be in the format ( x , y , w , h ). Returns : np . ndarray : Array of errors of shape ( N ,) representing the MSE error for each pair of worm and microscope bounding boxes . \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors","title":"ErrorCalculator"},{"location":"reference/wtracker/eval/error_calculator/#class-variables","text":"probe_hook","title":"Class variables"},{"location":"reference/wtracker/eval/error_calculator/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/eval/error_calculator/#calculate_bbox_error","text":"def calculate_bbox_error ( worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray ) -> numpy . ndarray Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Parameters: Name Type Description Default worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. View Source @staticmethod def calculate_bbox_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculate the bounding box error between worm bounding boxes and microscope bounding boxes. This error calculates the proportion of the worm bounding box that is outside the microscope bounding box. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the bounding box error for each pair of worm and microscope bounding boxes. \"\"\" wrm_left , wrm_top , wrm_width , wrm_height = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_width , mic_height = BoxUtils . unpack ( mic_bboxes ) wrm_right , wrm_bottom = wrm_left + wrm_width , wrm_top + wrm_height mic_right , mic_bottom = mic_left + mic_width , mic_top + mic_height int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) intersection = int_width * int_height total = wrm_width * wrm_height errors = 1.0 - intersection / total errors [ total == 0 ] = 0.0 return errors","title":"calculate_bbox_error"},{"location":"reference/wtracker/eval/error_calculator/#calculate_mse_error","text":"def calculate_mse_error ( worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray ) -> numpy . ndarray Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Parameters: Name Type Description Default worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. View Source @staticmethod def calculate_mse_error ( worm_bboxes : np . ndarray , mic_bboxes : np . ndarray ) -> np . ndarray : \"\"\" Calculates the Mean Squared Error (MSE) error between the centers of worm bounding boxes and microscope bounding boxes. Args: worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). Returns: np.ndarray: Array of errors of shape (N,) representing the MSE error for each pair of worm and microscope bounding boxes. \"\"\" worm_centers = BoxUtils . center ( worm_bboxes ) mic_centers = BoxUtils . center ( mic_bboxes ) errors = np . mean (( worm_centers - mic_centers ) ** 2 , axis = 1 ) return errors","title":"calculate_mse_error"},{"location":"reference/wtracker/eval/error_calculator/#calculate_precise","text":"def calculate_precise ( background : numpy . ndarray , worm_bboxes : numpy . ndarray , mic_bboxes : numpy . ndarray , frame_nums : numpy . ndarray , worm_reader : wtracker . utils . frame_reader . FrameReader , diff_thresh : float = 10 ) -> numpy . ndarray Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Parameters: Name Type Description Default background np.ndarray The background image. None worm_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). None mic_bboxes None A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). None frame_nums np.ndarray An array of frame numbers to calculate the error for. None worm_reader FrameReader A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. None diff_thresh float The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. None Returns: Type Description np.ndarray Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: Type Description AssertionError If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. View Source @staticmethod def calculate_precise ( background : np . ndarray , worm_bboxes : np . ndarray , mic_bboxes : np . ndarray , frame_nums : np . ndarray , worm_reader : FrameReader , diff_thresh : float = 10 , ) -> np . ndarray : \"\"\" Calculates the precise error for each frame in the given sequence. This error is based on precise segmentation of the worm object from the frame, and determining the exact proportion of worm's body outside the microscope view. Args: background (np.ndarray): The background image. worm_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of worms. The bounding boxes should be in the format (x, y, w, h). mic_bboxes: A numpy array of shape (N, 4) representing the bounding boxes of the microscope. The bounding boxes should be in the format (x, y, w, h). frame_nums (np.ndarray): An array of frame numbers to calculate the error for. worm_reader (FrameReader): A frame reader containing segmented worm images for each frame. These worm images should match the shape of the worm bounding boxes. Frames passed in frame_nums are read from this reader by index. diff_thresh (float, optional): The difference threshold to distinguish foreground and background objects from. A foreground object is detected if the pixel value difference with the background is greater than this threshold. Returns: np.ndarray: Array of errors of shape (N,) representing the precise segmentation error for each frame. Raises: AssertionError: If the length of frame_nums, worm_bboxes, and mic_bboxes do not match. \"\"\" assert frame_nums . ndim == 1 assert len ( frame_nums ) == worm_bboxes . shape [ 0 ] == mic_bboxes . shape [ 0 ] errors = np . zeros ( len ( frame_nums ), dtype = float ) bounds = background . shape [ :2 ] worm_bboxes , is_legal = BoxUtils . discretize ( worm_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) mic_bboxes , _ = BoxUtils . discretize ( mic_bboxes , bounds = bounds , box_format = BoxFormat . XYWH ) # filter out illegal bboxes , indicting no prediction or bad prediction . errors [ ~is_legal ] = np . nan worm_bboxes = worm_bboxes [ is_legal ] mic_bboxes = mic_bboxes [ is_legal ] frame_nums = frame_nums [ is_legal ] # convert to xyxy format for intersection calculation worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYWH , BoxFormat . XYXY ) wrm_left , wrm_top , wrm_right , wrm_bottom = BoxUtils . unpack ( worm_bboxes ) mic_left , mic_top , mic_right , mic_bottom = BoxUtils . unpack ( mic_bboxes ) # calculate intersection of worm and microscope bounding boxes int_left = np . maximum ( wrm_left , mic_left ) int_top = np . maximum ( wrm_top , mic_top ) int_right = np . minimum ( wrm_right , mic_right ) int_bottom = np . minimum ( wrm_bottom , mic_bottom ) int_width = np . maximum ( 0 , int_right - int_left ) int_height = np . maximum ( 0 , int_bottom - int_top ) # shift the intersection to the worm view coordinates int_left -= wrm_left int_top -= wrm_top # pack the intersection bounding boxes and convert to xywh format int_bboxes = BoxUtils . pack ( int_left , int_top , int_width , int_height ) worm_bboxes = BoxConverter . change_format ( worm_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) mic_bboxes = BoxConverter . change_format ( mic_bboxes , BoxFormat . XYXY , BoxFormat . XYWH ) for i , frame_num in tqdm ( enumerate ( frame_nums ), total = len ( frame_nums ), desc = \"Calculating Error\" , unit = \"fr\" ) : wrm_bbox = worm_bboxes [ i ] int_bbox = int_bboxes [ i ] worm_view = worm_reader [ frame_num ] mask_wrm = ErrorCalculator . calculate_segmentation ( bbox = wrm_bbox , image = worm_view , background = background , diff_thresh = diff_thresh , ) if ErrorCalculator . probe_hook is not None : ErrorCalculator . probe_hook ( worm_view , mask_wrm ) mask_mic = np . zeros_like ( mask_wrm , dtype = bool ) mask_mic [ int_bbox[1 ] : int_bbox [ 1 ] + int_bbox [ 3 ] , int_bbox [ 0 ] : int_bbox [ 0 ] + int_bbox [ 2 ] ] = True total = mask_wrm . sum () if total == 0 : errors [ i ] = 0.0 continue intersection = np . logical_and ( mask_wrm , mask_mic ). sum () error = 1.0 - intersection / total errors [ i ] = error return errors","title":"calculate_precise"},{"location":"reference/wtracker/eval/error_calculator/#calculate_segmentation","text":"def calculate_segmentation ( bbox : numpy . ndarray , image : numpy . ndarray , background : numpy . ndarray , diff_thresh : float ) -> numpy . ndarray Calculates the segmentation error between a view and background image. Parameters: Name Type Description Default bbox np.ndarray The bounding box of the image, in the format (x, y, w, h). None image np.ndarray The image to calculate segmentation from. None background np.ndarray The background image. None diff_thresh float The difference threshold to distinguish foreground and background objects from. None Returns: Type Description np.ndarray The segmentation mask. Raises: Type Description ValueError If the image is not grayscale or color. View Source @ staticmethod def calculate_segmentation ( bbox : np . ndarray , image : np . ndarray , background : np . ndarray , diff_thresh : float , ) -> np . ndarray : \"\"\" Calculates the segmentation error between a view and background image . Args : bbox ( np . ndarray ) : The bounding box of the image , in the format ( x , y , w , h ). image ( np . ndarray ) : The image to calculate segmentation from . background ( np . ndarray ) : The background image . diff_thresh ( float ) : The difference threshold to distinguish foreground and background objects from . Returns : np . ndarray : The segmentation mask . Raises : ValueError : If the image is not grayscale or color . \"\"\" x , y , w , h = bbox assert image . shape [ : 2 ] == ( h , w ) bg_view = background [ y : y + h , x : x + w ] diff = np . abs ( image . astype ( np . int32 ) - bg_view . astype ( np . int32 )). astype ( np . uint8 ) # if images are color, convert to grayscale if diff . ndim == 3 and diff . shape [ 2 ] == 3 : diff = cv . cvtColor ( diff , cv . COLOR_BGR2GRAY ) if diff . ndim != 2 : raise ValueError ( \"Image must be either a gray or a color image.\" ) mask_wrm = diff > diff_thresh return mask_wrm","title":"calculate_segmentation"},{"location":"reference/wtracker/eval/plotter/","text":"Module wtracker.eval.plotter View Source from __future__ import annotations import pandas as pd import seaborn as sns from typing import Callable class Plotter : \"\"\" A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Args: data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log. plot_height (int, optional): The height of the plot. palette (str, optional): The color palette to use for the plots. \"\"\" def __init__ ( self , data_list : list [ pd . DataFrame ], plot_height : int = 7 , palette : str = \"viridis\" , ) -> None : self . plot_height = plot_height self . palette = palette for i , data in enumerate ( data_list ): data [ \"log_num\" ] = i self . data = pd . concat ([ d for d in data_list ], ignore_index = True ) def _get_error_column ( self , error_kind : str ) -> str : if error_kind == \"bbox\" : return \"bbox_error\" elif error_kind == \"dist\" : return \"worm_deviation\" elif error_kind == \"precise\" : return \"precise_error\" else : raise ValueError ( f \"Invalid error kind: { error_kind } \" ) def plot_speed ( self , log_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self . create_distplot ( x_col = \"wrm_speed\" , x_label = \"speed\" , title = \"Worm Speed Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , kde = True , ** kwargs , ) def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ([ \"log_num\" , \"cycle\" ])[ error_col ] . max () . reset_index () return self . create_distplot ( x_col = error_col , x_label = f \" { error_kind } error\" , title = f \" { error_kind } Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) def plot_cycle_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"boxen\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) return self . create_catplot ( x_col = \"cycle_step\" , y_col = error_col , x_label = \"cycle step\" , y_label = f \" { error_kind } error\" , title = f \" { error_kind } error as function of cycle step\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , ** kwargs , ) def plot_speed_vs_error ( self , error_kind : str = \"bbox\" , cycle_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the speed of the worm vs the error. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = ( self . data . groupby ([ \"log_num\" , \"cycle\" ])[[ error_col , \"wrm_speed\" ]] . aggregate ({ error_col : \"max\" , \"wrm_speed\" : \"mean\" }) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \" { error_kind } error\" , title = f \"Speed vs { error_kind } Error\" , condition = condition , data = data , ** kwargs , ) def plot_trajectory ( self , hue_col = \"log_num\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self . create_jointplot ( x_col = \"wrm_center_x\" , y_col = \"wrm_center_y\" , x_label = \"X\" , y_label = \"Y\" , title = \"Worm Trajectory\" , hue_col = hue_col , plot_kind = \"scatter\" , alpha = 1 , linewidth = 0 , condition = condition , ** kwargs , ) plot . ax_marg_x . remove () plot . ax_marg_y . remove () plot . ax_joint . invert_yaxis () return plot def plot_head_size ( self , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the size of the worm head. Args: condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\" , \"kde\" , \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{ col_name }} :: { title . title () } \" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\" , \"box\" , \"violin\" , \"boxen\" , \"bar\" , \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{ col_name }} :: { title . title () } \" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot Classes Plotter class Plotter ( data_list : 'list[pd.DataFrame]' , plot_height : 'int' = 7 , palette : 'str' = 'viridis' ) A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Attributes Name Type Description Default data_list list[pd.DataFrame] A list of dataframes, each holding the data of a single experiment log. None plot_height int The height of the plot. None palette str The color palette to use for the plots. None View Source class Plotter : \"\"\" A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Args: data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log. plot_height (int, optional): The height of the plot. palette (str, optional): The color palette to use for the plots. \"\"\" def __init__ ( self , data_list : list [ pd.DataFrame ] , plot_height : int = 7 , palette : str = \"viridis\" , ) -> None : self . plot_height = plot_height self . palette = palette for i , data in enumerate ( data_list ) : data [ \"log_num\" ] = i self . data = pd . concat ( [ d for d in data_list ] , ignore_index = True ) def _get_error_column ( self , error_kind : str ) -> str : if error_kind == \"bbox\" : return \"bbox_error\" elif error_kind == \"dist\" : return \"worm_deviation\" elif error_kind == \"precise\" : return \"precise_error\" else : raise ValueError ( f \"Invalid error kind: {error_kind}\" ) def plot_speed ( self , log_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self . create_distplot ( x_col = \"wrm_speed\" , x_label = \"speed\" , title = \"Worm Speed Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , kde = True , ** kwargs , ) def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ error_col ] . max (). reset_index () return self . create_distplot ( x_col = error_col , x_label = f \"{error_kind} error\" , title = f \"{error_kind} Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) def plot_cycle_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"boxen\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" strip \", \" box \", \" violin \", \" boxen \", \" bar \", or \" count \". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) return self . create_catplot ( x_col = \"cycle_step\" , y_col = error_col , x_label = \"cycle step\" , y_label = f \"{error_kind} error\" , title = f \"{error_kind} error as function of cycle step\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , ** kwargs , ) def plot_speed_vs_error ( self , error_kind : str = \"bbox\" , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the speed of the worm vs the error. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = ( self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ [error_col, \"wrm_speed\" ] ] . aggregate ( { error_col : \"max\" , \"wrm_speed\" : \"mean\" } ) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \"{error_kind} error\" , title = f \"Speed vs {error_kind} Error\" , condition = condition , data = data , ** kwargs , ) def plot_trajectory ( self , hue_col = \"log_num\" , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self . create_jointplot ( x_col = \"wrm_center_x\" , y_col = \"wrm_center_y\" , x_label = \"X\" , y_label = \"Y\" , title = \"Worm Trajectory\" , hue_col = hue_col , plot_kind = \"scatter\" , alpha = 1 , linewidth = 0 , condition = condition , ** kwargs , ) plot . ax_marg_x . remove () plot . ax_marg_y . remove () plot . ax_joint . invert_yaxis () return plot def plot_head_size ( self , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the size of the worm head. Args: condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\", \"kde\", \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \" strip \", \" box \", \" violin \", \" boxen \", \" bar \", or \" count \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot Methods create_catplot def create_catplot ( self , x_col : 'str' , y_col : 'str' = None , hue_col : 'str' = None , log_wise : 'bool' = False , plot_kind : 'str' = 'strip' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str | None' = None , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.FacetGrid' Create a categorical plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None log_wise bool Whether to plot each log separately. None plot_kind str The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.catplot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\" , \"box\" , \"violin\" , \"boxen\" , \"bar\" , \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot create_distplot def create_distplot ( self , x_col : 'str' , y_col : 'str' = None , hue_col : 'str' = None , log_wise : 'bool' = False , plot_kind : 'str' = 'hist' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str | None' = None , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.FacetGrid' Create a distribution plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None log_wise bool Whether to plot each log separately. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.displot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\" , \"kde\" , \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot create_jointplot def create_jointplot ( self , x_col : 'str' , y_col : 'str' , hue_col : 'str' = None , plot_kind : 'str' = 'scatter' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str' = '' , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.JointGrid' Create a joint plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None plot_kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot plot_cycle_error def plot_cycle_error ( self , error_kind : 'str' = 'bbox' , log_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'boxen' , ** kwargs ) -> 'sns.JointGrid' Plot the error as a function of the cycle step. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None log_wise bool Whether to plot each log separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". None **kwargs None Additional keyword arguments to pass the Plotter.create_catplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_cycle_error( self, error_kind: str = \"bbox\", log_wise: bool = False, condition: Callable[[pd.DataFrame], pd.DataFrame] = None, plot_kind: str = \"boxen\", **kwargs, ) -> sns.JointGrid: \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self._get_error_column(error_kind) return self.create_catplot( x_col=\"cycle_step\", y_col=error_col, x_label=\"cycle step\", y_label=f\"{error_kind} error\", title=f\"{error_kind} error as function of cycle step\", plot_kind=plot_kind, log_wise=log_wise, condition=condition, **kwargs, ) plot_error def plot_error ( self , error_kind : 'str' = 'bbox' , log_wise : 'bool' = False , cycle_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.FacetGrid' Plot the error distribution. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None log_wise bool Whether to plot each log separately. None cycle_wise bool Whether to plot each cycle separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None **kwargs None Additional keyword arguments to pass the Plotter.create_distplot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ error_col ] . max (). reset_index () return self . create_distplot ( x_col = error_col , x_label = f \"{error_kind} error\" , title = f \"{error_kind} Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) plot_head_size def plot_head_size ( self , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.JointGrid' Plot the size of the worm head. Parameters: Name Type Description Default condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_head_size ( self , condition: Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind: str = \"hist\" , ** kwargs , ) -> sns . JointGrid: \"\"\" Plot the size of the worm head . Args: condition ( Callable [[ pd . DataFrame ], pd . DataFrame ], optional ) : A function to filter the data . plot_kind ( str , optional ) : The kind of plot to create . Can be \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , or \"resid\" . ** kwargs: Additional keyword arguments to pass the `Plotter . create_jointplot ` function . Returns: sns . JointGrid: The plot object . \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) plot_speed def plot_speed ( self , log_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.FacetGrid' Plot the speed distribution of the worm. Parameters: Name Type Description Default log_wise bool Whether to plot each log separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None **kwargs None Additional keyword arguments to pass the Plotter.create_distplot function. None View Source def plot_speed( self, log_wise: bool = False, condition: Callable[[pd.DataFrame], pd.DataFrame] = None, plot_kind: str = \"hist\", **kwargs, ) -> sns.FacetGrid: \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self.create_distplot( x_col=\"wrm_speed\", x_label=\"speed\", title=\"Worm Speed Distribution\", plot_kind=plot_kind, log_wise=log_wise, condition=condition, kde=True, **kwargs, ) plot_speed_vs_error def plot_speed_vs_error ( self , error_kind : 'str' = 'bbox' , cycle_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , kind : 'str' = 'hist' , ** kwargs ) -> 'sns.JointGrid' Plot the speed of the worm vs the error. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None cycle_wise bool Whether to plot each cycle separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_speed_vs_error ( self , error_kind: str = \"bbox\" , cycle_wise: bool = False , condition: Callable [[ pd . DataFrame ], pd . DataFrame ] = None , kind: str = \"hist\" , ** kwargs , ) -> sns . JointGrid: \"\"\" Plot the speed of the worm vs the error . Args: error_kind ( str , optional ) : The kind of error to plot . Can be \"bbox\" , \"dist\" , or \"precise\" . cycle_wise ( bool , optional ) : Whether to plot each cycle separately . condition ( Callable [[ pd . DataFrame ], pd . DataFrame ], optional ) : A function to filter the data . kind ( str , optional ) : The kind of plot to create . Can be \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , or \"resid\" . ** kwargs: Additional keyword arguments to pass the `Plotter . create_jointplot ` function . Returns: sns . JointGrid: The plot object . \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise: data = ( self . data . groupby ([ \"log_num\" , \"cycle\" ])[[ error_col , \"wrm_speed\" ]] . aggregate ({ error_col: \"max\" , \"wrm_speed\" : \"mean\" }) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \"{error_kind} error\" , title = f \"Speed vs {error_kind} Error\" , condition = condition , data = data , ** kwargs , ) plot_trajectory def plot_trajectory ( self , hue_col = 'log_num' , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , ** kwargs ) -> 'sns.JointGrid' Plot the trajectory of the worm. Parameters: Name Type Description Default hue_col str The column to use for coloring the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_trajectory( self, hue_col=\"log_num\", condition: Callable[[pd.DataFrame], pd.DataFrame] = None, **kwargs, ) -> sns.JointGrid: \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self.create_jointplot( x_col=\"wrm_center_x\", y_col=\"wrm_center_y\", x_label=\"X\", y_label=\"Y\", title=\"Worm Trajectory\", hue_col=hue_col, plot_kind=\"scatter\", alpha=1, linewidth=0, condition=condition, **kwargs, ) plot.ax_marg_x.remove() plot.ax_marg_y.remove() plot.ax_joint.invert_yaxis() return plot","title":"Plotter"},{"location":"reference/wtracker/eval/plotter/#module-wtrackerevalplotter","text":"View Source from __future__ import annotations import pandas as pd import seaborn as sns from typing import Callable class Plotter : \"\"\" A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Args: data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log. plot_height (int, optional): The height of the plot. palette (str, optional): The color palette to use for the plots. \"\"\" def __init__ ( self , data_list : list [ pd . DataFrame ], plot_height : int = 7 , palette : str = \"viridis\" , ) -> None : self . plot_height = plot_height self . palette = palette for i , data in enumerate ( data_list ): data [ \"log_num\" ] = i self . data = pd . concat ([ d for d in data_list ], ignore_index = True ) def _get_error_column ( self , error_kind : str ) -> str : if error_kind == \"bbox\" : return \"bbox_error\" elif error_kind == \"dist\" : return \"worm_deviation\" elif error_kind == \"precise\" : return \"precise_error\" else : raise ValueError ( f \"Invalid error kind: { error_kind } \" ) def plot_speed ( self , log_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self . create_distplot ( x_col = \"wrm_speed\" , x_label = \"speed\" , title = \"Worm Speed Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , kde = True , ** kwargs , ) def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ([ \"log_num\" , \"cycle\" ])[ error_col ] . max () . reset_index () return self . create_distplot ( x_col = error_col , x_label = f \" { error_kind } error\" , title = f \" { error_kind } Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) def plot_cycle_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"boxen\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) return self . create_catplot ( x_col = \"cycle_step\" , y_col = error_col , x_label = \"cycle step\" , y_label = f \" { error_kind } error\" , title = f \" { error_kind } error as function of cycle step\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , ** kwargs , ) def plot_speed_vs_error ( self , error_kind : str = \"bbox\" , cycle_wise : bool = False , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the speed of the worm vs the error. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = ( self . data . groupby ([ \"log_num\" , \"cycle\" ])[[ error_col , \"wrm_speed\" ]] . aggregate ({ error_col : \"max\" , \"wrm_speed\" : \"mean\" }) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \" { error_kind } error\" , title = f \"Speed vs { error_kind } Error\" , condition = condition , data = data , ** kwargs , ) def plot_trajectory ( self , hue_col = \"log_num\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self . create_jointplot ( x_col = \"wrm_center_x\" , y_col = \"wrm_center_y\" , x_label = \"X\" , y_label = \"Y\" , title = \"Worm Trajectory\" , hue_col = hue_col , plot_kind = \"scatter\" , alpha = 1 , linewidth = 0 , condition = condition , ** kwargs , ) plot . ax_marg_x . remove () plot . ax_marg_y . remove () plot . ax_joint . invert_yaxis () return plot def plot_head_size ( self , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the size of the worm head. Args: condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\" , \"kde\" , \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{ col_name }} :: { title . title () } \" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\" , \"box\" , \"violin\" , \"boxen\" , \"bar\" , \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{ col_name }} :: { title . title () } \" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot","title":"Module wtracker.eval.plotter"},{"location":"reference/wtracker/eval/plotter/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/eval/plotter/#plotter","text":"class Plotter ( data_list : 'list[pd.DataFrame]' , plot_height : 'int' = 7 , palette : 'str' = 'viridis' ) A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once.","title":"Plotter"},{"location":"reference/wtracker/eval/plotter/#attributes","text":"Name Type Description Default data_list list[pd.DataFrame] A list of dataframes, each holding the data of a single experiment log. None plot_height int The height of the plot. None palette str The color palette to use for the plots. None View Source class Plotter : \"\"\" A class for plotting experiment log data. The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once. Args: data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log. plot_height (int, optional): The height of the plot. palette (str, optional): The color palette to use for the plots. \"\"\" def __init__ ( self , data_list : list [ pd.DataFrame ] , plot_height : int = 7 , palette : str = \"viridis\" , ) -> None : self . plot_height = plot_height self . palette = palette for i , data in enumerate ( data_list ) : data [ \"log_num\" ] = i self . data = pd . concat ( [ d for d in data_list ] , ignore_index = True ) def _get_error_column ( self , error_kind : str ) -> str : if error_kind == \"bbox\" : return \"bbox_error\" elif error_kind == \"dist\" : return \"worm_deviation\" elif error_kind == \"precise\" : return \"precise_error\" else : raise ValueError ( f \"Invalid error kind: {error_kind}\" ) def plot_speed ( self , log_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self . create_distplot ( x_col = \"wrm_speed\" , x_label = \"speed\" , title = \"Worm Speed Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , kde = True , ** kwargs , ) def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ error_col ] . max (). reset_index () return self . create_distplot ( x_col = error_col , x_label = f \"{error_kind} error\" , title = f \"{error_kind} Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , ) def plot_cycle_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"boxen\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" strip \", \" box \", \" violin \", \" boxen \", \" bar \", or \" count \". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) return self . create_catplot ( x_col = \"cycle_step\" , y_col = error_col , x_label = \"cycle step\" , y_label = f \"{error_kind} error\" , title = f \"{error_kind} error as function of cycle step\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , ** kwargs , ) def plot_speed_vs_error ( self , error_kind : str = \"bbox\" , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the speed of the worm vs the error. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = ( self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ [error_col, \"wrm_speed\" ] ] . aggregate ( { error_col : \"max\" , \"wrm_speed\" : \"mean\" } ) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \"{error_kind} error\" , title = f \"Speed vs {error_kind} Error\" , condition = condition , data = data , ** kwargs , ) def plot_trajectory ( self , hue_col = \"log_num\" , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self . create_jointplot ( x_col = \"wrm_center_x\" , y_col = \"wrm_center_y\" , x_label = \"X\" , y_label = \"Y\" , title = \"Worm Trajectory\" , hue_col = hue_col , plot_kind = \"scatter\" , alpha = 1 , linewidth = 0 , condition = condition , ** kwargs , ) plot . ax_marg_x . remove () plot . ax_marg_y . remove () plot . ax_joint . invert_yaxis () return plot def plot_head_size ( self , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . JointGrid : \"\"\" Plot the size of the worm head. Args: condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , ) def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\", \"kde\", \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \" strip \", \" box \", \" violin \", \" boxen \", \" bar \", or \" count \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , transform : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \" scatter \", \" kde \", \" hist \", \" hex \", \" reg \", or \" resid \". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition(data) ] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot","title":"Attributes"},{"location":"reference/wtracker/eval/plotter/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/eval/plotter/#create_catplot","text":"def create_catplot ( self , x_col : 'str' , y_col : 'str' = None , hue_col : 'str' = None , log_wise : 'bool' = False , plot_kind : 'str' = 'strip' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str | None' = None , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.FacetGrid' Create a categorical plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None log_wise bool Whether to plot each log separately. None plot_kind str The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.catplot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def create_catplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"strip\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a categorical plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"strip\" , \"box\" , \"violin\" , \"boxen\" , \"bar\" , \"count\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . catplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot","title":"create_catplot"},{"location":"reference/wtracker/eval/plotter/#create_distplot","text":"def create_distplot ( self , x_col : 'str' , y_col : 'str' = None , hue_col : 'str' = None , log_wise : 'bool' = False , plot_kind : 'str' = 'hist' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str | None' = None , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.FacetGrid' Create a distribution plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None log_wise bool Whether to plot each log separately. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.displot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def create_distplot ( self , x_col : str , y_col : str = None , hue_col : str = None , log_wise : bool = False , plot_kind : str = \"hist\" , x_label : str = \"\" , y_label : str = \"\" , title : str | None = None , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . FacetGrid : \"\"\" Create a distribution plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str, optional): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. log_wise (bool, optional): Whether to plot each log separately. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.displot` function. Returns: sns.FacetGrid: The plot object. \"\"\" assert plot_kind in [ \"hist\" , \"kde\" , \"ecdf\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . displot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , col = \"log_num\" if log_wise else None , kind = plot_kind , height = self . plot_height , palette = palette , ** kwargs , ) plot . set_xlabels ( x_label . capitalize ()) plot . set_ylabels ( y_label . capitalize ()) if title is not None : if log_wise : title = f \"Log {{col_name}} :: {title.title()}\" plot . set_titles ( title ) else : plot . figure . suptitle ( title . title ()) plot . tight_layout () return plot","title":"create_distplot"},{"location":"reference/wtracker/eval/plotter/#create_jointplot","text":"def create_jointplot ( self , x_col : 'str' , y_col : 'str' , hue_col : 'str' = None , plot_kind : 'str' = 'scatter' , x_label : 'str' = '' , y_label : 'str' = '' , title : 'str' = '' , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , transform : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , data : 'pd.DataFrame' = None , ** kwargs ) -> 'sns.JointGrid' Create a joint plot from the data. Parameters: Name Type Description Default x_col str The column to plot on the x-axis. None y_col str The column to plot on the y-axis. None hue_col str The column to use for coloring the plot. None plot_kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None x_label str The x-axis label. None y_label str The y-axis label. None title str The title of the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None transform Callable[[pd.DataFrame], pd.DataFrame] A function to transform the data. None data pd.DataFrame Custom data to plot from. If None, the data passed to the constructor of the class is used. None **kwargs None Additional keyword arguments to pass to the seaborn.jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def create_jointplot ( self , x_col : str , y_col : str , hue_col : str = None , plot_kind : str = \"scatter\" , x_label : str = \"\" , y_label : str = \"\" , title : str = \"\" , condition : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , transform : Callable [[ pd . DataFrame ], pd . DataFrame ] = None , data : pd . DataFrame = None , ** kwargs , ) -> sns . JointGrid : \"\"\" Create a joint plot from the data. Args: x_col (str): The column to plot on the x-axis. y_col (str): The column to plot on the y-axis. hue_col (str, optional): The column to use for coloring the plot. plot_kind (str, optional): The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". x_label (str, optional): The x-axis label. y_label (str, optional): The y-axis label. title (str, optional): The title of the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data. data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used. **kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" assert plot_kind in [ \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , \"resid\" ] if data is None : data = self . data if transform is not None : data = transform ( data ) if condition is not None : data = data [ condition ( data )] palette = self . palette if hue_col is not None else None plot = sns . jointplot ( data = data . dropna (), x = x_col , y = y_col , hue = hue_col , kind = plot_kind , height = self . plot_height , palette = palette , marginal_kws = dict ( palette = palette ), ** kwargs , ) plot . set_axis_labels ( x_label . capitalize (), y_label . capitalize ()) plot . figure . suptitle ( title . title ()) plot . figure . tight_layout () return plot","title":"create_jointplot"},{"location":"reference/wtracker/eval/plotter/#plot_cycle_error","text":"def plot_cycle_error ( self , error_kind : 'str' = 'bbox' , log_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'boxen' , ** kwargs ) -> 'sns.JointGrid' Plot the error as a function of the cycle step. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None log_wise bool Whether to plot each log separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". None **kwargs None Additional keyword arguments to pass the Plotter.create_catplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_cycle_error( self, error_kind: str = \"bbox\", log_wise: bool = False, condition: Callable[[pd.DataFrame], pd.DataFrame] = None, plot_kind: str = \"boxen\", **kwargs, ) -> sns.JointGrid: \"\"\" Plot the error as a function of the cycle step. Args: error_kind (str, optional): The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"strip\", \"box\", \"violin\", \"boxen\", \"bar\", or \"count\". **kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function. Returns: sns.JointGrid: The plot object. \"\"\" error_col = self._get_error_column(error_kind) return self.create_catplot( x_col=\"cycle_step\", y_col=error_col, x_label=\"cycle step\", y_label=f\"{error_kind} error\", title=f\"{error_kind} error as function of cycle step\", plot_kind=plot_kind, log_wise=log_wise, condition=condition, **kwargs, )","title":"plot_cycle_error"},{"location":"reference/wtracker/eval/plotter/#plot_error","text":"def plot_error ( self , error_kind : 'str' = 'bbox' , log_wise : 'bool' = False , cycle_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.FacetGrid' Plot the error distribution. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None log_wise bool Whether to plot each log separately. None cycle_wise bool Whether to plot each cycle separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None **kwargs None Additional keyword arguments to pass the Plotter.create_distplot function. None Returns: Type Description sns.FacetGrid The plot object. View Source def plot_error ( self , error_kind : str = \"bbox\" , log_wise : bool = False , cycle_wise : bool = False , condition : Callable [ [pd.DataFrame ] , pd . DataFrame ] = None , plot_kind : str = \"hist\" , ** kwargs , ) -> sns . FacetGrid : \"\"\" Plot the error distribution. Args: error_kind (str, optional): The kind of error to plot. Can be \" bbox \", \" dist \", or \" precise \". log_wise (bool, optional): Whether to plot each log separately. cycle_wise (bool, optional): Whether to plot each cycle separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \" hist \", \" kde \", or \" ecdf \". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. Returns: sns.FacetGrid: The plot object. \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise : data = self . data . groupby ( [ \"log_num\", \"cycle\" ] ) [ error_col ] . max (). reset_index () return self . create_distplot ( x_col = error_col , x_label = f \"{error_kind} error\" , title = f \"{error_kind} Error Distribution\" , plot_kind = plot_kind , log_wise = log_wise , condition = condition , data = data , ** kwargs , )","title":"plot_error"},{"location":"reference/wtracker/eval/plotter/#plot_head_size","text":"def plot_head_size ( self , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.JointGrid' Plot the size of the worm head. Parameters: Name Type Description Default condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_head_size ( self , condition: Callable [[ pd . DataFrame ], pd . DataFrame ] = None , plot_kind: str = \"hist\" , ** kwargs , ) -> sns . JointGrid: \"\"\" Plot the size of the worm head . Args: condition ( Callable [[ pd . DataFrame ], pd . DataFrame ], optional ) : A function to filter the data . plot_kind ( str , optional ) : The kind of plot to create . Can be \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , or \"resid\" . ** kwargs: Additional keyword arguments to pass the `Plotter . create_jointplot ` function . Returns: sns . JointGrid: The plot object . \"\"\" return self . create_jointplot ( x_col = \"wrm_w\" , y_col = \"wrm_h\" , x_label = \"width\" , y_label = \"height\" , title = \"Worm Head Size\" , plot_kind = plot_kind , condition = condition , ** kwargs , )","title":"plot_head_size"},{"location":"reference/wtracker/eval/plotter/#plot_speed","text":"def plot_speed ( self , log_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , plot_kind : 'str' = 'hist' , ** kwargs ) -> 'sns.FacetGrid' Plot the speed distribution of the worm. Parameters: Name Type Description Default log_wise bool Whether to plot each log separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None plot_kind str The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". None **kwargs None Additional keyword arguments to pass the Plotter.create_distplot function. None View Source def plot_speed( self, log_wise: bool = False, condition: Callable[[pd.DataFrame], pd.DataFrame] = None, plot_kind: str = \"hist\", **kwargs, ) -> sns.FacetGrid: \"\"\" Plot the speed distribution of the worm. Args: log_wise (bool, optional): Whether to plot each log separately. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. plot_kind (str, optional): The kind of plot to create. Can be \"hist\", \"kde\", or \"ecdf\". **kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function. \"\"\" return self.create_distplot( x_col=\"wrm_speed\", x_label=\"speed\", title=\"Worm Speed Distribution\", plot_kind=plot_kind, log_wise=log_wise, condition=condition, kde=True, **kwargs, )","title":"plot_speed"},{"location":"reference/wtracker/eval/plotter/#plot_speed_vs_error","text":"def plot_speed_vs_error ( self , error_kind : 'str' = 'bbox' , cycle_wise : 'bool' = False , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , kind : 'str' = 'hist' , ** kwargs ) -> 'sns.JointGrid' Plot the speed of the worm vs the error. Parameters: Name Type Description Default error_kind str The kind of error to plot. Can be \"bbox\", \"dist\", or \"precise\". None cycle_wise bool Whether to plot each cycle separately. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None kind str The kind of plot to create. Can be \"scatter\", \"kde\", \"hist\", \"hex\", \"reg\", or \"resid\". None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_speed_vs_error ( self , error_kind: str = \"bbox\" , cycle_wise: bool = False , condition: Callable [[ pd . DataFrame ], pd . DataFrame ] = None , kind: str = \"hist\" , ** kwargs , ) -> sns . JointGrid: \"\"\" Plot the speed of the worm vs the error . Args: error_kind ( str , optional ) : The kind of error to plot . Can be \"bbox\" , \"dist\" , or \"precise\" . cycle_wise ( bool , optional ) : Whether to plot each cycle separately . condition ( Callable [[ pd . DataFrame ], pd . DataFrame ], optional ) : A function to filter the data . kind ( str , optional ) : The kind of plot to create . Can be \"scatter\" , \"kde\" , \"hist\" , \"hex\" , \"reg\" , or \"resid\" . ** kwargs: Additional keyword arguments to pass the `Plotter . create_jointplot ` function . Returns: sns . JointGrid: The plot object . \"\"\" error_col = self . _get_error_column ( error_kind ) data = self . data if cycle_wise: data = ( self . data . groupby ([ \"log_num\" , \"cycle\" ])[[ error_col , \"wrm_speed\" ]] . aggregate ({ error_col: \"max\" , \"wrm_speed\" : \"mean\" }) . reset_index () ) return self . create_jointplot ( x_col = \"wrm_speed\" , y_col = error_col , plot_kind = kind , x_label = \"speed\" , y_label = f \"{error_kind} error\" , title = f \"Speed vs {error_kind} Error\" , condition = condition , data = data , ** kwargs , )","title":"plot_speed_vs_error"},{"location":"reference/wtracker/eval/plotter/#plot_trajectory","text":"def plot_trajectory ( self , hue_col = 'log_num' , condition : 'Callable[[pd.DataFrame], pd.DataFrame]' = None , ** kwargs ) -> 'sns.JointGrid' Plot the trajectory of the worm. Parameters: Name Type Description Default hue_col str The column to use for coloring the plot. None condition Callable[[pd.DataFrame], pd.DataFrame] A function to filter the data. None **kwargs None Additional keyword arguments to pass the Plotter.create_jointplot function. None Returns: Type Description sns.JointGrid The plot object. View Source def plot_trajectory( self, hue_col=\"log_num\", condition: Callable[[pd.DataFrame], pd.DataFrame] = None, **kwargs, ) -> sns.JointGrid: \"\"\" Plot the trajectory of the worm. Args: hue_col (str, optional): The column to use for coloring the plot. condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data. **kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function. Returns: sns.JointGrid: The plot object. \"\"\" plot = self.create_jointplot( x_col=\"wrm_center_x\", y_col=\"wrm_center_y\", x_label=\"X\", y_label=\"Y\", title=\"Worm Trajectory\", hue_col=hue_col, plot_kind=\"scatter\", alpha=1, linewidth=0, condition=condition, **kwargs, ) plot.ax_marg_x.remove() plot.ax_marg_y.remove() plot.ax_joint.invert_yaxis() return plot","title":"plot_trajectory"},{"location":"reference/wtracker/eval/vlc/","text":"Module wtracker.eval.vlc View Source import pandas as pd import numpy as np from math import ceil , floor import os import cv2 as cv from typing import Callable from dataclasses import dataclass , field import matplotlib matplotlib . use ( \"QTAgg\" ) from wtracker.utils.path_utils import Files , create_directory , join_paths from wtracker.utils.io_utils import ImageSaver from wtracker.utils.frame_reader import FrameReader , DummyReader from wtracker.sim.config import TimingConfig @dataclass class HotKey : \"\"\" Represents a hotkey that can be used to trigger a specific function. Attributes: key (str): The key for the hotkey. func (Callable[[str], None]): The function to be called when the hotkey is triggered. description (str): The description of the hotkey (optional). \"\"\" key : str func : Callable [[ str ], None ] description : str = field ( default = \"\" ) def __post_init__ ( self ): self . key = self . key . lower () class StreamViewer : \"\"\" A class for viewing and interacting with photos and video streams. Args: window_name (str, optional): The name of the window. Example: with StreamViewer() as streamer: streamer.imshow(image) streamer.waitKey() \"\"\" def __init__ ( self , window_name : str = \"streamer\" ) -> None : self . window_name = window_name self . window = None self . hotkeys : list [ HotKey ] = [] self . register_hotkey ( HotKey ( \"q\" , self . close , \"close the window\" )) def register_hotkey ( self , hotkey : HotKey ): \"\"\" Registers a hotkey. Args: hotkey (HotKey): The hotkey to register. \"\"\" self . hotkeys . append ( hotkey ) def create_trackbar ( self , name : str , val : int , maxval : int , onChange = lambda x : x ): \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv . createTrackbar ( name , self . window_name , val , maxval , onChange ) def update_trackbar ( self , name : str , val : int ): \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv . setTrackbarPos ( name , self . window_name , val ) def set_title ( self , title : str ): \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv . setWindowTitle ( self . window_name , title ) def __enter__ ( self ): \"\"\" Enters the context manager. \"\"\" self . open () return self def __exit__ ( self , exc_type , exc_value , traceback ): \"\"\" Exits the context manager. \"\"\" self . close () def __del__ ( self ): \"\"\" Destructor method. \"\"\" self . close () def update ( self , image : np . ndarray , wait : int = 1 ): \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv . imshow ( self . window_name , image ) self . waitKey ( wait ) def waitKey ( self , timeout : int = 0 ): \"\"\" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\"\" key = cv . waitKey ( timeout ) if key <= 0 : return key key = chr ( key ) . lower () for hotkey in self . hotkeys : if key in hotkey . key : hotkey . func ( key ) return key def open ( self ): \"\"\" Opens the window. \"\"\" self . close () self . window = cv . namedWindow ( self . window_name , flags = cv . WINDOW_GUI_EXPANDED ) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) self . set_title ( self . window_name ) def close ( self , key : str = \"q\" ): \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self . window is not None : cv . destroyWindow ( self . window_name ) self . window = None def imshow ( self , image : np . ndarray , title : str = \"image\" ): \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self . update ( image , wait = 0 ) self . set_title ( title ) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) class VLC : \"\"\" The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Args: files (Files): The files to read frames from. If None, the video player will present the log data (simulation) on a white background. config (TimingConfig): The timing configuration of the system. log_path (str): The path to the log file. cam_type (str): The type of camera. This should match the prefix of the corresponding columns in the log file. show_pred (bool, optional): Whether to show the prediction box. show_micro (bool, optional): Whether to show the microscope box. show_cam (bool, optional): Whether to show the camera box. \"\"\" def __init__ ( self , files : Files | None , config : TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False , ) -> None : self . streamer = StreamViewer ( window_name = \"VLC\" ) self . index = 0 self . _curr_row = None self . exit = False self . delay = 0 self . play = False self . show_pred = show_pred self . show_micro = show_micro self . show_cam = show_cam self . cam_type : str = cam_type self . config : TimingConfig = config self . log : pd . DataFrame = self . _load_log ( log_path ) self . reader : FrameReader = self . _create_reader ( files ) def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () def _load_log ( self , log_path : str ) -> pd . DataFrame : if log_path is None : return None log = pd . read_csv ( log_path , index_col = \"frame\" ) if self . cam_type == \"plt\" : log [ \"plt_x\" ] = 0 log [ \"plt_y\" ] = 0 log [ \"plt_h\" ] = max ( log [ \"cam_y\" ]) + max ( log [ \"cam_h\" ]) log [ \"plt_w\" ] = max ( log [ \"cam_x\" ]) + max ( log [ \"cam_w\" ]) # assert len(log.index) == len(self.reader) self . _curr_row = log . iloc [ self . index ] return log def _init_hotkeys ( self ) -> None : self . streamer . register_hotkey ( HotKey ( \"q\" , self . close , \"close VLC\" )) self . streamer . register_hotkey ( HotKey ( \"d\" , self . next , \"next frame\" )) self . streamer . register_hotkey ( HotKey ( \"a\" , self . prev , \"previous frame\" )) self . streamer . register_hotkey ( HotKey ( \"p\" , self . toggle_play , \"play/pause\" )) self . streamer . register_hotkey ( HotKey ( \"h\" , self . toggle_pred , \"toggle prediction box\" )) self . streamer . register_hotkey ( HotKey ( \"m\" , self . toggle_micro , \"toggle microscope box\" )) self . streamer . register_hotkey ( HotKey ( \"c\" , self . toggle_cam , \"toggle camera box\" )) def print_hotkeys ( self ): print ( \"Hotkeys:\" ) for hotkey in self . streamer . hotkeys : print ( f \" - { hotkey . key } : { hotkey . description } \" ) def _create_window ( self ): self . streamer . open () self . streamer . create_trackbar ( \"delay\" , 0 , 250 , self . set_delay ) self . streamer . create_trackbar ( \"#frame\" , 0 , len ( self . reader ), self . seek ) def _create_reader ( self , files : Files ) -> FrameReader : if files is None : frame_num = len ( self . log . index ) frame_size = ( self . get_attribute ( self . cam_type + \"_h\" ), self . get_attribute ( self . cam_type + \"_w\" ), ) return DummyReader ( frame_num , frame_size ) filenames = [ f for f in files ] reader = FrameReader ( files . root , filenames ) return reader def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . streamer . close () def _get_title ( self ): curr_phase = self . get_attribute ( \"phase\" ) phase_title = f \"Action: { curr_phase } \" cycle_len = self . config . imaging_frame_num + self . config . moving_frame_num cycle_progress = 1 + self . index % cycle_len cycle_title = ( f \"cycle progress [ { cycle_progress } / { cycle_len } ]: \" + cycle_progress * \"#\" + ( cycle_len - cycle_progress ) * \"_\" ) title = f \" { phase_title } :: { cycle_title } \" return title def get_attribute ( self , col_name : str ): return self . _curr_row [ col_name ] def update_curr_row ( self ): self . _curr_row = self . log . iloc [ self . index ] def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo def seek ( self , pos : int ): self . index = ( pos ) % len ( self . reader ) self . update_curr_row () self . streamer . update ( self . get_photo ()) self . streamer . set_title ( self . _get_title ()) def next ( self , key = None ): self . index = ( self . index + 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def prev ( self , key = None ): self . index = ( self . index - 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def close ( self , key = None ): self . exit = True def set_delay ( self , delay : int ): self . delay = delay def toggle_play ( self , key : str = None ): self . play = not self . play def toggle_pred ( self , key : str = None ): self . show_pred = not self . show_pred def toggle_micro ( self , key : str = None ): self . show_micro = not self . show_micro def toggle_cam ( self , key : str = None ): self . show_cam = not self . show_cam def mainloop ( self ): \"\"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\"\" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ]: x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ) . all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) def draw_center ( self , photo : np . ndarray ): x , y , w , h = self . get_bbox ( \"mic\" ) center = ( x + w // 2 , y + h // 2 ) cv . drawMarker ( photo , center , ( 0 , 0 , 255 ), cv . MARKER_CROSS , 7 , 1 ) def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \" { self . cam_type } _\" + \" {:07d} .png\" with ImageSaver ( folder_path , tqdm_kwargs = { \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ) . replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate { fps } -start_number 0 -i { join_paths ( folder_path , img_name_format ) } -c:v copy { join_paths ( output_dir , 'video.mp4' ) } \" print ( command ) os . system ( command ) Classes HotKey class HotKey ( key : str , func : Callable [[ str ], NoneType ], description : str = '' ) Represents a hotkey that can be used to trigger a specific function. Attributes Name Type Description Default key str The key for the hotkey. None func Callable[[str], None] The function to be called when the hotkey is triggered. None description str The description of the hotkey (optional). None View Source @dataclass class HotKey : \"\"\" Represents a hotkey that can be used to trigger a specific function. Attributes: key (str): The key for the hotkey. func (Callable[[str], None]): The function to be called when the hotkey is triggered. description (str): The description of the hotkey (optional). \"\"\" key : str func : Callable [ [str ] , None ] description : str = field ( default = \"\" ) def __post_init__ ( self ) : self . key = self . key . lower () Class variables description StreamViewer class StreamViewer ( window_name : str = 'streamer' ) A class for viewing and interacting with photos and video streams. Attributes Name Type Description Default window_name str The name of the window. None View Source class StreamViewer : \"\"\" A class for viewing and interacting with photos and video streams. Args: window_name (str, optional): The name of the window. Example: with StreamViewer() as streamer: streamer.imshow(image) streamer.waitKey() \"\"\" def __init__ ( self , window_name : str = \"streamer\" ) -> None : self . window_name = window_name self . window = None self . hotkeys : list [ HotKey ] = [] self . register_hotkey ( HotKey ( \"q\" , self . close , \"close the window\" )) def register_hotkey ( self , hotkey : HotKey ) : \"\"\" Registers a hotkey. Args: hotkey (HotKey): The hotkey to register. \"\"\" self . hotkeys . append ( hotkey ) def create_trackbar ( self , name : str , val : int , maxval : int , onChange = lambda x : x ) : \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv . createTrackbar ( name , self . window_name , val , maxval , onChange ) def update_trackbar ( self , name : str , val : int ) : \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv . setTrackbarPos ( name , self . window_name , val ) def set_title ( self , title : str ) : \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv . setWindowTitle ( self . window_name , title ) def __enter__ ( self ) : \"\"\" Enters the context manager. \"\"\" self . open () return self def __exit__ ( self , exc_type , exc_value , traceback ) : \"\"\" Exits the context manager. \"\"\" self . close () def __del__ ( self ) : \"\"\" Destructor method. \"\"\" self . close () def update ( self , image : np . ndarray , wait : int = 1 ) : \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv . imshow ( self . window_name , image ) self . waitKey ( wait ) def waitKey ( self , timeout : int = 0 ) : \"\"\" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\"\" key = cv . waitKey ( timeout ) if key <= 0 : return key key = chr ( key ). lower () for hotkey in self . hotkeys : if key in hotkey . key : hotkey . func ( key ) return key def open ( self ) : \"\"\" Opens the window. \"\"\" self . close () self . window = cv . namedWindow ( self . window_name , flags = cv . WINDOW_GUI_EXPANDED ) # cv . setWindowProperty ( self . window_name , cv . WND_PROP_TOPMOST , 1 ) self . set_title ( self . window_name ) def close ( self , key : str = \"q\" ) : \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self . window is not None : cv . destroyWindow ( self . window_name ) self . window = None def imshow ( self , image : np . ndarray , title : str = \"image\" ) : \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self . update ( image , wait = 0 ) self . set_title ( title ) # cv . setWindowProperty ( self . window_name , cv . WND_PROP_TOPMOST , 1 ) Methods close def close ( self , key : str = 'q' ) Closes the window. Parameters: Name Type Description Default key str The key to close the window. None View Source def close(self, key: str = \"q\"): \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self.window is not None: cv.destroyWindow(self.window_name) self.window = None create_trackbar def create_trackbar ( self , name : str , val : int , maxval : int , onChange =< function StreamViewer .< lambda > at 0x7f9303198160 > ) Creates a trackbar. Parameters: Name Type Description Default name str The name of the trackbar. None val int The initial value of the trackbar. None maxval int The maximum value of the trackbar. None onChange function The function to call when the trackbar value changes. None View Source def create_trackbar(self, name: str, val: int, maxval: int, onChange=lambda x: x): \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv.createTrackbar(name, self.window_name, val, maxval, onChange) imshow def imshow ( self , image : numpy . ndarray , title : str = 'image' ) Displays an image in the window. Parameters: Name Type Description Default image np.ndarray The image to display. None title str The title of the image. None View Source def imshow(self, image: np.ndarray, title: str = \"image\"): \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self.update(image, wait=0) self.set_title(title) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) open def open ( self ) Opens the window. View Source def open(self): \"\"\" Opens the window. \"\"\" self.close() self.window = cv.namedWindow(self.window_name, flags=cv.WINDOW_GUI_EXPANDED) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) self.set_title(self.window_name) register_hotkey def register_hotkey ( self , hotkey : wtracker . eval . vlc . HotKey ) Registers a hotkey. Parameters: Name Type Description Default hotkey HotKey The hotkey to register. None View Source def register_hotkey ( self , hotkey: HotKey ) : \"\"\" Registers a hotkey . Args: hotkey ( HotKey ) : The hotkey to register . \"\"\" self . hotkeys . append ( hotkey ) set_title def set_title ( self , title : str ) Sets the title of the window. Parameters: Name Type Description Default title str The new title of the window. None View Source def set_title(self, title: str): \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv.setWindowTitle(self.window_name, title) update def update ( self , image : numpy . ndarray , wait : int = 1 ) Updates the window with a new image. Parameters: Name Type Description Default image np.ndarray The image to display. None wait int The delay in milliseconds before updating the window. None View Source def update(self, image: np.ndarray, wait: int = 1): \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv.imshow(self.window_name, image) self.waitKey(wait) update_trackbar def update_trackbar ( self , name : str , val : int ) Updates the value of a trackbar. Parameters: Name Type Description Default name str The name of the trackbar. None val int The new value of the trackbar. None View Source def update_trackbar(self, name: str, val: int): \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv.setTrackbarPos(name, self.window_name, val) waitKey def waitKey ( self , timeout : int = 0 ) Waits for a key press. This Function also triggers the hotkeys. Parameters: Name Type Description Default timeout int The timeout in milliseconds. None Returns: Type Description str The key that was pressed. View Source def waitKey ( self , timeout : int = 0 ) : \"\" \" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\" \" key = cv.waitKey(timeout) if key <= 0: return key key = chr(key).lower() for hotkey in self.hotkeys: if key in hotkey.key: hotkey.func(key) return key VLC class VLC ( files : wtracker . utils . path_utils . Files | None , config : wtracker . sim . config . TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False ) The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Attributes Name Type Description Default files Files The files to read frames from. If None, the video player will present the log data (simulation) on a white background. None config TimingConfig The timing configuration of the system. None log_path str The path to the log file. None cam_type str The type of camera. This should match the prefix of the corresponding columns in the log file. None show_pred bool Whether to show the prediction box. None show_micro bool Whether to show the microscope box. None show_cam bool Whether to show the camera box. None View Source class VLC : \"\"\" The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Args: files (Files): The files to read frames from. If None, the video player will present the log data (simulation) on a white background. config (TimingConfig): The timing configuration of the system. log_path (str): The path to the log file. cam_type (str): The type of camera. This should match the prefix of the corresponding columns in the log file. show_pred (bool, optional): Whether to show the prediction box. show_micro (bool, optional): Whether to show the microscope box. show_cam (bool, optional): Whether to show the camera box. \"\"\" def __init__ ( self , files : Files | None , config : TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False , ) -> None : self . streamer = StreamViewer ( window_name = \"VLC\" ) self . index = 0 self . _curr_row = None self . exit = False self . delay = 0 self . play = False self . show_pred = show_pred self . show_micro = show_micro self . show_cam = show_cam self . cam_type : str = cam_type self . config : TimingConfig = config self . log : pd . DataFrame = self . _load_log ( log_path ) self . reader : FrameReader = self . _create_reader ( files ) def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () def _load_log ( self , log_path : str ) -> pd . DataFrame : if log_path is None : return None log = pd . read_csv ( log_path , index_col = \"frame\" ) if self . cam_type == \"plt\" : log [ \"plt_x\" ] = 0 log [ \"plt_y\" ] = 0 log [ \"plt_h\" ] = max ( log [ \"cam_y\" ]) + max ( log [ \"cam_h\" ]) log [ \"plt_w\" ] = max ( log [ \"cam_x\" ]) + max ( log [ \"cam_w\" ]) # assert len(log.index) == len(self.reader) self . _curr_row = log . iloc [ self . index ] return log def _init_hotkeys ( self ) -> None : self . streamer . register_hotkey ( HotKey ( \"q\" , self . close , \"close VLC\" )) self . streamer . register_hotkey ( HotKey ( \"d\" , self . next , \"next frame\" )) self . streamer . register_hotkey ( HotKey ( \"a\" , self . prev , \"previous frame\" )) self . streamer . register_hotkey ( HotKey ( \"p\" , self . toggle_play , \"play/pause\" )) self . streamer . register_hotkey ( HotKey ( \"h\" , self . toggle_pred , \"toggle prediction box\" )) self . streamer . register_hotkey ( HotKey ( \"m\" , self . toggle_micro , \"toggle microscope box\" )) self . streamer . register_hotkey ( HotKey ( \"c\" , self . toggle_cam , \"toggle camera box\" )) def print_hotkeys ( self ): print ( \"Hotkeys:\" ) for hotkey in self . streamer . hotkeys : print ( f \" - {hotkey.key} : {hotkey.description}\" ) def _create_window ( self ): self . streamer . open () self . streamer . create_trackbar ( \"delay\" , 0 , 250 , self . set_delay ) self . streamer . create_trackbar ( \"#frame\" , 0 , len ( self . reader ), self . seek ) def _create_reader ( self , files : Files ) -> FrameReader : if files is None : frame_num = len ( self . log . index ) frame_size = ( self . get_attribute ( self . cam_type + \"_h\" ), self . get_attribute ( self . cam_type + \"_w\" ), ) return DummyReader ( frame_num , frame_size ) filenames = [ f for f in files ] reader = FrameReader ( files . root , filenames ) return reader def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . streamer . close () def _get_title ( self ): curr_phase = self . get_attribute ( \"phase\" ) phase_title = f \"Action: {curr_phase}\" cycle_len = self . config . imaging_frame_num + self . config . moving_frame_num cycle_progress = 1 + self . index % cycle_len cycle_title = ( f \"cycle progress [{cycle_progress}/{cycle_len}]: \" + cycle_progress * \"#\" + ( cycle_len - cycle_progress ) * \"_\" ) title = f \"{phase_title} :: {cycle_title}\" return title def get_attribute ( self , col_name : str ): return self . _curr_row [ col_name ] def update_curr_row ( self ): self . _curr_row = self . log . iloc [ self . index ] def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo def seek ( self , pos : int ): self . index = ( pos ) % len ( self . reader ) self . update_curr_row () self . streamer . update ( self . get_photo ()) self . streamer . set_title ( self . _get_title ()) def next ( self , key = None ): self . index = ( self . index + 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def prev ( self , key = None ): self . index = ( self . index - 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def close ( self , key = None ): self . exit = True def set_delay ( self , delay : int ): self . delay = delay def toggle_play ( self , key : str = None ): self . play = not self . play def toggle_pred ( self , key : str = None ): self . show_pred = not self . show_pred def toggle_micro ( self , key : str = None ): self . show_micro = not self . show_micro def toggle_cam ( self , key : str = None ): self . show_cam = not self . show_cam def mainloop ( self ): \"\"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\"\" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ]: x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ) . all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) def draw_center ( self , photo : np . ndarray ): x , y , w , h = self . get_bbox ( \"mic\" ) center = ( x + w // 2 , y + h // 2 ) cv . drawMarker ( photo , center , ( 0 , 0 , 255 ), cv . MARKER_CROSS , 7 , 1 ) def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \"{self.cam_type}_\" + \"{:07d}.png\" with ImageSaver ( folder_path , tqdm_kwargs = { \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ) . replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate {fps} -start_number 0 -i {join_paths(folder_path, img_name_format)} -c:v copy {join_paths(output_dir, 'video.mp4')}\" print ( command ) os . system ( command ) Methods add_cam_box def add_cam_box ( self , photo : numpy . ndarray ) -> None View Source def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) add_micro_box def add_micro_box ( self , photo : numpy . ndarray ) -> None View Source def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) add_pred def add_pred ( self , photo : numpy . ndarray ) -> None View Source def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) close def close ( self , key = None ) View Source def close ( self , key = None ) : self . exit = True draw_box def draw_box ( self , photo : numpy . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 ) -> None View Source def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ). all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) draw_center def draw_center ( self , photo : numpy . ndarray ) View Source def draw_center(self, photo: np.ndarray): x, y, w, h = self.get_bbox(\"mic\") center = (x + w // 2, y + h // 2) cv.drawMarker(photo, center, (0, 0, 255), cv.MARKER_CROSS, 7, 1) draw_marker def draw_marker ( self , photo : numpy . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = 0 , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 ) -> None View Source def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) get_attribute def get_attribute ( self , col_name : str ) View Source def get_attribute ( self , col_name : str ) : return self . _curr_row [ col_name ] get_bbox def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ] View Source def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ] : x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) get_photo def get_photo ( self ) -> numpy . ndarray View Source def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo initialize def initialize ( self ) -> None Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. View Source def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () mainloop def mainloop ( self ) Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the exit flag is set to True (by self.close() (called by an hotkey)). It checks the play flag to determine if the player should continue playing or pause. The delay variable is used to control the delay between each iteration of the loop and is set to 0 to pause. View Source def mainloop ( self ) : \" \"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\" \" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) make_vid def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None View Source def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate {fps} -start_number 0 -i {join_paths(folder_path, img_name_format)} -c:v copy {join_paths(output_dir, ' video . mp4 ')}\" print ( command ) os . system ( command ) next def next ( self , key = None ) View Source def next(self, key=None): self.index = (self.index + 1) % len(self.reader) self.streamer.update_trackbar(\"#frame\", self.index) prev def prev ( self , key = None ) View Source def prev(self, key=None): self.index = (self.index - 1) % len(self.reader) self.streamer.update_trackbar(\"#frame\", self.index) print_hotkeys def print_hotkeys ( self ) View Source def print_hotkeys(self): print(\"Hotkeys:\") for hotkey in self.streamer.hotkeys: print(f\" - {hotkey.key} : {hotkey.description}\") save_stream def save_stream ( self , folder_path : str ) -> None View Source def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \"{self.cam_type}_\" + \"{:07d}.png\" with ImageSaver ( folder_path , tqdm_kwargs ={ \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ). replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) seek def seek ( self , pos : int ) View Source def seek(self, pos: int): self.index = (pos) % len(self.reader) self.update_curr_row() self.streamer.update(self.get_photo()) self.streamer.set_title(self._get_title()) set_delay def set_delay ( self , delay : int ) View Source def set_delay(self, delay: int): self.delay = delay toggle_cam def toggle_cam ( self , key : str = None ) View Source def toggle_cam(self, key: str = None): self.show_cam = not self.show_cam toggle_micro def toggle_micro ( self , key : str = None ) View Source def toggle_micro(self, key: str = None): self.show_micro = not self.show_micro toggle_play def toggle_play ( self , key : str = None ) View Source def toggle_play(self, key: str = None): self.play = not self.play toggle_pred def toggle_pred ( self , key : str = None ) View Source def toggle_pred(self, key: str = None): self.show_pred = not self.show_pred update_curr_row def update_curr_row ( self ) View Source def update_curr_row(self): self._curr_row = self.log.iloc[self.index]","title":"Vlc"},{"location":"reference/wtracker/eval/vlc/#module-wtrackerevalvlc","text":"View Source import pandas as pd import numpy as np from math import ceil , floor import os import cv2 as cv from typing import Callable from dataclasses import dataclass , field import matplotlib matplotlib . use ( \"QTAgg\" ) from wtracker.utils.path_utils import Files , create_directory , join_paths from wtracker.utils.io_utils import ImageSaver from wtracker.utils.frame_reader import FrameReader , DummyReader from wtracker.sim.config import TimingConfig @dataclass class HotKey : \"\"\" Represents a hotkey that can be used to trigger a specific function. Attributes: key (str): The key for the hotkey. func (Callable[[str], None]): The function to be called when the hotkey is triggered. description (str): The description of the hotkey (optional). \"\"\" key : str func : Callable [[ str ], None ] description : str = field ( default = \"\" ) def __post_init__ ( self ): self . key = self . key . lower () class StreamViewer : \"\"\" A class for viewing and interacting with photos and video streams. Args: window_name (str, optional): The name of the window. Example: with StreamViewer() as streamer: streamer.imshow(image) streamer.waitKey() \"\"\" def __init__ ( self , window_name : str = \"streamer\" ) -> None : self . window_name = window_name self . window = None self . hotkeys : list [ HotKey ] = [] self . register_hotkey ( HotKey ( \"q\" , self . close , \"close the window\" )) def register_hotkey ( self , hotkey : HotKey ): \"\"\" Registers a hotkey. Args: hotkey (HotKey): The hotkey to register. \"\"\" self . hotkeys . append ( hotkey ) def create_trackbar ( self , name : str , val : int , maxval : int , onChange = lambda x : x ): \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv . createTrackbar ( name , self . window_name , val , maxval , onChange ) def update_trackbar ( self , name : str , val : int ): \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv . setTrackbarPos ( name , self . window_name , val ) def set_title ( self , title : str ): \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv . setWindowTitle ( self . window_name , title ) def __enter__ ( self ): \"\"\" Enters the context manager. \"\"\" self . open () return self def __exit__ ( self , exc_type , exc_value , traceback ): \"\"\" Exits the context manager. \"\"\" self . close () def __del__ ( self ): \"\"\" Destructor method. \"\"\" self . close () def update ( self , image : np . ndarray , wait : int = 1 ): \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv . imshow ( self . window_name , image ) self . waitKey ( wait ) def waitKey ( self , timeout : int = 0 ): \"\"\" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\"\" key = cv . waitKey ( timeout ) if key <= 0 : return key key = chr ( key ) . lower () for hotkey in self . hotkeys : if key in hotkey . key : hotkey . func ( key ) return key def open ( self ): \"\"\" Opens the window. \"\"\" self . close () self . window = cv . namedWindow ( self . window_name , flags = cv . WINDOW_GUI_EXPANDED ) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) self . set_title ( self . window_name ) def close ( self , key : str = \"q\" ): \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self . window is not None : cv . destroyWindow ( self . window_name ) self . window = None def imshow ( self , image : np . ndarray , title : str = \"image\" ): \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self . update ( image , wait = 0 ) self . set_title ( title ) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) class VLC : \"\"\" The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Args: files (Files): The files to read frames from. If None, the video player will present the log data (simulation) on a white background. config (TimingConfig): The timing configuration of the system. log_path (str): The path to the log file. cam_type (str): The type of camera. This should match the prefix of the corresponding columns in the log file. show_pred (bool, optional): Whether to show the prediction box. show_micro (bool, optional): Whether to show the microscope box. show_cam (bool, optional): Whether to show the camera box. \"\"\" def __init__ ( self , files : Files | None , config : TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False , ) -> None : self . streamer = StreamViewer ( window_name = \"VLC\" ) self . index = 0 self . _curr_row = None self . exit = False self . delay = 0 self . play = False self . show_pred = show_pred self . show_micro = show_micro self . show_cam = show_cam self . cam_type : str = cam_type self . config : TimingConfig = config self . log : pd . DataFrame = self . _load_log ( log_path ) self . reader : FrameReader = self . _create_reader ( files ) def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () def _load_log ( self , log_path : str ) -> pd . DataFrame : if log_path is None : return None log = pd . read_csv ( log_path , index_col = \"frame\" ) if self . cam_type == \"plt\" : log [ \"plt_x\" ] = 0 log [ \"plt_y\" ] = 0 log [ \"plt_h\" ] = max ( log [ \"cam_y\" ]) + max ( log [ \"cam_h\" ]) log [ \"plt_w\" ] = max ( log [ \"cam_x\" ]) + max ( log [ \"cam_w\" ]) # assert len(log.index) == len(self.reader) self . _curr_row = log . iloc [ self . index ] return log def _init_hotkeys ( self ) -> None : self . streamer . register_hotkey ( HotKey ( \"q\" , self . close , \"close VLC\" )) self . streamer . register_hotkey ( HotKey ( \"d\" , self . next , \"next frame\" )) self . streamer . register_hotkey ( HotKey ( \"a\" , self . prev , \"previous frame\" )) self . streamer . register_hotkey ( HotKey ( \"p\" , self . toggle_play , \"play/pause\" )) self . streamer . register_hotkey ( HotKey ( \"h\" , self . toggle_pred , \"toggle prediction box\" )) self . streamer . register_hotkey ( HotKey ( \"m\" , self . toggle_micro , \"toggle microscope box\" )) self . streamer . register_hotkey ( HotKey ( \"c\" , self . toggle_cam , \"toggle camera box\" )) def print_hotkeys ( self ): print ( \"Hotkeys:\" ) for hotkey in self . streamer . hotkeys : print ( f \" - { hotkey . key } : { hotkey . description } \" ) def _create_window ( self ): self . streamer . open () self . streamer . create_trackbar ( \"delay\" , 0 , 250 , self . set_delay ) self . streamer . create_trackbar ( \"#frame\" , 0 , len ( self . reader ), self . seek ) def _create_reader ( self , files : Files ) -> FrameReader : if files is None : frame_num = len ( self . log . index ) frame_size = ( self . get_attribute ( self . cam_type + \"_h\" ), self . get_attribute ( self . cam_type + \"_w\" ), ) return DummyReader ( frame_num , frame_size ) filenames = [ f for f in files ] reader = FrameReader ( files . root , filenames ) return reader def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . streamer . close () def _get_title ( self ): curr_phase = self . get_attribute ( \"phase\" ) phase_title = f \"Action: { curr_phase } \" cycle_len = self . config . imaging_frame_num + self . config . moving_frame_num cycle_progress = 1 + self . index % cycle_len cycle_title = ( f \"cycle progress [ { cycle_progress } / { cycle_len } ]: \" + cycle_progress * \"#\" + ( cycle_len - cycle_progress ) * \"_\" ) title = f \" { phase_title } :: { cycle_title } \" return title def get_attribute ( self , col_name : str ): return self . _curr_row [ col_name ] def update_curr_row ( self ): self . _curr_row = self . log . iloc [ self . index ] def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo def seek ( self , pos : int ): self . index = ( pos ) % len ( self . reader ) self . update_curr_row () self . streamer . update ( self . get_photo ()) self . streamer . set_title ( self . _get_title ()) def next ( self , key = None ): self . index = ( self . index + 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def prev ( self , key = None ): self . index = ( self . index - 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def close ( self , key = None ): self . exit = True def set_delay ( self , delay : int ): self . delay = delay def toggle_play ( self , key : str = None ): self . play = not self . play def toggle_pred ( self , key : str = None ): self . show_pred = not self . show_pred def toggle_micro ( self , key : str = None ): self . show_micro = not self . show_micro def toggle_cam ( self , key : str = None ): self . show_cam = not self . show_cam def mainloop ( self ): \"\"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\"\" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ]: x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ) . all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) def draw_center ( self , photo : np . ndarray ): x , y , w , h = self . get_bbox ( \"mic\" ) center = ( x + w // 2 , y + h // 2 ) cv . drawMarker ( photo , center , ( 0 , 0 , 255 ), cv . MARKER_CROSS , 7 , 1 ) def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \" { self . cam_type } _\" + \" {:07d} .png\" with ImageSaver ( folder_path , tqdm_kwargs = { \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ) . replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate { fps } -start_number 0 -i { join_paths ( folder_path , img_name_format ) } -c:v copy { join_paths ( output_dir , 'video.mp4' ) } \" print ( command ) os . system ( command )","title":"Module wtracker.eval.vlc"},{"location":"reference/wtracker/eval/vlc/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/eval/vlc/#hotkey","text":"class HotKey ( key : str , func : Callable [[ str ], NoneType ], description : str = '' ) Represents a hotkey that can be used to trigger a specific function.","title":"HotKey"},{"location":"reference/wtracker/eval/vlc/#attributes","text":"Name Type Description Default key str The key for the hotkey. None func Callable[[str], None] The function to be called when the hotkey is triggered. None description str The description of the hotkey (optional). None View Source @dataclass class HotKey : \"\"\" Represents a hotkey that can be used to trigger a specific function. Attributes: key (str): The key for the hotkey. func (Callable[[str], None]): The function to be called when the hotkey is triggered. description (str): The description of the hotkey (optional). \"\"\" key : str func : Callable [ [str ] , None ] description : str = field ( default = \"\" ) def __post_init__ ( self ) : self . key = self . key . lower ()","title":"Attributes"},{"location":"reference/wtracker/eval/vlc/#class-variables","text":"description","title":"Class variables"},{"location":"reference/wtracker/eval/vlc/#streamviewer","text":"class StreamViewer ( window_name : str = 'streamer' ) A class for viewing and interacting with photos and video streams.","title":"StreamViewer"},{"location":"reference/wtracker/eval/vlc/#attributes_1","text":"Name Type Description Default window_name str The name of the window. None View Source class StreamViewer : \"\"\" A class for viewing and interacting with photos and video streams. Args: window_name (str, optional): The name of the window. Example: with StreamViewer() as streamer: streamer.imshow(image) streamer.waitKey() \"\"\" def __init__ ( self , window_name : str = \"streamer\" ) -> None : self . window_name = window_name self . window = None self . hotkeys : list [ HotKey ] = [] self . register_hotkey ( HotKey ( \"q\" , self . close , \"close the window\" )) def register_hotkey ( self , hotkey : HotKey ) : \"\"\" Registers a hotkey. Args: hotkey (HotKey): The hotkey to register. \"\"\" self . hotkeys . append ( hotkey ) def create_trackbar ( self , name : str , val : int , maxval : int , onChange = lambda x : x ) : \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv . createTrackbar ( name , self . window_name , val , maxval , onChange ) def update_trackbar ( self , name : str , val : int ) : \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv . setTrackbarPos ( name , self . window_name , val ) def set_title ( self , title : str ) : \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv . setWindowTitle ( self . window_name , title ) def __enter__ ( self ) : \"\"\" Enters the context manager. \"\"\" self . open () return self def __exit__ ( self , exc_type , exc_value , traceback ) : \"\"\" Exits the context manager. \"\"\" self . close () def __del__ ( self ) : \"\"\" Destructor method. \"\"\" self . close () def update ( self , image : np . ndarray , wait : int = 1 ) : \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv . imshow ( self . window_name , image ) self . waitKey ( wait ) def waitKey ( self , timeout : int = 0 ) : \"\"\" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\"\" key = cv . waitKey ( timeout ) if key <= 0 : return key key = chr ( key ). lower () for hotkey in self . hotkeys : if key in hotkey . key : hotkey . func ( key ) return key def open ( self ) : \"\"\" Opens the window. \"\"\" self . close () self . window = cv . namedWindow ( self . window_name , flags = cv . WINDOW_GUI_EXPANDED ) # cv . setWindowProperty ( self . window_name , cv . WND_PROP_TOPMOST , 1 ) self . set_title ( self . window_name ) def close ( self , key : str = \"q\" ) : \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self . window is not None : cv . destroyWindow ( self . window_name ) self . window = None def imshow ( self , image : np . ndarray , title : str = \"image\" ) : \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self . update ( image , wait = 0 ) self . set_title ( title ) # cv . setWindowProperty ( self . window_name , cv . WND_PROP_TOPMOST , 1 )","title":"Attributes"},{"location":"reference/wtracker/eval/vlc/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/eval/vlc/#close","text":"def close ( self , key : str = 'q' ) Closes the window. Parameters: Name Type Description Default key str The key to close the window. None View Source def close(self, key: str = \"q\"): \"\"\" Closes the window. Args: key (str): The key to close the window. \"\"\" if self.window is not None: cv.destroyWindow(self.window_name) self.window = None","title":"close"},{"location":"reference/wtracker/eval/vlc/#create_trackbar","text":"def create_trackbar ( self , name : str , val : int , maxval : int , onChange =< function StreamViewer .< lambda > at 0x7f9303198160 > ) Creates a trackbar. Parameters: Name Type Description Default name str The name of the trackbar. None val int The initial value of the trackbar. None maxval int The maximum value of the trackbar. None onChange function The function to call when the trackbar value changes. None View Source def create_trackbar(self, name: str, val: int, maxval: int, onChange=lambda x: x): \"\"\" Creates a trackbar. Args: name (str): The name of the trackbar. val (int): The initial value of the trackbar. maxval (int): The maximum value of the trackbar. onChange (function): The function to call when the trackbar value changes. \"\"\" cv.createTrackbar(name, self.window_name, val, maxval, onChange)","title":"create_trackbar"},{"location":"reference/wtracker/eval/vlc/#imshow","text":"def imshow ( self , image : numpy . ndarray , title : str = 'image' ) Displays an image in the window. Parameters: Name Type Description Default image np.ndarray The image to display. None title str The title of the image. None View Source def imshow(self, image: np.ndarray, title: str = \"image\"): \"\"\" Displays an image in the window. Args: image (np.ndarray): The image to display. title (str): The title of the image. \"\"\" self.update(image, wait=0) self.set_title(title) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1)","title":"imshow"},{"location":"reference/wtracker/eval/vlc/#open","text":"def open ( self ) Opens the window. View Source def open(self): \"\"\" Opens the window. \"\"\" self.close() self.window = cv.namedWindow(self.window_name, flags=cv.WINDOW_GUI_EXPANDED) # cv.setWindowProperty(self.window_name, cv.WND_PROP_TOPMOST, 1) self.set_title(self.window_name)","title":"open"},{"location":"reference/wtracker/eval/vlc/#register_hotkey","text":"def register_hotkey ( self , hotkey : wtracker . eval . vlc . HotKey ) Registers a hotkey. Parameters: Name Type Description Default hotkey HotKey The hotkey to register. None View Source def register_hotkey ( self , hotkey: HotKey ) : \"\"\" Registers a hotkey . Args: hotkey ( HotKey ) : The hotkey to register . \"\"\" self . hotkeys . append ( hotkey )","title":"register_hotkey"},{"location":"reference/wtracker/eval/vlc/#set_title","text":"def set_title ( self , title : str ) Sets the title of the window. Parameters: Name Type Description Default title str The new title of the window. None View Source def set_title(self, title: str): \"\"\" Sets the title of the window. Args: title (str): The new title of the window. \"\"\" cv.setWindowTitle(self.window_name, title)","title":"set_title"},{"location":"reference/wtracker/eval/vlc/#update","text":"def update ( self , image : numpy . ndarray , wait : int = 1 ) Updates the window with a new image. Parameters: Name Type Description Default image np.ndarray The image to display. None wait int The delay in milliseconds before updating the window. None View Source def update(self, image: np.ndarray, wait: int = 1): \"\"\" Updates the window with a new image. Args: image (np.ndarray): The image to display. wait (int): The delay in milliseconds before updating the window. \"\"\" cv.imshow(self.window_name, image) self.waitKey(wait)","title":"update"},{"location":"reference/wtracker/eval/vlc/#update_trackbar","text":"def update_trackbar ( self , name : str , val : int ) Updates the value of a trackbar. Parameters: Name Type Description Default name str The name of the trackbar. None val int The new value of the trackbar. None View Source def update_trackbar(self, name: str, val: int): \"\"\" Updates the value of a trackbar. Args: name (str): The name of the trackbar. val (int): The new value of the trackbar. \"\"\" cv.setTrackbarPos(name, self.window_name, val)","title":"update_trackbar"},{"location":"reference/wtracker/eval/vlc/#waitkey","text":"def waitKey ( self , timeout : int = 0 ) Waits for a key press. This Function also triggers the hotkeys. Parameters: Name Type Description Default timeout int The timeout in milliseconds. None Returns: Type Description str The key that was pressed. View Source def waitKey ( self , timeout : int = 0 ) : \"\" \" Waits for a key press. This Function also triggers the hotkeys. Args: timeout (int): The timeout in milliseconds. Returns: str: The key that was pressed. \"\" \" key = cv.waitKey(timeout) if key <= 0: return key key = chr(key).lower() for hotkey in self.hotkeys: if key in hotkey.key: hotkey.func(key) return key","title":"waitKey"},{"location":"reference/wtracker/eval/vlc/#vlc","text":"class VLC ( files : wtracker . utils . path_utils . Files | None , config : wtracker . sim . config . TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False ) The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well.","title":"VLC"},{"location":"reference/wtracker/eval/vlc/#attributes_2","text":"Name Type Description Default files Files The files to read frames from. If None, the video player will present the log data (simulation) on a white background. None config TimingConfig The timing configuration of the system. None log_path str The path to the log file. None cam_type str The type of camera. This should match the prefix of the corresponding columns in the log file. None show_pred bool Whether to show the prediction box. None show_micro bool Whether to show the microscope box. None show_cam bool Whether to show the camera box. None View Source class VLC : \"\"\" The VLC class represents a video player for visualizing Simulations. This class supports saving Simulation frames (with or without boxes overlay) as well. Args: files (Files): The files to read frames from. If None, the video player will present the log data (simulation) on a white background. config (TimingConfig): The timing configuration of the system. log_path (str): The path to the log file. cam_type (str): The type of camera. This should match the prefix of the corresponding columns in the log file. show_pred (bool, optional): Whether to show the prediction box. show_micro (bool, optional): Whether to show the microscope box. show_cam (bool, optional): Whether to show the camera box. \"\"\" def __init__ ( self , files : Files | None , config : TimingConfig , log_path : str , cam_type : str , show_pred : bool = True , show_micro : bool = False , show_cam : bool = False , ) -> None : self . streamer = StreamViewer ( window_name = \"VLC\" ) self . index = 0 self . _curr_row = None self . exit = False self . delay = 0 self . play = False self . show_pred = show_pred self . show_micro = show_micro self . show_cam = show_cam self . cam_type : str = cam_type self . config : TimingConfig = config self . log : pd . DataFrame = self . _load_log ( log_path ) self . reader : FrameReader = self . _create_reader ( files ) def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys () def _load_log ( self , log_path : str ) -> pd . DataFrame : if log_path is None : return None log = pd . read_csv ( log_path , index_col = \"frame\" ) if self . cam_type == \"plt\" : log [ \"plt_x\" ] = 0 log [ \"plt_y\" ] = 0 log [ \"plt_h\" ] = max ( log [ \"cam_y\" ]) + max ( log [ \"cam_h\" ]) log [ \"plt_w\" ] = max ( log [ \"cam_x\" ]) + max ( log [ \"cam_w\" ]) # assert len(log.index) == len(self.reader) self . _curr_row = log . iloc [ self . index ] return log def _init_hotkeys ( self ) -> None : self . streamer . register_hotkey ( HotKey ( \"q\" , self . close , \"close VLC\" )) self . streamer . register_hotkey ( HotKey ( \"d\" , self . next , \"next frame\" )) self . streamer . register_hotkey ( HotKey ( \"a\" , self . prev , \"previous frame\" )) self . streamer . register_hotkey ( HotKey ( \"p\" , self . toggle_play , \"play/pause\" )) self . streamer . register_hotkey ( HotKey ( \"h\" , self . toggle_pred , \"toggle prediction box\" )) self . streamer . register_hotkey ( HotKey ( \"m\" , self . toggle_micro , \"toggle microscope box\" )) self . streamer . register_hotkey ( HotKey ( \"c\" , self . toggle_cam , \"toggle camera box\" )) def print_hotkeys ( self ): print ( \"Hotkeys:\" ) for hotkey in self . streamer . hotkeys : print ( f \" - {hotkey.key} : {hotkey.description}\" ) def _create_window ( self ): self . streamer . open () self . streamer . create_trackbar ( \"delay\" , 0 , 250 , self . set_delay ) self . streamer . create_trackbar ( \"#frame\" , 0 , len ( self . reader ), self . seek ) def _create_reader ( self , files : Files ) -> FrameReader : if files is None : frame_num = len ( self . log . index ) frame_size = ( self . get_attribute ( self . cam_type + \"_h\" ), self . get_attribute ( self . cam_type + \"_w\" ), ) return DummyReader ( frame_num , frame_size ) filenames = [ f for f in files ] reader = FrameReader ( files . root , filenames ) return reader def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . streamer . close () def _get_title ( self ): curr_phase = self . get_attribute ( \"phase\" ) phase_title = f \"Action: {curr_phase}\" cycle_len = self . config . imaging_frame_num + self . config . moving_frame_num cycle_progress = 1 + self . index % cycle_len cycle_title = ( f \"cycle progress [{cycle_progress}/{cycle_len}]: \" + cycle_progress * \"#\" + ( cycle_len - cycle_progress ) * \"_\" ) title = f \"{phase_title} :: {cycle_title}\" return title def get_attribute ( self , col_name : str ): return self . _curr_row [ col_name ] def update_curr_row ( self ): self . _curr_row = self . log . iloc [ self . index ] def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo def seek ( self , pos : int ): self . index = ( pos ) % len ( self . reader ) self . update_curr_row () self . streamer . update ( self . get_photo ()) self . streamer . set_title ( self . _get_title ()) def next ( self , key = None ): self . index = ( self . index + 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def prev ( self , key = None ): self . index = ( self . index - 1 ) % len ( self . reader ) self . streamer . update_trackbar ( \"#frame\" , self . index ) def close ( self , key = None ): self . exit = True def set_delay ( self , delay : int ): self . delay = delay def toggle_play ( self , key : str = None ): self . play = not self . play def toggle_pred ( self , key : str = None ): self . show_pred = not self . show_pred def toggle_micro ( self , key : str = None ): self . show_micro = not self . show_micro def toggle_cam ( self , key : str = None ): self . show_cam = not self . show_cam def mainloop ( self ): \"\"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\"\" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay ) def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ]: x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h ) def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ) . all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width ) def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness ) def draw_center ( self , photo : np . ndarray ): x , y , w , h = self . get_bbox ( \"mic\" ) center = ( x + w // 2 , y + h // 2 ) cv . drawMarker ( photo , center , ( 0 , 0 , 255 ), cv . MARKER_CROSS , 7 , 1 ) def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 ) def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 ) def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 ) def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \"{self.cam_type}_\" + \"{:07d}.png\" with ImageSaver ( folder_path , tqdm_kwargs = { \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ) . replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path ) def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate {fps} -start_number 0 -i {join_paths(folder_path, img_name_format)} -c:v copy {join_paths(output_dir, 'video.mp4')}\" print ( command ) os . system ( command )","title":"Attributes"},{"location":"reference/wtracker/eval/vlc/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/eval/vlc/#add_cam_box","text":"def add_cam_box ( self , photo : numpy . ndarray ) -> None View Source def add_cam_box ( self , photo : np . ndarray ) -> None : cam_bbox = self . get_bbox ( \"cam\" ) self . draw_box ( photo , cam_bbox , ( 128 , 0 , 0 ), 2 )","title":"add_cam_box"},{"location":"reference/wtracker/eval/vlc/#add_micro_box","text":"def add_micro_box ( self , photo : numpy . ndarray ) -> None View Source def add_micro_box ( self , photo : np . ndarray ) -> None : mic_bbox = self . get_bbox ( \"mic\" ) self . draw_box ( photo , mic_bbox , ( 0 , 0 , 255 ), 1 )","title":"add_micro_box"},{"location":"reference/wtracker/eval/vlc/#add_pred","text":"def add_pred ( self , photo : numpy . ndarray ) -> None View Source def add_pred ( self , photo : np . ndarray ) -> None : worm_bbox = self . get_bbox ( \"wrm\" ) self . draw_box ( photo , worm_bbox , ( 0 , 0 , 0 ), 1 )","title":"add_pred"},{"location":"reference/wtracker/eval/vlc/#close_1","text":"def close ( self , key = None ) View Source def close ( self , key = None ) : self . exit = True","title":"close"},{"location":"reference/wtracker/eval/vlc/#draw_box","text":"def draw_box ( self , photo : numpy . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 ) -> None View Source def draw_box ( self , photo : np . ndarray , bbox : tuple [ float , float , float , float ], color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), width : int = 1 , ) -> None : if not np . isfinite ( bbox ). all (): return x , y , w , h = self . get_bbox ( self . cam_type ) pred_x , pred_y , pred_w , pred_h = bbox pred_x = floor ( pred_x - x ) pred_y = floor ( pred_y - y ) pred_w = ceil ( pred_w ) pred_h = ceil ( pred_h ) cv . rectangle ( photo , ( pred_x , pred_y ), ( pred_x + pred_w , pred_y + pred_h ), color , width )","title":"draw_box"},{"location":"reference/wtracker/eval/vlc/#draw_center","text":"def draw_center ( self , photo : numpy . ndarray ) View Source def draw_center(self, photo: np.ndarray): x, y, w, h = self.get_bbox(\"mic\") center = (x + w // 2, y + h // 2) cv.drawMarker(photo, center, (0, 0, 255), cv.MARKER_CROSS, 7, 1)","title":"draw_center"},{"location":"reference/wtracker/eval/vlc/#draw_marker","text":"def draw_marker ( self , photo : numpy . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = 0 , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 ) -> None View Source def draw_marker ( self , photo : np . ndarray , x : float , y : float , marker_size : int = 5 , marker_type : int = cv . MARKER_CROSS , color : tuple [ int , int , int ] = ( 0 , 0 , 255 ), thickness : int = 1 , ) -> None : frame_x , frame_y , frame_w , frame_h = self . get_bbox ( self . cam_type ) x , y = floor ( x - frame_x ), floor ( y - frame_y ) cv . drawMarker ( photo , ( x , y ), color , marker_type , marker_size , thickness )","title":"draw_marker"},{"location":"reference/wtracker/eval/vlc/#get_attribute","text":"def get_attribute ( self , col_name : str ) View Source def get_attribute ( self , col_name : str ) : return self . _curr_row [ col_name ]","title":"get_attribute"},{"location":"reference/wtracker/eval/vlc/#get_bbox","text":"def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ] View Source def get_bbox ( self , prefix : str ) -> tuple [ float , float , float , float ] : x = self . get_attribute ( prefix + \"_x\" ) y = self . get_attribute ( prefix + \"_y\" ) w = self . get_attribute ( prefix + \"_w\" ) h = self . get_attribute ( prefix + \"_h\" ) return ( x , y , w , h )","title":"get_bbox"},{"location":"reference/wtracker/eval/vlc/#get_photo","text":"def get_photo ( self ) -> numpy . ndarray View Source def get_photo ( self ) -> np . ndarray : photo = self . reader [ self . index ] if self . show_pred : self . add_pred ( photo ) if self . show_micro : self . add_micro_box ( photo ) if self . show_cam : self . add_cam_box ( photo ) self . draw_center ( photo ) return photo","title":"get_photo"},{"location":"reference/wtracker/eval/vlc/#initialize","text":"def initialize ( self ) -> None Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. View Source def initialize ( self ) -> None : \"\"\" Initializes the VLC player by setting up hotkeys, opening the streamer, creating a window, and updating the trackbar. \"\"\" self . _init_hotkeys () self . _create_window () self . streamer . update_trackbar ( \"delay\" , round ( self . config . ms_per_frame )) self . print_hotkeys ()","title":"initialize"},{"location":"reference/wtracker/eval/vlc/#mainloop","text":"def mainloop ( self ) Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the exit flag is set to True (by self.close() (called by an hotkey)). It checks the play flag to determine if the player should continue playing or pause. The delay variable is used to control the delay between each iteration of the loop and is set to 0 to pause. View Source def mainloop ( self ) : \" \"\" Main loop for the VLC player. This method makes the VLC player interactive by allowing the user to control the player using hotkeys. This method continuously runs the VLC player until the `exit` flag is set to True (by self.close() (called by an hotkey)). It checks the `play` flag to determine if the player should continue playing or pause. The `delay` variable is used to control the delay between each iteration of the loop and is set to 0 to pause. \"\" \" with self as vlc : while not self . exit : delay = 0 if not self . play else self . delay if self . play : self . next () vlc . streamer . waitKey ( delay )","title":"mainloop"},{"location":"reference/wtracker/eval/vlc/#make_vid","text":"def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None View Source def make_vid ( self , folder_path : str , img_name_format : str , output_dir : str ) -> None : fps = self . config . frames_per_sec command = f \"ffmpeg -framerate {fps} -start_number 0 -i {join_paths(folder_path, img_name_format)} -c:v copy {join_paths(output_dir, ' video . mp4 ')}\" print ( command ) os . system ( command )","title":"make_vid"},{"location":"reference/wtracker/eval/vlc/#next","text":"def next ( self , key = None ) View Source def next(self, key=None): self.index = (self.index + 1) % len(self.reader) self.streamer.update_trackbar(\"#frame\", self.index)","title":"next"},{"location":"reference/wtracker/eval/vlc/#prev","text":"def prev ( self , key = None ) View Source def prev(self, key=None): self.index = (self.index - 1) % len(self.reader) self.streamer.update_trackbar(\"#frame\", self.index)","title":"prev"},{"location":"reference/wtracker/eval/vlc/#print_hotkeys","text":"def print_hotkeys ( self ) View Source def print_hotkeys(self): print(\"Hotkeys:\") for hotkey in self.streamer.hotkeys: print(f\" - {hotkey.key} : {hotkey.description}\")","title":"print_hotkeys"},{"location":"reference/wtracker/eval/vlc/#save_stream","text":"def save_stream ( self , folder_path : str ) -> None View Source def save_stream ( self , folder_path : str , ) -> None : create_directory ( folder_path ) filename = f \"{self.cam_type}_\" + \"{:07d}.png\" with ImageSaver ( folder_path , tqdm_kwargs ={ \"total\" : len ( self . log . index )}) as worker : for index in range ( len ( self . log . index )): self . index = index self . update_curr_row () path = join_paths ( folder_path , filename . format ( index )) img = self . get_photo () worker . schedule_save ( img , path ) image_format = filename . replace ( \"{:\" , \"%\" ). replace ( \"}\" , \"\" ) self . make_vid ( folder_path , image_format , folder_path )","title":"save_stream"},{"location":"reference/wtracker/eval/vlc/#seek","text":"def seek ( self , pos : int ) View Source def seek(self, pos: int): self.index = (pos) % len(self.reader) self.update_curr_row() self.streamer.update(self.get_photo()) self.streamer.set_title(self._get_title())","title":"seek"},{"location":"reference/wtracker/eval/vlc/#set_delay","text":"def set_delay ( self , delay : int ) View Source def set_delay(self, delay: int): self.delay = delay","title":"set_delay"},{"location":"reference/wtracker/eval/vlc/#toggle_cam","text":"def toggle_cam ( self , key : str = None ) View Source def toggle_cam(self, key: str = None): self.show_cam = not self.show_cam","title":"toggle_cam"},{"location":"reference/wtracker/eval/vlc/#toggle_micro","text":"def toggle_micro ( self , key : str = None ) View Source def toggle_micro(self, key: str = None): self.show_micro = not self.show_micro","title":"toggle_micro"},{"location":"reference/wtracker/eval/vlc/#toggle_play","text":"def toggle_play ( self , key : str = None ) View Source def toggle_play(self, key: str = None): self.play = not self.play","title":"toggle_play"},{"location":"reference/wtracker/eval/vlc/#toggle_pred","text":"def toggle_pred ( self , key : str = None ) View Source def toggle_pred(self, key: str = None): self.show_pred = not self.show_pred","title":"toggle_pred"},{"location":"reference/wtracker/eval/vlc/#update_curr_row","text":"def update_curr_row ( self ) View Source def update_curr_row(self): self._curr_row = self.log.iloc[self.index]","title":"update_curr_row"},{"location":"reference/wtracker/neural/","text":"Namespace wtracker.neural Sub-modules wtracker.neural.config wtracker.neural.dataset wtracker.neural.mlp wtracker.neural.train_results wtracker.neural.training","title":"Index"},{"location":"reference/wtracker/neural/#namespace-wtrackerneural","text":"","title":"Namespace wtracker.neural"},{"location":"reference/wtracker/neural/#sub-modules","text":"wtracker.neural.config wtracker.neural.dataset wtracker.neural.mlp wtracker.neural.train_results wtracker.neural.training","title":"Sub-modules"},{"location":"reference/wtracker/neural/config/","text":"Module wtracker.neural.config View Source from __future__ import annotations import torch from torch import nn from torch.optim import Optimizer from torch.utils.data import Dataset , DataLoader , random_split from dataclasses import dataclass , field from wtracker.utils.config_base import ConfigBase @dataclass class DatasetConfig ( ConfigBase ): input_frames : list [ int ] # The frames to use as input for the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). pred_frames : list [ int ] # The frames to predict. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). log_path : list [ str ] # The path to the log file containing the worm head predictions (by YOLO). def __post_init__ ( self ) -> None : if self . input_frames [ 0 ] != 0 : print ( \"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters.\" ) @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path ) OPTIMIZERS = { \"adam\" : torch . optim . Adam , \"sgd\" : torch . optim . SGD , \"rmsprop\" : torch . optim . RMSprop , \"adamw\" : torch . optim . AdamW , } LOSSES = { \"mse\" : nn . MSELoss , \"l1\" : nn . L1Loss , } @dataclass class TrainConfig ( ConfigBase ): # general parameters seed : int = field ( default = 42 , kw_only = True ) # Random seed for reproducibility dataset : DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is) # trainer parameters model : nn . Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk) loss_fn : str # The loss function to use, can be any of the keys in the LOSSES dict optimizer : str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict device : str = \"cuda\" # 'cuda' for training on GPU or 'cpu' otherwise log : bool = False # Whether to log and save the training process with tensorboard # training parameters num_epochs : int = 100 # Number of times to iterate over the dataset checkpoints : str = None # Path to save model checkpoints, influding the checkpoint name. early_stopping : int = None # Number of epochs to wait before stopping training if no improvement was made print_every : int = 5 # How often (#epochs) to print training progress # optimizer parameters learning_rate : float = 0.001 # Learning rate for the optimizer weight_decay : float = ( 1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger) ) # dataloader parameters batch_size : int = 256 # Number of samples in each batch shuffle : bool = True # Whether to shuffle the dataset at the beginning of each epoch num_workers : int = 0 # Number of subprocesses to use for data loading train_test_split : float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing dl_train : DataLoader = field ( init = False ) dl_test : DataLoader = field ( init = False ) @dataclass class IOConfig ( ConfigBase ): \"\"\" Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. \"\"\" input_frames : list [ int ] pred_frames : list [ int ] in_dim : int = field ( init = False ) out_dim : int = field ( init = False ) def __post_init__ ( self ): if 0 not in self . input_frames : print ( \"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters.\" ) self . in_dim = len ( self . input_frames ) * 4 self . out_dim = len ( self . pred_frames ) * 2 @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames ) Variables LOSSES OPTIMIZERS Classes DatasetConfig class DatasetConfig ( input_frames : 'list[int]' , pred_frames : 'list[int]' , log_path : 'list[str]' ) DatasetConfig(input_frames: 'list[int]', pred_frames: 'list[int]', log_path: 'list[str]') View Source @dataclass class DatasetConfig ( ConfigBase ) : input_frames : list [ int ] # The frames to use as input for the network . The frames are in the format of the number of frames before ( negative ) or after ( positive ) the prediction frame ( 0 ). pred_frames : list [ int ] # The frames to predict . The frames are in the format of the number of frames before ( negative ) or after ( positive ) the prediction frame ( 0 ). log_path : list [ str ] # The path to the log file containing the worm head predictions ( by YOLO ). def __post_init__ ( self ) -> None : if self . input_frames [ 0 ] != 0 : print ( \"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters.\" ) @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Static methods from_io_config def from_io_config ( io : 'IOConfig' , log_path : 'str' ) -> 'DatasetConfig' View Source @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path ) load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) IOConfig class IOConfig ( input_frames : 'list[int]' , pred_frames : 'list[int]' ) Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. View Source @dataclass class IOConfig ( ConfigBase ) : \"\"\" Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. \"\"\" input_frames : list [ int ] pred_frames : list [ int ] in_dim : int = field ( init = False ) out_dim : int = field ( init = False ) def __post_init__ ( self ) : if 0 not in self . input_frames : print ( \"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters.\" ) self . in_dim = len ( self . input_frames ) * 4 self . out_dim = len ( self . pred_frames ) * 2 @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Static methods from_datasetConfig def from_datasetConfig ( config : 'DatasetConfig' ) -> 'IOConfig' View Source @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames ) load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) TrainConfig class TrainConfig ( dataset : 'DatasetConfig' , model : 'nn.Module | str' , loss_fn : 'str' , optimizer : 'str' , device : 'str' = 'cuda' , log : 'bool' = False , num_epochs : 'int' = 100 , checkpoints : 'str' = None , early_stopping : 'int' = None , print_every : 'int' = 5 , learning_rate : 'float' = 0.001 , weight_decay : 'float' = 1e-05 , batch_size : 'int' = 256 , shuffle : 'bool' = True , num_workers : 'int' = 0 , train_test_split : 'float' = 0.8 , * , seed : 'int' = 42 ) TrainConfig(dataset: 'DatasetConfig', model: 'nn.Module | str', loss_fn: 'str', optimizer: 'str', device: 'str' = 'cuda', log: 'bool' = False, num_epochs: 'int' = 100, checkpoints: 'str' = None, early_stopping: 'int' = None, print_every: 'int' = 5, learning_rate: 'float' = 0.001, weight_decay: 'float' = 1e-05, batch_size: 'int' = 256, shuffle: 'bool' = True, num_workers: 'int' = 0, train_test_split: 'float' = 0.8, *, seed: 'int' = 42) View Source @ dataclass class TrainConfig ( ConfigBase ): # general parameters seed : int = field ( default = 42 , kw_only = True ) # Random seed for reproducibility dataset : DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is) # trainer parameters model : nn . Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk) loss_fn : str # The loss function to use, can be any of the keys in the LOSSES dict optimizer : str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict device : str = \"cuda\" # 'cuda' for training on GPU or 'cpu' otherwise log : bool = False # Whether to log and save the training process with tensorboard # training parameters num_epochs : int = 100 # Number of times to iterate over the dataset checkpoints : str = None # Path to save model checkpoints, influding the checkpoint name. early_stopping : int = None # Number of epochs to wait before stopping training if no improvement was made print_every : int = 5 # How often (#epochs) to print training progress # optimizer parameters learning_rate : float = 0.001 # Learning rate for the optimizer weight_decay : float = ( 1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger) ) # dataloader parameters batch_size : int = 256 # Number of samples in each batch shuffle : bool = True # Whether to shuffle the dataset at the beginning of each epoch num_workers : int = 0 # Number of subprocesses to use for data loading train_test_split : float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing dl_train : DataLoader = field ( init = False ) dl_test : DataLoader = field ( init = False ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables batch_size checkpoints device early_stopping learning_rate log num_epochs num_workers print_every seed shuffle train_test_split weight_decay Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"Config"},{"location":"reference/wtracker/neural/config/#module-wtrackerneuralconfig","text":"View Source from __future__ import annotations import torch from torch import nn from torch.optim import Optimizer from torch.utils.data import Dataset , DataLoader , random_split from dataclasses import dataclass , field from wtracker.utils.config_base import ConfigBase @dataclass class DatasetConfig ( ConfigBase ): input_frames : list [ int ] # The frames to use as input for the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). pred_frames : list [ int ] # The frames to predict. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). log_path : list [ str ] # The path to the log file containing the worm head predictions (by YOLO). def __post_init__ ( self ) -> None : if self . input_frames [ 0 ] != 0 : print ( \"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters.\" ) @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path ) OPTIMIZERS = { \"adam\" : torch . optim . Adam , \"sgd\" : torch . optim . SGD , \"rmsprop\" : torch . optim . RMSprop , \"adamw\" : torch . optim . AdamW , } LOSSES = { \"mse\" : nn . MSELoss , \"l1\" : nn . L1Loss , } @dataclass class TrainConfig ( ConfigBase ): # general parameters seed : int = field ( default = 42 , kw_only = True ) # Random seed for reproducibility dataset : DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is) # trainer parameters model : nn . Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk) loss_fn : str # The loss function to use, can be any of the keys in the LOSSES dict optimizer : str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict device : str = \"cuda\" # 'cuda' for training on GPU or 'cpu' otherwise log : bool = False # Whether to log and save the training process with tensorboard # training parameters num_epochs : int = 100 # Number of times to iterate over the dataset checkpoints : str = None # Path to save model checkpoints, influding the checkpoint name. early_stopping : int = None # Number of epochs to wait before stopping training if no improvement was made print_every : int = 5 # How often (#epochs) to print training progress # optimizer parameters learning_rate : float = 0.001 # Learning rate for the optimizer weight_decay : float = ( 1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger) ) # dataloader parameters batch_size : int = 256 # Number of samples in each batch shuffle : bool = True # Whether to shuffle the dataset at the beginning of each epoch num_workers : int = 0 # Number of subprocesses to use for data loading train_test_split : float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing dl_train : DataLoader = field ( init = False ) dl_test : DataLoader = field ( init = False ) @dataclass class IOConfig ( ConfigBase ): \"\"\" Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. \"\"\" input_frames : list [ int ] pred_frames : list [ int ] in_dim : int = field ( init = False ) out_dim : int = field ( init = False ) def __post_init__ ( self ): if 0 not in self . input_frames : print ( \"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters.\" ) self . in_dim = len ( self . input_frames ) * 4 self . out_dim = len ( self . pred_frames ) * 2 @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames )","title":"Module wtracker.neural.config"},{"location":"reference/wtracker/neural/config/#variables","text":"LOSSES OPTIMIZERS","title":"Variables"},{"location":"reference/wtracker/neural/config/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/config/#datasetconfig","text":"class DatasetConfig ( input_frames : 'list[int]' , pred_frames : 'list[int]' , log_path : 'list[str]' ) DatasetConfig(input_frames: 'list[int]', pred_frames: 'list[int]', log_path: 'list[str]') View Source @dataclass class DatasetConfig ( ConfigBase ) : input_frames : list [ int ] # The frames to use as input for the network . The frames are in the format of the number of frames before ( negative ) or after ( positive ) the prediction frame ( 0 ). pred_frames : list [ int ] # The frames to predict . The frames are in the format of the number of frames before ( negative ) or after ( positive ) the prediction frame ( 0 ). log_path : list [ str ] # The path to the log file containing the worm head predictions ( by YOLO ). def __post_init__ ( self ) -> None : if self . input_frames [ 0 ] != 0 : print ( \"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters.\" ) @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path )","title":"DatasetConfig"},{"location":"reference/wtracker/neural/config/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/config/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/neural/config/#from_io_config","text":"def from_io_config ( io : 'IOConfig' , log_path : 'str' ) -> 'DatasetConfig' View Source @staticmethod def from_io_config ( io : IOConfig , log_path : str ) -> DatasetConfig : return DatasetConfig ( io . input_frames , io . pred_frames , log_path )","title":"from_io_config"},{"location":"reference/wtracker/neural/config/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/neural/config/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/neural/config/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/config/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/neural/config/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/neural/config/#ioconfig","text":"class IOConfig ( input_frames : 'list[int]' , pred_frames : 'list[int]' ) Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. View Source @dataclass class IOConfig ( ConfigBase ) : \"\"\" Configuration for the basic input/output of the network The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame. \"\"\" input_frames : list [ int ] pred_frames : list [ int ] in_dim : int = field ( init = False ) out_dim : int = field ( init = False ) def __post_init__ ( self ) : if 0 not in self . input_frames : print ( \"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters.\" ) self . in_dim = len ( self . input_frames ) * 4 self . out_dim = len ( self . pred_frames ) * 2 @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames )","title":"IOConfig"},{"location":"reference/wtracker/neural/config/#ancestors-in-mro_1","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/config/#static-methods_1","text":"","title":"Static methods"},{"location":"reference/wtracker/neural/config/#from_datasetconfig","text":"def from_datasetConfig ( config : 'DatasetConfig' ) -> 'IOConfig' View Source @staticmethod def from_datasetConfig ( config : DatasetConfig ) -> IOConfig : return IOConfig ( config . input_frames , config . pred_frames )","title":"from_datasetConfig"},{"location":"reference/wtracker/neural/config/#load_json_1","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/neural/config/#load_pickle_1","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/neural/config/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/neural/config/#save_json_1","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/neural/config/#save_pickle_1","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/neural/config/#trainconfig","text":"class TrainConfig ( dataset : 'DatasetConfig' , model : 'nn.Module | str' , loss_fn : 'str' , optimizer : 'str' , device : 'str' = 'cuda' , log : 'bool' = False , num_epochs : 'int' = 100 , checkpoints : 'str' = None , early_stopping : 'int' = None , print_every : 'int' = 5 , learning_rate : 'float' = 0.001 , weight_decay : 'float' = 1e-05 , batch_size : 'int' = 256 , shuffle : 'bool' = True , num_workers : 'int' = 0 , train_test_split : 'float' = 0.8 , * , seed : 'int' = 42 ) TrainConfig(dataset: 'DatasetConfig', model: 'nn.Module | str', loss_fn: 'str', optimizer: 'str', device: 'str' = 'cuda', log: 'bool' = False, num_epochs: 'int' = 100, checkpoints: 'str' = None, early_stopping: 'int' = None, print_every: 'int' = 5, learning_rate: 'float' = 0.001, weight_decay: 'float' = 1e-05, batch_size: 'int' = 256, shuffle: 'bool' = True, num_workers: 'int' = 0, train_test_split: 'float' = 0.8, *, seed: 'int' = 42) View Source @ dataclass class TrainConfig ( ConfigBase ): # general parameters seed : int = field ( default = 42 , kw_only = True ) # Random seed for reproducibility dataset : DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is) # trainer parameters model : nn . Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk) loss_fn : str # The loss function to use, can be any of the keys in the LOSSES dict optimizer : str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict device : str = \"cuda\" # 'cuda' for training on GPU or 'cpu' otherwise log : bool = False # Whether to log and save the training process with tensorboard # training parameters num_epochs : int = 100 # Number of times to iterate over the dataset checkpoints : str = None # Path to save model checkpoints, influding the checkpoint name. early_stopping : int = None # Number of epochs to wait before stopping training if no improvement was made print_every : int = 5 # How often (#epochs) to print training progress # optimizer parameters learning_rate : float = 0.001 # Learning rate for the optimizer weight_decay : float = ( 1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger) ) # dataloader parameters batch_size : int = 256 # Number of samples in each batch shuffle : bool = True # Whether to shuffle the dataset at the beginning of each epoch num_workers : int = 0 # Number of subprocesses to use for data loading train_test_split : float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing dl_train : DataLoader = field ( init = False ) dl_test : DataLoader = field ( init = False )","title":"TrainConfig"},{"location":"reference/wtracker/neural/config/#ancestors-in-mro_2","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/config/#class-variables","text":"batch_size checkpoints device early_stopping learning_rate log num_epochs num_workers print_every seed shuffle train_test_split weight_decay","title":"Class variables"},{"location":"reference/wtracker/neural/config/#static-methods_2","text":"","title":"Static methods"},{"location":"reference/wtracker/neural/config/#load_json_2","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/neural/config/#load_pickle_2","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/neural/config/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/neural/config/#save_json_2","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/neural/config/#save_pickle_2","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/neural/dataset/","text":"Module wtracker.neural.dataset View Source from __future__ import annotations from torch.utils.data import Dataset from torch import Tensor import torch import pandas as pd import numpy as np from wtracker.neural.config import DatasetConfig from wtracker.utils.bbox_utils import BoxUtils class NumpyDataset ( Dataset ): \"\"\" A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Args: X (np.ndarray): The input data as a numpy array. y (np.ndarray): The output data as a numpy array. config (DatasetConfig, optional): The configuration object for the dataset. \"\"\" def __init__ ( self , X : np . ndarray , y : np . ndarray , config : DatasetConfig = None ): self . config = config self . X = Tensor ( X ) self . y = Tensor ( y ) def __len__ ( self ): return self . X . shape [ 0 ] def __getitem__ ( self , idx ): return self . X [ idx , :], self . y [ idx , :] def save ( self , path : str ) -> None : torch . save ( self , path ) @staticmethod def load ( path : str ) -> None : return torch . load ( path ) @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\" , \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str ( i ) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str ( i ) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ): X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna () . any ( axis = 1 ), y . isna () . any ( axis = 1 )) X = X . loc [ ~ na_mask ] y = y . loc [ ~ na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [:, 0 ] . reshape ( - 1 , 1 ) y_cords = X [:, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ]) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ]) % 2 == 1 y [:, x_cord_mask ] -= x_cords y [:, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ]) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ]) % 4 == 1 X [:, x_cord_mask ] -= x_cords # X [:, y_cord_mask ] -= y_cords # .reshape(-1, 1) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset Classes NumpyDataset class NumpyDataset ( X : 'np.ndarray' , y : 'np.ndarray' , config : 'DatasetConfig' = None ) A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Attributes Name Type Description Default X np.ndarray The input data as a numpy array. None y np.ndarray The output data as a numpy array. None config DatasetConfig The configuration object for the dataset. None View Source class NumpyDataset ( Dataset ) : \"\"\" A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Args: X (np.ndarray): The input data as a numpy array. y (np.ndarray): The output data as a numpy array. config (DatasetConfig, optional): The configuration object for the dataset. \"\"\" def __init__ ( self , X : np . ndarray , y : np . ndarray , config : DatasetConfig = None ) : self . config = config self . X = Tensor ( X ) self . y = Tensor ( y ) def __len__ ( self ) : return self . X . shape [ 0 ] def __getitem__ ( self , idx ) : return self . X [ idx, : ] , self . y [ idx, : ] def save ( self , path : str ) -> None : torch . save ( self , path ) @staticmethod def load ( path : str ) -> None : return torch . load ( path ) @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\", \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str(i) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str(i) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ) : X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna (). any ( axis = 1 ), y . isna (). any ( axis = 1 )) X = X . loc [ ~na_mask ] y = y . loc [ ~na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [ :, 0 ] . reshape ( - 1 , 1 ) y_cords = X [ :, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 1 y [ :, x_cord_mask ] -= x_cords y [ :, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 1 X [ :, x_cord_mask ] -= x_cords # X [ :, y_cord_mask ] -= y_cords # . reshape ( - 1 , 1 ) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset Ancestors (in MRO) torch.utils.data.dataset.Dataset typing.Generic Static methods create_from_config def create_from_config ( config : 'DatasetConfig' , save_path : 'str | None' = None ) -> 'NumpyDataset' View Source @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\", \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str(i) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str(i) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ) : X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna (). any ( axis = 1 ), y . isna (). any ( axis = 1 )) X = X . loc [ ~na_mask ] y = y . loc [ ~na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [ :, 0 ] . reshape ( - 1 , 1 ) y_cords = X [ :, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 1 y [ :, x_cord_mask ] -= x_cords y [ :, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 1 X [ :, x_cord_mask ] -= x_cords # X [ :, y_cord_mask ] -= y_cords # . reshape ( - 1 , 1 ) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset load def load ( path : 'str' ) -> 'None' View Source @ staticmethod def load ( path : str ) -> None : return torch . load ( path ) Methods save def save ( self , path : 'str' ) -> 'None' View Source def save ( self , path : str ) -> None : torch . save ( self , path )","title":"Dataset"},{"location":"reference/wtracker/neural/dataset/#module-wtrackerneuraldataset","text":"View Source from __future__ import annotations from torch.utils.data import Dataset from torch import Tensor import torch import pandas as pd import numpy as np from wtracker.neural.config import DatasetConfig from wtracker.utils.bbox_utils import BoxUtils class NumpyDataset ( Dataset ): \"\"\" A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Args: X (np.ndarray): The input data as a numpy array. y (np.ndarray): The output data as a numpy array. config (DatasetConfig, optional): The configuration object for the dataset. \"\"\" def __init__ ( self , X : np . ndarray , y : np . ndarray , config : DatasetConfig = None ): self . config = config self . X = Tensor ( X ) self . y = Tensor ( y ) def __len__ ( self ): return self . X . shape [ 0 ] def __getitem__ ( self , idx ): return self . X [ idx , :], self . y [ idx , :] def save ( self , path : str ) -> None : torch . save ( self , path ) @staticmethod def load ( path : str ) -> None : return torch . load ( path ) @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [[ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\" , \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str ( i ) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str ( i ) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ): X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna () . any ( axis = 1 ), y . isna () . any ( axis = 1 )) X = X . loc [ ~ na_mask ] y = y . loc [ ~ na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [:, 0 ] . reshape ( - 1 , 1 ) y_cords = X [:, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ]) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ]) % 2 == 1 y [:, x_cord_mask ] -= x_cords y [:, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ]) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ]) % 4 == 1 X [:, x_cord_mask ] -= x_cords # X [:, y_cord_mask ] -= y_cords # .reshape(-1, 1) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset","title":"Module wtracker.neural.dataset"},{"location":"reference/wtracker/neural/dataset/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/dataset/#numpydataset","text":"class NumpyDataset ( X : 'np.ndarray' , y : 'np.ndarray' , config : 'DatasetConfig' = None ) A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically.","title":"NumpyDataset"},{"location":"reference/wtracker/neural/dataset/#attributes","text":"Name Type Description Default X np.ndarray The input data as a numpy array. None y np.ndarray The output data as a numpy array. None config DatasetConfig The configuration object for the dataset. None View Source class NumpyDataset ( Dataset ) : \"\"\" A custom Dataset class used to train the neural network. This class is used to create a PyTorch Dataset from a numpy array, and can be initialized with 'ndarrays' of the samples and labels, as well as a DatasetConfig configuration, in which the samples (X) and labels(y) will be created automatically. Args: X (np.ndarray): The input data as a numpy array. y (np.ndarray): The output data as a numpy array. config (DatasetConfig, optional): The configuration object for the dataset. \"\"\" def __init__ ( self , X : np . ndarray , y : np . ndarray , config : DatasetConfig = None ) : self . config = config self . X = Tensor ( X ) self . y = Tensor ( y ) def __len__ ( self ) : return self . X . shape [ 0 ] def __getitem__ ( self , idx ) : return self . X [ idx, : ] , self . y [ idx, : ] def save ( self , path : str ) -> None : torch . save ( self , path ) @staticmethod def load ( path : str ) -> None : return torch . load ( path ) @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\", \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str(i) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str(i) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ) : X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna (). any ( axis = 1 ), y . isna (). any ( axis = 1 )) X = X . loc [ ~na_mask ] y = y . loc [ ~na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [ :, 0 ] . reshape ( - 1 , 1 ) y_cords = X [ :, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 1 y [ :, x_cord_mask ] -= x_cords y [ :, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 1 X [ :, x_cord_mask ] -= x_cords # X [ :, y_cord_mask ] -= y_cords # . reshape ( - 1 , 1 ) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset","title":"Attributes"},{"location":"reference/wtracker/neural/dataset/#ancestors-in-mro","text":"torch.utils.data.dataset.Dataset typing.Generic","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/dataset/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/neural/dataset/#create_from_config","text":"def create_from_config ( config : 'DatasetConfig' , save_path : 'str | None' = None ) -> 'NumpyDataset' View Source @staticmethod def create_from_config ( config : DatasetConfig , save_path : str | None = None ) -> NumpyDataset : data = pd . read_csv ( config . log_path ) start_idx = abs ( min ( config . input_frames )) + 1 X_mask = np . asanyarray ( config . input_frames ) y_mask = np . asanyarray ( config . pred_frames ) wrm_boxes = data [ [\"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ] . to_numpy ( dtype = np . float64 ) wrm_centers = BoxUtils . center ( wrm_boxes ) # Create columns for X and y X_cols_prefix = [ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] y_cols_prefix = [ \"wrm_center_x\", \"wrm_center_y\" ] X_cols = [] y_cols = [] for i in config . input_frames : X_cols += [ col + str(i) for col in X_cols_prefix ] for i in config . pred_frames : y_cols += [ col + str(i) for col in y_cols_prefix ] # Create X and y X = pd . DataFrame ( index = data . index , columns = X_cols ) y = pd . DataFrame ( index = data . index , columns = y_cols ) for i in range ( start_idx , len ( data ) - max ( config . pred_frames ) - 1 ) : X . iloc [ i ] = wrm_boxes [ i + X_mask ] . reshape ( 1 , - 1 ) y . iloc [ i ] = wrm_centers [ i + y_mask ] . reshape ( 1 , - 1 ) # Drop rows with NaN values na_mask = np . ma . mask_or ( X . isna (). any ( axis = 1 ), y . isna (). any ( axis = 1 )) X = X . loc [ ~na_mask ] y = y . loc [ ~na_mask ] X = X . to_numpy ( dtype = np . float32 , copy = True ) y = y . to_numpy ( dtype = np . float32 , copy = True ) # make X and y coordinates relative to the prediction frame x_cords = X [ :, 0 ] . reshape ( - 1 , 1 ) y_cords = X [ :, 1 ] . reshape ( - 1 , 1 ) x_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 0 y_cord_mask = np . arange ( y . shape [ 1 ] ) % 2 == 1 y [ :, x_cord_mask ] -= x_cords y [ :, y_cord_mask ] -= y_cords x_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 0 y_cord_mask = np . arange ( X . shape [ 1 ] ) % 4 == 1 X [ :, x_cord_mask ] -= x_cords # X [ :, y_cord_mask ] -= y_cords # . reshape ( - 1 , 1 ) dataset = NumpyDataset ( X , y , config ) if save_path is not None : dataset . save ( save_path ) return dataset","title":"create_from_config"},{"location":"reference/wtracker/neural/dataset/#load","text":"def load ( path : 'str' ) -> 'None' View Source @ staticmethod def load ( path : str ) -> None : return torch . load ( path )","title":"load"},{"location":"reference/wtracker/neural/dataset/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/dataset/#save","text":"def save ( self , path : 'str' ) -> 'None' View Source def save ( self , path : str ) -> None : torch . save ( self , path )","title":"save"},{"location":"reference/wtracker/neural/mlp/","text":"Module wtracker.neural.mlp View Source from torch import Tensor , nn from typing import Union , Sequence from collections import defaultdict from wtracker.neural.config import IOConfig ACTIVATIONS = { \"relu\" : nn . ReLU , \"tanh\" : nn . Tanh , \"sigmoid\" : nn . Sigmoid , \"softmax\" : nn . Softmax , \"logsoftmax\" : nn . LogSoftmax , \"lrelu\" : nn . LeakyReLU , \"none\" : nn . Identity , None : nn . Identity , } # Default keyword arguments to pass to activation class constructors, e.g. # activation_cls(**ACTIVATION_DEFAULT_KWARGS[name]) ACTIVATION_DEFAULT_KWARGS = defaultdict ( dict , { ### \"softmax\" : dict ( dim = 1 ), \"logsoftmax\" : dict ( dim = 1 ), }, ) class WormPredictor ( nn . Module ): \"\"\" A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes: model: The neural network model that predicts worm behavior. io_config: The IOConfig object of the model. \"\"\" def __init__ ( self , model : nn . Module , io_config : IOConfig ): super () . __init__ () self . io_config : IOConfig = io_config self . model : nn . Module = model def forward ( self , x : Tensor ) -> Tensor : return self . model ( x ) class MLPLayer ( nn . Module ): \"\"\" A single layer perceptron, that can hold a bach-norm and activation layers as well. \"\"\" def __init__ ( self , in_dim : int , out_dim : Sequence [ int ], nonlin : Union [ str , nn . Module ], batch_norm : bool = True , ) -> None : super () . __init__ () layers = [] layers . append ( nn . Linear ( in_dim , out_dim )) in_dim = out_dim if batch_norm and nonlin not in [ \"none\" , None ]: layers . append ( nn . BatchNorm1d ( out_dim )) layers . append ( self . _make_activation ( nonlin )) self . mlp_layer = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 )) class MlpBlock ( nn . Module ): \"\"\" A general-purpose MLP. Args: in_dim: Input dimension. dims: Hidden dimensions, including output dimension. nonlins: Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. \"\"\" def __init__ ( self , in_dim : int , dims : Sequence [ int ], nonlins : Sequence [ Union [ str , nn . Module ]], batch_norm : bool = True , ): assert len ( nonlins ) == len ( dims ) self . in_dim = in_dim self . out_dim = dims [ - 1 ] self . dims = dims self . nonlins = nonlins super () . __init__ () layers = [] for i , out_dim in enumerate ( self . dims ): layers . append ( MLPLayer ( in_dim , out_dim , nonlins [ i ], batch_norm )) in_dim = out_dim self . sequence = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 )) class RMLP ( nn . Module ): def __init__ ( self , block_in_dim : int , block_dims : Sequence [ int ], block_nonlins : Sequence [ Union [ str , nn . Module ]], n_blocks : int , out_dim : int , in_dim : int = None , # if in_dim is an int, then a first layer will be made batch_norm : bool = True , ) -> None : super () . __init__ () # Create first layer if in_dim is not None self . input = nn . Identity () if in_dim is not None : self . input = MLPLayer ( in_dim , block_in_dim , block_nonlins [ 0 ], batch_norm ) # Create blocks layers = [] for i in range ( n_blocks ): layers . append ( MlpBlock ( block_in_dim , block_dims , block_nonlins , batch_norm )) self . blocks = nn . ModuleList ( layers ) # Create output layer self . output = nn . Linear ( block_dims [ - 1 ], out_dim ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x ) Variables ACTIVATIONS ACTIVATION_DEFAULT_KWARGS Classes MLPLayer class MLPLayer ( in_dim : int , out_dim : Sequence [ int ], nonlin : Union [ str , torch . nn . modules . module . Module ], batch_norm : bool = True ) A single layer perceptron, that can hold a bach-norm and activation layers as well. View Source class MLPLayer ( nn . Module ) : \"\"\" A single layer perceptron, that can hold a bach-norm and activation layers as well. \"\"\" def __init__ ( self , in_dim : int , out_dim : Sequence [ int ] , nonlin : Union [ str, nn.Module ] , batch_norm : bool = True , ) -> None : super (). __init__ () layers = [] layers . append ( nn . Linear ( in_dim , out_dim )) in_dim = out_dim if batch_norm and nonlin not in [ \"none\", None ] : layers . append ( nn . BatchNorm1d ( out_dim )) layers . append ( self . _make_activation ( nonlin )) self . mlp_layer = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 )) Ancestors (in MRO) torch.nn.modules.module.Module Class variables T_destination call_super_init dump_patches Methods add_module def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module apply def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self bfloat16 def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t ) buffers def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf children def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module compile def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs ) cpu def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ()) cuda def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device )) double def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t ) eval def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False ) extra_repr def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return '' float def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t ) forward def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 )) get_buffer def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer get_extra_state def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) get_parameter def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param get_submodule def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod half def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t ) ipu def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device )) load_state_dict def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys ) modules def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module named_buffers def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen named_children def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module named_modules def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate ) named_parameters def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen parameters def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param register_backward_hook def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle register_buffer def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name ) register_forward_hook def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_forward_pre_hook def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_hook def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_pre_hook def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_load_state_dict_post_hook def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle register_module def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module ) register_parameter def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param register_state_dict_pre_hook def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle requires_grad_ def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self set_extra_state def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) share_memory def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ()) state_dict def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination to def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert ) to_empty def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse ) train def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self type def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type )) xpu def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device )) zero_grad def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() MlpBlock class MlpBlock ( in_dim : int , dims : Sequence [ int ], nonlins : Sequence [ Union [ str , torch . nn . modules . module . Module ]], batch_norm : bool = True ) A general-purpose MLP. Attributes Name Type Description Default in_dim None Input dimension. None dims None Hidden dimensions, including output dimension. None nonlins None Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. None View Source class MlpBlock ( nn . Module ) : \"\"\" A general-purpose MLP. Args: in_dim: Input dimension. dims: Hidden dimensions, including output dimension. nonlins: Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. \"\"\" def __init__ ( self , in_dim : int , dims : Sequence [ int ] , nonlins : Sequence [ Union[str, nn.Module ] ] , batch_norm : bool = True , ) : assert len ( nonlins ) == len ( dims ) self . in_dim = in_dim self . out_dim = dims [ -1 ] self . dims = dims self . nonlins = nonlins super (). __init__ () layers = [] for i , out_dim in enumerate ( self . dims ) : layers . append ( MLPLayer ( in_dim , out_dim , nonlins [ i ] , batch_norm )) in_dim = out_dim self . sequence = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 )) Ancestors (in MRO) torch.nn.modules.module.Module Class variables T_destination call_super_init dump_patches Methods add_module def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module apply def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self bfloat16 def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t ) buffers def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf children def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module compile def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs ) cpu def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ()) cuda def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device )) double def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t ) eval def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False ) extra_repr def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return '' float def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t ) forward def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 )) get_buffer def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer get_extra_state def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) get_parameter def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param get_submodule def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod half def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t ) ipu def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device )) load_state_dict def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys ) modules def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module named_buffers def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen named_children def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module named_modules def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate ) named_parameters def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen parameters def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param register_backward_hook def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle register_buffer def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name ) register_forward_hook def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_forward_pre_hook def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_hook def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_pre_hook def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_load_state_dict_post_hook def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle register_module def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module ) register_parameter def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param register_state_dict_pre_hook def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle requires_grad_ def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self set_extra_state def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) share_memory def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ()) state_dict def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination to def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert ) to_empty def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse ) train def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self type def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type )) xpu def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device )) zero_grad def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() RMLP class RMLP ( block_in_dim : int , block_dims : Sequence [ int ], block_nonlins : Sequence [ Union [ str , torch . nn . modules . module . Module ]], n_blocks : int , out_dim : int , in_dim : int = None , batch_norm : bool = True ) Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model ( nn . Module ): def __init__ ( self ): super () . __init__ () self . conv1 = nn . Conv2d ( 1 , 20 , 5 ) self . conv2 = nn . Conv2d ( 20 , 20 , 5 ) def forward ( self , x ): x = F . relu ( self . conv1 ( x )) return F . relu ( self . conv2 ( x )) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth: to , etc. .. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child. View Source class RMLP ( nn . Module ) : def __init__ ( self , block_in_dim : int , block_dims : Sequence [ int ] , block_nonlins : Sequence [ Union[str, nn.Module ] ] , n_blocks : int , out_dim : int , in_dim : int = None , # if in_dim is an int , then a first layer will be made batch_norm : bool = True , ) -> None : super (). __init__ () # Create first layer if in_dim is not None self . input = nn . Identity () if in_dim is not None : self . input = MLPLayer ( in_dim , block_in_dim , block_nonlins [ 0 ] , batch_norm ) # Create blocks layers = [] for i in range ( n_blocks ) : layers . append ( MlpBlock ( block_in_dim , block_dims , block_nonlins , batch_norm )) self . blocks = nn . ModuleList ( layers ) # Create output layer self . output = nn . Linear ( block_dims [ -1 ] , out_dim ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x ) Ancestors (in MRO) torch.nn.modules.module.Module Class variables T_destination call_super_init dump_patches Methods add_module def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module apply def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self bfloat16 def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t ) buffers def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf children def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module compile def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs ) cpu def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ()) cuda def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device )) double def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t ) eval def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False ) extra_repr def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return '' float def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t ) forward def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x ) get_buffer def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer get_extra_state def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) get_parameter def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param get_submodule def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod half def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t ) ipu def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device )) load_state_dict def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys ) modules def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module named_buffers def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen named_children def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module named_modules def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate ) named_parameters def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen parameters def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param register_backward_hook def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle register_buffer def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name ) register_forward_hook def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_forward_pre_hook def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_hook def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_pre_hook def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_load_state_dict_post_hook def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle register_module def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module ) register_parameter def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param register_state_dict_pre_hook def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle requires_grad_ def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self set_extra_state def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) share_memory def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ()) state_dict def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination to def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert ) to_empty def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse ) train def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self type def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type )) xpu def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device )) zero_grad def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() WormPredictor class WormPredictor ( model : torch . nn . modules . module . Module , io_config : wtracker . neural . config . IOConfig ) A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes Name Type Description Default model None The neural network model that predicts worm behavior. None io_config None The IOConfig object of the model. None View Source class WormPredictor ( nn . Module ): \"\"\" A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes: model: The neural network model that predicts worm behavior. io_config: The IOConfig object of the model. \"\"\" def __init__ ( self , model: nn . Module , io_config: IOConfig ): super (). __init__ () self . io_config: IOConfig = io_config self . model: nn . Module = model def forward ( self , x : Tensor ) -> Tensor: return self . model ( x ) Ancestors (in MRO) torch.nn.modules.module.Module Class variables T_destination call_super_init dump_patches Methods add_module def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module apply def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self bfloat16 def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t ) buffers def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf children def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module compile def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs ) cpu def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ()) cuda def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device )) double def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t ) eval def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False ) extra_repr def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return '' float def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t ) forward def forward ( self , x : torch . Tensor ) -> torch . Tensor Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class: Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. View Source def forward ( self , x : Tensor ) -> Tensor : return self . model ( x ) get_buffer def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer get_extra_state def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) get_parameter def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param get_submodule def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod half def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t ) ipu def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device )) load_state_dict def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys ) modules def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module named_buffers def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen named_children def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module named_modules def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate ) named_parameters def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen parameters def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param register_backward_hook def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle register_buffer def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name ) register_forward_hook def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_forward_pre_hook def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_hook def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_full_backward_pre_hook def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle register_load_state_dict_post_hook def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle register_module def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module ) register_parameter def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param register_state_dict_pre_hook def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle requires_grad_ def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self set_extra_state def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" ) share_memory def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ()) state_dict def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination to def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert ) to_empty def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse ) train def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self type def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type )) xpu def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device )) zero_grad def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"Mlp"},{"location":"reference/wtracker/neural/mlp/#module-wtrackerneuralmlp","text":"View Source from torch import Tensor , nn from typing import Union , Sequence from collections import defaultdict from wtracker.neural.config import IOConfig ACTIVATIONS = { \"relu\" : nn . ReLU , \"tanh\" : nn . Tanh , \"sigmoid\" : nn . Sigmoid , \"softmax\" : nn . Softmax , \"logsoftmax\" : nn . LogSoftmax , \"lrelu\" : nn . LeakyReLU , \"none\" : nn . Identity , None : nn . Identity , } # Default keyword arguments to pass to activation class constructors, e.g. # activation_cls(**ACTIVATION_DEFAULT_KWARGS[name]) ACTIVATION_DEFAULT_KWARGS = defaultdict ( dict , { ### \"softmax\" : dict ( dim = 1 ), \"logsoftmax\" : dict ( dim = 1 ), }, ) class WormPredictor ( nn . Module ): \"\"\" A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes: model: The neural network model that predicts worm behavior. io_config: The IOConfig object of the model. \"\"\" def __init__ ( self , model : nn . Module , io_config : IOConfig ): super () . __init__ () self . io_config : IOConfig = io_config self . model : nn . Module = model def forward ( self , x : Tensor ) -> Tensor : return self . model ( x ) class MLPLayer ( nn . Module ): \"\"\" A single layer perceptron, that can hold a bach-norm and activation layers as well. \"\"\" def __init__ ( self , in_dim : int , out_dim : Sequence [ int ], nonlin : Union [ str , nn . Module ], batch_norm : bool = True , ) -> None : super () . __init__ () layers = [] layers . append ( nn . Linear ( in_dim , out_dim )) in_dim = out_dim if batch_norm and nonlin not in [ \"none\" , None ]: layers . append ( nn . BatchNorm1d ( out_dim )) layers . append ( self . _make_activation ( nonlin )) self . mlp_layer = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 )) class MlpBlock ( nn . Module ): \"\"\" A general-purpose MLP. Args: in_dim: Input dimension. dims: Hidden dimensions, including output dimension. nonlins: Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. \"\"\" def __init__ ( self , in_dim : int , dims : Sequence [ int ], nonlins : Sequence [ Union [ str , nn . Module ]], batch_norm : bool = True , ): assert len ( nonlins ) == len ( dims ) self . in_dim = in_dim self . out_dim = dims [ - 1 ] self . dims = dims self . nonlins = nonlins super () . __init__ () layers = [] for i , out_dim in enumerate ( self . dims ): layers . append ( MLPLayer ( in_dim , out_dim , nonlins [ i ], batch_norm )) in_dim = out_dim self . sequence = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 )) class RMLP ( nn . Module ): def __init__ ( self , block_in_dim : int , block_dims : Sequence [ int ], block_nonlins : Sequence [ Union [ str , nn . Module ]], n_blocks : int , out_dim : int , in_dim : int = None , # if in_dim is an int, then a first layer will be made batch_norm : bool = True , ) -> None : super () . __init__ () # Create first layer if in_dim is not None self . input = nn . Identity () if in_dim is not None : self . input = MLPLayer ( in_dim , block_in_dim , block_nonlins [ 0 ], batch_norm ) # Create blocks layers = [] for i in range ( n_blocks ): layers . append ( MlpBlock ( block_in_dim , block_dims , block_nonlins , batch_norm )) self . blocks = nn . ModuleList ( layers ) # Create output layer self . output = nn . Linear ( block_dims [ - 1 ], out_dim ) def _make_activation ( self , act : Union [ str , nn . Module ]) -> nn . Module : if isinstance ( act , str ): return ACTIVATIONS [ act ]( ** ACTIVATION_DEFAULT_KWARGS [ act ]) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x )","title":"Module wtracker.neural.mlp"},{"location":"reference/wtracker/neural/mlp/#variables","text":"ACTIVATIONS ACTIVATION_DEFAULT_KWARGS","title":"Variables"},{"location":"reference/wtracker/neural/mlp/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/mlp/#mlplayer","text":"class MLPLayer ( in_dim : int , out_dim : Sequence [ int ], nonlin : Union [ str , torch . nn . modules . module . Module ], batch_norm : bool = True ) A single layer perceptron, that can hold a bach-norm and activation layers as well. View Source class MLPLayer ( nn . Module ) : \"\"\" A single layer perceptron, that can hold a bach-norm and activation layers as well. \"\"\" def __init__ ( self , in_dim : int , out_dim : Sequence [ int ] , nonlin : Union [ str, nn.Module ] , batch_norm : bool = True , ) -> None : super (). __init__ () layers = [] layers . append ( nn . Linear ( in_dim , out_dim )) in_dim = out_dim if batch_norm and nonlin not in [ \"none\", None ] : layers . append ( nn . BatchNorm1d ( out_dim )) layers . append ( self . _make_activation ( nonlin )) self . mlp_layer = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 ))","title":"MLPLayer"},{"location":"reference/wtracker/neural/mlp/#ancestors-in-mro","text":"torch.nn.modules.module.Module","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/mlp/#class-variables","text":"T_destination call_super_init dump_patches","title":"Class variables"},{"location":"reference/wtracker/neural/mlp/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/mlp/#add_module","text":"def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module","title":"add_module"},{"location":"reference/wtracker/neural/mlp/#apply","text":"def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self","title":"apply"},{"location":"reference/wtracker/neural/mlp/#bfloat16","text":"def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t )","title":"bfloat16"},{"location":"reference/wtracker/neural/mlp/#buffers","text":"def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf","title":"buffers"},{"location":"reference/wtracker/neural/mlp/#children","text":"def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module","title":"children"},{"location":"reference/wtracker/neural/mlp/#compile","text":"def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs )","title":"compile"},{"location":"reference/wtracker/neural/mlp/#cpu","text":"def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ())","title":"cpu"},{"location":"reference/wtracker/neural/mlp/#cuda","text":"def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device ))","title":"cuda"},{"location":"reference/wtracker/neural/mlp/#double","text":"def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t )","title":"double"},{"location":"reference/wtracker/neural/mlp/#eval","text":"def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False )","title":"eval"},{"location":"reference/wtracker/neural/mlp/#extra_repr","text":"def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return ''","title":"extra_repr"},{"location":"reference/wtracker/neural/mlp/#float","text":"def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t )","title":"float"},{"location":"reference/wtracker/neural/mlp/#forward","text":"def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . mlp_layer . forward ( x . reshape ( x . size ( 0 ), - 1 ))","title":"forward"},{"location":"reference/wtracker/neural/mlp/#get_buffer","text":"def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer","title":"get_buffer"},{"location":"reference/wtracker/neural/mlp/#get_extra_state","text":"def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"get_extra_state"},{"location":"reference/wtracker/neural/mlp/#get_parameter","text":"def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param","title":"get_parameter"},{"location":"reference/wtracker/neural/mlp/#get_submodule","text":"def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod","title":"get_submodule"},{"location":"reference/wtracker/neural/mlp/#half","text":"def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t )","title":"half"},{"location":"reference/wtracker/neural/mlp/#ipu","text":"def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device ))","title":"ipu"},{"location":"reference/wtracker/neural/mlp/#load_state_dict","text":"def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys )","title":"load_state_dict"},{"location":"reference/wtracker/neural/mlp/#modules","text":"def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module","title":"modules"},{"location":"reference/wtracker/neural/mlp/#named_buffers","text":"def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_buffers"},{"location":"reference/wtracker/neural/mlp/#named_children","text":"def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module","title":"named_children"},{"location":"reference/wtracker/neural/mlp/#named_modules","text":"def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate )","title":"named_modules"},{"location":"reference/wtracker/neural/mlp/#named_parameters","text":"def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_parameters"},{"location":"reference/wtracker/neural/mlp/#parameters","text":"def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param","title":"parameters"},{"location":"reference/wtracker/neural/mlp/#register_backward_hook","text":"def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle","title":"register_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_buffer","text":"def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name )","title":"register_buffer"},{"location":"reference/wtracker/neural/mlp/#register_forward_hook","text":"def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_hook"},{"location":"reference/wtracker/neural/mlp/#register_forward_pre_hook","text":"def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_hook","text":"def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_pre_hook","text":"def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_load_state_dict_post_hook","text":"def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle","title":"register_load_state_dict_post_hook"},{"location":"reference/wtracker/neural/mlp/#register_module","text":"def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module )","title":"register_module"},{"location":"reference/wtracker/neural/mlp/#register_parameter","text":"def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param","title":"register_parameter"},{"location":"reference/wtracker/neural/mlp/#register_state_dict_pre_hook","text":"def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle","title":"register_state_dict_pre_hook"},{"location":"reference/wtracker/neural/mlp/#requires_grad_","text":"def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self","title":"requires_grad_"},{"location":"reference/wtracker/neural/mlp/#set_extra_state","text":"def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"set_extra_state"},{"location":"reference/wtracker/neural/mlp/#share_memory","text":"def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ())","title":"share_memory"},{"location":"reference/wtracker/neural/mlp/#state_dict","text":"def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination","title":"state_dict"},{"location":"reference/wtracker/neural/mlp/#to","text":"def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert )","title":"to"},{"location":"reference/wtracker/neural/mlp/#to_empty","text":"def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse )","title":"to_empty"},{"location":"reference/wtracker/neural/mlp/#train","text":"def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self","title":"train"},{"location":"reference/wtracker/neural/mlp/#type","text":"def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type ))","title":"type"},{"location":"reference/wtracker/neural/mlp/#xpu","text":"def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device ))","title":"xpu"},{"location":"reference/wtracker/neural/mlp/#zero_grad","text":"def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"zero_grad"},{"location":"reference/wtracker/neural/mlp/#mlpblock","text":"class MlpBlock ( in_dim : int , dims : Sequence [ int ], nonlins : Sequence [ Union [ str , torch . nn . modules . module . Module ]], batch_norm : bool = True ) A general-purpose MLP.","title":"MlpBlock"},{"location":"reference/wtracker/neural/mlp/#attributes","text":"Name Type Description Default in_dim None Input dimension. None dims None Hidden dimensions, including output dimension. None nonlins None Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. None View Source class MlpBlock ( nn . Module ) : \"\"\" A general-purpose MLP. Args: in_dim: Input dimension. dims: Hidden dimensions, including output dimension. nonlins: Non-linearities to apply after each one of the hidden dimensions. Can be either a sequence of strings which are keys in the ACTIVATIONS dict, or instances of nn.Module (e.g. an instance of nn.ReLU()). Length should match 'dims'. \"\"\" def __init__ ( self , in_dim : int , dims : Sequence [ int ] , nonlins : Sequence [ Union[str, nn.Module ] ] , batch_norm : bool = True , ) : assert len ( nonlins ) == len ( dims ) self . in_dim = in_dim self . out_dim = dims [ -1 ] self . dims = dims self . nonlins = nonlins super (). __init__ () layers = [] for i , out_dim in enumerate ( self . dims ) : layers . append ( MLPLayer ( in_dim , out_dim , nonlins [ i ] , batch_norm )) in_dim = out_dim self . sequence = nn . Sequential ( * layers ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 ))","title":"Attributes"},{"location":"reference/wtracker/neural/mlp/#ancestors-in-mro_1","text":"torch.nn.modules.module.Module","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/mlp/#class-variables_1","text":"T_destination call_super_init dump_patches","title":"Class variables"},{"location":"reference/wtracker/neural/mlp/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/neural/mlp/#add_module_1","text":"def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module","title":"add_module"},{"location":"reference/wtracker/neural/mlp/#apply_1","text":"def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self","title":"apply"},{"location":"reference/wtracker/neural/mlp/#bfloat16_1","text":"def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t )","title":"bfloat16"},{"location":"reference/wtracker/neural/mlp/#buffers_1","text":"def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf","title":"buffers"},{"location":"reference/wtracker/neural/mlp/#children_1","text":"def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module","title":"children"},{"location":"reference/wtracker/neural/mlp/#compile_1","text":"def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs )","title":"compile"},{"location":"reference/wtracker/neural/mlp/#cpu_1","text":"def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ())","title":"cpu"},{"location":"reference/wtracker/neural/mlp/#cuda_1","text":"def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device ))","title":"cuda"},{"location":"reference/wtracker/neural/mlp/#double_1","text":"def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t )","title":"double"},{"location":"reference/wtracker/neural/mlp/#eval_1","text":"def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False )","title":"eval"},{"location":"reference/wtracker/neural/mlp/#extra_repr_1","text":"def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return ''","title":"extra_repr"},{"location":"reference/wtracker/neural/mlp/#float_1","text":"def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t )","title":"float"},{"location":"reference/wtracker/neural/mlp/#forward_1","text":"def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" return self . sequence . forward ( x . reshape ( x . size ( 0 ), - 1 ))","title":"forward"},{"location":"reference/wtracker/neural/mlp/#get_buffer_1","text":"def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer","title":"get_buffer"},{"location":"reference/wtracker/neural/mlp/#get_extra_state_1","text":"def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"get_extra_state"},{"location":"reference/wtracker/neural/mlp/#get_parameter_1","text":"def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param","title":"get_parameter"},{"location":"reference/wtracker/neural/mlp/#get_submodule_1","text":"def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod","title":"get_submodule"},{"location":"reference/wtracker/neural/mlp/#half_1","text":"def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t )","title":"half"},{"location":"reference/wtracker/neural/mlp/#ipu_1","text":"def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device ))","title":"ipu"},{"location":"reference/wtracker/neural/mlp/#load_state_dict_1","text":"def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys )","title":"load_state_dict"},{"location":"reference/wtracker/neural/mlp/#modules_1","text":"def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module","title":"modules"},{"location":"reference/wtracker/neural/mlp/#named_buffers_1","text":"def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_buffers"},{"location":"reference/wtracker/neural/mlp/#named_children_1","text":"def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module","title":"named_children"},{"location":"reference/wtracker/neural/mlp/#named_modules_1","text":"def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate )","title":"named_modules"},{"location":"reference/wtracker/neural/mlp/#named_parameters_1","text":"def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_parameters"},{"location":"reference/wtracker/neural/mlp/#parameters_1","text":"def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param","title":"parameters"},{"location":"reference/wtracker/neural/mlp/#register_backward_hook_1","text":"def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle","title":"register_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_buffer_1","text":"def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name )","title":"register_buffer"},{"location":"reference/wtracker/neural/mlp/#register_forward_hook_1","text":"def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_hook"},{"location":"reference/wtracker/neural/mlp/#register_forward_pre_hook_1","text":"def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_hook_1","text":"def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_pre_hook_1","text":"def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_load_state_dict_post_hook_1","text":"def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle","title":"register_load_state_dict_post_hook"},{"location":"reference/wtracker/neural/mlp/#register_module_1","text":"def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module )","title":"register_module"},{"location":"reference/wtracker/neural/mlp/#register_parameter_1","text":"def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param","title":"register_parameter"},{"location":"reference/wtracker/neural/mlp/#register_state_dict_pre_hook_1","text":"def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle","title":"register_state_dict_pre_hook"},{"location":"reference/wtracker/neural/mlp/#requires_grad__1","text":"def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self","title":"requires_grad_"},{"location":"reference/wtracker/neural/mlp/#set_extra_state_1","text":"def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"set_extra_state"},{"location":"reference/wtracker/neural/mlp/#share_memory_1","text":"def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ())","title":"share_memory"},{"location":"reference/wtracker/neural/mlp/#state_dict_1","text":"def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination","title":"state_dict"},{"location":"reference/wtracker/neural/mlp/#to_1","text":"def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert )","title":"to"},{"location":"reference/wtracker/neural/mlp/#to_empty_1","text":"def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse )","title":"to_empty"},{"location":"reference/wtracker/neural/mlp/#train_1","text":"def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self","title":"train"},{"location":"reference/wtracker/neural/mlp/#type_1","text":"def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type ))","title":"type"},{"location":"reference/wtracker/neural/mlp/#xpu_1","text":"def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device ))","title":"xpu"},{"location":"reference/wtracker/neural/mlp/#zero_grad_1","text":"def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"zero_grad"},{"location":"reference/wtracker/neural/mlp/#rmlp","text":"class RMLP ( block_in_dim : int , block_dims : Sequence [ int ], block_nonlins : Sequence [ Union [ str , torch . nn . modules . module . Module ]], n_blocks : int , out_dim : int , in_dim : int = None , batch_norm : bool = True ) Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model ( nn . Module ): def __init__ ( self ): super () . __init__ () self . conv1 = nn . Conv2d ( 1 , 20 , 5 ) self . conv2 = nn . Conv2d ( 20 , 20 , 5 ) def forward ( self , x ): x = F . relu ( self . conv1 ( x )) return F . relu ( self . conv2 ( x )) Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth: to , etc. .. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child. View Source class RMLP ( nn . Module ) : def __init__ ( self , block_in_dim : int , block_dims : Sequence [ int ] , block_nonlins : Sequence [ Union[str, nn.Module ] ] , n_blocks : int , out_dim : int , in_dim : int = None , # if in_dim is an int , then a first layer will be made batch_norm : bool = True , ) -> None : super (). __init__ () # Create first layer if in_dim is not None self . input = nn . Identity () if in_dim is not None : self . input = MLPLayer ( in_dim , block_in_dim , block_nonlins [ 0 ] , batch_norm ) # Create blocks layers = [] for i in range ( n_blocks ) : layers . append ( MlpBlock ( block_in_dim , block_dims , block_nonlins , batch_norm )) self . blocks = nn . ModuleList ( layers ) # Create output layer self . output = nn . Linear ( block_dims [ -1 ] , out_dim ) def _make_activation ( self , act : Union [ str, nn.Module ] ) -> nn . Module : if isinstance ( act , str ) : return ACTIVATIONS [ act ] ( ** ACTIVATION_DEFAULT_KWARGS [ act ] ) return act def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x )","title":"RMLP"},{"location":"reference/wtracker/neural/mlp/#ancestors-in-mro_2","text":"torch.nn.modules.module.Module","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/mlp/#class-variables_2","text":"T_destination call_super_init dump_patches","title":"Class variables"},{"location":"reference/wtracker/neural/mlp/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/neural/mlp/#add_module_2","text":"def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module","title":"add_module"},{"location":"reference/wtracker/neural/mlp/#apply_2","text":"def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self","title":"apply"},{"location":"reference/wtracker/neural/mlp/#bfloat16_2","text":"def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t )","title":"bfloat16"},{"location":"reference/wtracker/neural/mlp/#buffers_2","text":"def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf","title":"buffers"},{"location":"reference/wtracker/neural/mlp/#children_2","text":"def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module","title":"children"},{"location":"reference/wtracker/neural/mlp/#compile_2","text":"def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs )","title":"compile"},{"location":"reference/wtracker/neural/mlp/#cpu_2","text":"def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ())","title":"cpu"},{"location":"reference/wtracker/neural/mlp/#cuda_2","text":"def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device ))","title":"cuda"},{"location":"reference/wtracker/neural/mlp/#double_2","text":"def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t )","title":"double"},{"location":"reference/wtracker/neural/mlp/#eval_2","text":"def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False )","title":"eval"},{"location":"reference/wtracker/neural/mlp/#extra_repr_2","text":"def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return ''","title":"extra_repr"},{"location":"reference/wtracker/neural/mlp/#float_2","text":"def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t )","title":"float"},{"location":"reference/wtracker/neural/mlp/#forward_2","text":"def forward ( self , x : torch . Tensor ) -> torch . Tensor Parameters: Name Type Description Default x None An input tensor, of shape (N, D) containing N samples with D features. None Returns: Type Description None An output tensor of shape (N, D_out) where D_out is the output dim. View Source def forward ( self , x : Tensor ) -> Tensor : \"\"\" Args: x: An input tensor, of shape (N, D) containing N samples with D features. Returns: An output tensor of shape (N, D_out) where D_out is the output dim. \"\"\" x = self . input ( x ) for block in self . blocks : out = block ( x ) x = x + out return self . output ( x )","title":"forward"},{"location":"reference/wtracker/neural/mlp/#get_buffer_2","text":"def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer","title":"get_buffer"},{"location":"reference/wtracker/neural/mlp/#get_extra_state_2","text":"def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"get_extra_state"},{"location":"reference/wtracker/neural/mlp/#get_parameter_2","text":"def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param","title":"get_parameter"},{"location":"reference/wtracker/neural/mlp/#get_submodule_2","text":"def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod","title":"get_submodule"},{"location":"reference/wtracker/neural/mlp/#half_2","text":"def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t )","title":"half"},{"location":"reference/wtracker/neural/mlp/#ipu_2","text":"def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device ))","title":"ipu"},{"location":"reference/wtracker/neural/mlp/#load_state_dict_2","text":"def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys )","title":"load_state_dict"},{"location":"reference/wtracker/neural/mlp/#modules_2","text":"def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module","title":"modules"},{"location":"reference/wtracker/neural/mlp/#named_buffers_2","text":"def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_buffers"},{"location":"reference/wtracker/neural/mlp/#named_children_2","text":"def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module","title":"named_children"},{"location":"reference/wtracker/neural/mlp/#named_modules_2","text":"def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate )","title":"named_modules"},{"location":"reference/wtracker/neural/mlp/#named_parameters_2","text":"def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_parameters"},{"location":"reference/wtracker/neural/mlp/#parameters_2","text":"def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param","title":"parameters"},{"location":"reference/wtracker/neural/mlp/#register_backward_hook_2","text":"def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle","title":"register_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_buffer_2","text":"def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name )","title":"register_buffer"},{"location":"reference/wtracker/neural/mlp/#register_forward_hook_2","text":"def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_hook"},{"location":"reference/wtracker/neural/mlp/#register_forward_pre_hook_2","text":"def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_hook_2","text":"def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_pre_hook_2","text":"def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_load_state_dict_post_hook_2","text":"def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle","title":"register_load_state_dict_post_hook"},{"location":"reference/wtracker/neural/mlp/#register_module_2","text":"def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module )","title":"register_module"},{"location":"reference/wtracker/neural/mlp/#register_parameter_2","text":"def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param","title":"register_parameter"},{"location":"reference/wtracker/neural/mlp/#register_state_dict_pre_hook_2","text":"def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle","title":"register_state_dict_pre_hook"},{"location":"reference/wtracker/neural/mlp/#requires_grad__2","text":"def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self","title":"requires_grad_"},{"location":"reference/wtracker/neural/mlp/#set_extra_state_2","text":"def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"set_extra_state"},{"location":"reference/wtracker/neural/mlp/#share_memory_2","text":"def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ())","title":"share_memory"},{"location":"reference/wtracker/neural/mlp/#state_dict_2","text":"def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination","title":"state_dict"},{"location":"reference/wtracker/neural/mlp/#to_2","text":"def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert )","title":"to"},{"location":"reference/wtracker/neural/mlp/#to_empty_2","text":"def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse )","title":"to_empty"},{"location":"reference/wtracker/neural/mlp/#train_2","text":"def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self","title":"train"},{"location":"reference/wtracker/neural/mlp/#type_2","text":"def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type ))","title":"type"},{"location":"reference/wtracker/neural/mlp/#xpu_2","text":"def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device ))","title":"xpu"},{"location":"reference/wtracker/neural/mlp/#zero_grad_2","text":"def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"zero_grad"},{"location":"reference/wtracker/neural/mlp/#wormpredictor","text":"class WormPredictor ( model : torch . nn . modules . module . Module , io_config : wtracker . neural . config . IOConfig ) A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output.","title":"WormPredictor"},{"location":"reference/wtracker/neural/mlp/#attributes_1","text":"Name Type Description Default model None The neural network model that predicts worm behavior. None io_config None The IOConfig object of the model. None View Source class WormPredictor ( nn . Module ): \"\"\" A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output. Attributes: model: The neural network model that predicts worm behavior. io_config: The IOConfig object of the model. \"\"\" def __init__ ( self , model: nn . Module , io_config: IOConfig ): super (). __init__ () self . io_config: IOConfig = io_config self . model: nn . Module = model def forward ( self , x : Tensor ) -> Tensor: return self . model ( x )","title":"Attributes"},{"location":"reference/wtracker/neural/mlp/#ancestors-in-mro_3","text":"torch.nn.modules.module.Module","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/mlp/#class-variables_3","text":"T_destination call_super_init dump_patches","title":"Class variables"},{"location":"reference/wtracker/neural/mlp/#methods_3","text":"","title":"Methods"},{"location":"reference/wtracker/neural/mlp/#add_module_3","text":"def add_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Add a child module to the current module. The module can be accessed as an attribute using the given name. Parameters: Name Type Description Default name str name of the child module. The child module can be accessed from this module using the given name None module Module child module to be added to the module. None View Source def add_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \"\"\"Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. \"\"\" if not isinstance ( module , Module ) and module is not None : raise TypeError ( f \"{torch.typename(module)} is not a Module subclass\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"module name should be a string. Got {torch.typename(name)}\" ) elif hasattr ( self , name ) and name not in self . _modules : raise KeyError ( f \"attribute '{name}' already exists\" ) elif '.' in name : raise KeyError ( f \"module name can't contain \\\" . \\ \", got: {name}\" ) elif name == '' : raise KeyError ( \"module name can't be empty string \\\" \\ \"\" ) for hook in _global_module_registration_hooks . values () : output = hook ( self , name , module ) if output is not None : module = output self . _modules [ name ] = module","title":"add_module"},{"location":"reference/wtracker/neural/mlp/#apply_3","text":"def apply ( self : ~ T , fn : Callable [[ ForwardRef ( 'Module' )], NoneType ] ) -> ~ T Apply fn recursively to every submodule (as returned by .children() ) as well as self. Typical use includes initializing the parameters of a model (see also :ref: nn-init-doc ). Parameters: Name Type Description Default fn ( None class: Module -> None): function to be applied to each submodule None Returns: Type Description Module self View Source def apply ( self : T , fn : Callable [[ 'Module' ] , None ] ) -> T : r \" \"\" Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example:: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) \"\" \" for module in self . children () : module . apply ( fn ) fn ( self ) return self","title":"apply"},{"location":"reference/wtracker/neural/mlp/#bfloat16_3","text":"def bfloat16 ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to bfloat16 datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def bfloat16 ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``bfloat16`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . bfloat16 () if t . is_floating_point () else t )","title":"bfloat16"},{"location":"reference/wtracker/neural/mlp/#buffers_3","text":"def buffers ( self , recurse : bool = True ) -> Iterator [ torch . Tensor ] Return an iterator over module buffers. Parameters: Name Type Description Default recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. None Yields: Type Description torch.Tensor module buffer View Source def buffers ( self , recurse : bool = True ) -> Iterator [ Tensor ] : r \"\"\"Return an iterator over module buffers. Args: recurse (bool): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Yields: torch.Tensor: module buffer Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for buf in model.buffers(): >>> print(type(buf), buf.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for _ , buf in self . named_buffers ( recurse = recurse ) : yield buf","title":"buffers"},{"location":"reference/wtracker/neural/mlp/#children_3","text":"def children ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over immediate children modules. Yields: Type Description Module a child module View Source def children ( self ) -> Iterator [ 'Module']: r \"\"\"Return an iterator over immediate children modules. Yields: Module: a child module \"\"\" for name , module in self . named_children (): yield module","title":"children"},{"location":"reference/wtracker/neural/mlp/#compile_3","text":"def compile ( self , * args , ** kwargs ) Compile this Module's forward using :func: torch.compile . This Module's __call__ method is compiled and all arguments are passed as-is to :func: torch.compile . See :func: torch.compile for details on the arguments for this function. View Source def compile ( self , * args , ** kwargs ) : \" \"\" Compile this Module's forward using :func:`torch.compile`. This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`. See :func:`torch.compile` for details on the arguments for this function. \"\" \" self . _compiled_call_impl = torch . compile ( self . _call_impl , * args , ** kwargs )","title":"compile"},{"location":"reference/wtracker/neural/mlp/#cpu_3","text":"def cpu ( self : ~ T ) -> ~ T Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def cpu ( self : T ) -> T : r \"\"\"Move all model parameters and buffers to the CPU. .. note:: This method modifies the module in-place. Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cpu ())","title":"cpu"},{"location":"reference/wtracker/neural/mlp/#cuda_3","text":"def cuda ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def cuda ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . cuda ( device ))","title":"cuda"},{"location":"reference/wtracker/neural/mlp/#double_3","text":"def double ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to double datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def double ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``double`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . double () if t . is_floating_point () else t )","title":"double"},{"location":"reference/wtracker/neural/mlp/#eval_3","text":"def eval ( self : ~ T ) -> ~ T Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. This is equivalent with :meth: self.train(False) . See :ref: locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it. Returns: Type Description Module self View Source def eval ( self : T ) -> T : r \" \"\" Set the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) `. See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it. Returns: Module: self \"\" \" return self . train ( False )","title":"eval"},{"location":"reference/wtracker/neural/mlp/#extra_repr_3","text":"def extra_repr ( self ) -> str Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. View Source def extra_repr ( self ) -> str : r \"\"\"Set the extra representation of the module. To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable. \"\"\" return ''","title":"extra_repr"},{"location":"reference/wtracker/neural/mlp/#float_3","text":"def float ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to float datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def float ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``float`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . float () if t . is_floating_point () else t )","title":"float"},{"location":"reference/wtracker/neural/mlp/#forward_3","text":"def forward ( self , x : torch . Tensor ) -> torch . Tensor Define the computation performed at every call. Should be overridden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class: Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. View Source def forward ( self , x : Tensor ) -> Tensor : return self . model ( x )","title":"forward"},{"location":"reference/wtracker/neural/mlp/#get_buffer_3","text":"def get_buffer ( self , target : str ) -> 'Tensor' Return the buffer given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.Tensor The buffer referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not a buffer View Source def get_buffer ( self , target : str ) -> \"Tensor\" : \" \"\" Return the buffer given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the buffer to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.Tensor: The buffer referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not a buffer \"\" \" module_path , _ , buffer_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , buffer_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + buffer_name + \"`\" ) buffer : torch . Tensor = getattr ( mod , buffer_name ) if buffer_name not in mod . _buffers : raise AttributeError ( \"`\" + buffer_name + \"` is not a buffer\" ) return buffer","title":"get_buffer"},{"location":"reference/wtracker/neural/mlp/#get_extra_state_3","text":"def get_extra_state ( self ) -> Any Return any extra state to include in the module's state_dict. Implement this and a corresponding :func: set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict() . Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: Type Description object Any extra state to store in the module's state_dict View Source def get_extra_state ( self ) -> Any : \" \"\" Return any extra state to include in the module's state_dict. Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`. Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes. Returns: object: Any extra state to store in the module's state_dict \"\" \" raise RuntimeError ( \"Reached a code path in Module.get_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"get_extra_state"},{"location":"reference/wtracker/neural/mlp/#get_parameter_3","text":"def get_parameter ( self , target : str ) -> 'Parameter' Return the parameter given by target if it exists, otherwise throw an error. See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target . Parameters: Name Type Description Default target None The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Parameter The Parameter referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Parameter View Source def get_parameter ( self , target : str ) -> \"Parameter\" : \" \"\" Return the parameter given by ``target`` if it exists, otherwise throw an error. See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``. Args: target: The fully-qualified string name of the Parameter to look for. (See ``get_submodule`` for how to specify a fully-qualified string.) Returns: torch.nn.Parameter: The Parameter referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Parameter`` \"\" \" module_path , _ , param_name = target . rpartition ( \".\" ) mod : torch . nn . Module = self . get_submodule ( module_path ) if not hasattr ( mod , param_name ) : raise AttributeError ( mod . _get_name () + \" has no attribute `\" + param_name + \"`\" ) param : torch . nn . Parameter = getattr ( mod , param_name ) if not isinstance ( param , torch . nn . Parameter ) : raise AttributeError ( \"`\" + param_name + \"` is not an \" \"nn.Parameter\" ) return param","title":"get_parameter"},{"location":"reference/wtracker/neural/mlp/#get_submodule_3","text":"def get_submodule ( self , target : str ) -> 'Module' Return the submodule given by target if it exists, otherwise throw an error. For example, let's say you have an nn.Module A that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an nn.Module A . A has a nested submodule net_b , which itself has two submodules net_c and linear . net_c then has a submodule conv .) To check whether or not we have the linear submodule, we would call get_submodule(\"net_b.linear\") . To check whether we have the conv submodule, we would call get_submodule(\"net_b.net_c.conv\") . The runtime of get_submodule is bounded by the degree of module nesting in target . A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used. Parameters: Name Type Description Default target None The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) None Returns: Type Description torch.nn.Module The submodule referenced by target Raises: Type Description AttributeError If the target string references an invalid path or resolves to something that is not an nn.Module View Source def get_submodule ( self , target : str ) -> \"Module\" : \" \"\" Return the submodule given by ``target`` if it exists, otherwise throw an error. For example, let's say you have an ``nn.Module`` ``A`` that looks like this: .. code-block:: text A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) ) (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.) To check whether or not we have the ``linear`` submodule, we would call ``get_submodule(\" net_b . linear \")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule(\" net_b . net_c . conv \")``. The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used. Args: target: The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.) Returns: torch.nn.Module: The submodule referenced by ``target`` Raises: AttributeError: If the target string references an invalid path or resolves to something that is not an ``nn.Module`` \"\" \" if target == \"\" : return self atoms : List [ str ] = target . split ( \".\" ) mod : torch . nn . Module = self for item in atoms : if not hasattr ( mod , item ) : raise AttributeError ( mod . _get_name () + \" has no \" \"attribute `\" + item + \"`\" ) mod = getattr ( mod , item ) if not isinstance ( mod , torch . nn . Module ) : raise AttributeError ( \"`\" + item + \"` is not \" \"an nn.Module\" ) return mod","title":"get_submodule"},{"location":"reference/wtracker/neural/mlp/#half_3","text":"def half ( self : ~ T ) -> ~ T Casts all floating point parameters and buffers to half datatype. .. note:: This method modifies the module in-place. Returns: Type Description Module self View Source def half ( self : T ) -> T : r \" \"\" Casts all floating point parameters and buffers to ``half`` datatype. .. note:: This method modifies the module in-place. Returns: Module: self \"\" \" return self . _apply ( lambda t : t . half () if t . is_floating_point () else t )","title":"half"},{"location":"reference/wtracker/neural/mlp/#ipu_3","text":"def ipu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def ipu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . ipu ( device ))","title":"ipu"},{"location":"reference/wtracker/neural/mlp/#load_state_dict_3","text":"def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) Copy parameters and buffers from :attr: state_dict into this module and its descendants. If :attr: strict is True , then the keys of :attr: state_dict must exactly match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. .. warning:: If :attr: assign is True the optimizer must be created after the call to :attr: load_state_dict unless :func: ~torch.__future__.get_swap_module_params_on_conversion is True . Parameters: Name Type Description Default state_dict dict a dict containing parameters and persistent buffers. None strict bool whether to strictly enforce that the keys in :attr: state_dict match the keys returned by this module's :meth: ~torch.nn.Module.state_dict function. Default: True None assign bool When False , the properties of the tensors in the current module are preserved while when True , the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class: ~torch.nn.Parameter s for which the value from the module is preserved. Default: False None Returns: Type Description None NamedTuple with missing_keys and unexpected_keys fields: missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys View Source def load_state_dict ( self , state_dict : Mapping [ str , Any ], strict : bool = True , assign : bool = False ) : r \"\"\"Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. \"\"\" if not isinstance ( state_dict , Mapping ) : raise TypeError ( f \"Expected state_dict to be dict-like, got {type(state_dict)}.\" ) missing_keys : List [ str ] = [] unexpected_keys : List [ str ] = [] error_msgs : List [ str ] = [] # copy state_dict so _ load_from_state_dict can modify it metadata = getattr ( state_dict , '_metadata' , None ) state_dict = OrderedDict ( state_dict ) if metadata is not None : # mypy isn't aware that \"_metadata\" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, local_state_dict, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata['assign_to_params_buffers'] = assign module._load_from_state_dict( local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + ' . ' child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( \"Hooks registered with ``register_load_state_dict_post_hook`` are not\" \"expected to return new values, if incompatible_keys need to be modified,\" \"it should be done inplace.\" ) load(self, state_dict) del load if strict: if len(unexpected_keys) > 0: error_msgs.insert( 0, ' Unexpected key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in unexpected_keys))) if len(missing_keys) > 0: error_msgs.insert( 0, ' Missing key ( s ) in state_dict : {}. '.format( ' , '.join(f' \"{k}\" ' for k in missing_keys))) if len(error_msgs) > 0: raise RuntimeError(' Error ( s ) in loading state_dict for {} : \\n\\t {} ' . format ( self . __ class__ . __ name__ , \"\\n\\t\" . join ( error_msgs ))) return _ IncompatibleKeys ( missing_keys , unexpected_keys )","title":"load_state_dict"},{"location":"reference/wtracker/neural/mlp/#modules_3","text":"def modules ( self ) -> Iterator [ ForwardRef ( 'Module' )] Return an iterator over all modules in the network. Yields: Type Description Module a module in the network View Source def modules ( self ) -> Iterator [ 'Module' ] : r \" \"\" Return an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True) \"\" \" for _ , module in self . named_modules () : yield module","title":"modules"},{"location":"reference/wtracker/neural/mlp/#named_buffers_3","text":"def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . Tensor ]] Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Parameters: Name Type Description Default prefix str prefix to prepend to all buffer names. None recurse bool if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. None remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True Yields: Type Description None (str, torch.Tensor): Tuple containing the name and buffer View Source def named_buffers ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Tensor ]]: r \"\"\"Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True. remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. Yields: (str, torch.Tensor): Tuple containing the name and buffer Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size()) \"\"\" gen = self . _named_members ( lambda module : module . _buffers . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_buffers"},{"location":"reference/wtracker/neural/mlp/#named_children_3","text":"def named_children ( self ) -> Iterator [ Tuple [ str , ForwardRef ( 'Module' )]] Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: Type Description None (str, Module): Tuple containing a name and child module View Source def named_children ( self ) -> Iterator [ Tuple [ str , 'Module' ]]: r \"\"\"Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (str, Module): Tuple containing a name and child module Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) \"\"\" memo = set () for name , module in self . _modules . items (): if module is not None and module not in memo : memo . add ( module ) yield name , module","title":"named_children"},{"location":"reference/wtracker/neural/mlp/#named_modules_3","text":"def named_modules ( self , memo : Optional [ Set [ ForwardRef ( 'Module' )]] = None , prefix : str = '' , remove_duplicate : bool = True ) Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Parameters: Name Type Description Default memo None a memo to store the set of modules already added to the result None prefix None a prefix that will be added to the name of the module None remove_duplicate None whether to remove the duplicated module instances in the result or not None Yields: Type Description None (str, Module): Tuple of name and module View Source def named_modules ( self , memo : Optional [ Set [ 'Module' ]] = None , prefix : str = '' , remove_duplicate : bool = True ) : r \" \"\" Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result or not Yields: (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. Example:: >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) \"\" \" if memo is None : memo = set () if self not in memo : if remove_duplicate : memo . add ( self ) yield prefix , self for name , module in self . _modules . items () : if module is None : continue submodule_prefix = prefix + ( '.' if prefix else '' ) + name yield from module . named_modules ( memo , submodule_prefix , remove_duplicate )","title":"named_modules"},{"location":"reference/wtracker/neural/mlp/#named_parameters_3","text":"def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , torch . nn . parameter . Parameter ]] Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Parameters: Name Type Description Default prefix str prefix to prepend to all parameter names. None recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None remove_duplicate bool whether to remove the duplicated parameters in the result. Defaults to True. None Yields: Type Description None (str, Parameter): Tuple containing the name and parameter View Source def named_parameters ( self , prefix : str = '' , recurse : bool = True , remove_duplicate : bool = True ) -> Iterator [ Tuple [ str , Parameter ]]: r \"\"\"Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. remove_duplicate (bool, optional): whether to remove the duplicated parameters in the result. Defaults to True. Yields: (str, Parameter): Tuple containing the name and parameter Example:: >>> # xdoctest: +SKIP(\"undefined vars\") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) \"\"\" gen = self . _named_members ( lambda module : module . _parameters . items (), prefix = prefix , recurse = recurse , remove_duplicate = remove_duplicate ) yield from gen","title":"named_parameters"},{"location":"reference/wtracker/neural/mlp/#parameters_3","text":"def parameters ( self , recurse : bool = True ) -> Iterator [ torch . nn . parameter . Parameter ] Return an iterator over module parameters. This is typically passed to an optimizer. Parameters: Name Type Description Default recurse bool if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. None Yields: Type Description Parameter module parameter View Source def parameters ( self , recurse : bool = True ) -> Iterator [ Parameter ] : r \"\"\"Return an iterator over module parameters. This is typically passed to an optimizer. Args: recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Yields: Parameter: module parameter Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> for param in model.parameters(): >>> print(type(param), param.size()) (20L,) (20L, 1L, 5L, 5L) \"\"\" for name , param in self . named_parameters ( recurse = recurse ) : yield param","title":"parameters"},{"location":"reference/wtracker/neural/mlp/#register_backward_hook_3","text":"def register_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]] ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. This function is deprecated in favor of :meth: ~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_backward_hook ( self , hook: Callable [[' Module ', _grad_t , _grad_t ], Union [ None , _grad_t ]] ) -> RemovableHandle: r \"\"\"Register a backward hook on the module. This function is deprecated in favor of : meth: ` ~ torch . nn . Module . register_full_backward_hook ` and the behavior of this function will change in future versions . Returns: : class: `torch . utils . hooks . RemovableHandle ` : a handle that can be used to remove the added hook by calling ` `handle . remove () `` \"\"\" if self . _is_full_backward_hook is True: raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = False handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook return handle","title":"register_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_buffer_3","text":"def register_buffer ( self , name : str , tensor : Optional [ torch . Tensor ], persistent : bool = True ) -> None Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr: persistent to False . The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr: state_dict . Buffers can be accessed as attributes using given names. Parameters: Name Type Description Default name str name of the buffer. The buffer can be accessed from this module using the given name None tensor Tensor or None buffer to be registered. If None , then operations that run on buffers, such as :attr: cuda , are ignored. If None , the buffer is not included in the module's :attr: state_dict . None persistent bool whether the buffer is part of this module's :attr: state_dict . None View Source def register_buffer ( self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None : r \" \"\" Add a buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`. Buffers can be accessed as attributes using given names. Args: name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, the buffer is **not** included in the module's :attr:`state_dict`. persistent (bool): whether the buffer is part of this module's :attr:`state_dict`. Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> self.register_buffer('running_mean', torch.zeros(num_features)) \"\" \" if persistent is False and isinstance ( self , torch . jit . ScriptModule ) : raise RuntimeError ( \"ScriptModule does not support non-persistent buffers\" ) if '_buffers' not in self . __dict__ : raise AttributeError ( \"cannot assign buffer before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"buffer name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"buffer name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"buffer name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _buffers : raise KeyError ( f \"attribute '{name}' already exists\" ) elif tensor is not None and not isinstance ( tensor , torch . Tensor ) : raise TypeError ( f \"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' \" \"(torch Tensor or None required)\" ) else : for hook in _global_buffer_registration_hooks . values () : output = hook ( self , name , tensor ) if output is not None : tensor = output self . _buffers [ name ] = tensor if persistent : self . _non_persistent_buffers_set . discard ( name ) else : self . _non_persistent_buffers_set . add ( name )","title":"register_buffer"},{"location":"reference/wtracker/neural/mlp/#register_forward_hook_3","text":"def register_forward_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ], Any ], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ], Any ], Optional [ Any ]]], * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward hook on the module. The hook will be called every time after :func: forward has computed an output. If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func: forward is called. The hook should have the following signature:: hook ( module , args , output ) -> None or modified output If with_kwargs is True , the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook ( module , args , kwargs , output ) -> None or modified output Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If True , the provided hook will be fired before all existing forward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward hooks on this :class: torch.nn.modules.Module . Note that global forward hooks registered with :func: register_module_forward_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If True , the hook will be passed the kwargs given to the forward function. Default: False None always_call bool If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ] , Any ] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ] , Any ] , Optional [ Any ]] , ] , * , prepend : bool = False , with_kwargs : bool = False , always_call : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature:: hook(module, args, output) -> None or modified output If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:: hook(module, args, kwargs, output) -> None or modified output Args: hook (Callable): The user defined hook to be registered. prepend (bool): If ``True``, the provided ``hook`` will be fired before all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward`` hooks registered with :func:`register_module_forward_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If ``True``, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` always_call (bool): If ``True`` the ``hook`` will be run regardless of whether an exception is raised while calling the Module. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_hooks , extra_dict = [ self . _forward_hooks_with_kwargs , self . _forward_hooks_always_called ] , ) self . _forward_hooks [ handle . id ] = hook if with_kwargs : self . _forward_hooks_with_kwargs [ handle . id ] = True if always_call : self . _forward_hooks_always_called [ handle . id ] = True if prepend : self . _forward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_hook"},{"location":"reference/wtracker/neural/mlp/#register_forward_pre_hook_3","text":"def register_forward_pre_hook ( self , hook : Union [ Callable [[ ~ T , Tuple [ Any , ... ]], Optional [ Any ]], Callable [[ ~ T , Tuple [ Any , ... ], Dict [ str , Any ]], Optional [ Tuple [ Any , Dict [ str , Any ]]]]], * , prepend : bool = False , with_kwargs : bool = False ) -> torch . utils . hooks . RemovableHandle Register a forward pre-hook on the module. The hook will be called every time before :func: forward is invoked. If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward . The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook ( module , args ) -> None or modified input If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook ( module , args , kwargs ) -> None or a tuple of modified input and kwargs Parameters: Name Type Description Default hook Callable The user defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing forward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing forward_pre hooks on this :class: torch.nn.modules.Module . Note that global forward_pre hooks registered with :func: register_module_forward_pre_hook will fire before all hooks registered by this method. Default: False None with_kwargs bool If true, the hook will be passed the kwargs given to the forward function. Default: False None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_forward_pre_hook ( self , hook : Union [ Callable [[ T , Tuple [ Any , ... ]] , Optional [ Any ]] , Callable [[ T , Tuple [ Any , ... ] , Dict [ str , Any ]] , Optional [ Tuple [ Any , Dict [ str , Any ]]]] , ] , * , prepend : bool = False , with_kwargs : bool = False , ) -> RemovableHandle : r \" \"\" Register a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:: hook(module, args) -> None or modified input If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:: hook(module, args, kwargs) -> None or a tuple of modified input and kwargs Args: hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``forward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``forward_pre`` hooks registered with :func:`register_module_forward_pre_hook` will fire before all hooks registered by this method. Default: ``False`` with_kwargs (bool): If true, the ``hook`` will be passed the kwargs given to the forward function. Default: ``False`` Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _forward_pre_hooks , extra_dict = self . _forward_pre_hooks_with_kwargs ) self . _forward_pre_hooks [ handle . id ] = hook if with_kwargs : self . _forward_pre_hooks_with_kwargs [ handle . id ] = True if prepend : self . _forward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_forward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_hook_3","text":"def register_full_backward_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ], Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook ( module , grad_input , grad_output ) -> tuple ( Tensor ) or None The :attr: grad_input and :attr: grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr: grad_input in subsequent computations. :attr: grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr: grad_input and :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward hooks on this :class: torch.nn.modules.Module . Note that global backward hooks registered with :func: register_module_full_backward_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_hook ( self , hook : Callable [[ \"Module\" , _grad_t , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward hook on the module. The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward`` hooks registered with :func:`register_module_full_backward_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" if self . _is_full_backward_hook is False : raise RuntimeError ( \"Cannot use both regular backward hooks and full backward hooks on a \" \"single Module. Please use only one of them.\" ) self . _is_full_backward_hook = True handle = hooks . RemovableHandle ( self . _backward_hooks ) self . _backward_hooks [ handle . id ] = hook if prepend : self . _backward_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_hook"},{"location":"reference/wtracker/neural/mlp/#register_full_backward_pre_hook_3","text":"def register_full_backward_pre_hook ( self , hook : Callable [[ ForwardRef ( 'Module' ), Union [ Tuple [ torch . Tensor , ... ], torch . Tensor ]], Union [ NoneType , Tuple [ torch . Tensor , ... ], torch . Tensor ]], prepend : bool = False ) -> torch . utils . hooks . RemovableHandle Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook ( module , grad_output ) -> tuple [ Tensor ] or None The :attr: grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr: grad_output in subsequent computations. Entries in :attr: grad_output will be None for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Parameters: Name Type Description Default hook Callable The user-defined hook to be registered. None prepend bool If true, the provided hook will be fired before all existing backward_pre hooks on this :class: torch.nn.modules.Module . Otherwise, the provided hook will be fired after all existing backward_pre hooks on this :class: torch.nn.modules.Module . Note that global backward_pre hooks registered with :func: register_module_full_backward_pre_hook will fire before all hooks registered by this method. None Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_full_backward_pre_hook ( self , hook : Callable [[ \"Module\" , _grad_t ] , Union [ None , _grad_t ]] , prepend : bool = False , ) -> RemovableHandle : r \" \"\" Register a backward pre-hook on the module. The hook will be called every time the gradients for the module are computed. The hook should have the following signature:: hook(module, grad_output) -> tuple[Tensor] or None The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments. For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function. .. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error. Args: hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided ``hook`` will be fired before all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Otherwise, the provided ``hook`` will be fired after all existing ``backward_pre`` hooks on this :class:`torch.nn.modules.Module`. Note that global ``backward_pre`` hooks registered with :func:`register_module_full_backward_pre_hook` will fire before all hooks registered by this method. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _backward_pre_hooks ) self . _backward_pre_hooks [ handle . id ] = hook if prepend : self . _backward_pre_hooks . move_to_end ( handle . id , last = False ) # type: ignore[attr-defined] return handle","title":"register_full_backward_pre_hook"},{"location":"reference/wtracker/neural/mlp/#register_load_state_dict_post_hook_3","text":"def register_load_state_dict_post_hook ( self , hook ) Register a post hook to be run after module's load_state_dict is called. It should have the following signature:: hook(module, incompatible_keys) -> None The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys . missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func: load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys , as expected. Additions to either set of keys will result in an error being thrown when strict=True , and clearing out both missing and unexpected keys will avoid an error. Returns: Type Description None :class: torch.utils.hooks.RemovableHandle : a handle that can be used to remove the added hook by calling handle.remove() View Source def register_load_state_dict_post_hook ( self , hook ) : r \" \"\" Register a post hook to be run after module's ``load_state_dict`` is called. It should have the following signature:: hook(module, incompatible_keys) -> None The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. The given incompatible_keys can be modified inplace if needed. Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` \"\" \" handle = hooks . RemovableHandle ( self . _load_state_dict_post_hooks ) self . _load_state_dict_post_hooks [ handle . id ] = hook return handle","title":"register_load_state_dict_post_hook"},{"location":"reference/wtracker/neural/mlp/#register_module_3","text":"def register_module ( self , name : str , module : Optional [ ForwardRef ( 'Module' )] ) -> None Alias for :func: add_module . View Source def register_module ( self , name : str , module : Optional [ 'Module' ] ) -> None : r \" \"\" Alias for :func:`add_module`. \"\" \" self . add_module ( name , module )","title":"register_module"},{"location":"reference/wtracker/neural/mlp/#register_parameter_3","text":"def register_parameter ( self , name : str , param : Optional [ torch . nn . parameter . Parameter ] ) -> None Add a parameter to the module. The parameter can be accessed as an attribute using given name. Parameters: Name Type Description Default name str name of the parameter. The parameter can be accessed from this module using the given name None param Parameter or None parameter to be added to the module. If None , then operations that run on parameters, such as :attr: cuda , are ignored. If None , the parameter is not included in the module's :attr: state_dict . None View Source def register_parameter ( self , name : str , param : Optional [ Parameter ] ) -> None : r \" \"\" Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. \"\" \" if '_parameters' not in self . __dict__ : raise AttributeError ( \"cannot assign parameter before Module.__init__() call\" ) elif not isinstance ( name , str ) : raise TypeError ( f \"parameter name should be a string. Got {torch.typename(name)}\" ) elif '.' in name : raise KeyError ( \"parameter name can't contain \\\" . \\\" \" ) elif name == '' : raise KeyError ( \"parameter name can't be empty string \\\"\\\" \" ) elif hasattr ( self , name ) and name not in self . _parameters : raise KeyError ( f \"attribute '{name}' already exists\" ) if param is None : self . _parameters [ name ] = None elif not isinstance ( param , Parameter ) : raise TypeError ( f \"cannot assign '{torch.typename(param)}' object to parameter '{name}' \" \"(torch.nn.Parameter or None required)\" ) elif param . grad_fn : raise ValueError ( f \"Cannot assign non-leaf Tensor to parameter '{name}'. Model \" f \"parameters must be created explicitly. To express '{name}' \" \"as a function of another Tensor, compute the value in \" \"the forward() method.\" ) else : for hook in _global_parameter_registration_hooks . values () : output = hook ( self , name , param ) if output is not None : param = output self . _parameters [ name ] = param","title":"register_parameter"},{"location":"reference/wtracker/neural/mlp/#register_state_dict_pre_hook_3","text":"def register_state_dict_pre_hook ( self , hook ) Register a pre-hook for the :meth: ~torch.nn.Module.state_dict method. These hooks will be called with arguments: self , prefix , and keep_vars before calling state_dict on self . The registered hooks can be used to perform pre-processing before the state_dict call is made. View Source def register_state_dict_pre_hook ( self , hook ) : r \" \"\" Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made. \"\" \" handle = hooks . RemovableHandle ( self . _state_dict_pre_hooks ) self . _state_dict_pre_hooks [ handle . id ] = hook return handle","title":"register_state_dict_pre_hook"},{"location":"reference/wtracker/neural/mlp/#requires_grad__3","text":"def requires_grad_ ( self : ~ T , requires_grad : bool = True ) -> ~ T Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr: requires_grad attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref: locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it. Parameters: Name Type Description Default requires_grad bool whether autograd should record operations on parameters in this module. Default: True . None Returns: Type Description Module self View Source def requires_grad_ ( self : T , requires_grad : bool = True ) -> T : r \" \"\" Change if autograd should record operations on parameters in this module. This method sets the parameters' :attr:`requires_grad` attributes in-place. This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it. Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. Returns: Module: self \"\" \" for p in self . parameters () : p . requires_grad_ ( requires_grad ) return self","title":"requires_grad_"},{"location":"reference/wtracker/neural/mlp/#set_extra_state_3","text":"def set_extra_state ( self , state : Any ) -> None Set extra state contained in the loaded state_dict . This function is called from :func: load_state_dict to handle any extra state found within the state_dict . Implement this function and a corresponding View Source def set _extra_state ( self , state : Any ) -> None : \" \"\" Set extra state contained in the loaded `state_dict`. This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`. Args: state (dict): Extra state from the `state_dict` \"\" \" raise RuntimeError ( \"Reached a code path in Module.set_extra_state() that should never be called. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"to report this bug.\" )","title":"set_extra_state"},{"location":"reference/wtracker/neural/mlp/#share_memory_3","text":"def share_memory ( self : ~ T ) -> ~ T See :meth: torch.Tensor.share_memory_ . View Source def share_memory ( self : T ) -> T : r \"\"\"See :meth:`torch.Tensor.share_memory_`.\"\"\" return self . _apply ( lambda t : t . share_memory_ ())","title":"share_memory"},{"location":"reference/wtracker/neural/mlp/#state_dict_3","text":"def state_dict ( self , * args , destination = None , prefix = '' , keep_vars = False ) Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently state_dict() also accepts positional arguments for destination , prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument destination as it is not designed for end-users. Parameters: Name Type Description Default destination dict If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None . None prefix str a prefix added to parameter and buffer names to compose the keys in state_dict. Default: '' . None keep_vars bool by default the :class: ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True , detaching will not be performed. Default: False . None Returns: Type Description dict a dictionary containing a whole state of the module View Source def state_dict ( self , * args , destination = None , prefix='' , keep_vars = False ) : r \"\"\"Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. .. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers. .. warning:: Currently ``state_dict()`` also accepts positional arguments for ``destination``, ``prefix`` and ``keep_vars`` in order. However, this is being deprecated and keyword arguments will be enforced in future releases. .. warning:: Please avoid the use of argument ``destination`` as it is not designed for end-users. Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. Returns: dict: a dictionary containing a whole state of the module Example:: >>> # xdoctest: +SKIP(\" undefined vars \") >>> module.state_dict().keys() ['bias', 'weight'] \"\"\" # TODO : Remove ` args ` and the parsing logic when BC allows . if len ( args ) > 0 : if destination is None : destination = args [ 0 ] if len ( args ) > 1 and prefix == '' : prefix = args [ 1 ] if len ( args ) > 2 and keep_vars is False : keep_vars = args [ 2 ] # DeprecationWarning is ignored by default warnings . warn ( \"Positional args are being deprecated, use kwargs instead. Refer to \" \"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict\" \" for details.\" ) if destination is None : destination = OrderedDict () destination . _ metadata = OrderedDict () local_metadata = dict ( version = self . _ version ) if hasattr ( destination , \"_metadata\" ) : destination . _ metadata [ prefix [ :- 1 ]] = local_metadata for hook in self . _ state_dict_pre_hooks . val ues () : hook ( self , prefix , keep_vars ) self . _ save_to_state_dict ( destination , prefix , keep_vars ) for name , module in self . _ modules . items () : if module is not None : module . state_dict ( destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars ) for hook in self . _ state_dict_hooks . val ues () : hook_result = hook ( self , destination , prefix , local_metadata ) if hook_result is not None : destination = hook_result return destination","title":"state_dict"},{"location":"reference/wtracker/neural/mlp/#to_3","text":"def to ( self , * args , ** kwargs ) Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth: torch.Tensor.to , but only accepts floating point or complex :attr: dtype \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr: dtype (if given). The integral parameters and buffers will be moved :attr: device , if that is given, but with dtypes unchanged. When :attr: non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device ( None class: torch.device ): the desired device of the parameters and buffers in this module None dtype ( None class: torch.dtype ): the desired floating point or complex dtype of the parameters and buffers in this module None tensor torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module None memory_format ( None class: torch.memory_format ): the desired memory format for 4D parameters and buffers in this module (keyword only argument) None Returns: Type Description Module self View Source def to ( self , * args , ** kwargs ) : r \" \"\" Move and/or cast the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) :noindex: .. function:: to(dtype, non_blocking=False) :noindex: .. function:: to(tensor, non_blocking=False) :noindex: .. function:: to(memory_format=torch.channels_last) :noindex: Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype` \\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory format for 4D parameters and buffers in this module (keyword only argument) Returns: Module: self Examples:: >>> # xdoctest: +IGNORE_WANT(\" non - deterministic \") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device(\" cuda : 1 \") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device(\" cpu \") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) \"\" \" device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , ** kwargs ) if dtype is not None : if not ( dtype . is_floating_point or dtype . is_complex ) : raise TypeError ( 'nn.Module.to only accepts floating point or complex ' f 'dtypes, but got desired dtype={dtype}' ) if dtype . is_complex : warnings . warn ( \"Complex modules are a new feature under active development whose design may change, \" \"and some modules might not work as expected when using complex tensors as parameters or buffers. \" \"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml \" \"if a complex module does not work as expected.\" ) def convert ( t ) : try : if convert_to_format is not None and t . dim () in ( 4 , 5 ) : return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , memory_format = convert_to_format , ) return t . to ( device , dtype if t . is_floating_point () or t . is_complex () else None , non_blocking , ) except NotImplementedError as e : if str ( e ) == \"Cannot copy out of meta tensor; no data!\" : raise NotImplementedError ( f \"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() \" f \"when moving module from meta to a different device.\" ) from None else : raise return self . _apply ( convert )","title":"to"},{"location":"reference/wtracker/neural/mlp/#to_empty_3","text":"def to_empty ( self : ~ T , * , device : Union [ int , str , torch . device , NoneType ], recurse : bool = True ) -> ~ T Move the parameters and buffers to the specified device without copying storage. Parameters: Name Type Description Default device ( None class: torch.device ): The desired device of the parameters and buffers in this module. None recurse bool Whether parameters and buffers of submodules should be recursively moved to the specified device. None Returns: Type Description Module self View Source def to_empty ( self : T , * , device : Optional [ DeviceLikeType ] , recurse : bool = True ) -> T : r \"\"\"Move the parameters and buffers to the specified device without copying storage. Args: device (:class:`torch.device`): The desired device of the parameters and buffers in this module. recurse (bool): Whether parameters and buffers of submodules should be recursively moved to the specified device. Returns: Module: self \"\"\" return self . _apply ( lambda t : torch . empty_like ( t , device = device ), recurse = recurse )","title":"to_empty"},{"location":"reference/wtracker/neural/mlp/#train_3","text":"def train ( self : ~ T , mode : bool = True ) -> ~ T Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class: Dropout , :class: BatchNorm , etc. Parameters: Name Type Description Default mode bool whether to set training mode ( True ) or evaluation mode ( False ). Default: True . None Returns: Type Description Module self View Source def train ( self : T , mode : bool = True ) -> T : r \" \"\" Set the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self \"\" \" if not isinstance ( mode , bool ) : raise ValueError ( \"training mode is expected to be boolean\" ) self . training = mode for module in self . children () : module . train ( mode ) return self","title":"train"},{"location":"reference/wtracker/neural/mlp/#type_3","text":"def type ( self : ~ T , dst_type : Union [ torch . dtype , str ] ) -> ~ T Casts all parameters and buffers to :attr: dst_type . .. note:: This method modifies the module in-place. Parameters: Name Type Description Default dst_type type or string the desired type None Returns: Type Description Module self View Source def type ( self : T , dst_type : Union [ dtype , str ] ) -> T : r \" \"\" Casts all parameters and buffers to :attr:`dst_type`. .. note:: This method modifies the module in-place. Args: dst_type (type or string): the desired type Returns: Module: self \"\" \" return self . _apply ( lambda t : t . type ( dst_type ))","title":"type"},{"location":"reference/wtracker/neural/mlp/#xpu_3","text":"def xpu ( self : ~ T , device : Union [ int , torch . device , NoneType ] = None ) -> ~ T Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Parameters: Name Type Description Default device int if specified, all parameters will be copied to that device None Returns: Type Description Module self View Source def xpu ( self : T , device : Optional [ Union [ int , device ]] = None ) -> T : r \"\"\"Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized. .. note:: This method modifies the module in-place. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self \"\"\" return self . _apply ( lambda t : t . xpu ( device ))","title":"xpu"},{"location":"reference/wtracker/neural/mlp/#zero_grad_3","text":"def zero_grad ( self , set_to_none : bool = True ) -> None Reset gradients of all model parameters. See similar function under :class: torch.optim.Optimizer for more context. Parameters: Name Type Description Default set_to_none bool instead of setting to zero, set the grads to None. See :meth: torch.optim.Optimizer.zero_grad for details. None View Source def zero_grad ( self , set_to_none : bool = True ) -> None : r \"\"\"Reset gradients of all model parameters. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. \"\"\" if getattr ( self , '_is_replica' , False ) : warnings . warn ( \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \" \"The parameters are copied (in a differentiable manner) from the original module. \" \"This means they are not leaf nodes in autograd and so don' t accumulate gradients . \" \" If you need gradients in your forward method , consider using autograd . grad instead . \") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()","title":"zero_grad"},{"location":"reference/wtracker/neural/train_results/","text":"Module wtracker.neural.train_results View Source from typing import List , NamedTuple class BatchResult ( NamedTuple ): \"\"\" Represents the result of training for a single batch: the loss and number of correct classifications. \"\"\" loss : float num_correct : int class EpochResult ( NamedTuple ): \"\"\" Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). \"\"\" losses : List [ float ] accuracy : float class FitResult ( NamedTuple ): \"\"\" Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. \"\"\" num_epochs : int train_loss : List [ float ] train_acc : List [ float ] test_loss : List [ float ] test_acc : List [ float ] Classes BatchResult class BatchResult ( / , * args , ** kwargs ) Represents the result of training for a single batch: the loss and number of correct classifications. View Source class BatchResult ( NamedTuple ): \"\"\" Represents the result of training for a single batch: the loss and number of correct classifications. \"\"\" loss: float num_correct: int Ancestors (in MRO) builtins.tuple Class variables loss num_correct Methods count def count ( self , value , / ) Return number of occurrences of value. index def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present. EpochResult class EpochResult ( / , * args , ** kwargs ) Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). View Source class EpochResult ( NamedTuple ) : \"\"\" Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). \"\"\" losses : List [ float ] accuracy : float Ancestors (in MRO) builtins.tuple Class variables accuracy losses Methods count def count ( self , value , / ) Return number of occurrences of value. index def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present. FitResult class FitResult ( / , * args , ** kwargs ) Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. View Source class FitResult ( NamedTuple ) : \"\"\" Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. \"\"\" num_epochs : int train_loss : List [ float ] train_acc : List [ float ] test_loss : List [ float ] test_acc : List [ float ] Ancestors (in MRO) builtins.tuple Class variables num_epochs test_acc test_loss train_acc train_loss Methods count def count ( self , value , / ) Return number of occurrences of value. index def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present.","title":"Train Results"},{"location":"reference/wtracker/neural/train_results/#module-wtrackerneuraltrain_results","text":"View Source from typing import List , NamedTuple class BatchResult ( NamedTuple ): \"\"\" Represents the result of training for a single batch: the loss and number of correct classifications. \"\"\" loss : float num_correct : int class EpochResult ( NamedTuple ): \"\"\" Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). \"\"\" losses : List [ float ] accuracy : float class FitResult ( NamedTuple ): \"\"\" Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. \"\"\" num_epochs : int train_loss : List [ float ] train_acc : List [ float ] test_loss : List [ float ] test_acc : List [ float ]","title":"Module wtracker.neural.train_results"},{"location":"reference/wtracker/neural/train_results/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/train_results/#batchresult","text":"class BatchResult ( / , * args , ** kwargs ) Represents the result of training for a single batch: the loss and number of correct classifications. View Source class BatchResult ( NamedTuple ): \"\"\" Represents the result of training for a single batch: the loss and number of correct classifications. \"\"\" loss: float num_correct: int","title":"BatchResult"},{"location":"reference/wtracker/neural/train_results/#ancestors-in-mro","text":"builtins.tuple","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/train_results/#class-variables","text":"loss num_correct","title":"Class variables"},{"location":"reference/wtracker/neural/train_results/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/train_results/#count","text":"def count ( self , value , / ) Return number of occurrences of value.","title":"count"},{"location":"reference/wtracker/neural/train_results/#index","text":"def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present.","title":"index"},{"location":"reference/wtracker/neural/train_results/#epochresult","text":"class EpochResult ( / , * args , ** kwargs ) Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). View Source class EpochResult ( NamedTuple ) : \"\"\" Represents the result of training for a single epoch: the loss per batch and accuracy on the dataset (train or test). \"\"\" losses : List [ float ] accuracy : float","title":"EpochResult"},{"location":"reference/wtracker/neural/train_results/#ancestors-in-mro_1","text":"builtins.tuple","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/train_results/#class-variables_1","text":"accuracy losses","title":"Class variables"},{"location":"reference/wtracker/neural/train_results/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/neural/train_results/#count_1","text":"def count ( self , value , / ) Return number of occurrences of value.","title":"count"},{"location":"reference/wtracker/neural/train_results/#index_1","text":"def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present.","title":"index"},{"location":"reference/wtracker/neural/train_results/#fitresult","text":"class FitResult ( / , * args , ** kwargs ) Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. View Source class FitResult ( NamedTuple ) : \"\"\" Represents the result of fitting a model for multiple epochs given a training and test (or validation) set. The losses are for each batch and the accuracies are per epoch. \"\"\" num_epochs : int train_loss : List [ float ] train_acc : List [ float ] test_loss : List [ float ] test_acc : List [ float ]","title":"FitResult"},{"location":"reference/wtracker/neural/train_results/#ancestors-in-mro_2","text":"builtins.tuple","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/train_results/#class-variables_2","text":"num_epochs test_acc test_loss train_acc train_loss","title":"Class variables"},{"location":"reference/wtracker/neural/train_results/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/neural/train_results/#count_2","text":"def count ( self , value , / ) Return number of occurrences of value.","title":"count"},{"location":"reference/wtracker/neural/train_results/#index_2","text":"def index ( self , value , start = 0 , stop = 9223372036854775807 , / ) Return first index of value. Raises ValueError if the value is not present.","title":"index"},{"location":"reference/wtracker/neural/training/","text":"Module wtracker.neural.training View Source import os import abc import sys import torch import torch.nn as nn import torch.nn.functional import tqdm.auto from torch import Tensor from typing import Any , Tuple , Callable , Optional from torch.optim import Optimizer from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from wtracker.neural.train_results import FitResult , BatchResult , EpochResult class Trainer ( abc . ABC ): \"\"\" A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Args: model (nn.Module): The model to train. device (Optional[torch.device], optional): The device to run training on (CPU or GPU). log (bool, optional): Whether to log training progress with tensorboard. \"\"\" def __init__ ( self , model : nn . Module , device : Optional [ torch . device ] = None , log : bool = False , ): self . model = model self . device = device self . logger = None if not log else SummaryWriter () if self . logger is not None : self . logger . add_hparams ({ \"model\" : model . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"device\" : str ( device )}, {}, run_name = \"hparams\" ) if self . device : model . to ( self . device ) def _make_batch_result ( self , loss , num_correct ) -> BatchResult : loss = loss . item () if isinstance ( loss , Tensor ) else loss num_correct = num_correct . item () if isinstance ( num_correct , Tensor ) else num_correct return BatchResult ( float ( loss ), int ( num_correct )) def _make_fit_result ( self , num_epochs , train_losses , train_acc , test_losses , test_acc ) -> FitResult : num_epochs = num_epochs . item () if isinstance ( num_epochs , Tensor ) else num_epochs train_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in train_losses ] train_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in train_acc ] test_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in test_losses ] test_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in test_acc ] return FitResult ( int ( num_epochs ), train_losses , train_acc , test_losses , test_acc ) def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH { epoch + 1 } / { num_epochs } ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \" { self . logger . log_dir } / { checkpoint_filename } \" torch . save ( self . model , checkpoint_filename ) print ( f \" \\n *** Saved checkpoint { checkpoint_filename } :: val_loss= { loss : .3f } \" ) def train_epoch ( self , dl_train : DataLoader , ** kw ) -> EpochResult : \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( True ) # set train mode return self . _foreach_batch ( dl_train , self . train_batch , ** kw ) def test_epoch ( self , dl_test : DataLoader , ** kw ) -> EpochResult : \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( False ) # set evaluation (test) mode return self . _foreach_batch ( dl_test , self . test_batch , ** kw ) @abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @staticmethod def _print ( message , verbose = True ): \"\"\"Simple wrapper around print to make it conditional\"\"\" if verbose : print ( message ) @staticmethod def _foreach_batch ( dl : DataLoader , forward_fn : Callable [[ Any ], BatchResult ], verbose = True , max_batches = None , ) -> EpochResult : \"\"\" Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. \"\"\" losses = [] num_correct = 0 num_samples = len ( dl . sampler ) num_batches = len ( dl . batch_sampler ) if max_batches is not None : if max_batches < num_batches : num_batches = max_batches num_samples = num_batches * dl . batch_size if verbose : pbar_fn = tqdm . auto . tqdm pbar_file = sys . stdout else : pbar_fn = tqdm . tqdm pbar_file = open ( os . devnull , \"w\" ) pbar_name = forward_fn . __name__ with pbar_fn ( desc = pbar_name , total = num_batches , file = pbar_file ) as pbar : dl_iter = iter ( dl ) for batch_idx in range ( num_batches ): data = next ( dl_iter ) batch_res = forward_fn ( data ) pbar . set_description ( f \" { pbar_name } ( { batch_res . loss : .3f } )\" ) pbar . update () losses . append ( batch_res . loss ) num_correct += batch_res . num_correct avg_loss = sum ( losses ) / num_batches accuracy = 100.0 * num_correct / num_samples pbar . set_description ( f \" { pbar_name } \" f \"(Avg. Loss { avg_loss : .3f } , \" f \"Accuracy { accuracy : .2f } %)\" ) if not verbose : pbar_file . close () return EpochResult ( losses = losses , accuracy = accuracy ) def log_hparam ( self , hparam_dict : dict [ str , Any ], metric_dict : dict [ str , Any ] = {}, run_name : str = \"hparams\" ): if self . logger is not None : self . logger . add_hparams ( hparam_dict , metric_dict , run_name = run_name ) class MLPTrainer ( Trainer ): \"\"\" The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models. Args: model (nn.Module): The MLP model to be trained. loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. device (Optional[torch.device], optional): The device on which the model and data should be loaded. log (bool, optional): Whether to log training progress with tensorboard. Attributes: loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. \"\"\" def __init__ ( self , model : nn . Module , loss_fn : nn . Module , optimizer : Optimizer , device : Optional [ torch . device ] = None , log : bool = False , ): super () . __init__ ( model , device , log = log ) self . loss_fn = loss_fn self . optimizer = optimizer if self . logger is not None : self . logger . add_hparams ({ \"loss_fn\" : loss_fn . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"optimizer\" : optimizer . __class__ . __name__ }, {}, run_name = \"hparams\" ) optimizer_params = {} for key , val in optimizer . param_groups [ 0 ] . items (): optimizer_params [ key ] = str ( val ) optimizer_params . update ({ \"params\" : \"\" }) self . logger . add_hparams ( optimizer_params , {}, run_name = \"hparams\" ) def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) @torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct ) Classes MLPTrainer class MLPTrainer ( model : torch . nn . modules . module . Module , loss_fn : torch . nn . modules . module . Module , optimizer : torch . optim . optimizer . Optimizer , device : Optional [ torch . device ] = None , log : bool = False ) The MLPTrainer class is responsible for training and testing a multi-layer perceptron (MLP) models. Attributes Name Type Description Default model nn.Module The MLP model to be trained. None loss_fn nn.Module The loss function used for training. None optimizer Optimizer The optimizer used for updating the model's parameters. None device Optional[torch.device] The device on which the model and data should be loaded. None log bool Whether to log training progress with tensorboard. None loss_fn nn.Module The loss function used for training. None optimizer Optimizer The optimizer used for updating the model's parameters. None View Source class MLPTrainer ( Trainer ): \"\"\" The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models. Args: model (nn.Module): The MLP model to be trained. loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. device (Optional[torch.device], optional): The device on which the model and data should be loaded. log (bool, optional): Whether to log training progress with tensorboard. Attributes: loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. \"\"\" def __init__ ( self , model : nn . Module , loss_fn : nn . Module , optimizer : Optimizer , device : Optional [ torch . device ] = None , log : bool = False , ): super () . __init__ ( model , device , log = log ) self . loss_fn = loss_fn self . optimizer = optimizer if self . logger is not None : self . logger . add_hparams ({ \"loss_fn\" : loss_fn . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"optimizer\" : optimizer . __class__ . __name__ }, {}, run_name = \"hparams\" ) optimizer_params = {} for key , val in optimizer . param_groups [ 0 ] . items (): optimizer_params [ key ] = str ( val ) optimizer_params . update ({ \"params\" : \"\" }) self . logger . add_hparams ( optimizer_params , {}, run_name = \"hparams\" ) def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) @ torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct ) Ancestors (in MRO) wtracker.neural.training.Trainer abc.ABC Methods fit def fit ( self , dl_train : torch . utils . data . dataloader . DataLoader , dl_test : torch . utils . data . dataloader . DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw ) -> wtracker . neural . train_results . FitResult Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Parameters: Name Type Description Default dl_train DataLoader Dataloader for the training set. None dl_test DataLoader Dataloader for the test set. None num_epochs int Number of epochs to train for. None checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None print_every int Print progress every this number of epochs. None Returns: Type Description FitResult A FitResult object containing train and test losses per epoch. View Source def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) log_hparam def log_hparam ( self , hparam_dict : dict [ str , typing . Any ], metric_dict : dict [ str , typing . Any ] = {}, run_name : str = 'hparams' ) View Source def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = \"hparams\"): if self.logger is not None: self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name) save_checkpoint def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None Saves the model in it's current state to a file with the given name (treated as a relative path). Parameters: Name Type Description Default checkpoint_filename str File name or relative path to save to. None View Source def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \"\\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" ) test_batch def test_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model and calculates loss. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ). norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct ) test_epoch def test_epoch ( self , dl_test : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Evaluate model once over a test set (single epoch). Parameters: Name Type Description Default dl_test DataLoader DataLoader for the test set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(False) # set evaluation (test) mode return self._foreach_batch(dl_test, self.test_batch, **kw) train_batch def train_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ). norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) train_epoch def train_epoch ( self , dl_train : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Train once over a training set (single epoch). Parameters: Name Type Description Default dl_train DataLoader DataLoader for the training set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(True) # set train mode return self._foreach_batch(dl_train, self.train_batch, **kw) Trainer class Trainer ( model : torch . nn . modules . module . Module , device : Optional [ torch . device ] = None , log : bool = False ) A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Attributes Name Type Description Default model nn.Module The model to train. None device Optional[torch.device] The device to run training on (CPU or GPU). None log bool Whether to log training progress with tensorboard. None View Source class Trainer ( abc . ABC ): \"\"\" A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Args: model (nn.Module): The model to train. device (Optional[torch.device], optional): The device to run training on (CPU or GPU). log (bool, optional): Whether to log training progress with tensorboard. \"\"\" def __init__ ( self , model : nn . Module , device : Optional [ torch . device ] = None , log : bool = False , ): self . model = model self . device = device self . logger = None if not log else SummaryWriter () if self . logger is not None : self . logger . add_hparams ({ \"model\" : model . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"device\" : str ( device )}, {}, run_name = \"hparams\" ) if self . device : model . to ( self . device ) def _make_batch_result ( self , loss , num_correct ) -> BatchResult : loss = loss . item () if isinstance ( loss , Tensor ) else loss num_correct = num_correct . item () if isinstance ( num_correct , Tensor ) else num_correct return BatchResult ( float ( loss ), int ( num_correct )) def _make_fit_result ( self , num_epochs , train_losses , train_acc , test_losses , test_acc ) -> FitResult : num_epochs = num_epochs . item () if isinstance ( num_epochs , Tensor ) else num_epochs train_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in train_losses ] train_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in train_acc ] test_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in test_losses ] test_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in test_acc ] return FitResult ( int ( num_epochs ), train_losses , train_acc , test_losses , test_acc ) def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \" \\n *** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" ) def train_epoch ( self , dl_train : DataLoader , ** kw ) -> EpochResult : \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( True ) # set train mode return self . _foreach_batch ( dl_train , self . train_batch , ** kw ) def test_epoch ( self , dl_test : DataLoader , ** kw ) -> EpochResult : \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( False ) # set evaluation (test) mode return self . _foreach_batch ( dl_test , self . test_batch , ** kw ) @ abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @ abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @ staticmethod def _print ( message , verbose = True ): \"\"\"Simple wrapper around print to make it conditional\"\"\" if verbose : print ( message ) @ staticmethod def _foreach_batch ( dl : DataLoader , forward_fn : Callable [[ Any ], BatchResult ], verbose = True , max_batches = None , ) -> EpochResult : \"\"\" Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. \"\"\" losses = [] num_correct = 0 num_samples = len ( dl . sampler ) num_batches = len ( dl . batch_sampler ) if max_batches is not None : if max_batches < num_batches : num_batches = max_batches num_samples = num_batches * dl . batch_size if verbose : pbar_fn = tqdm . auto . tqdm pbar_file = sys . stdout else : pbar_fn = tqdm . tqdm pbar_file = open ( os . devnull , \"w\" ) pbar_name = forward_fn . __name__ with pbar_fn ( desc = pbar_name , total = num_batches , file = pbar_file ) as pbar : dl_iter = iter ( dl ) for batch_idx in range ( num_batches ): data = next ( dl_iter ) batch_res = forward_fn ( data ) pbar . set_description ( f \"{pbar_name} ({batch_res.loss:.3f})\" ) pbar . update () losses . append ( batch_res . loss ) num_correct += batch_res . num_correct avg_loss = sum ( losses ) / num_batches accuracy = 100.0 * num_correct / num_samples pbar . set_description ( f \"{pbar_name} \" f \"(Avg. Loss {avg_loss:.3f}, \" f \"Accuracy {accuracy:.2f}%)\" ) if not verbose : pbar_file . close () return EpochResult ( losses = losses , accuracy = accuracy ) def log_hparam ( self , hparam_dict : dict [ str , Any ], metric_dict : dict [ str , Any ] = {}, run_name : str = \"hparams\" ): if self . logger is not None : self . logger . add_hparams ( hparam_dict , metric_dict , run_name = run_name ) Ancestors (in MRO) abc.ABC Descendants wtracker.neural.training.MLPTrainer Methods fit def fit ( self , dl_train : torch . utils . data . dataloader . DataLoader , dl_test : torch . utils . data . dataloader . DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw ) -> wtracker . neural . train_results . FitResult Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Parameters: Name Type Description Default dl_train DataLoader Dataloader for the training set. None dl_test DataLoader Dataloader for the test set. None num_epochs int Number of epochs to train for. None checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None print_every int Print progress every this number of epochs. None Returns: Type Description FitResult A FitResult object containing train and test losses per epoch. View Source def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) log_hparam def log_hparam ( self , hparam_dict : dict [ str , typing . Any ], metric_dict : dict [ str , typing . Any ] = {}, run_name : str = 'hparams' ) View Source def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = \"hparams\"): if self.logger is not None: self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name) save_checkpoint def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None Saves the model in it's current state to a file with the given name (treated as a relative path). Parameters: Name Type Description Default checkpoint_filename str File name or relative path to save to. None View Source def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \"\\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" ) test_batch def test_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model and calculates loss. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @ abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () test_epoch def test_epoch ( self , dl_test : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Evaluate model once over a test set (single epoch). Parameters: Name Type Description Default dl_test DataLoader DataLoader for the test set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(False) # set evaluation (test) mode return self._foreach_batch(dl_test, self.test_batch, **kw) train_batch def train_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @ abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () train_epoch def train_epoch ( self , dl_train : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Train once over a training set (single epoch). Parameters: Name Type Description Default dl_train DataLoader DataLoader for the training set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(True) # set train mode return self._foreach_batch(dl_train, self.train_batch, **kw)","title":"Training"},{"location":"reference/wtracker/neural/training/#module-wtrackerneuraltraining","text":"View Source import os import abc import sys import torch import torch.nn as nn import torch.nn.functional import tqdm.auto from torch import Tensor from typing import Any , Tuple , Callable , Optional from torch.optim import Optimizer from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from wtracker.neural.train_results import FitResult , BatchResult , EpochResult class Trainer ( abc . ABC ): \"\"\" A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Args: model (nn.Module): The model to train. device (Optional[torch.device], optional): The device to run training on (CPU or GPU). log (bool, optional): Whether to log training progress with tensorboard. \"\"\" def __init__ ( self , model : nn . Module , device : Optional [ torch . device ] = None , log : bool = False , ): self . model = model self . device = device self . logger = None if not log else SummaryWriter () if self . logger is not None : self . logger . add_hparams ({ \"model\" : model . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"device\" : str ( device )}, {}, run_name = \"hparams\" ) if self . device : model . to ( self . device ) def _make_batch_result ( self , loss , num_correct ) -> BatchResult : loss = loss . item () if isinstance ( loss , Tensor ) else loss num_correct = num_correct . item () if isinstance ( num_correct , Tensor ) else num_correct return BatchResult ( float ( loss ), int ( num_correct )) def _make_fit_result ( self , num_epochs , train_losses , train_acc , test_losses , test_acc ) -> FitResult : num_epochs = num_epochs . item () if isinstance ( num_epochs , Tensor ) else num_epochs train_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in train_losses ] train_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in train_acc ] test_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in test_losses ] test_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in test_acc ] return FitResult ( int ( num_epochs ), train_losses , train_acc , test_losses , test_acc ) def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH { epoch + 1 } / { num_epochs } ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \" { self . logger . log_dir } / { checkpoint_filename } \" torch . save ( self . model , checkpoint_filename ) print ( f \" \\n *** Saved checkpoint { checkpoint_filename } :: val_loss= { loss : .3f } \" ) def train_epoch ( self , dl_train : DataLoader , ** kw ) -> EpochResult : \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( True ) # set train mode return self . _foreach_batch ( dl_train , self . train_batch , ** kw ) def test_epoch ( self , dl_test : DataLoader , ** kw ) -> EpochResult : \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( False ) # set evaluation (test) mode return self . _foreach_batch ( dl_test , self . test_batch , ** kw ) @abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @staticmethod def _print ( message , verbose = True ): \"\"\"Simple wrapper around print to make it conditional\"\"\" if verbose : print ( message ) @staticmethod def _foreach_batch ( dl : DataLoader , forward_fn : Callable [[ Any ], BatchResult ], verbose = True , max_batches = None , ) -> EpochResult : \"\"\" Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. \"\"\" losses = [] num_correct = 0 num_samples = len ( dl . sampler ) num_batches = len ( dl . batch_sampler ) if max_batches is not None : if max_batches < num_batches : num_batches = max_batches num_samples = num_batches * dl . batch_size if verbose : pbar_fn = tqdm . auto . tqdm pbar_file = sys . stdout else : pbar_fn = tqdm . tqdm pbar_file = open ( os . devnull , \"w\" ) pbar_name = forward_fn . __name__ with pbar_fn ( desc = pbar_name , total = num_batches , file = pbar_file ) as pbar : dl_iter = iter ( dl ) for batch_idx in range ( num_batches ): data = next ( dl_iter ) batch_res = forward_fn ( data ) pbar . set_description ( f \" { pbar_name } ( { batch_res . loss : .3f } )\" ) pbar . update () losses . append ( batch_res . loss ) num_correct += batch_res . num_correct avg_loss = sum ( losses ) / num_batches accuracy = 100.0 * num_correct / num_samples pbar . set_description ( f \" { pbar_name } \" f \"(Avg. Loss { avg_loss : .3f } , \" f \"Accuracy { accuracy : .2f } %)\" ) if not verbose : pbar_file . close () return EpochResult ( losses = losses , accuracy = accuracy ) def log_hparam ( self , hparam_dict : dict [ str , Any ], metric_dict : dict [ str , Any ] = {}, run_name : str = \"hparams\" ): if self . logger is not None : self . logger . add_hparams ( hparam_dict , metric_dict , run_name = run_name ) class MLPTrainer ( Trainer ): \"\"\" The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models. Args: model (nn.Module): The MLP model to be trained. loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. device (Optional[torch.device], optional): The device on which the model and data should be loaded. log (bool, optional): Whether to log training progress with tensorboard. Attributes: loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. \"\"\" def __init__ ( self , model : nn . Module , loss_fn : nn . Module , optimizer : Optimizer , device : Optional [ torch . device ] = None , log : bool = False , ): super () . __init__ ( model , device , log = log ) self . loss_fn = loss_fn self . optimizer = optimizer if self . logger is not None : self . logger . add_hparams ({ \"loss_fn\" : loss_fn . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"optimizer\" : optimizer . __class__ . __name__ }, {}, run_name = \"hparams\" ) optimizer_params = {} for key , val in optimizer . param_groups [ 0 ] . items (): optimizer_params [ key ] = str ( val ) optimizer_params . update ({ \"params\" : \"\" }) self . logger . add_hparams ( optimizer_params , {}, run_name = \"hparams\" ) def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) @torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct )","title":"Module wtracker.neural.training"},{"location":"reference/wtracker/neural/training/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/neural/training/#mlptrainer","text":"class MLPTrainer ( model : torch . nn . modules . module . Module , loss_fn : torch . nn . modules . module . Module , optimizer : torch . optim . optimizer . Optimizer , device : Optional [ torch . device ] = None , log : bool = False ) The MLPTrainer class is responsible for training and testing a multi-layer perceptron (MLP) models.","title":"MLPTrainer"},{"location":"reference/wtracker/neural/training/#attributes","text":"Name Type Description Default model nn.Module The MLP model to be trained. None loss_fn nn.Module The loss function used for training. None optimizer Optimizer The optimizer used for updating the model's parameters. None device Optional[torch.device] The device on which the model and data should be loaded. None log bool Whether to log training progress with tensorboard. None loss_fn nn.Module The loss function used for training. None optimizer Optimizer The optimizer used for updating the model's parameters. None View Source class MLPTrainer ( Trainer ): \"\"\" The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models. Args: model (nn.Module): The MLP model to be trained. loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. device (Optional[torch.device], optional): The device on which the model and data should be loaded. log (bool, optional): Whether to log training progress with tensorboard. Attributes: loss_fn (nn.Module): The loss function used for training. optimizer (Optimizer): The optimizer used for updating the model's parameters. \"\"\" def __init__ ( self , model : nn . Module , loss_fn : nn . Module , optimizer : Optimizer , device : Optional [ torch . device ] = None , log : bool = False , ): super () . __init__ ( model , device , log = log ) self . loss_fn = loss_fn self . optimizer = optimizer if self . logger is not None : self . logger . add_hparams ({ \"loss_fn\" : loss_fn . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"optimizer\" : optimizer . __class__ . __name__ }, {}, run_name = \"hparams\" ) optimizer_params = {} for key , val in optimizer . param_groups [ 0 ] . items (): optimizer_params [ key ] = str ( val ) optimizer_params . update ({ \"params\" : \"\" }) self . logger . add_hparams ( optimizer_params , {}, run_name = \"hparams\" ) def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct ) @ torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ) . norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct )","title":"Attributes"},{"location":"reference/wtracker/neural/training/#ancestors-in-mro","text":"wtracker.neural.training.Trainer abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/training/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/neural/training/#fit","text":"def fit ( self , dl_train : torch . utils . data . dataloader . DataLoader , dl_test : torch . utils . data . dataloader . DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw ) -> wtracker . neural . train_results . FitResult Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Parameters: Name Type Description Default dl_train DataLoader Dataloader for the training set. None dl_test DataLoader Dataloader for the test set. None num_epochs int Number of epochs to train for. None checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None print_every int Print progress every this number of epochs. None Returns: Type Description FitResult A FitResult object containing train and test losses per epoch. View Source def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc )","title":"fit"},{"location":"reference/wtracker/neural/training/#log_hparam","text":"def log_hparam ( self , hparam_dict : dict [ str , typing . Any ], metric_dict : dict [ str , typing . Any ] = {}, run_name : str = 'hparams' ) View Source def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = \"hparams\"): if self.logger is not None: self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name)","title":"log_hparam"},{"location":"reference/wtracker/neural/training/#save_checkpoint","text":"def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None Saves the model in it's current state to a file with the given name (treated as a relative path). Parameters: Name Type Description Default checkpoint_filename str File name or relative path to save to. None View Source def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \"\\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" )","title":"save_checkpoint"},{"location":"reference/wtracker/neural/training/#test_batch","text":"def test_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model and calculates loss. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @torch . no_grad () def test_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) preds = self . model . forward ( X ) num_correct = torch . sum (( preds - y ). norm ( dim = 1 ) < 1.0 ) loss = self . loss_fn ( preds , y ) return self . _make_batch_result ( loss , num_correct )","title":"test_batch"},{"location":"reference/wtracker/neural/training/#test_epoch","text":"def test_epoch ( self , dl_test : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Evaluate model once over a test set (single epoch). Parameters: Name Type Description Default dl_test DataLoader DataLoader for the test set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(False) # set evaluation (test) mode return self._foreach_batch(dl_test, self.test_batch, **kw)","title":"test_epoch"},{"location":"reference/wtracker/neural/training/#train_batch","text":"def train_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source def train_batch ( self , batch ) -> BatchResult : X , y = batch if self . device : X = X . to ( self . device ) y = y . to ( self . device ) self . model : nn . Module self . optimizer . zero_grad () preds = self . model . forward ( X ) loss = self . loss_fn ( preds , y ) loss . backward () self . optimizer . step () num_correct = torch . sum (( preds - y ). norm ( dim = 1 ) < 1.0 ) return self . _make_batch_result ( loss , num_correct )","title":"train_batch"},{"location":"reference/wtracker/neural/training/#train_epoch","text":"def train_epoch ( self , dl_train : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Train once over a training set (single epoch). Parameters: Name Type Description Default dl_train DataLoader DataLoader for the training set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(True) # set train mode return self._foreach_batch(dl_train, self.train_batch, **kw)","title":"train_epoch"},{"location":"reference/wtracker/neural/training/#trainer","text":"class Trainer ( model : torch . nn . modules . module . Module , device : Optional [ torch . device ] = None , log : bool = False ) A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch)","title":"Trainer"},{"location":"reference/wtracker/neural/training/#attributes_1","text":"Name Type Description Default model nn.Module The model to train. None device Optional[torch.device] The device to run training on (CPU or GPU). None log bool Whether to log training progress with tensorboard. None View Source class Trainer ( abc . ABC ): \"\"\" A class abstracting the various tasks of training models. Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch) Args: model (nn.Module): The model to train. device (Optional[torch.device], optional): The device to run training on (CPU or GPU). log (bool, optional): Whether to log training progress with tensorboard. \"\"\" def __init__ ( self , model : nn . Module , device : Optional [ torch . device ] = None , log : bool = False , ): self . model = model self . device = device self . logger = None if not log else SummaryWriter () if self . logger is not None : self . logger . add_hparams ({ \"model\" : model . __class__ . __name__ }, {}, run_name = \"hparams\" ) self . logger . add_hparams ({ \"device\" : str ( device )}, {}, run_name = \"hparams\" ) if self . device : model . to ( self . device ) def _make_batch_result ( self , loss , num_correct ) -> BatchResult : loss = loss . item () if isinstance ( loss , Tensor ) else loss num_correct = num_correct . item () if isinstance ( num_correct , Tensor ) else num_correct return BatchResult ( float ( loss ), int ( num_correct )) def _make_fit_result ( self , num_epochs , train_losses , train_acc , test_losses , test_acc ) -> FitResult : num_epochs = num_epochs . item () if isinstance ( num_epochs , Tensor ) else num_epochs train_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in train_losses ] train_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in train_acc ] test_losses = [ x . item () if isinstance ( x , Tensor ) else x for x in test_losses ] test_acc = [ x . item () if isinstance ( x , Tensor ) else x for x in test_acc ] return FitResult ( int ( num_epochs ), train_losses , train_acc , test_losses , test_acc ) def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc ) def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \" \\n *** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" ) def train_epoch ( self , dl_train : DataLoader , ** kw ) -> EpochResult : \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( True ) # set train mode return self . _foreach_batch ( dl_train , self . train_batch , ** kw ) def test_epoch ( self , dl_test : DataLoader , ** kw ) -> EpochResult : \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self . model . train ( False ) # set evaluation (test) mode return self . _foreach_batch ( dl_test , self . test_batch , ** kw ) @ abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @ abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError () @ staticmethod def _print ( message , verbose = True ): \"\"\"Simple wrapper around print to make it conditional\"\"\" if verbose : print ( message ) @ staticmethod def _foreach_batch ( dl : DataLoader , forward_fn : Callable [[ Any ], BatchResult ], verbose = True , max_batches = None , ) -> EpochResult : \"\"\" Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. \"\"\" losses = [] num_correct = 0 num_samples = len ( dl . sampler ) num_batches = len ( dl . batch_sampler ) if max_batches is not None : if max_batches < num_batches : num_batches = max_batches num_samples = num_batches * dl . batch_size if verbose : pbar_fn = tqdm . auto . tqdm pbar_file = sys . stdout else : pbar_fn = tqdm . tqdm pbar_file = open ( os . devnull , \"w\" ) pbar_name = forward_fn . __name__ with pbar_fn ( desc = pbar_name , total = num_batches , file = pbar_file ) as pbar : dl_iter = iter ( dl ) for batch_idx in range ( num_batches ): data = next ( dl_iter ) batch_res = forward_fn ( data ) pbar . set_description ( f \"{pbar_name} ({batch_res.loss:.3f})\" ) pbar . update () losses . append ( batch_res . loss ) num_correct += batch_res . num_correct avg_loss = sum ( losses ) / num_batches accuracy = 100.0 * num_correct / num_samples pbar . set_description ( f \"{pbar_name} \" f \"(Avg. Loss {avg_loss:.3f}, \" f \"Accuracy {accuracy:.2f}%)\" ) if not verbose : pbar_file . close () return EpochResult ( losses = losses , accuracy = accuracy ) def log_hparam ( self , hparam_dict : dict [ str , Any ], metric_dict : dict [ str , Any ] = {}, run_name : str = \"hparams\" ): if self . logger is not None : self . logger . add_hparams ( hparam_dict , metric_dict , run_name = run_name )","title":"Attributes"},{"location":"reference/wtracker/neural/training/#ancestors-in-mro_1","text":"abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/neural/training/#descendants","text":"wtracker.neural.training.MLPTrainer","title":"Descendants"},{"location":"reference/wtracker/neural/training/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/neural/training/#fit_1","text":"def fit ( self , dl_train : torch . utils . data . dataloader . DataLoader , dl_test : torch . utils . data . dataloader . DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw ) -> wtracker . neural . train_results . FitResult Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Parameters: Name Type Description Default dl_train DataLoader Dataloader for the training set. None dl_test DataLoader Dataloader for the test set. None num_epochs int Number of epochs to train for. None checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None print_every int Print progress every this number of epochs. None Returns: Type Description FitResult A FitResult object containing train and test losses per epoch. View Source def fit ( self , dl_train : DataLoader , dl_test : DataLoader , num_epochs : int , checkpoints : str = None , early_stopping : int = None , print_every : int = 1 , ** kw , ) -> FitResult : \"\"\" Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. Args: dl_train (DataLoader): Dataloader for the training set. dl_test (DataLoader): Dataloader for the test set. num_epochs (int): Number of epochs to train for. checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs. print_every (int, optional): Print progress every this number of epochs. Returns: FitResult: A FitResult object containing train and test losses per epoch. \"\"\" actual_epoch_num = 0 epochs_without_improvement = 0 train_loss , train_acc , test_loss , test_acc = [], [], [], [] best_val_loss = None # add graph to tensorboard if self . logger is not None : self . logger . add_graph ( self . model , next ( iter ( dl_train ))[ 0 ]) for epoch in range ( num_epochs ): actual_epoch_num += 1 verbose = False # pass this to train/test_epoch. if print_every > 0 and ( epoch % print_every == 0 or epoch == num_epochs - 1 ): verbose = True self . _print ( f \"--- EPOCH {epoch+1}/{num_epochs} ---\" , verbose ) train_result = self . train_epoch ( dl_train , verbose = verbose , ** kw ) test_result = self . test_epoch ( dl_test , verbose = verbose , ** kw ) train_loss . extend ( train_result . losses ) train_acc . append ( train_result . accuracy ) test_loss . extend ( test_result . losses ) test_acc . append ( test_result . accuracy ) # log results to tensorboard if self . logger is not None : self . logger . add_scalar ( \"loss/train\" , Tensor ( train_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"loss/test\" , Tensor ( test_result . losses ) . mean (), epoch ) self . logger . add_scalar ( \"accuracy/train\" , train_result . accuracy , epoch ) self . logger . add_scalar ( \"accuracy/test\" , test_result . accuracy , epoch ) self . logger . add_scalar ( \"learning_rate\" , self . optimizer . param_groups [ 0 ][ \"lr\" ], epoch ) curr_val_loss = Tensor ( test_result . losses ) . mean () . item () if best_val_loss is None or curr_val_loss < best_val_loss : best_val_loss = curr_val_loss epochs_without_improvement = 0 if checkpoints is not None : self . save_checkpoint ( checkpoints , curr_val_loss ) else : epochs_without_improvement += 1 if early_stopping is not None and epochs_without_improvement >= early_stopping : break return self . _make_fit_result ( actual_epoch_num , train_loss , train_acc , test_loss , test_acc )","title":"fit"},{"location":"reference/wtracker/neural/training/#log_hparam_1","text":"def log_hparam ( self , hparam_dict : dict [ str , typing . Any ], metric_dict : dict [ str , typing . Any ] = {}, run_name : str = 'hparams' ) View Source def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = \"hparams\"): if self.logger is not None: self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name)","title":"log_hparam"},{"location":"reference/wtracker/neural/training/#save_checkpoint_1","text":"def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None Saves the model in it's current state to a file with the given name (treated as a relative path). Parameters: Name Type Description Default checkpoint_filename str File name or relative path to save to. None View Source def save_checkpoint ( self , checkpoint_filename : str , loss : Optional [ float ] = None ) -> None : \"\"\" Saves the model in it's current state to a file with the given name (treated as a relative path). Args: checkpoint_filename (str): File name or relative path to save to. \"\"\" if self . logger is not None : checkpoint_filename = f \"{self.logger.log_dir}/{checkpoint_filename}\" torch . save ( self . model , checkpoint_filename ) print ( f \"\\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}\" )","title":"save_checkpoint"},{"location":"reference/wtracker/neural/training/#test_batch_1","text":"def test_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model and calculates loss. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @ abc . abstractmethod def test_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model and calculates loss. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError ()","title":"test_batch"},{"location":"reference/wtracker/neural/training/#test_epoch_1","text":"def test_epoch ( self , dl_test : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Evaluate model once over a test set (single epoch). Parameters: Name Type Description Default dl_test DataLoader DataLoader for the test set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: \"\"\" Evaluate model once over a test set (single epoch). Args: dl_test (DataLoader): DataLoader for the test set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(False) # set evaluation (test) mode return self._foreach_batch(dl_test, self.test_batch, **kw)","title":"test_epoch"},{"location":"reference/wtracker/neural/training/#train_batch_1","text":"def train_batch ( self , batch ) -> wtracker . neural . train_results . BatchResult Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Parameters: Name Type Description Default batch None A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). None Returns: Type Description BatchResult A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. View Source @ abc . abstractmethod def train_batch ( self , batch ) -> BatchResult : \"\"\" Runs a single batch forward through the model, calculates loss, preforms back-propagation and updates weights. Args: batch: A single batch of data from a data loader (might be a tuple of data and labels or anything else depending on the underlying dataset). Returns: BatchResult: A BatchResult containing the value of the loss function and the number of correctly classified samples in the batch. \"\"\" raise NotImplementedError ()","title":"train_batch"},{"location":"reference/wtracker/neural/training/#train_epoch_1","text":"def train_epoch ( self , dl_train : torch . utils . data . dataloader . DataLoader , ** kw ) -> wtracker . neural . train_results . EpochResult Train once over a training set (single epoch). Parameters: Name Type Description Default dl_train DataLoader DataLoader for the training set. None kw None Keyword args supported by _foreach_batch. None Returns: Type Description EpochResult An EpochResult for the epoch. View Source def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: \"\"\" Train once over a training set (single epoch). Args: dl_train (DataLoader): DataLoader for the training set. kw: Keyword args supported by _foreach_batch. Returns: EpochResult: An EpochResult for the epoch. \"\"\" self.model.train(True) # set train mode return self._foreach_batch(dl_train, self.train_batch, **kw)","title":"train_epoch"},{"location":"reference/wtracker/sim/","text":"Module wtracker.sim View Source from wtracker.sim.config import TimingConfig , ExperimentConfig from wtracker.sim.motor_controllers import MotorController , StepMotorController , SineMotorController from wtracker.sim.simulator import Simulator , SimController from wtracker.sim.view_controller import ViewController Sub-modules wtracker.sim.config wtracker.sim.motor_controllers wtracker.sim.sim_controllers wtracker.sim.simulator wtracker.sim.view_controller","title":"Index"},{"location":"reference/wtracker/sim/#module-wtrackersim","text":"View Source from wtracker.sim.config import TimingConfig , ExperimentConfig from wtracker.sim.motor_controllers import MotorController , StepMotorController , SineMotorController from wtracker.sim.simulator import Simulator , SimController from wtracker.sim.view_controller import ViewController","title":"Module wtracker.sim"},{"location":"reference/wtracker/sim/#sub-modules","text":"wtracker.sim.config wtracker.sim.motor_controllers wtracker.sim.sim_controllers wtracker.sim.simulator wtracker.sim.view_controller","title":"Sub-modules"},{"location":"reference/wtracker/sim/config/","text":"Module wtracker.sim.config View Source from __future__ import annotations from dataclasses import dataclass , field import math from wtracker.utils.config_base import ConfigBase from wtracker.utils.frame_reader import FrameReader @dataclass class TimingConfig ( ConfigBase ): \"\"\" Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. \"\"\" experiment_config : ExperimentConfig = field ( repr = False ) \"\"\"The configuration of the experiment parameters.\"\"\" px_per_mm : int = field ( init = False ) mm_per_px : float = field ( init = False ) frames_per_sec : int = field ( init = False ) ms_per_frame : float = field ( init = False ) imaging_time_ms : float imaging_frame_num : int = field ( init = False ) pred_time_ms : float pred_frame_num : int = field ( init = False ) moving_time_ms : float moving_frame_num : int = field ( init = False ) camera_size_mm : tuple [ float , float ] camera_size_px : tuple [ int , int ] = field ( init = False ) micro_size_mm : tuple [ float , float ] micro_size_px : tuple [ int , int ] = field ( init = False ) def __post_init__ ( self ): self . frames_per_sec = self . experiment_config . frames_per_sec self . ms_per_frame = self . experiment_config . ms_per_frame self . imaging_frame_num = math . ceil ( self . imaging_time_ms / self . ms_per_frame ) self . pred_frame_num = math . ceil ( self . pred_time_ms / self . ms_per_frame ) self . moving_frame_num = math . ceil ( self . moving_time_ms / self . ms_per_frame ) self . mm_per_px = self . experiment_config . mm_per_px self . px_per_mm = self . experiment_config . px_per_mm self . camera_size_px = ( round ( self . px_per_mm * self . camera_size_mm [ 0 ]), round ( self . px_per_mm * self . camera_size_mm [ 1 ]), ) self . micro_size_px = ( round ( self . px_per_mm * self . micro_size_mm [ 0 ]), round ( self . px_per_mm * self . micro_size_mm [ 1 ]), ) del self . experiment_config # experiment_config was temporary, only for the constructor @property def cycle_frame_num ( self ) -> int : return self . imaging_frame_num + self . moving_frame_num @property def cycle_time_ms ( self ) -> float : return self . cycle_frame_num * self . ms_per_frame @dataclass class ExperimentConfig ( ConfigBase ): \"\"\" Configuration for the experiment parameters. These parameters can change between different experiments. \"\"\" name : str \"\"\"Experiment name\"\"\" num_frames : int \"\"\"total number of frames of the experiment\"\"\" frames_per_sec : float \"\"\"Number of frames per second that the experiment was recorded at\"\"\" orig_resolution : tuple [ int , int ] \"\"\"Original resolution of the frames in pixels, in format (h, w)\"\"\" px_per_mm : float \"\"\"Number of pixels in a single millimeter\"\"\" init_position : tuple [ int , int ] \"\"\"The initial position of the center of the platform, in pixels (x, y) format. Platform's initial position should point to the worm, or close to it.\"\"\" comments : str = \"\" \"\"\"Additional comments about the experiment\"\"\" mm_per_px : float = field ( init = False ) \"\"\"Number of millimeters in a single pixel\"\"\" ms_per_frame : float = field ( init = False ) \"\"\"Number of milliseconds per frame\"\"\" def __post_init__ ( self ): self . ms_per_frame = 1000 / self . frames_per_sec self . mm_per_px = 1 / self . px_per_mm @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int , int ], ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , ) Classes ExperimentConfig class ExperimentConfig ( name : 'str' , num_frames : 'int' , frames_per_sec : 'float' , orig_resolution : 'tuple[int, int]' , px_per_mm : 'float' , init_position : 'tuple[int, int]' , comments : 'str' = '' ) Configuration for the experiment parameters. These parameters can change between different experiments. View Source @dataclass class ExperimentConfig ( ConfigBase ) : \"\"\" Configuration for the experiment parameters. These parameters can change between different experiments. \"\"\" name : str \"\"\"Experiment name\"\"\" num_frames : int \"\"\"total number of frames of the experiment\"\"\" frames_per_sec : float \"\"\"Number of frames per second that the experiment was recorded at\"\"\" orig_resolution : tuple [ int, int ] \"\"\"Original resolution of the frames in pixels, in format (h, w)\"\"\" px_per_mm : float \"\"\"Number of pixels in a single millimeter\"\"\" init_position : tuple [ int, int ] \"\"\"The initial position of the center of the platform, in pixels (x, y) format. Platform's initial position should point to the worm, or close to it.\"\"\" comments : str = \"\" \"\"\"Additional comments about the experiment\"\"\" mm_per_px : float = field ( init = False ) \"\"\"Number of millimeters in a single pixel\"\"\" ms_per_frame : float = field ( init = False ) \"\"\"Number of milliseconds per frame\"\"\" def __post_init__ ( self ) : self . ms_per_frame = 1000 / self . frames_per_sec self . mm_per_px = 1 / self . px_per_mm @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int, int ] , ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables comments Static methods from_frame_reader def from_frame_reader ( reader : 'FrameReader' , name : 'str' , frames_per_sec : 'int' , px_per_mm : 'float' , init_position : 'tuple[int, int]' ) -> 'ExperimentConfig' View Source @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int, int ] , ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , ) load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) TimingConfig class TimingConfig ( experiment_config : 'ExperimentConfig' , imaging_time_ms : 'float' , pred_time_ms : 'float' , moving_time_ms : 'float' , camera_size_mm : 'tuple[float, float]' , micro_size_mm : 'tuple[float, float]' ) Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. View Source @ dataclass class TimingConfig ( ConfigBase ): \"\"\" Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. \"\"\" experiment_config : ExperimentConfig = field ( repr = False ) \"\"\"The configuration of the experiment parameters.\"\"\" px_per_mm : int = field ( init = False ) mm_per_px : float = field ( init = False ) frames_per_sec : int = field ( init = False ) ms_per_frame : float = field ( init = False ) imaging_time_ms : float imaging_frame_num : int = field ( init = False ) pred_time_ms : float pred_frame_num : int = field ( init = False ) moving_time_ms : float moving_frame_num : int = field ( init = False ) camera_size_mm : tuple [ float , float ] camera_size_px : tuple [ int , int ] = field ( init = False ) micro_size_mm : tuple [ float , float ] micro_size_px : tuple [ int , int ] = field ( init = False ) def __post_init__ ( self ): self . frames_per_sec = self . experiment_config . frames_per_sec self . ms_per_frame = self . experiment_config . ms_per_frame self . imaging_frame_num = math . ceil ( self . imaging_time_ms / self . ms_per_frame ) self . pred_frame_num = math . ceil ( self . pred_time_ms / self . ms_per_frame ) self . moving_frame_num = math . ceil ( self . moving_time_ms / self . ms_per_frame ) self . mm_per_px = self . experiment_config . mm_per_px self . px_per_mm = self . experiment_config . px_per_mm self . camera_size_px = ( round ( self . px_per_mm * self . camera_size_mm [ 0 ]), round ( self . px_per_mm * self . camera_size_mm [ 1 ]), ) self . micro_size_px = ( round ( self . px_per_mm * self . micro_size_mm [ 0 ]), round ( self . px_per_mm * self . micro_size_mm [ 1 ]), ) del self . experiment_config # experiment_config was temporary, only for the constructor @ property def cycle_frame_num ( self ) -> int : return self . imaging_frame_num + self . moving_frame_num @ property def cycle_time_ms ( self ) -> float : return self . cycle_frame_num * self . ms_per_frame Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Instance variables cycle_frame_num cycle_time_ms Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"Config"},{"location":"reference/wtracker/sim/config/#module-wtrackersimconfig","text":"View Source from __future__ import annotations from dataclasses import dataclass , field import math from wtracker.utils.config_base import ConfigBase from wtracker.utils.frame_reader import FrameReader @dataclass class TimingConfig ( ConfigBase ): \"\"\" Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. \"\"\" experiment_config : ExperimentConfig = field ( repr = False ) \"\"\"The configuration of the experiment parameters.\"\"\" px_per_mm : int = field ( init = False ) mm_per_px : float = field ( init = False ) frames_per_sec : int = field ( init = False ) ms_per_frame : float = field ( init = False ) imaging_time_ms : float imaging_frame_num : int = field ( init = False ) pred_time_ms : float pred_frame_num : int = field ( init = False ) moving_time_ms : float moving_frame_num : int = field ( init = False ) camera_size_mm : tuple [ float , float ] camera_size_px : tuple [ int , int ] = field ( init = False ) micro_size_mm : tuple [ float , float ] micro_size_px : tuple [ int , int ] = field ( init = False ) def __post_init__ ( self ): self . frames_per_sec = self . experiment_config . frames_per_sec self . ms_per_frame = self . experiment_config . ms_per_frame self . imaging_frame_num = math . ceil ( self . imaging_time_ms / self . ms_per_frame ) self . pred_frame_num = math . ceil ( self . pred_time_ms / self . ms_per_frame ) self . moving_frame_num = math . ceil ( self . moving_time_ms / self . ms_per_frame ) self . mm_per_px = self . experiment_config . mm_per_px self . px_per_mm = self . experiment_config . px_per_mm self . camera_size_px = ( round ( self . px_per_mm * self . camera_size_mm [ 0 ]), round ( self . px_per_mm * self . camera_size_mm [ 1 ]), ) self . micro_size_px = ( round ( self . px_per_mm * self . micro_size_mm [ 0 ]), round ( self . px_per_mm * self . micro_size_mm [ 1 ]), ) del self . experiment_config # experiment_config was temporary, only for the constructor @property def cycle_frame_num ( self ) -> int : return self . imaging_frame_num + self . moving_frame_num @property def cycle_time_ms ( self ) -> float : return self . cycle_frame_num * self . ms_per_frame @dataclass class ExperimentConfig ( ConfigBase ): \"\"\" Configuration for the experiment parameters. These parameters can change between different experiments. \"\"\" name : str \"\"\"Experiment name\"\"\" num_frames : int \"\"\"total number of frames of the experiment\"\"\" frames_per_sec : float \"\"\"Number of frames per second that the experiment was recorded at\"\"\" orig_resolution : tuple [ int , int ] \"\"\"Original resolution of the frames in pixels, in format (h, w)\"\"\" px_per_mm : float \"\"\"Number of pixels in a single millimeter\"\"\" init_position : tuple [ int , int ] \"\"\"The initial position of the center of the platform, in pixels (x, y) format. Platform's initial position should point to the worm, or close to it.\"\"\" comments : str = \"\" \"\"\"Additional comments about the experiment\"\"\" mm_per_px : float = field ( init = False ) \"\"\"Number of millimeters in a single pixel\"\"\" ms_per_frame : float = field ( init = False ) \"\"\"Number of milliseconds per frame\"\"\" def __post_init__ ( self ): self . ms_per_frame = 1000 / self . frames_per_sec self . mm_per_px = 1 / self . px_per_mm @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int , int ], ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , )","title":"Module wtracker.sim.config"},{"location":"reference/wtracker/sim/config/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/config/#experimentconfig","text":"class ExperimentConfig ( name : 'str' , num_frames : 'int' , frames_per_sec : 'float' , orig_resolution : 'tuple[int, int]' , px_per_mm : 'float' , init_position : 'tuple[int, int]' , comments : 'str' = '' ) Configuration for the experiment parameters. These parameters can change between different experiments. View Source @dataclass class ExperimentConfig ( ConfigBase ) : \"\"\" Configuration for the experiment parameters. These parameters can change between different experiments. \"\"\" name : str \"\"\"Experiment name\"\"\" num_frames : int \"\"\"total number of frames of the experiment\"\"\" frames_per_sec : float \"\"\"Number of frames per second that the experiment was recorded at\"\"\" orig_resolution : tuple [ int, int ] \"\"\"Original resolution of the frames in pixels, in format (h, w)\"\"\" px_per_mm : float \"\"\"Number of pixels in a single millimeter\"\"\" init_position : tuple [ int, int ] \"\"\"The initial position of the center of the platform, in pixels (x, y) format. Platform's initial position should point to the worm, or close to it.\"\"\" comments : str = \"\" \"\"\"Additional comments about the experiment\"\"\" mm_per_px : float = field ( init = False ) \"\"\"Number of millimeters in a single pixel\"\"\" ms_per_frame : float = field ( init = False ) \"\"\"Number of milliseconds per frame\"\"\" def __post_init__ ( self ) : self . ms_per_frame = 1000 / self . frames_per_sec self . mm_per_px = 1 / self . px_per_mm @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int, int ] , ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , )","title":"ExperimentConfig"},{"location":"reference/wtracker/sim/config/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/config/#class-variables","text":"comments","title":"Class variables"},{"location":"reference/wtracker/sim/config/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/config/#from_frame_reader","text":"def from_frame_reader ( reader : 'FrameReader' , name : 'str' , frames_per_sec : 'int' , px_per_mm : 'float' , init_position : 'tuple[int, int]' ) -> 'ExperimentConfig' View Source @classmethod def from_frame_reader ( cls , reader : FrameReader , name : str , frames_per_sec : int , px_per_mm : float , init_position : tuple [ int, int ] , ) -> ExperimentConfig : return ExperimentConfig ( name = name , num_frames = len ( reader ), frames_per_sec = frames_per_sec , orig_resolution = reader . frame_size , px_per_mm = px_per_mm , init_position = init_position , )","title":"from_frame_reader"},{"location":"reference/wtracker/sim/config/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/config/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/config/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/config/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/config/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/config/#timingconfig","text":"class TimingConfig ( experiment_config : 'ExperimentConfig' , imaging_time_ms : 'float' , pred_time_ms : 'float' , moving_time_ms : 'float' , camera_size_mm : 'tuple[float, float]' , micro_size_mm : 'tuple[float, float]' ) Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. View Source @ dataclass class TimingConfig ( ConfigBase ): \"\"\" Configuration for timing parameters of the experiment. These parameters should not change between different experiments. This class affects the timings of the simulation. \"\"\" experiment_config : ExperimentConfig = field ( repr = False ) \"\"\"The configuration of the experiment parameters.\"\"\" px_per_mm : int = field ( init = False ) mm_per_px : float = field ( init = False ) frames_per_sec : int = field ( init = False ) ms_per_frame : float = field ( init = False ) imaging_time_ms : float imaging_frame_num : int = field ( init = False ) pred_time_ms : float pred_frame_num : int = field ( init = False ) moving_time_ms : float moving_frame_num : int = field ( init = False ) camera_size_mm : tuple [ float , float ] camera_size_px : tuple [ int , int ] = field ( init = False ) micro_size_mm : tuple [ float , float ] micro_size_px : tuple [ int , int ] = field ( init = False ) def __post_init__ ( self ): self . frames_per_sec = self . experiment_config . frames_per_sec self . ms_per_frame = self . experiment_config . ms_per_frame self . imaging_frame_num = math . ceil ( self . imaging_time_ms / self . ms_per_frame ) self . pred_frame_num = math . ceil ( self . pred_time_ms / self . ms_per_frame ) self . moving_frame_num = math . ceil ( self . moving_time_ms / self . ms_per_frame ) self . mm_per_px = self . experiment_config . mm_per_px self . px_per_mm = self . experiment_config . px_per_mm self . camera_size_px = ( round ( self . px_per_mm * self . camera_size_mm [ 0 ]), round ( self . px_per_mm * self . camera_size_mm [ 1 ]), ) self . micro_size_px = ( round ( self . px_per_mm * self . micro_size_mm [ 0 ]), round ( self . px_per_mm * self . micro_size_mm [ 1 ]), ) del self . experiment_config # experiment_config was temporary, only for the constructor @ property def cycle_frame_num ( self ) -> int : return self . imaging_frame_num + self . moving_frame_num @ property def cycle_time_ms ( self ) -> float : return self . cycle_frame_num * self . ms_per_frame","title":"TimingConfig"},{"location":"reference/wtracker/sim/config/#ancestors-in-mro_1","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/config/#static-methods_1","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/config/#load_json_1","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/config/#load_pickle_1","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/config/#instance-variables","text":"cycle_frame_num cycle_time_ms","title":"Instance variables"},{"location":"reference/wtracker/sim/config/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/config/#save_json_1","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/config/#save_pickle_1","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/motor_controllers/","text":"Module wtracker.sim.motor_controllers View Source import abc import numpy as np from wtracker.sim.config import TimingConfig class MotorController ( abc . ABC ): \"\"\" Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Args: timing_config (TimingConfig): The timing configuration of the simulation. Attributes: timing_config (TimingConfig): The timing configuration for the motor controller. movement_steps (int): The number of movement steps (in units of frames) based on the timing configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig ): self . timing_config = timing_config self . movement_steps = self . timing_config . moving_frame_num @abc . abstractmethod def register_move ( self , dx : int , dy : int ): pass @abc . abstractmethod def step ( self ) -> tuple [ int , int ]: pass class StepMotorController ( MotorController ): \"\"\" A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Args: timing_config (TimingConfig): The timing configuration of the simulation. move_after_ratio (float, optional): The ratio of movement steps after which the motor should move. \"\"\" def __init__ ( self , timing_config : TimingConfig , move_after_ratio : float = 0.5 ): assert 0 <= move_after_ratio <= 1 super () . __init__ ( timing_config ) self . queue : list = [] self . move_at_step = round ( self . movement_steps * move_after_ratio ) def register_move ( self , dx : int , dy : int ): for _ in range ( self . movement_steps - 1 ): self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) def step ( self ) -> tuple [ int , int ]: return self . queue . pop ( 0 ) class SineMotorController ( MotorController ): \"\"\" A motor controller that generates sinusoidal movements. Args: timing_config (TimingConfig): The timing configuration of the simulation. \"\"\" def __init__ ( self , timing_config : TimingConfig ): super () . __init__ ( timing_config ) self . queue : list = [] def register_move ( self , dx : int , dy : int ) -> None : assert len ( self . queue ) == 0 for i in range ( self . movement_steps ): step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) def step ( self ) -> tuple [ int , int ]: dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue : self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy ) Classes MotorController class MotorController ( timing_config : wtracker . sim . config . TimingConfig ) Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Attributes Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None timing_config TimingConfig The timing configuration for the motor controller. None movement_steps int The number of movement steps (in units of frames) based on the timing configuration. None View Source class MotorController ( abc . ABC ) : \"\"\" Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Args: timing_config (TimingConfig): The timing configuration of the simulation. Attributes: timing_config (TimingConfig): The timing configuration for the motor controller. movement_steps (int): The number of movement steps (in units of frames) based on the timing configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig ) : self . timing_config = timing_config self . movement_steps = self . timing_config . moving_frame_num @abc . abstractmethod def register_move ( self , dx : int , dy : int ) : pass @abc . abstractmethod def step ( self ) -> tuple [ int, int ] : pass Ancestors (in MRO) abc.ABC Descendants wtracker.sim.motor_controllers.StepMotorController wtracker.sim.motor_controllers.SineMotorController Methods register_move def register_move ( self , dx : int , dy : int ) View Source @abc . abstractmethod def register_move ( self , dx : int , dy : int ) : pass step def step ( self ) -> tuple [ int , int ] View Source @abc . abstractmethod def step ( self ) -> tuple [ int, int ] : pass SineMotorController class SineMotorController ( timing_config : wtracker . sim . config . TimingConfig ) A motor controller that generates sinusoidal movements. Attributes Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None View Source class SineMotorController ( MotorController ) : \"\"\" A motor controller that generates sinusoidal movements . Args: timing_config ( TimingConfig ) : The timing configuration of the simulation . \"\"\" def __init__ ( self , timing_config: TimingConfig ) : super (). __init__ ( timing_config ) self . queue: list = [] def register_move ( self , dx: int , dy: int ) -> None: assert len ( self . queue ) == 0 for i in range ( self . movement_steps ) : step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) def step ( self ) -> tuple [ int , int ] : dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue: self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy ) Ancestors (in MRO) wtracker.sim.motor_controllers.MotorController abc.ABC Methods register_move def register_move ( self , dx : int , dy : int ) -> None View Source def register_move ( self , dx: int , dy: int ) -> None: assert len ( self . queue ) == 0 for i in range ( self . movement_steps ) : step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) step def step ( self ) -> tuple [ int , int ] View Source def step ( self ) -> tuple [ int , int ] : dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue : self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy ) StepMotorController class StepMotorController ( timing_config : wtracker . sim . config . TimingConfig , move_after_ratio : float = 0.5 ) A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Attributes Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None move_after_ratio float The ratio of movement steps after which the motor should move. None View Source class StepMotorController ( MotorController ): \"\"\" A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Args: timing_config (TimingConfig): The timing configuration of the simulation. move_after_ratio (float, optional): The ratio of movement steps after which the motor should move. \"\"\" def __init__ ( self , timing_config : TimingConfig , move_after_ratio : float = 0.5 ): assert 0 <= move_after_ratio <= 1 super (). __init__ ( timing_config ) self . queue : list = [] self . move_at_step = round ( self . movement_steps * move_after_ratio ) def register_move ( self , dx : int , dy : int ): for _ in range ( self . movement_steps - 1 ): self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) def step ( self ) -> tuple [ int , int ]: return self . queue . pop ( 0 ) Ancestors (in MRO) wtracker.sim.motor_controllers.MotorController abc.ABC Methods register_move def register_move ( self , dx : int , dy : int ) View Source def register_move ( self , dx: int , dy: int ) : for _ in range ( self . movement_steps - 1 ) : self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) step def step ( self ) -> tuple [ int , int ] View Source def step ( self ) -> tuple [ int , int ] : return self . queue . pop ( 0 )","title":"Motor Controllers"},{"location":"reference/wtracker/sim/motor_controllers/#module-wtrackersimmotor_controllers","text":"View Source import abc import numpy as np from wtracker.sim.config import TimingConfig class MotorController ( abc . ABC ): \"\"\" Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Args: timing_config (TimingConfig): The timing configuration of the simulation. Attributes: timing_config (TimingConfig): The timing configuration for the motor controller. movement_steps (int): The number of movement steps (in units of frames) based on the timing configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig ): self . timing_config = timing_config self . movement_steps = self . timing_config . moving_frame_num @abc . abstractmethod def register_move ( self , dx : int , dy : int ): pass @abc . abstractmethod def step ( self ) -> tuple [ int , int ]: pass class StepMotorController ( MotorController ): \"\"\" A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Args: timing_config (TimingConfig): The timing configuration of the simulation. move_after_ratio (float, optional): The ratio of movement steps after which the motor should move. \"\"\" def __init__ ( self , timing_config : TimingConfig , move_after_ratio : float = 0.5 ): assert 0 <= move_after_ratio <= 1 super () . __init__ ( timing_config ) self . queue : list = [] self . move_at_step = round ( self . movement_steps * move_after_ratio ) def register_move ( self , dx : int , dy : int ): for _ in range ( self . movement_steps - 1 ): self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) def step ( self ) -> tuple [ int , int ]: return self . queue . pop ( 0 ) class SineMotorController ( MotorController ): \"\"\" A motor controller that generates sinusoidal movements. Args: timing_config (TimingConfig): The timing configuration of the simulation. \"\"\" def __init__ ( self , timing_config : TimingConfig ): super () . __init__ ( timing_config ) self . queue : list = [] def register_move ( self , dx : int , dy : int ) -> None : assert len ( self . queue ) == 0 for i in range ( self . movement_steps ): step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) def step ( self ) -> tuple [ int , int ]: dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue : self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy )","title":"Module wtracker.sim.motor_controllers"},{"location":"reference/wtracker/sim/motor_controllers/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/motor_controllers/#motorcontroller","text":"class MotorController ( timing_config : wtracker . sim . config . TimingConfig ) Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform.","title":"MotorController"},{"location":"reference/wtracker/sim/motor_controllers/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None timing_config TimingConfig The timing configuration for the motor controller. None movement_steps int The number of movement steps (in units of frames) based on the timing configuration. None View Source class MotorController ( abc . ABC ) : \"\"\" Abstract base class for motor controllers used in the Simulator class. This motor controls the movement of the simulated platform. Args: timing_config (TimingConfig): The timing configuration of the simulation. Attributes: timing_config (TimingConfig): The timing configuration for the motor controller. movement_steps (int): The number of movement steps (in units of frames) based on the timing configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig ) : self . timing_config = timing_config self . movement_steps = self . timing_config . moving_frame_num @abc . abstractmethod def register_move ( self , dx : int , dy : int ) : pass @abc . abstractmethod def step ( self ) -> tuple [ int, int ] : pass","title":"Attributes"},{"location":"reference/wtracker/sim/motor_controllers/#ancestors-in-mro","text":"abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/motor_controllers/#descendants","text":"wtracker.sim.motor_controllers.StepMotorController wtracker.sim.motor_controllers.SineMotorController","title":"Descendants"},{"location":"reference/wtracker/sim/motor_controllers/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/motor_controllers/#register_move","text":"def register_move ( self , dx : int , dy : int ) View Source @abc . abstractmethod def register_move ( self , dx : int , dy : int ) : pass","title":"register_move"},{"location":"reference/wtracker/sim/motor_controllers/#step","text":"def step ( self ) -> tuple [ int , int ] View Source @abc . abstractmethod def step ( self ) -> tuple [ int, int ] : pass","title":"step"},{"location":"reference/wtracker/sim/motor_controllers/#sinemotorcontroller","text":"class SineMotorController ( timing_config : wtracker . sim . config . TimingConfig ) A motor controller that generates sinusoidal movements.","title":"SineMotorController"},{"location":"reference/wtracker/sim/motor_controllers/#attributes_1","text":"Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None View Source class SineMotorController ( MotorController ) : \"\"\" A motor controller that generates sinusoidal movements . Args: timing_config ( TimingConfig ) : The timing configuration of the simulation . \"\"\" def __init__ ( self , timing_config: TimingConfig ) : super (). __init__ ( timing_config ) self . queue: list = [] def register_move ( self , dx: int , dy: int ) -> None: assert len ( self . queue ) == 0 for i in range ( self . movement_steps ) : step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step ) def step ( self ) -> tuple [ int , int ] : dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue: self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy )","title":"Attributes"},{"location":"reference/wtracker/sim/motor_controllers/#ancestors-in-mro_1","text":"wtracker.sim.motor_controllers.MotorController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/motor_controllers/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/motor_controllers/#register_move_1","text":"def register_move ( self , dx : int , dy : int ) -> None View Source def register_move ( self , dx: int , dy: int ) -> None: assert len ( self . queue ) == 0 for i in range ( self . movement_steps ) : step_size = ( np . cos (( i * np . pi ) / self . movement_steps ) - np . cos ((( i + 1 ) * np . pi ) / self . movement_steps ) ) / 2 step = ( step_size * dx , step_size * dy ) self . queue . append ( step )","title":"register_move"},{"location":"reference/wtracker/sim/motor_controllers/#step_1","text":"def step ( self ) -> tuple [ int , int ] View Source def step ( self ) -> tuple [ int , int ] : dx , dy = self . queue . pop ( 0 ) rdx , rdy = ( round ( dx ), round ( dy )) resid_x , resid_y = dx - rdx , dy - rdy if self . queue : self . queue [ 0 ] = ( self . queue [ 0 ][ 0 ] + resid_x , self . queue [ 0 ][ 1 ] + resid_y ) return ( rdx , rdy )","title":"step"},{"location":"reference/wtracker/sim/motor_controllers/#stepmotorcontroller","text":"class StepMotorController ( timing_config : wtracker . sim . config . TimingConfig , move_after_ratio : float = 0.5 ) A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed.","title":"StepMotorController"},{"location":"reference/wtracker/sim/motor_controllers/#attributes_2","text":"Name Type Description Default timing_config TimingConfig The timing configuration of the simulation. None move_after_ratio float The ratio of movement steps after which the motor should move. None View Source class StepMotorController ( MotorController ): \"\"\" A simple motor controller that manages the movement of a motor. The motor moved the entire distance in one step, the movement happens after 'move_after_ratio' percent of 'movement_steps' have passed. Args: timing_config (TimingConfig): The timing configuration of the simulation. move_after_ratio (float, optional): The ratio of movement steps after which the motor should move. \"\"\" def __init__ ( self , timing_config : TimingConfig , move_after_ratio : float = 0.5 ): assert 0 <= move_after_ratio <= 1 super (). __init__ ( timing_config ) self . queue : list = [] self . move_at_step = round ( self . movement_steps * move_after_ratio ) def register_move ( self , dx : int , dy : int ): for _ in range ( self . movement_steps - 1 ): self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy )) def step ( self ) -> tuple [ int , int ]: return self . queue . pop ( 0 )","title":"Attributes"},{"location":"reference/wtracker/sim/motor_controllers/#ancestors-in-mro_2","text":"wtracker.sim.motor_controllers.MotorController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/motor_controllers/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/sim/motor_controllers/#register_move_2","text":"def register_move ( self , dx : int , dy : int ) View Source def register_move ( self , dx: int , dy: int ) : for _ in range ( self . movement_steps - 1 ) : self . queue . append (( 0 , 0 )) self . queue . insert ( self . move_at_step , ( dx , dy ))","title":"register_move"},{"location":"reference/wtracker/sim/motor_controllers/#step_2","text":"def step ( self ) -> tuple [ int , int ] View Source def step ( self ) -> tuple [ int , int ] : return self . queue . pop ( 0 )","title":"step"},{"location":"reference/wtracker/sim/simulator/","text":"Module wtracker.sim.simulator View Source from __future__ import annotations import numpy as np import abc from tqdm.auto import tqdm from wtracker.sim.view_controller import ViewController from wtracker.sim.config import TimingConfig , ExperimentConfig from wtracker.sim.motor_controllers import MotorController , SineMotorController from wtracker.utils.frame_reader import FrameReader , DummyReader class Simulator : \"\"\" A class representing a simulator for a biological experiment. Args: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. sim_controller (SimController): The simulation controller. reader (FrameReader, optional): The frame reader. motor_controller (MotorController, optional): The motor controller. Attributes: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig , experiment_config : ExperimentConfig , sim_controller : SimController , reader : FrameReader = None , motor_controller : MotorController = None , ) -> None : self . timing_config = timing_config self . experiment_config = experiment_config self . _sim_controller = sim_controller if reader is None : num_frames = experiment_config . num_frames padding_size = ( timing_config . camera_size_px [ 0 ] // 2 * 2 , timing_config . camera_size_px [ 1 ] // 2 * 2 ) resolution = tuple ([ sum ( x ) for x in zip ( experiment_config . orig_resolution , padding_size )]) reader = DummyReader ( num_frames , resolution , colored = True ) if motor_controller is None : motor_controller = SineMotorController ( timing_config ) self . _motor_controller = motor_controller self . _view = ViewController ( frame_reader = reader , camera_size = timing_config . camera_size_px , micro_size = timing_config . micro_size_px , init_position = experiment_config . init_position , ) @property def view ( self ) -> ViewController : \"\"\" Get the view controller. Returns: ViewController: The view controller. \"\"\" return self . _view @property def position ( self ) -> tuple [ int , int ]: \"\"\" Get the current position. Returns: tuple[int, int]: The current position. \"\"\" return self . _view . position @property def cycle_number ( self ) -> int : \"\"\" Get the current cycle number. Returns: int: The current cycle number. \"\"\" return self . _view . index // self . timing_config . cycle_frame_num @property def frame_number ( self ) -> int : \"\"\" Get the current frame number. Returns: int: The current frame number. \"\"\" return self . _view . index @property def cycle_step ( self ) -> int : \"\"\" Get the current cycle step. Returns: int: The current cycle step. \"\"\" return self . _view . index % self . timing_config . cycle_frame_num def camera_view ( self ) -> np . ndarray : \"\"\" Get the view that the camera sees. Returns: np.ndarray: The camera view. \"\"\" return self . _view . camera_view () def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () def _reset ( self ): \"\"\" Reset the simulator. \"\"\" self . view . reset () self . view . set_position ( * self . experiment_config . init_position ) def run ( self , visualize : bool = False , wait_key : bool = False ): \"\"\" Run the simulation. Args: visualize (bool, optional): Whether to visualize the simulation. wait_key (bool, optional): Whether to wait for a key press to advance the simulation during visualization. \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config . cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress (): if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num : self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num : self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num : self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num : dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize : self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close () class SimController ( abc . ABC ): \"\"\" Abstract base class for simulator controllers. Attributes: timing_config (TimingConfig): The timing configuration for the simulator. \"\"\" def __init__ ( self , timing_config : TimingConfig ): self . timing_config = timing_config def on_sim_start ( self , sim : Simulator ): \"\"\" Called when the simulation starts. \"\"\" pass def on_sim_end ( self , sim : Simulator ): \"\"\" Called when the simulation ends. \"\"\" pass def on_cycle_start ( self , sim : Simulator ): \"\"\" Called when a new cycle starts. \"\"\" pass def on_cycle_end ( self , sim : Simulator ): \"\"\" Called when a cycle ends. \"\"\" pass def on_camera_frame ( self , sim : Simulator ): \"\"\" Called when a camera frame is captured. Happens every frame. \"\"\" pass def on_imaging_start ( self , sim : Simulator ): \"\"\" Called when imaging phase starts. \"\"\" pass def on_micro_frame ( self , sim : Simulator ): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass def on_imaging_end ( self , sim : Simulator ): \"\"\" Called when imaging phase ends. \"\"\" pass def on_movement_start ( self , sim : Simulator ): \"\"\" Called when movement phase starts. \"\"\" pass def on_movement_end ( self , sim : Simulator ): \"\"\" Called when movement phase ends. \"\"\" pass @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () @abc . abstractmethod def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : \"\"\" Returns a list of bbox predictions of the worm, for each frame of the current cycle. If a prediction is not available, return None for that frame. Used internally for logging. \"\"\" raise NotImplementedError () Classes SimController class SimController ( timing_config : 'TimingConfig' ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class SimController ( abc . ABC ) : \"\"\" Abstract base class for simulator controllers. Attributes: timing_config (TimingConfig): The timing configuration for the simulator. \"\"\" def __init__ ( self , timing_config : TimingConfig ) : self . timing_config = timing_config def on_sim_start ( self , sim : Simulator ) : \"\"\" Called when the simulation starts. \"\"\" pass def on_sim_end ( self , sim : Simulator ) : \"\"\" Called when the simulation ends. \"\"\" pass def on_cycle_start ( self , sim : Simulator ) : \"\"\" Called when a new cycle starts. \"\"\" pass def on_cycle_end ( self , sim : Simulator ) : \"\"\" Called when a cycle ends. \"\"\" pass def on_camera_frame ( self , sim : Simulator ) : \"\"\" Called when a camera frame is captured. Happens every frame. \"\"\" pass def on_imaging_start ( self , sim : Simulator ) : \"\"\" Called when imaging phase starts. \"\"\" pass def on_micro_frame ( self , sim : Simulator ) : \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass def on_imaging_end ( self , sim : Simulator ) : \"\"\" Called when imaging phase ends. \"\"\" pass def on_movement_start ( self , sim : Simulator ) : \"\"\" Called when movement phase starts. \"\"\" pass def on_movement_end ( self , sim : Simulator ) : \"\"\" Called when movement phase ends. \"\"\" pass @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () @abc . abstractmethod def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : \"\"\" Returns a list of bbox predictions of the worm, for each frame of the current cycle. If a prediction is not available, return None for that frame. Used internally for logging. \"\"\" raise NotImplementedError () Ancestors (in MRO) abc.ABC Descendants wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.sim_controllers.logging_controller.LoggingController wtracker.sim.sim_controllers.yolo_controller.YoloController Methods begin_movement_prediction def begin_movement_prediction ( self , sim : 'Simulator' ) -> 'None' Called when the movement prediction begins. View Source @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () on_camera_frame def on_camera_frame ( self , sim : 'Simulator' ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : \"\" \" Called when a camera frame is captured. Happens every frame. \" \"\" pass on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : 'Simulator' ) Called when the simulation starts. View Source def on_sim_start(self, sim: Simulator): \"\"\" Called when the simulation starts. \"\"\" pass provide_movement_vector def provide_movement_vector ( self , sim : 'Simulator' ) -> 'tuple[int, int]' Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () Simulator class Simulator ( timing_config : 'TimingConfig' , experiment_config : 'ExperimentConfig' , sim_controller : 'SimController' , reader : 'FrameReader' = None , motor_controller : 'MotorController' = None ) A class representing a simulator for a biological experiment. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the experiment. None experiment_config ExperimentConfig The experiment configuration. None sim_controller SimController The simulation controller. None reader FrameReader The frame reader. None motor_controller MotorController The motor controller. None timing_config TimingConfig The timing configuration for the experiment. None experiment_config ExperimentConfig The experiment configuration. None View Source class Simulator : \"\"\" A class representing a simulator for a biological experiment. Args: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. sim_controller (SimController): The simulation controller. reader (FrameReader, optional): The frame reader. motor_controller (MotorController, optional): The motor controller. Attributes: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig , experiment_config : ExperimentConfig , sim_controller : SimController , reader : FrameReader = None , motor_controller : MotorController = None , ) -> None : self . timing_config = timing_config self . experiment_config = experiment_config self . _sim_controller = sim_controller if reader is None : num_frames = experiment_config . num_frames padding_size = ( timing_config . camera_size_px [ 0 ] // 2 * 2 , timing_config . camera_size_px [ 1 ] // 2 * 2 ) resolution = tuple ( [ sum(x) for x in zip(experiment_config.orig_resolution, padding_size) ] ) reader = DummyReader ( num_frames , resolution , colored = True ) if motor_controller is None : motor_controller = SineMotorController ( timing_config ) self . _motor_controller = motor_controller self . _view = ViewController ( frame_reader = reader , camera_size = timing_config . camera_size_px , micro_size = timing_config . micro_size_px , init_position = experiment_config . init_position , ) @property def view ( self ) -> ViewController : \"\"\" Get the view controller. Returns: ViewController: The view controller. \"\"\" return self . _view @property def position ( self ) -> tuple [ int, int ] : \"\"\" Get the current position. Returns: tuple[int, int]: The current position. \"\"\" return self . _view . position @property def cycle_number ( self ) -> int : \"\"\" Get the current cycle number. Returns: int: The current cycle number. \"\"\" return self . _view . index // self . timing_config . cycle_frame_num @property def frame_number ( self ) -> int : \"\"\" Get the current frame number. Returns: int: The current frame number. \"\"\" return self . _view . index @property def cycle_step ( self ) -> int : \"\"\" Get the current cycle step. Returns: int: The current cycle step. \"\"\" return self . _view . index % self . timing_config . cycle_frame_num def camera_view ( self ) -> np . ndarray : \"\"\" Get the view that the camera sees. Returns: np.ndarray: The camera view. \"\"\" return self . _view . camera_view () def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () def _reset ( self ) : \"\"\" Reset the simulator. \"\"\" self . view . reset () self . view . set_position ( * self . experiment_config . init_position ) def run ( self , visualize : bool = False , wait_key : bool = False ) : \"\"\" Run the simulation. Args: visualize (bool, optional): Whether to visualize the simulation. wait_key (bool, optional): Whether to wait for a key press to advance the simulation during visualization. \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config . cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress () : if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num : self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num : self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num : self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num : dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize : self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close () Instance variables cycle_number Get the current cycle number. cycle_step Get the current cycle step. frame_number Get the current frame number. position Get the current position. view Get the view controller. Methods camera_view def camera_view ( self ) -> 'np.ndarray' Get the view that the camera sees. Returns: Type Description np.ndarray The camera view. View Source def camera_view ( self ) - > np . ndarray : \"\" \" Get the view that the camera sees. Returns: np.ndarray: The camera view. \" \"\" return self . _view . camera_view () micro_view def micro_view ( self ) -> 'np.ndarray' Get the view that the microscope sees. Returns: Type Description np.ndarray The micro view. View Source def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () run def run ( self , visualize : 'bool' = False , wait_key : 'bool' = False ) Run the simulation. Parameters: Name Type Description Default visualize bool Whether to visualize the simulation. None wait_key bool Whether to wait for a key press to advance the simulation during visualization. None View Source def run ( self , visualize: bool = False , wait_key: bool = False ) : \"\"\" Run the simulation . Args: visualize ( bool , optional ) : Whether to visualize the simulation . wait_key ( bool , optional ) : Whether to wait for a key press to advance the simulation during visualization . \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config.cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress () : if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num: self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num: self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num: self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num: dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize: self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close ()","title":"Simulator"},{"location":"reference/wtracker/sim/simulator/#module-wtrackersimsimulator","text":"View Source from __future__ import annotations import numpy as np import abc from tqdm.auto import tqdm from wtracker.sim.view_controller import ViewController from wtracker.sim.config import TimingConfig , ExperimentConfig from wtracker.sim.motor_controllers import MotorController , SineMotorController from wtracker.utils.frame_reader import FrameReader , DummyReader class Simulator : \"\"\" A class representing a simulator for a biological experiment. Args: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. sim_controller (SimController): The simulation controller. reader (FrameReader, optional): The frame reader. motor_controller (MotorController, optional): The motor controller. Attributes: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig , experiment_config : ExperimentConfig , sim_controller : SimController , reader : FrameReader = None , motor_controller : MotorController = None , ) -> None : self . timing_config = timing_config self . experiment_config = experiment_config self . _sim_controller = sim_controller if reader is None : num_frames = experiment_config . num_frames padding_size = ( timing_config . camera_size_px [ 0 ] // 2 * 2 , timing_config . camera_size_px [ 1 ] // 2 * 2 ) resolution = tuple ([ sum ( x ) for x in zip ( experiment_config . orig_resolution , padding_size )]) reader = DummyReader ( num_frames , resolution , colored = True ) if motor_controller is None : motor_controller = SineMotorController ( timing_config ) self . _motor_controller = motor_controller self . _view = ViewController ( frame_reader = reader , camera_size = timing_config . camera_size_px , micro_size = timing_config . micro_size_px , init_position = experiment_config . init_position , ) @property def view ( self ) -> ViewController : \"\"\" Get the view controller. Returns: ViewController: The view controller. \"\"\" return self . _view @property def position ( self ) -> tuple [ int , int ]: \"\"\" Get the current position. Returns: tuple[int, int]: The current position. \"\"\" return self . _view . position @property def cycle_number ( self ) -> int : \"\"\" Get the current cycle number. Returns: int: The current cycle number. \"\"\" return self . _view . index // self . timing_config . cycle_frame_num @property def frame_number ( self ) -> int : \"\"\" Get the current frame number. Returns: int: The current frame number. \"\"\" return self . _view . index @property def cycle_step ( self ) -> int : \"\"\" Get the current cycle step. Returns: int: The current cycle step. \"\"\" return self . _view . index % self . timing_config . cycle_frame_num def camera_view ( self ) -> np . ndarray : \"\"\" Get the view that the camera sees. Returns: np.ndarray: The camera view. \"\"\" return self . _view . camera_view () def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () def _reset ( self ): \"\"\" Reset the simulator. \"\"\" self . view . reset () self . view . set_position ( * self . experiment_config . init_position ) def run ( self , visualize : bool = False , wait_key : bool = False ): \"\"\" Run the simulation. Args: visualize (bool, optional): Whether to visualize the simulation. wait_key (bool, optional): Whether to wait for a key press to advance the simulation during visualization. \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config . cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress (): if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num : self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num : self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num : self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num : dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize : self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close () class SimController ( abc . ABC ): \"\"\" Abstract base class for simulator controllers. Attributes: timing_config (TimingConfig): The timing configuration for the simulator. \"\"\" def __init__ ( self , timing_config : TimingConfig ): self . timing_config = timing_config def on_sim_start ( self , sim : Simulator ): \"\"\" Called when the simulation starts. \"\"\" pass def on_sim_end ( self , sim : Simulator ): \"\"\" Called when the simulation ends. \"\"\" pass def on_cycle_start ( self , sim : Simulator ): \"\"\" Called when a new cycle starts. \"\"\" pass def on_cycle_end ( self , sim : Simulator ): \"\"\" Called when a cycle ends. \"\"\" pass def on_camera_frame ( self , sim : Simulator ): \"\"\" Called when a camera frame is captured. Happens every frame. \"\"\" pass def on_imaging_start ( self , sim : Simulator ): \"\"\" Called when imaging phase starts. \"\"\" pass def on_micro_frame ( self , sim : Simulator ): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass def on_imaging_end ( self , sim : Simulator ): \"\"\" Called when imaging phase ends. \"\"\" pass def on_movement_start ( self , sim : Simulator ): \"\"\" Called when movement phase starts. \"\"\" pass def on_movement_end ( self , sim : Simulator ): \"\"\" Called when movement phase ends. \"\"\" pass @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () @abc . abstractmethod def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : \"\"\" Returns a list of bbox predictions of the worm, for each frame of the current cycle. If a prediction is not available, return None for that frame. Used internally for logging. \"\"\" raise NotImplementedError ()","title":"Module wtracker.sim.simulator"},{"location":"reference/wtracker/sim/simulator/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/simulator/#simcontroller","text":"class SimController ( timing_config : 'TimingConfig' ) Abstract base class for simulator controllers.","title":"SimController"},{"location":"reference/wtracker/sim/simulator/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class SimController ( abc . ABC ) : \"\"\" Abstract base class for simulator controllers. Attributes: timing_config (TimingConfig): The timing configuration for the simulator. \"\"\" def __init__ ( self , timing_config : TimingConfig ) : self . timing_config = timing_config def on_sim_start ( self , sim : Simulator ) : \"\"\" Called when the simulation starts. \"\"\" pass def on_sim_end ( self , sim : Simulator ) : \"\"\" Called when the simulation ends. \"\"\" pass def on_cycle_start ( self , sim : Simulator ) : \"\"\" Called when a new cycle starts. \"\"\" pass def on_cycle_end ( self , sim : Simulator ) : \"\"\" Called when a cycle ends. \"\"\" pass def on_camera_frame ( self , sim : Simulator ) : \"\"\" Called when a camera frame is captured. Happens every frame. \"\"\" pass def on_imaging_start ( self , sim : Simulator ) : \"\"\" Called when imaging phase starts. \"\"\" pass def on_micro_frame ( self , sim : Simulator ) : \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass def on_imaging_end ( self , sim : Simulator ) : \"\"\" Called when imaging phase ends. \"\"\" pass def on_movement_start ( self , sim : Simulator ) : \"\"\" Called when movement phase starts. \"\"\" pass def on_movement_end ( self , sim : Simulator ) : \"\"\" Called when movement phase ends. \"\"\" pass @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError () @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError () @abc . abstractmethod def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : \"\"\" Returns a list of bbox predictions of the worm, for each frame of the current cycle. If a prediction is not available, return None for that frame. Used internally for logging. \"\"\" raise NotImplementedError ()","title":"Attributes"},{"location":"reference/wtracker/sim/simulator/#ancestors-in-mro","text":"abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/simulator/#descendants","text":"wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.sim_controllers.logging_controller.LoggingController wtracker.sim.sim_controllers.yolo_controller.YoloController","title":"Descendants"},{"location":"reference/wtracker/sim/simulator/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/simulator/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : 'Simulator' ) -> 'None' Called when the movement prediction begins. View Source @abc . abstractmethod def begin_movement_prediction ( self , sim : Simulator ) -> None : \"\"\" Called when the movement prediction begins. \"\"\" raise NotImplementedError ()","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/simulator/#on_camera_frame","text":"def on_camera_frame ( self , sim : 'Simulator' ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : \"\" \" Called when a camera frame is captured. Happens every frame. \" \"\" pass","title":"on_camera_frame"},{"location":"reference/wtracker/sim/simulator/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/simulator/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/simulator/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/simulator/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/simulator/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/simulator/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/simulator/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/simulator/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/simulator/#on_sim_start","text":"def on_sim_start ( self , sim : 'Simulator' ) Called when the simulation starts. View Source def on_sim_start(self, sim: Simulator): \"\"\" Called when the simulation starts. \"\"\" pass","title":"on_sim_start"},{"location":"reference/wtracker/sim/simulator/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : 'Simulator' ) -> 'tuple[int, int]' Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source @abc . abstractmethod def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : \"\"\" Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: tuple[int, int]: The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. \"\"\" raise NotImplementedError ()","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/simulator/#simulator","text":"class Simulator ( timing_config : 'TimingConfig' , experiment_config : 'ExperimentConfig' , sim_controller : 'SimController' , reader : 'FrameReader' = None , motor_controller : 'MotorController' = None ) A class representing a simulator for a biological experiment.","title":"Simulator"},{"location":"reference/wtracker/sim/simulator/#attributes_1","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the experiment. None experiment_config ExperimentConfig The experiment configuration. None sim_controller SimController The simulation controller. None reader FrameReader The frame reader. None motor_controller MotorController The motor controller. None timing_config TimingConfig The timing configuration for the experiment. None experiment_config ExperimentConfig The experiment configuration. None View Source class Simulator : \"\"\" A class representing a simulator for a biological experiment. Args: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. sim_controller (SimController): The simulation controller. reader (FrameReader, optional): The frame reader. motor_controller (MotorController, optional): The motor controller. Attributes: timing_config (TimingConfig): The timing configuration for the experiment. experiment_config (ExperimentConfig): The experiment configuration. \"\"\" def __init__ ( self , timing_config : TimingConfig , experiment_config : ExperimentConfig , sim_controller : SimController , reader : FrameReader = None , motor_controller : MotorController = None , ) -> None : self . timing_config = timing_config self . experiment_config = experiment_config self . _sim_controller = sim_controller if reader is None : num_frames = experiment_config . num_frames padding_size = ( timing_config . camera_size_px [ 0 ] // 2 * 2 , timing_config . camera_size_px [ 1 ] // 2 * 2 ) resolution = tuple ( [ sum(x) for x in zip(experiment_config.orig_resolution, padding_size) ] ) reader = DummyReader ( num_frames , resolution , colored = True ) if motor_controller is None : motor_controller = SineMotorController ( timing_config ) self . _motor_controller = motor_controller self . _view = ViewController ( frame_reader = reader , camera_size = timing_config . camera_size_px , micro_size = timing_config . micro_size_px , init_position = experiment_config . init_position , ) @property def view ( self ) -> ViewController : \"\"\" Get the view controller. Returns: ViewController: The view controller. \"\"\" return self . _view @property def position ( self ) -> tuple [ int, int ] : \"\"\" Get the current position. Returns: tuple[int, int]: The current position. \"\"\" return self . _view . position @property def cycle_number ( self ) -> int : \"\"\" Get the current cycle number. Returns: int: The current cycle number. \"\"\" return self . _view . index // self . timing_config . cycle_frame_num @property def frame_number ( self ) -> int : \"\"\" Get the current frame number. Returns: int: The current frame number. \"\"\" return self . _view . index @property def cycle_step ( self ) -> int : \"\"\" Get the current cycle step. Returns: int: The current cycle step. \"\"\" return self . _view . index % self . timing_config . cycle_frame_num def camera_view ( self ) -> np . ndarray : \"\"\" Get the view that the camera sees. Returns: np.ndarray: The camera view. \"\"\" return self . _view . camera_view () def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view () def _reset ( self ) : \"\"\" Reset the simulator. \"\"\" self . view . reset () self . view . set_position ( * self . experiment_config . init_position ) def run ( self , visualize : bool = False , wait_key : bool = False ) : \"\"\" Run the simulation. Args: visualize (bool, optional): Whether to visualize the simulation. wait_key (bool, optional): Whether to wait for a key press to advance the simulation during visualization. \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config . cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress () : if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num : self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num : self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num : self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num : dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize : self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close ()","title":"Attributes"},{"location":"reference/wtracker/sim/simulator/#instance-variables","text":"cycle_number Get the current cycle number. cycle_step Get the current cycle step. frame_number Get the current frame number. position Get the current position. view Get the view controller.","title":"Instance variables"},{"location":"reference/wtracker/sim/simulator/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/simulator/#camera_view","text":"def camera_view ( self ) -> 'np.ndarray' Get the view that the camera sees. Returns: Type Description np.ndarray The camera view. View Source def camera_view ( self ) - > np . ndarray : \"\" \" Get the view that the camera sees. Returns: np.ndarray: The camera view. \" \"\" return self . _view . camera_view ()","title":"camera_view"},{"location":"reference/wtracker/sim/simulator/#micro_view","text":"def micro_view ( self ) -> 'np.ndarray' Get the view that the microscope sees. Returns: Type Description np.ndarray The micro view. View Source def micro_view ( self ) -> np . ndarray : \"\"\" Get the view that the microscope sees. Returns: np.ndarray: The micro view. \"\"\" return self . _view . micro_view ()","title":"micro_view"},{"location":"reference/wtracker/sim/simulator/#run","text":"def run ( self , visualize : 'bool' = False , wait_key : 'bool' = False ) Run the simulation. Parameters: Name Type Description Default visualize bool Whether to visualize the simulation. None wait_key bool Whether to wait for a key press to advance the simulation during visualization. None View Source def run ( self , visualize: bool = False , wait_key: bool = False ) : \"\"\" Run the simulation . Args: visualize ( bool , optional ) : Whether to visualize the simulation . wait_key ( bool , optional ) : Whether to wait for a key press to advance the simulation during visualization . \"\"\" config = self . timing_config total_cycles = len ( self . _view ) // config.cycle_frame_num pbar = tqdm ( total = total_cycles , desc = \"Simulation Progress\" , unit = \"cycle\" ) self . _reset () self . _sim_controller . on_sim_start ( self ) while self . _view . progress () : if self . cycle_step == 0 : if self . cycle_number > 0 : self . _sim_controller . on_movement_end ( self ) self . _sim_controller . on_cycle_end ( self ) self . _sim_controller . on_cycle_start ( self ) self . _sim_controller . on_camera_frame ( self ) if self . cycle_step == 0 : self . _sim_controller . on_imaging_start ( self ) if self . cycle_step < config . imaging_frame_num: self . _sim_controller . on_micro_frame ( self ) if self . cycle_step == config . imaging_frame_num - config . pred_frame_num: self . _sim_controller . begin_movement_prediction ( self ) if self . cycle_step == config . imaging_frame_num: self . _sim_controller . on_imaging_end ( self ) dx , dy = self . _sim_controller . provide_movement_vector ( self ) self . _sim_controller . on_movement_start ( self ) self . _motor_controller . register_move ( dx , dy ) if config . imaging_frame_num <= self . cycle_step < config . imaging_frame_num + config . moving_frame_num: dx , dy = self . _motor_controller . step () self . _view . move_position ( dx , dy ) if self . cycle_step == config . cycle_frame_num - 1 : pbar . update ( 1 ) if visualize: self . _view . visualize_world ( timeout = 0 if wait_key else 1 ) self . _sim_controller . on_sim_end ( self ) pbar . close ()","title":"run"},{"location":"reference/wtracker/sim/view_controller/","text":"Module wtracker.sim.view_controller View Source import cv2 as cv import numpy as np from wtracker.utils.frame_reader import FrameReader , FrameStream class ViewController ( FrameStream ): \"\"\" A class representing a view controller for a frame stream. This class allows for easy manipulation of the camera and microscope positions, and provides their corresponding views. Args: frame_reader (FrameReader): The frame reader object. camera_size (tuple[int, int], optional): The size of the camera frame. micro_size (tuple[int, int], optional): The size of the micro frame. init_position (tuple[int, int], optional): The initial position of the view. Attributes: frame_reader (FrameReader): The frame reader object. camera_size (tuple[int, int]): The size of the camera view (w, h). micro_size (tuple[int, int]): The size of the micro view (w, h). position (tuple[int, int]): The current position of the center of the view (x, y). \"\"\" def __init__ ( self , frame_reader : FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ), ): super () . __init__ ( frame_reader ) assert camera_size [ 0 ] >= micro_size [ 0 ] assert camera_size [ 1 ] >= micro_size [ 1 ] self . _padding_size : tuple [ int , int ] = ( camera_size [ 0 ] // 2 , camera_size [ 1 ] // 2 ) self . _camera_size = camera_size self . _micro_size = micro_size self . _position = init_position self . set_position ( * init_position ) def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding. Returns: np.ndarray: The padded frame. \"\"\" frame = super () . read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame @property def position ( self ) -> tuple [ int , int ]: \"\"\" Get the current position of the view controller. Returns: tuple[int, int]: The current position (x, y). \"\"\" return self . _position @property def camera_size ( self ) -> tuple [ int , int ]: \"\"\" Get the size of the camera view. Returns: tuple[int, int]: The size of the camera view (w, h). \"\"\" return self . _camera_size @property def micro_size ( self ) -> tuple [ int , int ]: \"\"\" Get the size of the micro view. Returns: tuple[int, int]: The size of the micro view (w, h). \"\"\" return self . _micro_size @property def camera_position ( self ) -> tuple [ int , int , int , int ]: \"\"\" Get the position of the camera view. Returns: tuple[int, int, int, int]: The position of the camera view (x, y, w, h). \"\"\" w , h = self . camera_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h @property def micro_position ( self ) -> tuple [ int , int , int , int ]: \"\"\" Get the position of the micro view. Returns: tuple[int, int, int, int]: The position of the micro view (x, y, w, h). \"\"\" w , h = self . micro_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h def set_position ( self , x : int , y : int ): \"\"\" Set the position of the view controller. Note, that the position is clamped to the frame size. Args: x (int): The x-coordinate of the position. y (int): The y-coordinate of the position. \"\"\" x = np . clip ( x , 0 , self . _frame_reader . frame_shape [ 1 ] - 1 ) y = np . clip ( y , 0 , self . _frame_reader . frame_shape [ 0 ] - 1 ) self . _position = ( x , y ) def move_position ( self , dx : int , dy : int ): \"\"\" Move the position of the view controller by dx and dy. Args: dx (int): The amount to move in the x-direction. dy (int): The amount to move in the y-direction. \"\"\" self . set_position ( self . _position [ 0 ] + dx , self . _position [ 1 ] + dy ) def _calc_view_bbox ( self , w : int , h : int ) -> tuple [ int , int , int , int ]: \"\"\" Calculate the bbox of the view, while taking padding into account. Args: w (int): The width of the view. h (int): The height of the view. Returns: tuple[int, int, int, int]: The bounding box of the view (x, y, w, h). \"\"\" x = self . _position [ 0 ] + self . _padding_size [ 0 ] - w // 2 y = self . _position [ 1 ] + self . _padding_size [ 1 ] - h // 2 return x , y , w , h def _custom_view ( self , w : int , h : int ) -> np . ndarray : \"\"\" Get a custom view of the frame. Args: w (int): The width of the view. h (int): The height of the view. Returns: np.ndarray: The custom view of the frame. \"\"\" x , y , w , h = self . _calc_view_bbox ( w , h ) frame = self . read () slice = frame [ y : y + w , x : x + h ] return slice def camera_view ( self ) -> np . ndarray : \"\"\" Get the camera view. Returns: np.ndarray: The camera view. \"\"\" return self . _custom_view ( * self . camera_size ) def micro_view ( self ) -> np . ndarray : \"\"\" Get the micro view. Returns: np.ndarray: The micro view. \"\"\" return self . _custom_view ( * self . micro_size ) def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ): \"\"\" Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Args: line_width (int): The width of the bounding box lines. \"\"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout ) Classes ViewController class ViewController ( frame_reader : wtracker . utils . frame_reader . FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ) ) A class representing a view controller for a frame stream. This class allows for easy manipulation of the camera and microscope positions, and provides their corresponding views. Attributes Name Type Description Default frame_reader FrameReader The frame reader object. None camera_size tuple[int, int] The size of the camera frame. None micro_size tuple[int, int] The size of the micro frame. None init_position tuple[int, int] The initial position of the view. None frame_reader FrameReader The frame reader object. None camera_size tuple[int, int] The size of the camera view (w, h). None micro_size tuple[int, int] The size of the micro view (w, h). None position tuple[int, int] The current position of the center of the view (x, y). None View Source class ViewController ( FrameStream ) : \"\"\" A class representing a view controller for a frame stream . This class allows for easy manipulation of the camera and microscope positions , and provides their corresponding views . Args : frame_reader ( FrameReader ) : The frame reader object . camera_size ( tuple [ int , int ], optional ) : The size of the camera frame . micro_size ( tuple [ int , int ], optional ) : The size of the micro frame . init_position ( tuple [ int , int ], optional ) : The initial position of the view . Attributes : frame_reader ( FrameReader ) : The frame reader object . camera_size ( tuple [ int , int ]) : The size of the camera view ( w , h ). micro_size ( tuple [ int , int ]) : The size of the micro view ( w , h ). position ( tuple [ int , int ]) : The current position of the center of the view ( x , y ). \"\"\" def __init__ ( self , frame_reader : FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ), ) : super (). __init__ ( frame_reader ) assert camera_size [ 0 ] >= micro_size [ 0 ] assert camera_size [ 1 ] >= micro_size [ 1 ] self . _padding_size : tuple [ int , int ] = ( camera_size [ 0 ] // 2, camera_size[1] // 2) self . _camera_size = camera_size self . _micro_size = micro_size self . _position = init_position self . set_position ( * init_position ) def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding . Returns : np . ndarray : The padded frame . \"\"\" frame = super (). read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame @property def position ( self ) -> tuple [ int , int ] : \"\"\" Get the current position of the view controller . Returns : tuple [ int , int ] : The current position ( x , y ). \"\"\" return self . _position @property def camera_size ( self ) -> tuple [ int , int ] : \"\"\" Get the size of the camera view . Returns : tuple [ int , int ] : The size of the camera view ( w , h ). \"\"\" return self . _camera_size @property def micro_size ( self ) -> tuple [ int , int ] : \"\"\" Get the size of the micro view . Returns : tuple [ int , int ] : The size of the micro view ( w , h ). \"\"\" return self . _micro_size @property def camera_position ( self ) -> tuple [ int , int , int , int ] : \"\"\" Get the position of the camera view . Returns : tuple [ int , int , int , int ] : The position of the camera view ( x , y , w , h ). \"\"\" w , h = self . camera_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h @property def micro_position ( self ) -> tuple [ int , int , int , int ] : \"\"\" Get the position of the micro view . Returns : tuple [ int , int , int , int ] : The position of the micro view ( x , y , w , h ). \"\"\" w , h = self . micro_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h def set_position ( self , x : int , y : int ) : \"\"\" Set the position of the view controller . Note , that the position is clamped to the frame size . Args : x ( int ) : The x - coordinate of the position . y ( int ) : The y - coordinate of the position . \"\"\" x = np . clip ( x , 0 , self . _frame_reader . frame_shape [ 1 ] - 1 ) y = np . clip ( y , 0 , self . _frame_reader . frame_shape [ 0 ] - 1 ) self . _position = ( x , y ) def move_position ( self , dx : int , dy : int ) : \"\"\" Move the position of the view controller by dx and dy . Args : dx ( int ) : The amount to move in the x - direction . dy ( int ) : The amount to move in the y - direction . \"\"\" self . set_position ( self . _position [ 0 ] + dx , self . _position [ 1 ] + dy ) def _calc_view_bbox ( self , w : int , h : int ) -> tuple [ int , int , int , int ] : \"\"\" Calculate the bbox of the view , while taking padding into account . Args : w ( int ) : The width of the view . h ( int ) : The height of the view . Returns : tuple [ int , int , int , int ] : The bounding box of the view ( x , y , w , h ). \"\"\" x = self . _position [ 0 ] + self . _padding_size [ 0 ] - w // 2 y = self . _position [ 1 ] + self . _padding_size [ 1 ] - h // 2 return x , y , w , h def _custom_view ( self , w : int , h : int ) -> np . ndarray : \"\"\" Get a custom view of the frame . Args : w ( int ) : The width of the view . h ( int ) : The height of the view . Returns : np . ndarray : The custom view of the frame . \"\"\" x , y , w , h = self . _calc_view_bbox ( w , h ) frame = self . read () slice = frame [ y : y + w , x : x + h ] return slice def camera_view ( self ) -> np . ndarray : \"\"\" Get the camera view . Returns : np . ndarray : The camera view . \"\"\" return self . _custom_view ( * self . camera_size ) def micro_view ( self ) -> np . ndarray : \"\"\" Get the micro view . Returns : np . ndarray : The micro view . \"\"\" return self . _custom_view ( * self . micro_size ) def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) : \"\"\" Visualize the world view with bounding boxes . Both the camera and micro views are visualized , along with the center point . Args : line_width ( int ) : The width of the bounding box lines . \"\"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout ) Ancestors (in MRO) wtracker.utils.frame_reader.FrameStream Instance variables camera_position Get the position of the camera view. camera_size Get the size of the camera view. index The index of the current frame. micro_position Get the position of the micro view. micro_size Get the size of the micro view. position Get the current position of the view controller. Methods camera_view def camera_view ( self ) -> numpy . ndarray Get the camera view. Returns: Type Description np.ndarray The camera view. View Source def camera_view ( self ) - > np . ndarray : \"\" \" Get the camera view. Returns: np.ndarray: The camera view. \" \"\" return self . _custom_view ( * self . camera_size ) can_read def can_read ( self ) -> 'bool' View Source def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) micro_view def micro_view ( self ) -> numpy . ndarray Get the micro view. Returns: Type Description np.ndarray The micro view. View Source def micro_view(self) -> np.ndarray: \"\"\" Get the micro view. Returns: np.ndarray: The micro view. \"\"\" return self._custom_view(*self.micro_size) move_position def move_position ( self , dx : int , dy : int ) Move the position of the view controller by dx and dy. Parameters: Name Type Description Default dx int The amount to move in the x-direction. None dy int The amount to move in the y-direction. None View Source def move_position(self, dx: int, dy: int): \"\"\" Move the position of the view controller by dx and dy. Args: dx (int): The amount to move in the x-direction. dy (int): The amount to move in the y-direction. \"\"\" self.set_position(self._position[0] + dx, self._position[1] + dy) progress def progress ( self , n : 'int' = 1 ) -> 'bool' Moves the current index forward by the specified number of steps. Parameters: Name Type Description Default n int The number of steps to move forward. None Returns: Type Description bool True if the index was successfully moved forward, False otherwise. View Source def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) read def read ( self ) -> numpy . ndarray Read a frame from the frame reader and apply padding. Returns: Type Description np.ndarray The padded frame. View Source def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding. Returns: np.ndarray: The padded frame. \"\"\" frame = super (). read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame reset def reset ( self ) Resets the frame reader to the beginning of the steam. View Source def reset(self): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self.seek(-1) seek def seek ( self , idx : 'int' ) -> 'bool' Move the index to the specified position. Parameters: Name Type Description Default idx int The index to seek to. None Returns: Type Description bool True if the index is within the valid range, False otherwise. View Source def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () set_position def set_position ( self , x : int , y : int ) Set the position of the view controller. Note, that the position is clamped to the frame size. Parameters: Name Type Description Default x int The x-coordinate of the position. None y int The y-coordinate of the position. None View Source def set_position(self, x: int, y: int): \"\"\" Set the position of the view controller. Note, that the position is clamped to the frame size. Args: x (int): The x-coordinate of the position. y (int): The y-coordinate of the position. \"\"\" x = np.clip(x, 0, self._frame_reader.frame_shape[1] - 1) y = np.clip(y, 0, self._frame_reader.frame_shape[0] - 1) self._position = (x, y) visualize_world def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Parameters: Name Type Description Default line_width int The width of the bounding box lines. None View Source def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) : \"\" \" Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Args: line_width (int): The width of the bounding box lines. \" \"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout )","title":"View Controller"},{"location":"reference/wtracker/sim/view_controller/#module-wtrackersimview_controller","text":"View Source import cv2 as cv import numpy as np from wtracker.utils.frame_reader import FrameReader , FrameStream class ViewController ( FrameStream ): \"\"\" A class representing a view controller for a frame stream. This class allows for easy manipulation of the camera and microscope positions, and provides their corresponding views. Args: frame_reader (FrameReader): The frame reader object. camera_size (tuple[int, int], optional): The size of the camera frame. micro_size (tuple[int, int], optional): The size of the micro frame. init_position (tuple[int, int], optional): The initial position of the view. Attributes: frame_reader (FrameReader): The frame reader object. camera_size (tuple[int, int]): The size of the camera view (w, h). micro_size (tuple[int, int]): The size of the micro view (w, h). position (tuple[int, int]): The current position of the center of the view (x, y). \"\"\" def __init__ ( self , frame_reader : FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ), ): super () . __init__ ( frame_reader ) assert camera_size [ 0 ] >= micro_size [ 0 ] assert camera_size [ 1 ] >= micro_size [ 1 ] self . _padding_size : tuple [ int , int ] = ( camera_size [ 0 ] // 2 , camera_size [ 1 ] // 2 ) self . _camera_size = camera_size self . _micro_size = micro_size self . _position = init_position self . set_position ( * init_position ) def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding. Returns: np.ndarray: The padded frame. \"\"\" frame = super () . read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame @property def position ( self ) -> tuple [ int , int ]: \"\"\" Get the current position of the view controller. Returns: tuple[int, int]: The current position (x, y). \"\"\" return self . _position @property def camera_size ( self ) -> tuple [ int , int ]: \"\"\" Get the size of the camera view. Returns: tuple[int, int]: The size of the camera view (w, h). \"\"\" return self . _camera_size @property def micro_size ( self ) -> tuple [ int , int ]: \"\"\" Get the size of the micro view. Returns: tuple[int, int]: The size of the micro view (w, h). \"\"\" return self . _micro_size @property def camera_position ( self ) -> tuple [ int , int , int , int ]: \"\"\" Get the position of the camera view. Returns: tuple[int, int, int, int]: The position of the camera view (x, y, w, h). \"\"\" w , h = self . camera_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h @property def micro_position ( self ) -> tuple [ int , int , int , int ]: \"\"\" Get the position of the micro view. Returns: tuple[int, int, int, int]: The position of the micro view (x, y, w, h). \"\"\" w , h = self . micro_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h def set_position ( self , x : int , y : int ): \"\"\" Set the position of the view controller. Note, that the position is clamped to the frame size. Args: x (int): The x-coordinate of the position. y (int): The y-coordinate of the position. \"\"\" x = np . clip ( x , 0 , self . _frame_reader . frame_shape [ 1 ] - 1 ) y = np . clip ( y , 0 , self . _frame_reader . frame_shape [ 0 ] - 1 ) self . _position = ( x , y ) def move_position ( self , dx : int , dy : int ): \"\"\" Move the position of the view controller by dx and dy. Args: dx (int): The amount to move in the x-direction. dy (int): The amount to move in the y-direction. \"\"\" self . set_position ( self . _position [ 0 ] + dx , self . _position [ 1 ] + dy ) def _calc_view_bbox ( self , w : int , h : int ) -> tuple [ int , int , int , int ]: \"\"\" Calculate the bbox of the view, while taking padding into account. Args: w (int): The width of the view. h (int): The height of the view. Returns: tuple[int, int, int, int]: The bounding box of the view (x, y, w, h). \"\"\" x = self . _position [ 0 ] + self . _padding_size [ 0 ] - w // 2 y = self . _position [ 1 ] + self . _padding_size [ 1 ] - h // 2 return x , y , w , h def _custom_view ( self , w : int , h : int ) -> np . ndarray : \"\"\" Get a custom view of the frame. Args: w (int): The width of the view. h (int): The height of the view. Returns: np.ndarray: The custom view of the frame. \"\"\" x , y , w , h = self . _calc_view_bbox ( w , h ) frame = self . read () slice = frame [ y : y + w , x : x + h ] return slice def camera_view ( self ) -> np . ndarray : \"\"\" Get the camera view. Returns: np.ndarray: The camera view. \"\"\" return self . _custom_view ( * self . camera_size ) def micro_view ( self ) -> np . ndarray : \"\"\" Get the micro view. Returns: np.ndarray: The micro view. \"\"\" return self . _custom_view ( * self . micro_size ) def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ): \"\"\" Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Args: line_width (int): The width of the bounding box lines. \"\"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout )","title":"Module wtracker.sim.view_controller"},{"location":"reference/wtracker/sim/view_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/view_controller/#viewcontroller","text":"class ViewController ( frame_reader : wtracker . utils . frame_reader . FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ) ) A class representing a view controller for a frame stream. This class allows for easy manipulation of the camera and microscope positions, and provides their corresponding views.","title":"ViewController"},{"location":"reference/wtracker/sim/view_controller/#attributes","text":"Name Type Description Default frame_reader FrameReader The frame reader object. None camera_size tuple[int, int] The size of the camera frame. None micro_size tuple[int, int] The size of the micro frame. None init_position tuple[int, int] The initial position of the view. None frame_reader FrameReader The frame reader object. None camera_size tuple[int, int] The size of the camera view (w, h). None micro_size tuple[int, int] The size of the micro view (w, h). None position tuple[int, int] The current position of the center of the view (x, y). None View Source class ViewController ( FrameStream ) : \"\"\" A class representing a view controller for a frame stream . This class allows for easy manipulation of the camera and microscope positions , and provides their corresponding views . Args : frame_reader ( FrameReader ) : The frame reader object . camera_size ( tuple [ int , int ], optional ) : The size of the camera frame . micro_size ( tuple [ int , int ], optional ) : The size of the micro frame . init_position ( tuple [ int , int ], optional ) : The initial position of the view . Attributes : frame_reader ( FrameReader ) : The frame reader object . camera_size ( tuple [ int , int ]) : The size of the camera view ( w , h ). micro_size ( tuple [ int , int ]) : The size of the micro view ( w , h ). position ( tuple [ int , int ]) : The current position of the center of the view ( x , y ). \"\"\" def __init__ ( self , frame_reader : FrameReader , camera_size : tuple [ int , int ] = ( 251 , 251 ), micro_size : tuple [ int , int ] = ( 45 , 45 ), init_position : tuple [ int , int ] = ( 0 , 0 ), ) : super (). __init__ ( frame_reader ) assert camera_size [ 0 ] >= micro_size [ 0 ] assert camera_size [ 1 ] >= micro_size [ 1 ] self . _padding_size : tuple [ int , int ] = ( camera_size [ 0 ] // 2, camera_size[1] // 2) self . _camera_size = camera_size self . _micro_size = micro_size self . _position = init_position self . set_position ( * init_position ) def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding . Returns : np . ndarray : The padded frame . \"\"\" frame = super (). read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame @property def position ( self ) -> tuple [ int , int ] : \"\"\" Get the current position of the view controller . Returns : tuple [ int , int ] : The current position ( x , y ). \"\"\" return self . _position @property def camera_size ( self ) -> tuple [ int , int ] : \"\"\" Get the size of the camera view . Returns : tuple [ int , int ] : The size of the camera view ( w , h ). \"\"\" return self . _camera_size @property def micro_size ( self ) -> tuple [ int , int ] : \"\"\" Get the size of the micro view . Returns : tuple [ int , int ] : The size of the micro view ( w , h ). \"\"\" return self . _micro_size @property def camera_position ( self ) -> tuple [ int , int , int , int ] : \"\"\" Get the position of the camera view . Returns : tuple [ int , int , int , int ] : The position of the camera view ( x , y , w , h ). \"\"\" w , h = self . camera_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h @property def micro_position ( self ) -> tuple [ int , int , int , int ] : \"\"\" Get the position of the micro view . Returns : tuple [ int , int , int , int ] : The position of the micro view ( x , y , w , h ). \"\"\" w , h = self . micro_size x = self . _position [ 0 ] - w // 2 y = self . _position [ 1 ] - h // 2 return x , y , w , h def set_position ( self , x : int , y : int ) : \"\"\" Set the position of the view controller . Note , that the position is clamped to the frame size . Args : x ( int ) : The x - coordinate of the position . y ( int ) : The y - coordinate of the position . \"\"\" x = np . clip ( x , 0 , self . _frame_reader . frame_shape [ 1 ] - 1 ) y = np . clip ( y , 0 , self . _frame_reader . frame_shape [ 0 ] - 1 ) self . _position = ( x , y ) def move_position ( self , dx : int , dy : int ) : \"\"\" Move the position of the view controller by dx and dy . Args : dx ( int ) : The amount to move in the x - direction . dy ( int ) : The amount to move in the y - direction . \"\"\" self . set_position ( self . _position [ 0 ] + dx , self . _position [ 1 ] + dy ) def _calc_view_bbox ( self , w : int , h : int ) -> tuple [ int , int , int , int ] : \"\"\" Calculate the bbox of the view , while taking padding into account . Args : w ( int ) : The width of the view . h ( int ) : The height of the view . Returns : tuple [ int , int , int , int ] : The bounding box of the view ( x , y , w , h ). \"\"\" x = self . _position [ 0 ] + self . _padding_size [ 0 ] - w // 2 y = self . _position [ 1 ] + self . _padding_size [ 1 ] - h // 2 return x , y , w , h def _custom_view ( self , w : int , h : int ) -> np . ndarray : \"\"\" Get a custom view of the frame . Args : w ( int ) : The width of the view . h ( int ) : The height of the view . Returns : np . ndarray : The custom view of the frame . \"\"\" x , y , w , h = self . _calc_view_bbox ( w , h ) frame = self . read () slice = frame [ y : y + w , x : x + h ] return slice def camera_view ( self ) -> np . ndarray : \"\"\" Get the camera view . Returns : np . ndarray : The camera view . \"\"\" return self . _custom_view ( * self . camera_size ) def micro_view ( self ) -> np . ndarray : \"\"\" Get the micro view . Returns : np . ndarray : The micro view . \"\"\" return self . _custom_view ( * self . micro_size ) def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) : \"\"\" Visualize the world view with bounding boxes . Both the camera and micro views are visualized , along with the center point . Args : line_width ( int ) : The width of the bounding box lines . \"\"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout )","title":"Attributes"},{"location":"reference/wtracker/sim/view_controller/#ancestors-in-mro","text":"wtracker.utils.frame_reader.FrameStream","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/view_controller/#instance-variables","text":"camera_position Get the position of the camera view. camera_size Get the size of the camera view. index The index of the current frame. micro_position Get the position of the micro view. micro_size Get the size of the micro view. position Get the current position of the view controller.","title":"Instance variables"},{"location":"reference/wtracker/sim/view_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/view_controller/#camera_view","text":"def camera_view ( self ) -> numpy . ndarray Get the camera view. Returns: Type Description np.ndarray The camera view. View Source def camera_view ( self ) - > np . ndarray : \"\" \" Get the camera view. Returns: np.ndarray: The camera view. \" \"\" return self . _custom_view ( * self . camera_size )","title":"camera_view"},{"location":"reference/wtracker/sim/view_controller/#can_read","text":"def can_read ( self ) -> 'bool' View Source def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader )","title":"can_read"},{"location":"reference/wtracker/sim/view_controller/#micro_view","text":"def micro_view ( self ) -> numpy . ndarray Get the micro view. Returns: Type Description np.ndarray The micro view. View Source def micro_view(self) -> np.ndarray: \"\"\" Get the micro view. Returns: np.ndarray: The micro view. \"\"\" return self._custom_view(*self.micro_size)","title":"micro_view"},{"location":"reference/wtracker/sim/view_controller/#move_position","text":"def move_position ( self , dx : int , dy : int ) Move the position of the view controller by dx and dy. Parameters: Name Type Description Default dx int The amount to move in the x-direction. None dy int The amount to move in the y-direction. None View Source def move_position(self, dx: int, dy: int): \"\"\" Move the position of the view controller by dx and dy. Args: dx (int): The amount to move in the x-direction. dy (int): The amount to move in the y-direction. \"\"\" self.set_position(self._position[0] + dx, self._position[1] + dy)","title":"move_position"},{"location":"reference/wtracker/sim/view_controller/#progress","text":"def progress ( self , n : 'int' = 1 ) -> 'bool' Moves the current index forward by the specified number of steps. Parameters: Name Type Description Default n int The number of steps to move forward. None Returns: Type Description bool True if the index was successfully moved forward, False otherwise. View Source def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n )","title":"progress"},{"location":"reference/wtracker/sim/view_controller/#read","text":"def read ( self ) -> numpy . ndarray Read a frame from the frame reader and apply padding. Returns: Type Description np.ndarray The padded frame. View Source def read ( self ) -> np . ndarray : \"\"\" Read a frame from the frame reader and apply padding. Returns: np.ndarray: The padded frame. \"\"\" frame = super (). read () frame = cv . copyMakeBorder ( src = frame , left = self . _padding_size [ 0 ], right = self . _padding_size [ 0 ], top = self . _padding_size [ 1 ], bottom = self . _padding_size [ 1 ], borderType = cv . BORDER_REPLICATE , ) return frame","title":"read"},{"location":"reference/wtracker/sim/view_controller/#reset","text":"def reset ( self ) Resets the frame reader to the beginning of the steam. View Source def reset(self): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self.seek(-1)","title":"reset"},{"location":"reference/wtracker/sim/view_controller/#seek","text":"def seek ( self , idx : 'int' ) -> 'bool' Move the index to the specified position. Parameters: Name Type Description Default idx int The index to seek to. None Returns: Type Description bool True if the index is within the valid range, False otherwise. View Source def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read ()","title":"seek"},{"location":"reference/wtracker/sim/view_controller/#set_position","text":"def set_position ( self , x : int , y : int ) Set the position of the view controller. Note, that the position is clamped to the frame size. Parameters: Name Type Description Default x int The x-coordinate of the position. None y int The y-coordinate of the position. None View Source def set_position(self, x: int, y: int): \"\"\" Set the position of the view controller. Note, that the position is clamped to the frame size. Args: x (int): The x-coordinate of the position. y (int): The y-coordinate of the position. \"\"\" x = np.clip(x, 0, self._frame_reader.frame_shape[1] - 1) y = np.clip(y, 0, self._frame_reader.frame_shape[0] - 1) self._position = (x, y)","title":"set_position"},{"location":"reference/wtracker/sim/view_controller/#visualize_world","text":"def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Parameters: Name Type Description Default line_width int The width of the bounding box lines. None View Source def visualize_world ( self , line_width : int = 4 , timeout : int = 1 ) : \"\" \" Visualize the world view with bounding boxes. Both the camera and micro views are visualized, along with the center point. Args: line_width (int): The width of the bounding box lines. \" \"\" x_mid , y_mid , _ , _ = self . _calc_view_bbox ( 0 , 0 ) x_cam , y_cam , w_cam , h_cam = self . _calc_view_bbox ( * self . camera_size ) x_mic , y_mic , w_mic , h_mic = self . _calc_view_bbox ( * self . micro_size ) world = self . read () if len ( self . _frame_reader . frame_shape ) == 2 : world = cv . cvtColor ( world , cv . COLOR_GRAY2BGR ) cv . rectangle ( world , ( x_cam , y_cam ), ( x_cam + w_cam , y_cam + h_cam ), ( 0 , 0 , 255 ), line_width ) cv . rectangle ( world , ( x_mic , y_mic ), ( x_mic + w_mic , y_mic + h_mic ), ( 0 , 255 , 0 ), line_width ) cv . circle ( world , ( x_mid , y_mid ), 1 , ( 255 , 0 , 0 ), line_width ) cv . imshow ( \"World View\" , world ) cv . waitKey ( timeout )","title":"visualize_world"},{"location":"reference/wtracker/sim/sim_controllers/","text":"Module wtracker.sim.sim_controllers View Source from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.sim.sim_controllers.mlp_controllers import MLPController from wtracker.sim.sim_controllers.logging_controller import LogConfig , LoggingController from wtracker.sim.sim_controllers.optimal_controller import OptimalController from wtracker.sim.sim_controllers.polyfit_controller import PolyfitConfig , PolyfitController from wtracker.sim.sim_controllers.yolo_controller import YoloConfig , YoloController Sub-modules wtracker.sim.sim_controllers.csv_controller wtracker.sim.sim_controllers.logging_controller wtracker.sim.sim_controllers.mlp_controllers wtracker.sim.sim_controllers.optimal_controller wtracker.sim.sim_controllers.polyfit_controller wtracker.sim.sim_controllers.yolo_controller","title":"Index"},{"location":"reference/wtracker/sim/sim_controllers/#module-wtrackersimsim_controllers","text":"View Source from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.sim.sim_controllers.mlp_controllers import MLPController from wtracker.sim.sim_controllers.logging_controller import LogConfig , LoggingController from wtracker.sim.sim_controllers.optimal_controller import OptimalController from wtracker.sim.sim_controllers.polyfit_controller import PolyfitConfig , PolyfitController from wtracker.sim.sim_controllers.yolo_controller import YoloConfig , YoloController","title":"Module wtracker.sim.sim_controllers"},{"location":"reference/wtracker/sim/sim_controllers/#sub-modules","text":"wtracker.sim.sim_controllers.csv_controller wtracker.sim.sim_controllers.logging_controller wtracker.sim.sim_controllers.mlp_controllers wtracker.sim.sim_controllers.optimal_controller wtracker.sim.sim_controllers.polyfit_controller wtracker.sim.sim_controllers.yolo_controller","title":"Sub-modules"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/","text":"Module wtracker.sim.sim_controllers.csv_controller View Source from collections import deque from typing import Collection import pandas as pd import numpy as np from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import SimController , Simulator from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat class CsvController ( SimController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super () . __init__ ( timing_config ) self . csv_path = csv_path self . _csv_data = pd . read_csv ( self . csv_path , usecols = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]) . to_numpy ( dtype = float ) self . _camera_bboxes = deque ( maxlen = timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ): self . _camera_bboxes . clear () def on_camera_frame ( self , sim : Simulator ): self . _camera_bboxes . append ( sim . view . camera_position ) def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ]) worm_bboxes = np . full (( frame_nums . shape [ 0 ], 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums [ valid_mask ], :] if not relative : return worm_bboxes # TODO: if relative == True then it works only if frame number if within the last cycle. # maybe fix that. cam_bboxes = [ self . _camera_bboxes [ n % self . timing_config . cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [:, 0 ] -= cam_bboxes [:, 0 ] worm_bboxes [:, 1 ] -= cam_bboxes [:, 1 ] return worm_bboxes def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: bbox = self . predict ([ sim . frame_number - self . timing_config . pred_frame_num ]) bbox = bbox [ 0 , :] if not np . isfinite ( bbox ) . all (): return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ]) dy = round ( center [ 1 ] - cam_center [ 1 ]) return dx , dy def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : start = ( sim . cycle_number - 1 ) * self . timing_config . cycle_frame_num end = start + self . timing_config . cycle_frame_num end = min ( end , len ( self . _csv_data )) return self . predict ( np . arange ( start , end )) Classes CsvController class CsvController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class CsvController ( SimController ) : def __init__ ( self , timing_config : TimingConfig , csv_path : str ) : super (). __init__ ( timing_config ) self . csv_path = csv_path self . _csv_data = pd . read_csv ( self . csv_path , usecols =[ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ). to_numpy ( dtype = float ) self . _camera_bboxes = deque ( maxlen = timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : bbox = self . predict ( [ sim.frame_number - self.timing_config.pred_frame_num ] ) bbox = bbox [ 0, : ] if not np . isfinite ( bbox ). all () : return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ] ) dy = round ( center [ 1 ] - cam_center [ 1 ] ) return dx , dy def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : start = ( sim . cycle_number - 1 ) * self . timing_config . cycle_frame_num end = start + self . timing_config . cycle_frame_num end = min ( end , len ( self . _csv_data )) return self . predict ( np . arange ( start , end )) Ancestors (in MRO) wtracker.sim.simulator.SimController abc.ABC Descendants wtracker.sim.sim_controllers.mlp_controllers.MLPController wtracker.sim.sim_controllers.optimal_controller.OptimalController wtracker.sim.sim_controllers.polyfit_controller.PolyfitController Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () predict def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) - > tuple [ int , int ] : bbox = self . predict ([ sim . frame_number - self . timing_config . pred_frame_num ]) bbox = bbox [ 0 , : ] if not np . isfinite ( bbox ) . all () : return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ]) dy = round ( center [ 1 ] - cam_center [ 1 ]) return dx , dy","title":"Csv Controller"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#module-wtrackersimsim_controllerscsv_controller","text":"View Source from collections import deque from typing import Collection import pandas as pd import numpy as np from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import SimController , Simulator from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat class CsvController ( SimController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super () . __init__ ( timing_config ) self . csv_path = csv_path self . _csv_data = pd . read_csv ( self . csv_path , usecols = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]) . to_numpy ( dtype = float ) self . _camera_bboxes = deque ( maxlen = timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ): self . _camera_bboxes . clear () def on_camera_frame ( self , sim : Simulator ): self . _camera_bboxes . append ( sim . view . camera_position ) def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ]) worm_bboxes = np . full (( frame_nums . shape [ 0 ], 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums [ valid_mask ], :] if not relative : return worm_bboxes # TODO: if relative == True then it works only if frame number if within the last cycle. # maybe fix that. cam_bboxes = [ self . _camera_bboxes [ n % self . timing_config . cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [:, 0 ] -= cam_bboxes [:, 0 ] worm_bboxes [:, 1 ] -= cam_bboxes [:, 1 ] return worm_bboxes def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: bbox = self . predict ([ sim . frame_number - self . timing_config . pred_frame_num ]) bbox = bbox [ 0 , :] if not np . isfinite ( bbox ) . all (): return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ]) dy = round ( center [ 1 ] - cam_center [ 1 ]) return dx , dy def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : start = ( sim . cycle_number - 1 ) * self . timing_config . cycle_frame_num end = start + self . timing_config . cycle_frame_num end = min ( end , len ( self . _csv_data )) return self . predict ( np . arange ( start , end ))","title":"Module wtracker.sim.sim_controllers.csv_controller"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#csvcontroller","text":"class CsvController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str ) Abstract base class for simulator controllers.","title":"CsvController"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class CsvController ( SimController ) : def __init__ ( self , timing_config : TimingConfig , csv_path : str ) : super (). __init__ ( timing_config ) self . csv_path = csv_path self . _csv_data = pd . read_csv ( self . csv_path , usecols =[ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ). to_numpy ( dtype = float ) self . _camera_bboxes = deque ( maxlen = timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : bbox = self . predict ( [ sim.frame_number - self.timing_config.pred_frame_num ] ) bbox = bbox [ 0, : ] if not np . isfinite ( bbox ). all () : return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ] ) dy = round ( center [ 1 ] - cam_center [ 1 ] ) return dx , dy def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : start = ( sim . cycle_number - 1 ) * self . timing_config . cycle_frame_num end = start + self . timing_config . cycle_frame_num end = min ( end , len ( self . _csv_data )) return self . predict ( np . arange ( start , end ))","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#ancestors-in-mro","text":"wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#descendants","text":"wtracker.sim.sim_controllers.mlp_controllers.MLPController wtracker.sim.sim_controllers.optimal_controller.OptimalController wtracker.sim.sim_controllers.polyfit_controller.PolyfitController","title":"Descendants"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#predict","text":"def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/csv_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) - > tuple [ int , int ] : bbox = self . predict ([ sim . frame_number - self . timing_config . pred_frame_num ]) bbox = bbox [ 0 , : ] if not np . isfinite ( bbox ) . all () : return 0 , 0 center = BoxUtils . center ( bbox ) cam_center = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( center [ 0 ] - cam_center [ 0 ]) dy = round ( center [ 1 ] - cam_center [ 1 ]) return dx , dy","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/","text":"Module wtracker.sim.sim_controllers.logging_controller View Source from collections import deque import numpy as np from dataclasses import dataclass , field from copy import deepcopy from wtracker.sim.simulator import Simulator , SimController from wtracker.utils.io_utils import ImageSaver , FrameSaver from wtracker.utils.log_utils import CSVLogger from wtracker.utils.config_base import ConfigBase from wtracker.utils.path_utils import join_paths , create_parent_directory from wtracker.utils.bbox_utils import BoxUtils , BoxFormat @dataclass class LogConfig ( ConfigBase ): root_folder : str \"\"\"The directory where the logs will be saved into.\"\"\" save_mic_view : bool = False \"\"\"Whether to save the microscope view of each frame.\"\"\" save_cam_view : bool = False \"\"\"Whether to save the camera view of each frame.\"\"\" save_err_view : bool = True \"\"\"Whether to camera view of frames in which no prediction was made.\"\"\" save_wrm_view : bool = False \"\"\"whether to save the detected worm head of each frame.\"\"\" mic_folder_name : str = \"micro\" cam_folder_name : str = \"camera\" err_folder_name : str = \"errors\" wrm_folder_name : str = \"worms\" # TODO: WHY DO WE SAVE IN PNG FORMAT AND NOT BMP? bbox_file_name : str = \"bboxes.csv\" mic_file_name : str = \"mic_ {:09d} .png\" cam_file_name : str = \"cam_ {:09d} .png\" wrm_file_name : str = \"wrm_ {:09d} .png\" mic_file_path : str = field ( init = False ) cam_file_path : str = field ( init = False ) err_file_path : str = field ( init = False ) wrm_file_path : str = field ( init = False ) bbox_file_path : str = field ( init = False ) def __post_init__ ( self ): self . mic_file_path = join_paths ( self . root_folder , self . mic_folder_name , self . mic_file_name ) self . cam_file_path = join_paths ( self . root_folder , self . cam_folder_name , self . cam_file_name ) self . err_file_path = join_paths ( self . root_folder , self . err_folder_name , self . cam_file_name ) self . wrm_file_path = join_paths ( self . root_folder , self . wrm_folder_name , self . wrm_file_name ) self . bbox_file_path = join_paths ( self . root_folder , self . bbox_file_name ) def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path ) class LoggingController ( SimController ): def __init__ ( self , sim_controller : SimController , log_config : LogConfig , ): super () . __init__ ( sim_controller . timing_config ) self . sim_controller = sim_controller self . log_config = log_config self . _camera_frames = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _platform_positions = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _camera_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _micro_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ): self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names = [ \"frame\" , \"cycle\" , \"phase\" , \"plt_x\" , \"plt_y\" , \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" , \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" , \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" , ], ) def on_cycle_start ( self , sim : Simulator ): self . sim_controller . on_cycle_start ( sim ) def on_camera_frame ( self , sim : Simulator ): self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) def _log_cycle ( self , sim : Simulator ): cycle_number = sim . cycle_number - 1 frame_offset = cycle_number * self . timing_config . cycle_frame_num worm_bboxes = self . sim_controller . _cycle_predict_all ( sim ) cam_bboxes = np . asanyarray ( list ( self . _camera_bboxes )) # make worm bboxes coordinate absolute worm_bboxes [:, 0 ] += cam_bboxes [:, 0 ] worm_bboxes [:, 1 ] += cam_bboxes [:, 1 ] # calc the crop dims to get the worm view from the original frame ( H , W ) = sim . experiment_config . orig_resolution crop_dims , is_crop_legal = BoxUtils . discretize ( worm_bboxes , ( H , W ), BoxFormat . XYWH ) for i , worm_bbox in enumerate ( worm_bboxes ): frame_number = frame_offset + i # if no prediction and we're saving error frames if not np . isfinite ( worm_bbox ) . all () and self . log_config . save_err_view : err_view = self . _camera_frames [ i ] path = self . log_config . err_file_path . format ( frame_number ) self . _image_saver . schedule_save ( img = err_view , img_name = path ) # save cropped worm view if crop is legal if self . log_config . save_wrm_view and is_crop_legal [ i ]: crop_dim = crop_dims [ i ] path = self . log_config . wrm_file_path . format ( frame_number ) self . _frame_saver . schedule_save ( img_index = frame_number , crop_dims = crop_dim , img_name = path ) csv_row = {} csv_row [ \"plt_x\" ], csv_row [ \"plt_y\" ] = self . _platform_positions [ i ] csv_row [ \"cam_x\" ], csv_row [ \"cam_y\" ], csv_row [ \"cam_w\" ], csv_row [ \"cam_h\" ] = self . _camera_bboxes [ i ] csv_row [ \"mic_x\" ], csv_row [ \"mic_y\" ], csv_row [ \"mic_w\" ], csv_row [ \"mic_h\" ] = self . _micro_bboxes [ i ] csv_row [ \"cycle\" ] = cycle_number csv_row [ \"frame\" ] = frame_number csv_row [ \"phase\" ] = \"imaging\" if i < self . timing_config . imaging_frame_num else \"moving\" csv_row [ \"wrm_x\" ], csv_row [ \"wrm_y\" ], csv_row [ \"wrm_w\" ], csv_row [ \"wrm_h\" ] = worm_bbox self . _bbox_logger . write ( csv_row ) self . _bbox_logger . flush () def on_cycle_end ( self , sim : Simulator ): self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () def on_sim_end ( self , sim : Simulator ): self . sim_controller . on_sim_end ( sim ) self . _image_saver . close () self . _frame_saver . close () self . _bbox_logger . close () def on_imaging_start ( self , sim : Simulator ): self . sim_controller . on_imaging_start ( sim ) def on_micro_frame ( self , sim : Simulator ): self . sim_controller . on_micro_frame ( sim ) def on_imaging_end ( self , sim : Simulator ): self . sim_controller . on_imaging_end ( sim ) def on_movement_start ( self , sim : Simulator ): self . sim_controller . on_movement_start ( sim ) def on_movement_end ( self , sim : Simulator ): self . sim_controller . on_movement_end ( sim ) def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: return self . sim_controller . provide_movement_vector ( sim ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . sim_controller . _cycle_predict_all ( sim ) Classes LogConfig class LogConfig ( root_folder : str , save_mic_view : bool = False , save_cam_view : bool = False , save_err_view : bool = True , save_wrm_view : bool = False , mic_folder_name : str = 'micro' , cam_folder_name : str = 'camera' , err_folder_name : str = 'errors' , wrm_folder_name : str = 'worms' , bbox_file_name : str = 'bboxes.csv' , mic_file_name : str = 'mic_ {:09d} .png' , cam_file_name : str = 'cam_ {:09d} .png' , wrm_file_name : str = 'wrm_ {:09d} .png' ) LogConfig(root_folder: str, save_mic_view: bool = False, save_cam_view: bool = False, save_err_view: bool = True, save_wrm_view: bool = False, mic_folder_name: str = 'micro', cam_folder_name: str = 'camera', err_folder_name: str = 'errors', wrm_folder_name: str = 'worms', bbox_file_name: str = 'bboxes.csv', mic_file_name: str = 'mic_{:09d}.png', cam_file_name: str = 'cam_{:09d}.png', wrm_file_name: str = 'wrm_{:09d}.png') View Source @dataclass class LogConfig ( ConfigBase ) : root_folder : str \"\"\"The directory where the logs will be saved into.\"\"\" save_mic_view : bool = False \"\"\"Whether to save the microscope view of each frame.\"\"\" save_cam_view : bool = False \"\"\"Whether to save the camera view of each frame.\"\"\" save_err_view : bool = True \"\"\"Whether to camera view of frames in which no prediction was made.\"\"\" save_wrm_view : bool = False \"\"\"whether to save the detected worm head of each frame.\"\"\" mic_folder_name : str = \"micro\" cam_folder_name : str = \"camera\" err_folder_name : str = \"errors\" wrm_folder_name : str = \"worms\" # TODO : WHY DO WE SAVE IN PNG FORMAT AND NOT BMP ? bbox_file_name : str = \"bboxes.csv\" mic_file_name : str = \"mic_{:09d}.png\" cam_file_name : str = \"cam_{:09d}.png\" wrm_file_name : str = \"wrm_{:09d}.png\" mic_file_path : str = field ( init = False ) cam_file_path : str = field ( init = False ) err_file_path : str = field ( init = False ) wrm_file_path : str = field ( init = False ) bbox_file_path : str = field ( init = False ) def __post_init__ ( self ) : self . mic_file_path = join_paths ( self . root_folder , self . mic_folder_name , self . mic_file_name ) self . cam_file_path = join_paths ( self . root_folder , self . cam_folder_name , self . cam_file_name ) self . err_file_path = join_paths ( self . root_folder , self . err_folder_name , self . cam_file_name ) self . wrm_file_path = join_paths ( self . root_folder , self . wrm_folder_name , self . wrm_file_name ) self . bbox_file_path = join_paths ( self . root_folder , self . bbox_file_name ) def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables bbox_file_name cam_file_name cam_folder_name err_folder_name mic_file_name mic_folder_name save_cam_view save_err_view save_mic_view save_wrm_view wrm_file_name wrm_folder_name Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods create_dirs def create_dirs ( self ) -> None View Source def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path ) save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) LoggingController class LoggingController ( sim_controller : wtracker . sim . simulator . SimController , log_config : wtracker . sim . sim_controllers . logging_controller . LogConfig ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class LoggingController ( SimController ) : def __init__ ( self , sim_controller : SimController , log_config : LogConfig , ) : super (). __init__ ( sim_controller . timing_config ) self . sim_controller = sim_controller self . log_config = log_config self . _camera_frames = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _platform_positions = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _camera_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _micro_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ) : self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names =[ \"frame\", \"cycle\", \"phase\", \"plt_x\", \"plt_y\", \"cam_x\", \"cam_y\", \"cam_w\", \"cam_h\", \"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\", \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\", ] , ) def on_cycle_start ( self , sim : Simulator ) : self . sim_controller . on_cycle_start ( sim ) def on_camera_frame ( self , sim : Simulator ) : self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) def _log_cycle ( self , sim : Simulator ) : cycle_number = sim . cycle_number - 1 frame_offset = cycle_number * self . timing_config . cycle_frame_num worm_bboxes = self . sim_controller . _cycle_predict_all ( sim ) cam_bboxes = np . asanyarray ( list ( self . _camera_bboxes )) # make worm bboxes coordinate absolute worm_bboxes [ :, 0 ] += cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] += cam_bboxes [ :, 1 ] # calc the crop dims to get the worm view from the original frame ( H , W ) = sim . experiment_config . orig_resolution crop_dims , is_crop_legal = BoxUtils . discretize ( worm_bboxes , ( H , W ), BoxFormat . XYWH ) for i , worm_bbox in enumerate ( worm_bboxes ) : frame_number = frame_offset + i # if no prediction and we ' re saving error frames if not np . isfinite ( worm_bbox ). all () and self . log_config . save_err_view : err_view = self . _camera_frames [ i ] path = self . log_config . err_file_path . format ( frame_number ) self . _image_saver . schedule_save ( img = err_view , img_name = path ) # save cropped worm view if crop is legal if self . log_config . save_wrm_view and is_crop_legal [ i ] : crop_dim = crop_dims [ i ] path = self . log_config . wrm_file_path . format ( frame_number ) self . _frame_saver . schedule_save ( img_index = frame_number , crop_dims = crop_dim , img_name = path ) csv_row = {} csv_row [ \"plt_x\" ] , csv_row [ \"plt_y\" ] = self . _platform_positions [ i ] csv_row [ \"cam_x\" ] , csv_row [ \"cam_y\" ] , csv_row [ \"cam_w\" ] , csv_row [ \"cam_h\" ] = self . _camera_bboxes [ i ] csv_row [ \"mic_x\" ] , csv_row [ \"mic_y\" ] , csv_row [ \"mic_w\" ] , csv_row [ \"mic_h\" ] = self . _micro_bboxes [ i ] csv_row [ \"cycle\" ] = cycle_number csv_row [ \"frame\" ] = frame_number csv_row [ \"phase\" ] = \"imaging\" if i < self . timing_config . imaging_frame_num else \"moving\" csv_row [ \"wrm_x\" ] , csv_row [ \"wrm_y\" ] , csv_row [ \"wrm_w\" ] , csv_row [ \"wrm_h\" ] = worm_bbox self . _bbox_logger . write ( csv_row ) self . _bbox_logger . flush () def on_cycle_end ( self , sim : Simulator ) : self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () def on_sim_end ( self , sim : Simulator ) : self . sim_controller . on_sim_end ( sim ) self . _image_saver . close () self . _frame_saver . close () self . _bbox_logger . close () def on_imaging_start ( self , sim : Simulator ) : self . sim_controller . on_imaging_start ( sim ) def on_micro_frame ( self , sim : Simulator ) : self . sim_controller . on_micro_frame ( sim ) def on_imaging_end ( self , sim : Simulator ) : self . sim_controller . on_imaging_end ( sim ) def on_movement_start ( self , sim : Simulator ) : self . sim_controller . on_movement_start ( sim ) def on_movement_end ( self , sim : Simulator ) : self . sim_controller . on_movement_end ( sim ) def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : return self . sim_controller . provide_movement_vector ( sim ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . sim_controller . _cycle_predict_all ( sim ) Ancestors (in MRO) wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) on_cycle_end def on_cycle_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when a cycle ends. View Source def on_cycle_end ( self , sim : Simulator ) : self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () on_cycle_start def on_cycle_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): self.sim_controller.on_cycle_start(sim) on_imaging_end def on_imaging_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): self.sim_controller.on_imaging_end(sim) on_imaging_start def on_imaging_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): self.sim_controller.on_imaging_start(sim) on_micro_frame def on_micro_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): self.sim_controller.on_micro_frame(sim) on_movement_end def on_movement_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): self.sim_controller.on_movement_end(sim) on_movement_start def on_movement_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): self.sim_controller.on_movement_start(sim) on_sim_end def on_sim_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): self.sim_controller.on_sim_end(sim) self._image_saver.close() self._frame_saver.close() self._bbox_logger.close() on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names = [ \"frame\" , \"cycle\" , \"phase\" , \"plt_x\" , \"plt_y\" , \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" , \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" , \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" , ], ) provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ] : return self . sim_controller . provide_movement_vector ( sim )","title":"Logging Controller"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#module-wtrackersimsim_controllerslogging_controller","text":"View Source from collections import deque import numpy as np from dataclasses import dataclass , field from copy import deepcopy from wtracker.sim.simulator import Simulator , SimController from wtracker.utils.io_utils import ImageSaver , FrameSaver from wtracker.utils.log_utils import CSVLogger from wtracker.utils.config_base import ConfigBase from wtracker.utils.path_utils import join_paths , create_parent_directory from wtracker.utils.bbox_utils import BoxUtils , BoxFormat @dataclass class LogConfig ( ConfigBase ): root_folder : str \"\"\"The directory where the logs will be saved into.\"\"\" save_mic_view : bool = False \"\"\"Whether to save the microscope view of each frame.\"\"\" save_cam_view : bool = False \"\"\"Whether to save the camera view of each frame.\"\"\" save_err_view : bool = True \"\"\"Whether to camera view of frames in which no prediction was made.\"\"\" save_wrm_view : bool = False \"\"\"whether to save the detected worm head of each frame.\"\"\" mic_folder_name : str = \"micro\" cam_folder_name : str = \"camera\" err_folder_name : str = \"errors\" wrm_folder_name : str = \"worms\" # TODO: WHY DO WE SAVE IN PNG FORMAT AND NOT BMP? bbox_file_name : str = \"bboxes.csv\" mic_file_name : str = \"mic_ {:09d} .png\" cam_file_name : str = \"cam_ {:09d} .png\" wrm_file_name : str = \"wrm_ {:09d} .png\" mic_file_path : str = field ( init = False ) cam_file_path : str = field ( init = False ) err_file_path : str = field ( init = False ) wrm_file_path : str = field ( init = False ) bbox_file_path : str = field ( init = False ) def __post_init__ ( self ): self . mic_file_path = join_paths ( self . root_folder , self . mic_folder_name , self . mic_file_name ) self . cam_file_path = join_paths ( self . root_folder , self . cam_folder_name , self . cam_file_name ) self . err_file_path = join_paths ( self . root_folder , self . err_folder_name , self . cam_file_name ) self . wrm_file_path = join_paths ( self . root_folder , self . wrm_folder_name , self . wrm_file_name ) self . bbox_file_path = join_paths ( self . root_folder , self . bbox_file_name ) def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path ) class LoggingController ( SimController ): def __init__ ( self , sim_controller : SimController , log_config : LogConfig , ): super () . __init__ ( sim_controller . timing_config ) self . sim_controller = sim_controller self . log_config = log_config self . _camera_frames = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _platform_positions = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _camera_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _micro_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ): self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names = [ \"frame\" , \"cycle\" , \"phase\" , \"plt_x\" , \"plt_y\" , \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" , \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" , \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" , ], ) def on_cycle_start ( self , sim : Simulator ): self . sim_controller . on_cycle_start ( sim ) def on_camera_frame ( self , sim : Simulator ): self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) def _log_cycle ( self , sim : Simulator ): cycle_number = sim . cycle_number - 1 frame_offset = cycle_number * self . timing_config . cycle_frame_num worm_bboxes = self . sim_controller . _cycle_predict_all ( sim ) cam_bboxes = np . asanyarray ( list ( self . _camera_bboxes )) # make worm bboxes coordinate absolute worm_bboxes [:, 0 ] += cam_bboxes [:, 0 ] worm_bboxes [:, 1 ] += cam_bboxes [:, 1 ] # calc the crop dims to get the worm view from the original frame ( H , W ) = sim . experiment_config . orig_resolution crop_dims , is_crop_legal = BoxUtils . discretize ( worm_bboxes , ( H , W ), BoxFormat . XYWH ) for i , worm_bbox in enumerate ( worm_bboxes ): frame_number = frame_offset + i # if no prediction and we're saving error frames if not np . isfinite ( worm_bbox ) . all () and self . log_config . save_err_view : err_view = self . _camera_frames [ i ] path = self . log_config . err_file_path . format ( frame_number ) self . _image_saver . schedule_save ( img = err_view , img_name = path ) # save cropped worm view if crop is legal if self . log_config . save_wrm_view and is_crop_legal [ i ]: crop_dim = crop_dims [ i ] path = self . log_config . wrm_file_path . format ( frame_number ) self . _frame_saver . schedule_save ( img_index = frame_number , crop_dims = crop_dim , img_name = path ) csv_row = {} csv_row [ \"plt_x\" ], csv_row [ \"plt_y\" ] = self . _platform_positions [ i ] csv_row [ \"cam_x\" ], csv_row [ \"cam_y\" ], csv_row [ \"cam_w\" ], csv_row [ \"cam_h\" ] = self . _camera_bboxes [ i ] csv_row [ \"mic_x\" ], csv_row [ \"mic_y\" ], csv_row [ \"mic_w\" ], csv_row [ \"mic_h\" ] = self . _micro_bboxes [ i ] csv_row [ \"cycle\" ] = cycle_number csv_row [ \"frame\" ] = frame_number csv_row [ \"phase\" ] = \"imaging\" if i < self . timing_config . imaging_frame_num else \"moving\" csv_row [ \"wrm_x\" ], csv_row [ \"wrm_y\" ], csv_row [ \"wrm_w\" ], csv_row [ \"wrm_h\" ] = worm_bbox self . _bbox_logger . write ( csv_row ) self . _bbox_logger . flush () def on_cycle_end ( self , sim : Simulator ): self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () def on_sim_end ( self , sim : Simulator ): self . sim_controller . on_sim_end ( sim ) self . _image_saver . close () self . _frame_saver . close () self . _bbox_logger . close () def on_imaging_start ( self , sim : Simulator ): self . sim_controller . on_imaging_start ( sim ) def on_micro_frame ( self , sim : Simulator ): self . sim_controller . on_micro_frame ( sim ) def on_imaging_end ( self , sim : Simulator ): self . sim_controller . on_imaging_end ( sim ) def on_movement_start ( self , sim : Simulator ): self . sim_controller . on_movement_start ( sim ) def on_movement_end ( self , sim : Simulator ): self . sim_controller . on_movement_end ( sim ) def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: return self . sim_controller . provide_movement_vector ( sim ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . sim_controller . _cycle_predict_all ( sim )","title":"Module wtracker.sim.sim_controllers.logging_controller"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#logconfig","text":"class LogConfig ( root_folder : str , save_mic_view : bool = False , save_cam_view : bool = False , save_err_view : bool = True , save_wrm_view : bool = False , mic_folder_name : str = 'micro' , cam_folder_name : str = 'camera' , err_folder_name : str = 'errors' , wrm_folder_name : str = 'worms' , bbox_file_name : str = 'bboxes.csv' , mic_file_name : str = 'mic_ {:09d} .png' , cam_file_name : str = 'cam_ {:09d} .png' , wrm_file_name : str = 'wrm_ {:09d} .png' ) LogConfig(root_folder: str, save_mic_view: bool = False, save_cam_view: bool = False, save_err_view: bool = True, save_wrm_view: bool = False, mic_folder_name: str = 'micro', cam_folder_name: str = 'camera', err_folder_name: str = 'errors', wrm_folder_name: str = 'worms', bbox_file_name: str = 'bboxes.csv', mic_file_name: str = 'mic_{:09d}.png', cam_file_name: str = 'cam_{:09d}.png', wrm_file_name: str = 'wrm_{:09d}.png') View Source @dataclass class LogConfig ( ConfigBase ) : root_folder : str \"\"\"The directory where the logs will be saved into.\"\"\" save_mic_view : bool = False \"\"\"Whether to save the microscope view of each frame.\"\"\" save_cam_view : bool = False \"\"\"Whether to save the camera view of each frame.\"\"\" save_err_view : bool = True \"\"\"Whether to camera view of frames in which no prediction was made.\"\"\" save_wrm_view : bool = False \"\"\"whether to save the detected worm head of each frame.\"\"\" mic_folder_name : str = \"micro\" cam_folder_name : str = \"camera\" err_folder_name : str = \"errors\" wrm_folder_name : str = \"worms\" # TODO : WHY DO WE SAVE IN PNG FORMAT AND NOT BMP ? bbox_file_name : str = \"bboxes.csv\" mic_file_name : str = \"mic_{:09d}.png\" cam_file_name : str = \"cam_{:09d}.png\" wrm_file_name : str = \"wrm_{:09d}.png\" mic_file_path : str = field ( init = False ) cam_file_path : str = field ( init = False ) err_file_path : str = field ( init = False ) wrm_file_path : str = field ( init = False ) bbox_file_path : str = field ( init = False ) def __post_init__ ( self ) : self . mic_file_path = join_paths ( self . root_folder , self . mic_folder_name , self . mic_file_name ) self . cam_file_path = join_paths ( self . root_folder , self . cam_folder_name , self . cam_file_name ) self . err_file_path = join_paths ( self . root_folder , self . err_folder_name , self . cam_file_name ) self . wrm_file_path = join_paths ( self . root_folder , self . wrm_folder_name , self . wrm_file_name ) self . bbox_file_path = join_paths ( self . root_folder , self . bbox_file_name ) def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path )","title":"LogConfig"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#class-variables","text":"bbox_file_name cam_file_name cam_folder_name err_folder_name mic_file_name mic_folder_name save_cam_view save_err_view save_mic_view save_wrm_view wrm_file_name wrm_folder_name","title":"Class variables"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#create_dirs","text":"def create_dirs ( self ) -> None View Source def create_dirs ( self ) -> None : create_parent_directory ( self . bbox_file_path ) create_parent_directory ( self . mic_file_path ) create_parent_directory ( self . cam_file_path ) create_parent_directory ( self . err_file_path ) create_parent_directory ( self . wrm_file_path )","title":"create_dirs"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#loggingcontroller","text":"class LoggingController ( sim_controller : wtracker . sim . simulator . SimController , log_config : wtracker . sim . sim_controllers . logging_controller . LogConfig ) Abstract base class for simulator controllers.","title":"LoggingController"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class LoggingController ( SimController ) : def __init__ ( self , sim_controller : SimController , log_config : LogConfig , ) : super (). __init__ ( sim_controller . timing_config ) self . sim_controller = sim_controller self . log_config = log_config self . _camera_frames = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _platform_positions = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _camera_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) self . _micro_bboxes = deque ( maxlen = self . timing_config . cycle_frame_num ) def on_sim_start ( self , sim : Simulator ) : self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names =[ \"frame\", \"cycle\", \"phase\", \"plt_x\", \"plt_y\", \"cam_x\", \"cam_y\", \"cam_w\", \"cam_h\", \"mic_x\", \"mic_y\", \"mic_w\", \"mic_h\", \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\", ] , ) def on_cycle_start ( self , sim : Simulator ) : self . sim_controller . on_cycle_start ( sim ) def on_camera_frame ( self , sim : Simulator ) : self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path ) def _log_cycle ( self , sim : Simulator ) : cycle_number = sim . cycle_number - 1 frame_offset = cycle_number * self . timing_config . cycle_frame_num worm_bboxes = self . sim_controller . _cycle_predict_all ( sim ) cam_bboxes = np . asanyarray ( list ( self . _camera_bboxes )) # make worm bboxes coordinate absolute worm_bboxes [ :, 0 ] += cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] += cam_bboxes [ :, 1 ] # calc the crop dims to get the worm view from the original frame ( H , W ) = sim . experiment_config . orig_resolution crop_dims , is_crop_legal = BoxUtils . discretize ( worm_bboxes , ( H , W ), BoxFormat . XYWH ) for i , worm_bbox in enumerate ( worm_bboxes ) : frame_number = frame_offset + i # if no prediction and we ' re saving error frames if not np . isfinite ( worm_bbox ). all () and self . log_config . save_err_view : err_view = self . _camera_frames [ i ] path = self . log_config . err_file_path . format ( frame_number ) self . _image_saver . schedule_save ( img = err_view , img_name = path ) # save cropped worm view if crop is legal if self . log_config . save_wrm_view and is_crop_legal [ i ] : crop_dim = crop_dims [ i ] path = self . log_config . wrm_file_path . format ( frame_number ) self . _frame_saver . schedule_save ( img_index = frame_number , crop_dims = crop_dim , img_name = path ) csv_row = {} csv_row [ \"plt_x\" ] , csv_row [ \"plt_y\" ] = self . _platform_positions [ i ] csv_row [ \"cam_x\" ] , csv_row [ \"cam_y\" ] , csv_row [ \"cam_w\" ] , csv_row [ \"cam_h\" ] = self . _camera_bboxes [ i ] csv_row [ \"mic_x\" ] , csv_row [ \"mic_y\" ] , csv_row [ \"mic_w\" ] , csv_row [ \"mic_h\" ] = self . _micro_bboxes [ i ] csv_row [ \"cycle\" ] = cycle_number csv_row [ \"frame\" ] = frame_number csv_row [ \"phase\" ] = \"imaging\" if i < self . timing_config . imaging_frame_num else \"moving\" csv_row [ \"wrm_x\" ] , csv_row [ \"wrm_y\" ] , csv_row [ \"wrm_w\" ] , csv_row [ \"wrm_h\" ] = worm_bbox self . _bbox_logger . write ( csv_row ) self . _bbox_logger . flush () def on_cycle_end ( self , sim : Simulator ) : self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () def on_sim_end ( self , sim : Simulator ) : self . sim_controller . on_sim_end ( sim ) self . _image_saver . close () self . _frame_saver . close () self . _bbox_logger . close () def on_imaging_start ( self , sim : Simulator ) : self . sim_controller . on_imaging_start ( sim ) def on_micro_frame ( self , sim : Simulator ) : self . sim_controller . on_micro_frame ( sim ) def on_imaging_end ( self , sim : Simulator ) : self . sim_controller . on_imaging_end ( sim ) def on_movement_start ( self , sim : Simulator ) : self . sim_controller . on_movement_start ( sim ) def on_movement_end ( self , sim : Simulator ) : self . sim_controller . on_movement_end ( sim ) def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : return self . sim_controller . provide_movement_vector ( sim ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . sim_controller . _cycle_predict_all ( sim )","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#ancestors-in-mro_1","text":"wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : return self . sim_controller . begin_movement_prediction ( sim )","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . sim_controller . on_camera_frame ( sim ) # log everything self . _platform_positions . append ( sim . position ) self . _camera_bboxes . append ( sim . view . camera_position ) self . _micro_bboxes . append ( sim . view . micro_position ) if self . log_config . save_err_view : cam_view = sim . camera_view () self . _camera_frames . append ( cam_view ) if self . log_config . save_cam_view : # save camera view cam_view = sim . camera_view () path = self . log_config . cam_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( cam_view , path ) if self . log_config . save_mic_view : # save micro view mic_view = sim . view . micro_view () path = self . log_config . mic_file_path . format ( sim . frame_number ) self . _image_saver . schedule_save ( mic_view , path )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when a cycle ends. View Source def on_cycle_end ( self , sim : Simulator ) : self . _log_cycle ( sim ) self . sim_controller . on_cycle_end ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear ()","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): self.sim_controller.on_cycle_start(sim)","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): self.sim_controller.on_imaging_end(sim)","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): self.sim_controller.on_imaging_start(sim)","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): self.sim_controller.on_micro_frame(sim)","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_movement_end","text":"def on_movement_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): self.sim_controller.on_movement_end(sim)","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_movement_start","text":"def on_movement_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): self.sim_controller.on_movement_start(sim)","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_sim_end","text":"def on_sim_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): self.sim_controller.on_sim_end(sim) self._image_saver.close() self._frame_saver.close() self._bbox_logger.close()","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . sim_controller . on_sim_start ( sim ) self . _camera_frames . clear () self . _platform_positions . clear () self . _camera_bboxes . clear () self . _micro_bboxes . clear () self . log_config . create_dirs () self . _image_saver = ImageSaver ( tqdm = True ) self . _image_saver . start () self . _frame_saver = FrameSaver ( deepcopy ( sim . view . _frame_reader ), tqdm = True ) self . _frame_saver . start () self . _bbox_logger = CSVLogger ( self . log_config . bbox_file_path , col_names = [ \"frame\" , \"cycle\" , \"phase\" , \"plt_x\" , \"plt_y\" , \"cam_x\" , \"cam_y\" , \"cam_w\" , \"cam_h\" , \"mic_x\" , \"mic_y\" , \"mic_w\" , \"mic_h\" , \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" , ], )","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/logging_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ] : return self . sim_controller . provide_movement_vector ( sim )","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/","text":"Module wtracker.sim.sim_controllers.mlp_controllers View Source from typing import Collection import numpy as np from collections import deque from torch import Tensor from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat from wtracker.neural.mlp import WormPredictor from wtracker.neural.config import IOConfig class MLPController ( CsvController ): \"\"\" MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Args: timing_config (TimingConfig): The timing configuration for the simulation. csv_path (str): The path to the CSV file containing the simulation data. model (WormPredictor): The WormPredictor model used for predicting worm movement. max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped. \"\"\" def __init__ ( self , timing_config : TimingConfig , csv_path : str , model : WormPredictor , max_speed : float = 0.9 ): super () . __init__ ( timing_config , csv_path ) self . model : WormPredictor = model self . io_config : IOConfig = model . io_config self . model . eval () px_per_mm = self . timing_config . px_per_mm fps = self . timing_config . frames_per_sec max_speed_px_frame = max_speed * ( px_per_mm / fps ) self . max_dist_per_pred = max_speed_px_frame * ( self . io_config . pred_frames [ 0 ]) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction (input to the model) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ) . reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ) . all (): return 0 , 0 # relative position of the worm to the camera center, we use the worm x,y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )) . flatten () . detach () . numpy () # make sure the prediction is within the limits and apply post-proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ] . item () + rel_x ) dy = round ( pred [ 1 ] . item () + rel_y ) # dx = np.clip(dx, -self.max_dist_per_pred, self.max_dist_per_pred) # dy = np.clip(dy, -self.max_dist_per_pred, self.max_dist_per_pred) return ( dx , dy ) def print_model ( self ): print ( self . model ) Classes MLPController class MLPController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str , model : wtracker . neural . mlp . WormPredictor , max_speed : float = 0.9 ) MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulation. None csv_path str The path to the CSV file containing the simulation data. None model WormPredictor The WormPredictor model used for predicting worm movement. None max_speed float max speed of the worm in mm/s, predictions above this will be clipped. None View Source class MLPController ( CsvController ): \"\"\" MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Args: timing_config (TimingConfig): The timing configuration for the simulation. csv_path (str): The path to the CSV file containing the simulation data. model (WormPredictor): The WormPredictor model used for predicting worm movement. max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped. \"\"\" def __init__ ( self , timing_config : TimingConfig , csv_path : str , model : WormPredictor , max_speed : float = 0.9 ): super (). __init__ ( timing_config , csv_path ) self . model : WormPredictor = model self . io_config : IOConfig = model . io_config self . model . eval () px_per_mm = self . timing_config . px_per_mm fps = self . timing_config . frames_per_sec max_speed_px_frame = max_speed * ( px_per_mm / fps ) self . max_dist_per_pred = max_speed_px_frame * ( self . io_config . pred_frames [ 0 ]) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction ( input to the model ) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ). reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ). all (): return 0 , 0 # relative position of the worm to the camera center , we use the worm x , y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )). flatten (). detach (). numpy () # make sure the prediction is within the limits and apply post - proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ]. item () + rel_x ) dy = round ( pred [ 1 ]. item () + rel_y ) # dx = np . clip ( dx , - self . max_dist_per_pred , self . max_dist_per_pred ) # dy = np . clip ( dy , - self . max_dist_per_pred , self . max_dist_per_pred ) return ( dx , dy ) def print_model ( self ): print ( self . model ) Ancestors (in MRO) wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () predict def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes print_model def print_model ( self ) View Source def print_model(self): print(self.model) provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction ( input to the model ) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ). reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ). all (): return 0 , 0 # relative position of the worm to the camera center , we use the worm x , y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )). flatten (). detach (). numpy () # make sure the prediction is within the limits and apply post - proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ]. item () + rel_x ) dy = round ( pred [ 1 ]. item () + rel_y ) # dx = np . clip ( dx , - self . max_dist_per_pred , self . max_dist_per_pred ) # dy = np . clip ( dy , - self . max_dist_per_pred , self . max_dist_per_pred ) return ( dx , dy )","title":"Mlp Controllers"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#module-wtrackersimsim_controllersmlp_controllers","text":"View Source from typing import Collection import numpy as np from collections import deque from torch import Tensor from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat from wtracker.neural.mlp import WormPredictor from wtracker.neural.config import IOConfig class MLPController ( CsvController ): \"\"\" MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Args: timing_config (TimingConfig): The timing configuration for the simulation. csv_path (str): The path to the CSV file containing the simulation data. model (WormPredictor): The WormPredictor model used for predicting worm movement. max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped. \"\"\" def __init__ ( self , timing_config : TimingConfig , csv_path : str , model : WormPredictor , max_speed : float = 0.9 ): super () . __init__ ( timing_config , csv_path ) self . model : WormPredictor = model self . io_config : IOConfig = model . io_config self . model . eval () px_per_mm = self . timing_config . px_per_mm fps = self . timing_config . frames_per_sec max_speed_px_frame = max_speed * ( px_per_mm / fps ) self . max_dist_per_pred = max_speed_px_frame * ( self . io_config . pred_frames [ 0 ]) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction (input to the model) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ) . reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ) . all (): return 0 , 0 # relative position of the worm to the camera center, we use the worm x,y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )) . flatten () . detach () . numpy () # make sure the prediction is within the limits and apply post-proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ] . item () + rel_x ) dy = round ( pred [ 1 ] . item () + rel_y ) # dx = np.clip(dx, -self.max_dist_per_pred, self.max_dist_per_pred) # dy = np.clip(dy, -self.max_dist_per_pred, self.max_dist_per_pred) return ( dx , dy ) def print_model ( self ): print ( self . model )","title":"Module wtracker.sim.sim_controllers.mlp_controllers"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#mlpcontroller","text":"class MLPController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str , model : wtracker . neural . mlp . WormPredictor , max_speed : float = 0.9 ) MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation.","title":"MLPController"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulation. None csv_path str The path to the CSV file containing the simulation data. None model WormPredictor The WormPredictor model used for predicting worm movement. None max_speed float max speed of the worm in mm/s, predictions above this will be clipped. None View Source class MLPController ( CsvController ): \"\"\" MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation. Args: timing_config (TimingConfig): The timing configuration for the simulation. csv_path (str): The path to the CSV file containing the simulation data. model (WormPredictor): The WormPredictor model used for predicting worm movement. max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped. \"\"\" def __init__ ( self , timing_config : TimingConfig , csv_path : str , model : WormPredictor , max_speed : float = 0.9 ): super (). __init__ ( timing_config , csv_path ) self . model : WormPredictor = model self . io_config : IOConfig = model . io_config self . model . eval () px_per_mm = self . timing_config . px_per_mm fps = self . timing_config . frames_per_sec max_speed_px_frame = max_speed * ( px_per_mm / fps ) self . max_dist_per_pred = max_speed_px_frame * ( self . io_config . pred_frames [ 0 ]) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction ( input to the model ) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ). reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ). all (): return 0 , 0 # relative position of the worm to the camera center , we use the worm x , y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )). flatten (). detach (). numpy () # make sure the prediction is within the limits and apply post - proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ]. item () + rel_x ) dy = round ( pred [ 1 ]. item () + rel_y ) # dx = np . clip ( dx , - self . max_dist_per_pred , self . max_dist_per_pred ) # dy = np . clip ( dy , - self . max_dist_per_pred , self . max_dist_per_pred ) return ( dx , dy ) def print_model ( self ): print ( self . model )","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#ancestors-in-mro","text":"wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#predict","text":"def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#print_model","text":"def print_model ( self ) View Source def print_model(self): print(self.model)","title":"print_model"},{"location":"reference/wtracker/sim/sim_controllers/mlp_controllers/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # frames for prediction ( input to the model ) frames_for_pred = np . asanyarray ( self . io_config . input_frames , dtype = int ) frames_for_pred += sim . frame_number - self . timing_config . pred_frame_num cam_center = BoxUtils . center ( np . asanyarray ( sim . view . camera_position )) worm_bboxes = self . predict ( frames_for_pred , relative = False ). reshape ( 1 , - 1 ) if not np . isfinite ( worm_bboxes ). all (): return 0 , 0 # relative position of the worm to the camera center , we use the worm x , y instead of center because of how the model and dataset are built rel_x , rel_y = worm_bboxes [ 0 , 0 ] - cam_center [ 0 ], worm_bboxes [ 0 , 1 ] - cam_center [ 1 ] # make coordinates relative to first bbox x = worm_bboxes [ 0 , 0 ] x_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 0 y = worm_bboxes [ 0 , 1 ] y_mask = np . arange ( 0 , worm_bboxes . shape [ 1 ]) % 4 == 1 worm_bboxes [:, x_mask ] -= x worm_bboxes [:, y_mask ] -= y # predict the movement of the worm via the model pred = self . model . forward ( Tensor ( worm_bboxes )). flatten (). detach (). numpy () # make sure the prediction is within the limits and apply post - proccessing steps pred = np . clip ( pred , - self . max_dist_per_pred , self . max_dist_per_pred ) dx = round ( pred [ 0 ]. item () + rel_x ) dy = round ( pred [ 1 ]. item () + rel_y ) # dx = np . clip ( dx , - self . max_dist_per_pred , self . max_dist_per_pred ) # dy = np . clip ( dy , - self . max_dist_per_pred , self . max_dist_per_pred ) return ( dx , dy )","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/","text":"Module wtracker.sim.sim_controllers.optimal_controller View Source import numpy as np from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController class OptimalController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super () . __init__ ( timing_config , csv_path ) self . _csv_centers = np . empty (( len ( self . _csv_data ), 2 ), dtype = self . _csv_data . dtype ) self . _csv_centers [:, 0 ] = self . _csv_data [:, 0 ] + self . _csv_data [:, 2 ] / 2 self . _csv_centers [:, 1 ] = self . _csv_data [:, 1 ] + self . _csv_data [:, 3 ] / 2 def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , :] next_imaging = next_imaging [ np . isfinite ( next_imaging ) . all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ]) Classes OptimalController class OptimalController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class OptimalController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super (). __init__ ( timing_config , csv_path ) self . _csv_centers = np . empty (( len ( self . _csv_data ), 2 ), dtype = self . _csv_data . dtype ) self . _csv_centers [:, 0 ] = self . _csv_data [:, 0 ] + self . _csv_data [:, 2 ] / 2 self . _csv_centers [:, 1 ] = self . _csv_data [:, 1 ] + self . _csv_data [:, 3 ] / 2 def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , :] next_imaging = next_imaging [ np . isfinite ( next_imaging ). all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ]) Ancestors (in MRO) wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () predict def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) - > tuple [ int , int ] : # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , : ] next_imaging = next_imaging [ np . isfinite ( next_imaging ) . all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ])","title":"Optimal Controller"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#module-wtrackersimsim_controllersoptimal_controller","text":"View Source import numpy as np from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController class OptimalController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super () . __init__ ( timing_config , csv_path ) self . _csv_centers = np . empty (( len ( self . _csv_data ), 2 ), dtype = self . _csv_data . dtype ) self . _csv_centers [:, 0 ] = self . _csv_data [:, 0 ] + self . _csv_data [:, 2 ] / 2 self . _csv_centers [:, 1 ] = self . _csv_data [:, 1 ] + self . _csv_data [:, 3 ] / 2 def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , :] next_imaging = next_imaging [ np . isfinite ( next_imaging ) . all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ])","title":"Module wtracker.sim.sim_controllers.optimal_controller"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#optimalcontroller","text":"class OptimalController ( timing_config : wtracker . sim . config . TimingConfig , csv_path : str ) Abstract base class for simulator controllers.","title":"OptimalController"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class OptimalController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , csv_path : str ): super (). __init__ ( timing_config , csv_path ) self . _csv_centers = np . empty (( len ( self . _csv_data ), 2 ), dtype = self . _csv_data . dtype ) self . _csv_centers [:, 0 ] = self . _csv_data [:, 0 ] + self . _csv_data [:, 2 ] / 2 self . _csv_centers [:, 1 ] = self . _csv_data [:, 1 ] + self . _csv_data [:, 3 ] / 2 def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , :] next_imaging = next_imaging [ np . isfinite ( next_imaging ). all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ])","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#ancestors-in-mro","text":"wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#predict","text":"def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/optimal_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) - > tuple [ int , int ] : # extract portion matching next imaging phase next_imaging_start = ( sim . cycle_number + 1 ) * self . timing_config . cycle_frame_num next_imaging_end = next_imaging_start + self . timing_config . imaging_frame_num next_imaging = self . _csv_centers [ next_imaging_start : next_imaging_end , : ] next_imaging = next_imaging [ np . isfinite ( next_imaging ) . all ( axis = 1 )] if len ( next_imaging ) == 0 : return 0 , 0 x_next , y_next = np . median ( next_imaging , axis = 0 ) cam_x , cam_y , cam_w , cam_h = sim . view . camera_position cam_mid = cam_x + cam_w / 2 , cam_y + cam_h / 2 return round ( x_next - cam_mid [ 0 ]), round ( y_next - cam_mid [ 1 ])","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/","text":"Module wtracker.sim.sim_controllers.polyfit_controller View Source import numpy as np import pandas as pd from dataclasses import dataclass import numpy.polynomial.polynomial as poly from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat from wtracker.utils.config_base import ConfigBase @dataclass class PolyfitConfig ( ConfigBase ): degree : int \"\"\"The degree of the polynomial, which will be fitted to the worm movement.\"\"\" sample_times : list [ int ] \"\"\"Times at which the worm position is be sampled for the polynomial fit. Time 0 denotes the beginning of the current cycle. Negative values are allowed.\"\"\" weights : list [ float ] = None \"\"\"Weights for each position sample for the polynomial fit. If None, all weights are set to 1.0. If the weights are not uniform, weighted polynomial fit is performed, where the residuals of samples with higher weights are more important for the fitting.\"\"\" def __post_init__ ( self ): self . sample_times = sorted ( self . sample_times ) if self . weights is None : self . weights = [ 1.0 for _ in self . sample_times ] assert len ( self . sample_times ) == len ( self . weights ) class PolyfitController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , polyfit_config : PolyfitConfig , csv_path : str , ) -> None : \"\"\" Args: timing_config (TimingConfig): The timing configuration of the simulation. csv_path (str): The path to the csv file with the worm data. polyfit_config (PolyfitConfig): The configuration for the polynomial fit. \"\"\" super () . __init__ ( timing_config , csv_path ) self . polyfit_config = polyfit_config self . _sample_times = np . asanyarray ( polyfit_config . sample_times , dtype = int ) self . _weights = np . asanyarray ( polyfit_config . weights , dtype = float ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [:, 0 ] -= camera_bbox [ 0 ] bboxes [:, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ) . all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ]) dy = round ( y_pred - camera_mid [ 1 ]) return dx , dy class WeightEvaluator : \"\"\" Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Args: csv_paths (list[str]): The paths to the csv files with the worm data. timing_config (TimingConfig): The timing configuration of the simulation. input_time_offsets (np.ndarray): The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. pred_time_offset (int): The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. min_speed (float, optional): The minimum speed of the worm for a cycle to be considered. max_speed (float, optional): The maximum speed of the worm for a cycle to be considered. \"\"\" def __init__ ( self , csv_paths : list [ str ], timing_config : TimingConfig , input_time_offsets : np . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = np . inf , ): self . csv_paths = csv_paths self . timing_config = timing_config self . pred_time_offset = pred_time_offset self . min_speed = min_speed self . max_speed = max_speed self . input_time_offsets = np . sort ( input_time_offsets ) self . _construct_dataset () def _construct_dataset ( self ) -> None : input_positions = [] target_positions = [] for i , path in enumerate ( self . csv_paths ): bboxes = pd . read_csv ( path , usecols = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]) . to_numpy ( dtype = float ) input_pos , target_pos = self . _extract_positions ( bboxes , self . timing_config . cycle_frame_num ) input_positions . append ( input_pos ) target_positions . append ( target_pos ) # print stats init_num_cycles = len ( bboxes ) // self . timing_config . cycle_frame_num final_num_cycles = len ( target_pos ) // 2 removed_percent = round (( init_num_cycles - final_num_cycles ) / init_num_cycles * 100 , 1 ) print ( f \"Log { i } :: Number of evaluation cycles: { final_num_cycles } \" ) print ( f \"Log { i } :: Number of cycles removed: { init_num_cycles - final_num_cycles } ( { removed_percent } %)\" ) self . y_input = np . concatenate ( input_positions , axis = 1 ) self . x_input = self . input_time_offsets . reshape ( - 1 ) self . y_target = np . concatenate ( target_positions , axis = 0 ) self . x_target = np . full_like ( self . y_target , self . pred_time_offset ) def _extract_positions ( self , raw_bboxes : pd . DataFrame , cycle_length : int ) -> tuple [ np . ndarray , np . ndarray ]: N = self . input_time_offsets . shape [ 0 ] cycle_starts = np . arange ( 0 , raw_bboxes . shape [ 0 ], cycle_length , dtype = int ) centers = BoxUtils . center ( raw_bboxes ) # x are times, y are positions # create input and target arrays for the times x_input = np . repeat ( cycle_starts , repeats = N ) + np . tile ( self . input_time_offsets , reps = cycle_starts . shape [ 0 ]) x_input = x_input . reshape ( - 1 , N ) x_target = cycle_starts + self . pred_time_offset # remove input and target cycles with invalid time # i.e. when input time is negative or target time is out of bounds mask = ( x_input >= 0 ) . all ( axis = 1 ) & ( x_target < len ( centers )) x_input = x_input [ mask , :] x_target = x_target [ mask ] # get input and target positions for each cycle y_input = centers [ x_input . flatten (), :] y_input = y_input . reshape ( - 1 , N , 2 ) y_target = centers [ x_target . flatten (), :] y_target = y_target . reshape ( - 1 , 2 ) # remove all cycles with invalid positions input_mask = np . isfinite ( y_input ) . all ( axis = ( 1 , 2 )) target_mask = np . isfinite ( y_target ) . all ( axis = 1 ) mask = input_mask & target_mask y_input = y_input [ mask , :, :] y_target = y_target [ mask , :] # remove cycles with average speed below threshold # dist = np.sqrt((y_target[:, 1] - y_input[:, 0, 1]) ** 2 + (y_target[:, 0] - y_input[:, 0, 0]) ** 2) dist = np . linalg . norm ( y_target - y_input [:, 0 , :], axis = 1 ) time = self . pred_time_offset - self . input_time_offsets [ 0 ] speed = dist / time speed_mask = ( speed >= self . min_speed ) & ( speed <= self . max_speed ) y_input = y_input [ speed_mask , :, :] y_target = y_target [ speed_mask , :] # reshape target arrays y_input = y_input . swapaxes ( 0 , 1 ) . reshape ( N , - 1 ) y_target = y_target . reshape ( - 1 ) return y_input , y_target def _polyval ( self , coeffs : np . ndarray , x : np . ndarray ) -> np . ndarray : \"\"\" Evaluate a polynomial at given values. This implementation is way faster than np.polyval for multiple polynomials. Args: coeffs (np.ndarray): Coefficients of the polynomial. Coefficients at increasing order. Should have shape [deg+1, N]. x (np.ndarray): Values at which to evaluate the polynomial. Should have shape [N]. Returns: np.ndarray: The result of evaluating the polynomial at the given values. Shape is [N]. \"\"\" coeffs = coeffs . swapaxes ( 0 , 1 ) van = np . vander ( x , N = coeffs . shape [ 1 ], increasing = True ) return np . sum ( van * coeffs , axis =- 1 ) def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae Classes PolyfitConfig class PolyfitConfig ( degree : int , sample_times : list [ int ], weights : list [ float ] = None ) PolyfitConfig(degree: int, sample_times: list[int], weights: list[float] = None) View Source @dataclass class PolyfitConfig ( ConfigBase ) : degree : int \"\"\"The degree of the polynomial, which will be fitted to the worm movement.\"\"\" sample_times : list [ int ] \"\"\"Times at which the worm position is be sampled for the polynomial fit. Time 0 denotes the beginning of the current cycle. Negative values are allowed.\"\"\" weights : list [ float ] = None \"\"\"Weights for each position sample for the polynomial fit. If None, all weights are set to 1.0. If the weights are not uniform, weighted polynomial fit is performed, where the residuals of samples with higher weights are more important for the fitting.\"\"\" def __post_init__ ( self ) : self . sample_times = sorted ( self . sample_times ) if self . weights is None : self . weights = [ 1.0 for _ in self.sample_times ] assert len ( self . sample_times ) == len ( self . weights ) Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables weights Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) PolyfitController class PolyfitController ( timing_config : wtracker . sim . config . TimingConfig , polyfit_config : wtracker . sim . sim_controllers . polyfit_controller . PolyfitConfig , csv_path : str ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class PolyfitController ( CsvController ) : def __init__ ( self , timing_config : TimingConfig , polyfit_config : PolyfitConfig , csv_path : str , ) -> None : \"\"\" Args: timing_config (TimingConfig): The timing configuration of the simulation. csv_path (str): The path to the csv file with the worm data. polyfit_config (PolyfitConfig): The configuration for the polynomial fit. \"\"\" super (). __init__ ( timing_config , csv_path ) self . polyfit_config = polyfit_config self . _sample_times = np . asanyarray ( polyfit_config . sample_times , dtype = int ) self . _weights = np . asanyarray ( polyfit_config . weights , dtype = float ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [ :, 0 ] -= camera_bbox [ 0 ] bboxes [ :, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ). all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ] ) dy = round ( y_pred - camera_mid [ 1 ] ) return dx , dy Ancestors (in MRO) wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position ) on_cycle_end def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear () predict def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [ :, 0 ] -= camera_bbox [ 0 ] bboxes [ :, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ). all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ] ) dy = round ( y_pred - camera_mid [ 1 ] ) return dx , dy WeightEvaluator class WeightEvaluator ( csv_paths : list [ str ], timing_config : wtracker . sim . config . TimingConfig , input_time_offsets : numpy . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = inf ) Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Attributes Name Type Description Default csv_paths list[str] The paths to the csv files with the worm data. None timing_config TimingConfig The timing configuration of the simulation. None input_time_offsets np.ndarray The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. None pred_time_offset int The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. None min_speed float The minimum speed of the worm for a cycle to be considered. None max_speed float The maximum speed of the worm for a cycle to be considered. None View Source class WeightEvaluator : \"\"\" Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Args: csv_paths (list[str]): The paths to the csv files with the worm data. timing_config (TimingConfig): The timing configuration of the simulation. input_time_offsets (np.ndarray): The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. pred_time_offset (int): The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. min_speed (float, optional): The minimum speed of the worm for a cycle to be considered. max_speed (float, optional): The maximum speed of the worm for a cycle to be considered. \"\"\" def __init__ ( self , csv_paths : list [ str ] , timing_config : TimingConfig , input_time_offsets : np . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = np . inf , ) : self . csv_paths = csv_paths self . timing_config = timing_config self . pred_time_offset = pred_time_offset self . min_speed = min_speed self . max_speed = max_speed self . input_time_offsets = np . sort ( input_time_offsets ) self . _construct_dataset () def _construct_dataset ( self ) -> None : input_positions = [] target_positions = [] for i , path in enumerate ( self . csv_paths ) : bboxes = pd . read_csv ( path , usecols =[ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ). to_numpy ( dtype = float ) input_pos , target_pos = self . _extract_positions ( bboxes , self . timing_config . cycle_frame_num ) input_positions . append ( input_pos ) target_positions . append ( target_pos ) # print stats init_num_cycles = len ( bboxes ) // self . timing_config . cycle_frame_num final_num_cycles = len ( target_pos ) // 2 removed_percent = round (( init_num_cycles - final_num_cycles ) / init_num_cycles * 100 , 1 ) print ( f \"Log {i} :: Number of evaluation cycles: {final_num_cycles}\" ) print ( f \"Log {i} :: Number of cycles removed: {init_num_cycles - final_num_cycles} ({removed_percent} %)\" ) self . y_input = np . concatenate ( input_positions , axis = 1 ) self . x_input = self . input_time_offsets . reshape ( - 1 ) self . y_target = np . concatenate ( target_positions , axis = 0 ) self . x_target = np . full_like ( self . y_target , self . pred_time_offset ) def _extract_positions ( self , raw_bboxes : pd . DataFrame , cycle_length : int ) -> tuple [ np.ndarray, np.ndarray ] : N = self . input_time_offsets . shape [ 0 ] cycle_starts = np . arange ( 0 , raw_bboxes . shape [ 0 ] , cycle_length , dtype = int ) centers = BoxUtils . center ( raw_bboxes ) # x are times , y are positions # create input and target arrays for the times x_input = np . repeat ( cycle_starts , repeats = N ) + np . tile ( self . input_time_offsets , reps = cycle_starts . shape [ 0 ] ) x_input = x_input . reshape ( - 1 , N ) x_target = cycle_starts + self . pred_time_offset # remove input and target cycles with invalid time # i . e . when input time is negative or target time is out of bounds mask = ( x_input >= 0 ). all ( axis = 1 ) & ( x_target < len ( centers )) x_input = x_input [ mask, : ] x_target = x_target [ mask ] # get input and target positions for each cycle y_input = centers [ x_input.flatten(), : ] y_input = y_input . reshape ( - 1 , N , 2 ) y_target = centers [ x_target.flatten(), : ] y_target = y_target . reshape ( - 1 , 2 ) # remove all cycles with invalid positions input_mask = np . isfinite ( y_input ). all ( axis = ( 1 , 2 )) target_mask = np . isfinite ( y_target ). all ( axis = 1 ) mask = input_mask & target_mask y_input = y_input [ mask, :, : ] y_target = y_target [ mask, : ] # remove cycles with average speed below threshold # dist = np . sqrt (( y_target [ :, 1 ] - y_input [ :, 0, 1 ] ) ** 2 + ( y_target [ :, 0 ] - y_input [ :, 0, 0 ] ) ** 2 ) dist = np . linalg . norm ( y_target - y_input [ :, 0, : ] , axis = 1 ) time = self . pred_time_offset - self . input_time_offsets [ 0 ] speed = dist / time speed_mask = ( speed >= self . min_speed ) & ( speed <= self . max_speed ) y_input = y_input [ speed_mask, :, : ] y_target = y_target [ speed_mask, : ] # reshape target arrays y_input = y_input . swapaxes ( 0 , 1 ). reshape ( N , - 1 ) y_target = y_target . reshape ( - 1 ) return y_input , y_target def _polyval ( self , coeffs : np . ndarray , x : np . ndarray ) -> np . ndarray : \"\"\" Evaluate a polynomial at given values. This implementation is way faster than np.polyval for multiple polynomials. Args: coeffs (np.ndarray): Coefficients of the polynomial. Coefficients at increasing order. Should have shape [deg+1, N]. x (np.ndarray): Values at which to evaluate the polynomial. Should have shape [N]. Returns: np.ndarray: The result of evaluating the polynomial at the given values. Shape is [N]. \"\"\" coeffs = coeffs . swapaxes ( 0 , 1 ) van = np . vander ( x , N = coeffs . shape [ 1 ] , increasing = True ) return np . sum ( van * coeffs , axis =- 1 ) def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae Methods eval def eval ( self , weights : numpy . ndarray , deg : int = 2 ) -> float Evaluate the mean absolute error (MAE) of the polynomial fit. Parameters: Name Type Description Default weights np.ndarray The weights used for the polynomial fit. Should have shape [N]. None deg int The degree of the polynomial fit. None Returns: Type Description float The mean absolute error (MAE) of the polynomial fit. View Source def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae","title":"Polyfit Controller"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#module-wtrackersimsim_controllerspolyfit_controller","text":"View Source import numpy as np import pandas as pd from dataclasses import dataclass import numpy.polynomial.polynomial as poly from wtracker.sim.config import TimingConfig from wtracker.sim.simulator import Simulator from wtracker.sim.sim_controllers.csv_controller import CsvController from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat from wtracker.utils.config_base import ConfigBase @dataclass class PolyfitConfig ( ConfigBase ): degree : int \"\"\"The degree of the polynomial, which will be fitted to the worm movement.\"\"\" sample_times : list [ int ] \"\"\"Times at which the worm position is be sampled for the polynomial fit. Time 0 denotes the beginning of the current cycle. Negative values are allowed.\"\"\" weights : list [ float ] = None \"\"\"Weights for each position sample for the polynomial fit. If None, all weights are set to 1.0. If the weights are not uniform, weighted polynomial fit is performed, where the residuals of samples with higher weights are more important for the fitting.\"\"\" def __post_init__ ( self ): self . sample_times = sorted ( self . sample_times ) if self . weights is None : self . weights = [ 1.0 for _ in self . sample_times ] assert len ( self . sample_times ) == len ( self . weights ) class PolyfitController ( CsvController ): def __init__ ( self , timing_config : TimingConfig , polyfit_config : PolyfitConfig , csv_path : str , ) -> None : \"\"\" Args: timing_config (TimingConfig): The timing configuration of the simulation. csv_path (str): The path to the csv file with the worm data. polyfit_config (PolyfitConfig): The configuration for the polynomial fit. \"\"\" super () . __init__ ( timing_config , csv_path ) self . polyfit_config = polyfit_config self . _sample_times = np . asanyarray ( polyfit_config . sample_times , dtype = int ) self . _weights = np . asanyarray ( polyfit_config . weights , dtype = float ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [:, 0 ] -= camera_bbox [ 0 ] bboxes [:, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ) . all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ]) dy = round ( y_pred - camera_mid [ 1 ]) return dx , dy class WeightEvaluator : \"\"\" Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Args: csv_paths (list[str]): The paths to the csv files with the worm data. timing_config (TimingConfig): The timing configuration of the simulation. input_time_offsets (np.ndarray): The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. pred_time_offset (int): The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. min_speed (float, optional): The minimum speed of the worm for a cycle to be considered. max_speed (float, optional): The maximum speed of the worm for a cycle to be considered. \"\"\" def __init__ ( self , csv_paths : list [ str ], timing_config : TimingConfig , input_time_offsets : np . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = np . inf , ): self . csv_paths = csv_paths self . timing_config = timing_config self . pred_time_offset = pred_time_offset self . min_speed = min_speed self . max_speed = max_speed self . input_time_offsets = np . sort ( input_time_offsets ) self . _construct_dataset () def _construct_dataset ( self ) -> None : input_positions = [] target_positions = [] for i , path in enumerate ( self . csv_paths ): bboxes = pd . read_csv ( path , usecols = [ \"wrm_x\" , \"wrm_y\" , \"wrm_w\" , \"wrm_h\" ]) . to_numpy ( dtype = float ) input_pos , target_pos = self . _extract_positions ( bboxes , self . timing_config . cycle_frame_num ) input_positions . append ( input_pos ) target_positions . append ( target_pos ) # print stats init_num_cycles = len ( bboxes ) // self . timing_config . cycle_frame_num final_num_cycles = len ( target_pos ) // 2 removed_percent = round (( init_num_cycles - final_num_cycles ) / init_num_cycles * 100 , 1 ) print ( f \"Log { i } :: Number of evaluation cycles: { final_num_cycles } \" ) print ( f \"Log { i } :: Number of cycles removed: { init_num_cycles - final_num_cycles } ( { removed_percent } %)\" ) self . y_input = np . concatenate ( input_positions , axis = 1 ) self . x_input = self . input_time_offsets . reshape ( - 1 ) self . y_target = np . concatenate ( target_positions , axis = 0 ) self . x_target = np . full_like ( self . y_target , self . pred_time_offset ) def _extract_positions ( self , raw_bboxes : pd . DataFrame , cycle_length : int ) -> tuple [ np . ndarray , np . ndarray ]: N = self . input_time_offsets . shape [ 0 ] cycle_starts = np . arange ( 0 , raw_bboxes . shape [ 0 ], cycle_length , dtype = int ) centers = BoxUtils . center ( raw_bboxes ) # x are times, y are positions # create input and target arrays for the times x_input = np . repeat ( cycle_starts , repeats = N ) + np . tile ( self . input_time_offsets , reps = cycle_starts . shape [ 0 ]) x_input = x_input . reshape ( - 1 , N ) x_target = cycle_starts + self . pred_time_offset # remove input and target cycles with invalid time # i.e. when input time is negative or target time is out of bounds mask = ( x_input >= 0 ) . all ( axis = 1 ) & ( x_target < len ( centers )) x_input = x_input [ mask , :] x_target = x_target [ mask ] # get input and target positions for each cycle y_input = centers [ x_input . flatten (), :] y_input = y_input . reshape ( - 1 , N , 2 ) y_target = centers [ x_target . flatten (), :] y_target = y_target . reshape ( - 1 , 2 ) # remove all cycles with invalid positions input_mask = np . isfinite ( y_input ) . all ( axis = ( 1 , 2 )) target_mask = np . isfinite ( y_target ) . all ( axis = 1 ) mask = input_mask & target_mask y_input = y_input [ mask , :, :] y_target = y_target [ mask , :] # remove cycles with average speed below threshold # dist = np.sqrt((y_target[:, 1] - y_input[:, 0, 1]) ** 2 + (y_target[:, 0] - y_input[:, 0, 0]) ** 2) dist = np . linalg . norm ( y_target - y_input [:, 0 , :], axis = 1 ) time = self . pred_time_offset - self . input_time_offsets [ 0 ] speed = dist / time speed_mask = ( speed >= self . min_speed ) & ( speed <= self . max_speed ) y_input = y_input [ speed_mask , :, :] y_target = y_target [ speed_mask , :] # reshape target arrays y_input = y_input . swapaxes ( 0 , 1 ) . reshape ( N , - 1 ) y_target = y_target . reshape ( - 1 ) return y_input , y_target def _polyval ( self , coeffs : np . ndarray , x : np . ndarray ) -> np . ndarray : \"\"\" Evaluate a polynomial at given values. This implementation is way faster than np.polyval for multiple polynomials. Args: coeffs (np.ndarray): Coefficients of the polynomial. Coefficients at increasing order. Should have shape [deg+1, N]. x (np.ndarray): Values at which to evaluate the polynomial. Should have shape [N]. Returns: np.ndarray: The result of evaluating the polynomial at the given values. Shape is [N]. \"\"\" coeffs = coeffs . swapaxes ( 0 , 1 ) van = np . vander ( x , N = coeffs . shape [ 1 ], increasing = True ) return np . sum ( van * coeffs , axis =- 1 ) def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae","title":"Module wtracker.sim.sim_controllers.polyfit_controller"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#polyfitconfig","text":"class PolyfitConfig ( degree : int , sample_times : list [ int ], weights : list [ float ] = None ) PolyfitConfig(degree: int, sample_times: list[int], weights: list[float] = None) View Source @dataclass class PolyfitConfig ( ConfigBase ) : degree : int \"\"\"The degree of the polynomial, which will be fitted to the worm movement.\"\"\" sample_times : list [ int ] \"\"\"Times at which the worm position is be sampled for the polynomial fit. Time 0 denotes the beginning of the current cycle. Negative values are allowed.\"\"\" weights : list [ float ] = None \"\"\"Weights for each position sample for the polynomial fit. If None, all weights are set to 1.0. If the weights are not uniform, weighted polynomial fit is performed, where the residuals of samples with higher weights are more important for the fitting.\"\"\" def __post_init__ ( self ) : self . sample_times = sorted ( self . sample_times ) if self . weights is None : self . weights = [ 1.0 for _ in self.sample_times ] assert len ( self . sample_times ) == len ( self . weights )","title":"PolyfitConfig"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#class-variables","text":"weights","title":"Class variables"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#polyfitcontroller","text":"class PolyfitController ( timing_config : wtracker . sim . config . TimingConfig , polyfit_config : wtracker . sim . sim_controllers . polyfit_controller . PolyfitConfig , csv_path : str ) Abstract base class for simulator controllers.","title":"PolyfitController"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class PolyfitController ( CsvController ) : def __init__ ( self , timing_config : TimingConfig , polyfit_config : PolyfitConfig , csv_path : str , ) -> None : \"\"\" Args: timing_config (TimingConfig): The timing configuration of the simulation. csv_path (str): The path to the csv file with the worm data. polyfit_config (PolyfitConfig): The configuration for the polynomial fit. \"\"\" super (). __init__ ( timing_config , csv_path ) self . polyfit_config = polyfit_config self . _sample_times = np . asanyarray ( polyfit_config . sample_times , dtype = int ) self . _weights = np . asanyarray ( polyfit_config . weights , dtype = float ) def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [ :, 0 ] -= camera_bbox [ 0 ] bboxes [ :, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ). all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ] ) dy = round ( y_pred - camera_mid [ 1 ] ) return dx , dy","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#ancestors-in-mro_1","text":"wtracker.sim.sim_controllers.csv_controller.CsvController wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_bboxes . append ( sim . view . camera_position )","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : 'Simulator' ) Called when a cycle ends. View Source def on_cycle_end(self, sim: Simulator): \"\"\" Called when a cycle ends. \"\"\" pass","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_bboxes . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#predict","text":"def predict ( self , frame_nums : Collection [ int ], relative : bool = True ) -> numpy . ndarray View Source def predict ( self , frame_nums : Collection [ int ] , relative : bool = True ) -> np . ndarray : assert len ( frame_nums ) > 0 frame_nums = np . asanyarray ( frame_nums , dtype = int ) valid_mask = ( frame_nums >= 0 ) & ( frame_nums < self . _csv_data . shape [ 0 ] ) worm_bboxes = np . full (( frame_nums . shape [ 0 ] , 4 ), np . nan ) worm_bboxes [ valid_mask ] = self . _csv_data [ frame_nums[valid_mask ] , :] if not relative : return worm_bboxes # TODO : if relative == True then it works only if frame number if within the last cycle . # maybe fix that . cam_bboxes = [ self._camera_bboxes[n % self.timing_config.cycle_frame_num ] for n in frame_nums ] cam_bboxes = np . asanyarray ( cam_bboxes , dtype = float ) # make bbox relative to camera view worm_bboxes [ :, 0 ] -= cam_bboxes [ :, 0 ] worm_bboxes [ :, 1 ] -= cam_bboxes [ :, 1 ] return worm_bboxes","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : timing = self . timing_config config = self . polyfit_config bboxes = self . predict ( sim . cycle_number * timing . cycle_frame_num + self . _sample_times , relative = False ) # make all bboxes relative to current camera view camera_bbox = sim . view . camera_position bboxes [ :, 0 ] -= camera_bbox [ 0 ] bboxes [ :, 1 ] -= camera_bbox [ 1 ] positions = BoxUtils . center ( bboxes ) mask = np . isfinite ( positions ). all ( axis = 1 ) time = self . _sample_times [ mask ] positions = positions [ mask ] weights = self . _weights [ mask ] if len ( time ) == 0 : return 0 , 0 # predict future x and future y based on the fitted polynomial coeffs = poly . polyfit ( time , positions , deg = config . degree , w = weights ) x_pred , y_pred = poly . polyval ( timing . cycle_frame_num + timing . imaging_frame_num // 2 , coeffs ) # calculate camera correction based on the speed of the worm and current worm position camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 dx = round ( x_pred - camera_mid [ 0 ] ) dy = round ( y_pred - camera_mid [ 1 ] ) return dx , dy","title":"provide_movement_vector"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#weightevaluator","text":"class WeightEvaluator ( csv_paths : list [ str ], timing_config : wtracker . sim . config . TimingConfig , input_time_offsets : numpy . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = inf ) Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights.","title":"WeightEvaluator"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#attributes_1","text":"Name Type Description Default csv_paths list[str] The paths to the csv files with the worm data. None timing_config TimingConfig The timing configuration of the simulation. None input_time_offsets np.ndarray The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. None pred_time_offset int The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. None min_speed float The minimum speed of the worm for a cycle to be considered. None max_speed float The maximum speed of the worm for a cycle to be considered. None View Source class WeightEvaluator : \"\"\" Class for evaluating the mean absolute error (MAE) of a polynomial fit with given weights. Args: csv_paths (list[str]): The paths to the csv files with the worm data. timing_config (TimingConfig): The timing configuration of the simulation. input_time_offsets (np.ndarray): The time offsets for the input positions. These offsets are calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. pred_time_offset (int): The time offset for the target position from the beginning of the current cycle. This time offset is calculated from the beginning of the current cycle. The begging of the current cycle is considered as time 0. min_speed (float, optional): The minimum speed of the worm for a cycle to be considered. max_speed (float, optional): The maximum speed of the worm for a cycle to be considered. \"\"\" def __init__ ( self , csv_paths : list [ str ] , timing_config : TimingConfig , input_time_offsets : np . ndarray , pred_time_offset : int , min_speed : float = 0 , max_speed : float = np . inf , ) : self . csv_paths = csv_paths self . timing_config = timing_config self . pred_time_offset = pred_time_offset self . min_speed = min_speed self . max_speed = max_speed self . input_time_offsets = np . sort ( input_time_offsets ) self . _construct_dataset () def _construct_dataset ( self ) -> None : input_positions = [] target_positions = [] for i , path in enumerate ( self . csv_paths ) : bboxes = pd . read_csv ( path , usecols =[ \"wrm_x\", \"wrm_y\", \"wrm_w\", \"wrm_h\" ] ). to_numpy ( dtype = float ) input_pos , target_pos = self . _extract_positions ( bboxes , self . timing_config . cycle_frame_num ) input_positions . append ( input_pos ) target_positions . append ( target_pos ) # print stats init_num_cycles = len ( bboxes ) // self . timing_config . cycle_frame_num final_num_cycles = len ( target_pos ) // 2 removed_percent = round (( init_num_cycles - final_num_cycles ) / init_num_cycles * 100 , 1 ) print ( f \"Log {i} :: Number of evaluation cycles: {final_num_cycles}\" ) print ( f \"Log {i} :: Number of cycles removed: {init_num_cycles - final_num_cycles} ({removed_percent} %)\" ) self . y_input = np . concatenate ( input_positions , axis = 1 ) self . x_input = self . input_time_offsets . reshape ( - 1 ) self . y_target = np . concatenate ( target_positions , axis = 0 ) self . x_target = np . full_like ( self . y_target , self . pred_time_offset ) def _extract_positions ( self , raw_bboxes : pd . DataFrame , cycle_length : int ) -> tuple [ np.ndarray, np.ndarray ] : N = self . input_time_offsets . shape [ 0 ] cycle_starts = np . arange ( 0 , raw_bboxes . shape [ 0 ] , cycle_length , dtype = int ) centers = BoxUtils . center ( raw_bboxes ) # x are times , y are positions # create input and target arrays for the times x_input = np . repeat ( cycle_starts , repeats = N ) + np . tile ( self . input_time_offsets , reps = cycle_starts . shape [ 0 ] ) x_input = x_input . reshape ( - 1 , N ) x_target = cycle_starts + self . pred_time_offset # remove input and target cycles with invalid time # i . e . when input time is negative or target time is out of bounds mask = ( x_input >= 0 ). all ( axis = 1 ) & ( x_target < len ( centers )) x_input = x_input [ mask, : ] x_target = x_target [ mask ] # get input and target positions for each cycle y_input = centers [ x_input.flatten(), : ] y_input = y_input . reshape ( - 1 , N , 2 ) y_target = centers [ x_target.flatten(), : ] y_target = y_target . reshape ( - 1 , 2 ) # remove all cycles with invalid positions input_mask = np . isfinite ( y_input ). all ( axis = ( 1 , 2 )) target_mask = np . isfinite ( y_target ). all ( axis = 1 ) mask = input_mask & target_mask y_input = y_input [ mask, :, : ] y_target = y_target [ mask, : ] # remove cycles with average speed below threshold # dist = np . sqrt (( y_target [ :, 1 ] - y_input [ :, 0, 1 ] ) ** 2 + ( y_target [ :, 0 ] - y_input [ :, 0, 0 ] ) ** 2 ) dist = np . linalg . norm ( y_target - y_input [ :, 0, : ] , axis = 1 ) time = self . pred_time_offset - self . input_time_offsets [ 0 ] speed = dist / time speed_mask = ( speed >= self . min_speed ) & ( speed <= self . max_speed ) y_input = y_input [ speed_mask, :, : ] y_target = y_target [ speed_mask, : ] # reshape target arrays y_input = y_input . swapaxes ( 0 , 1 ). reshape ( N , - 1 ) y_target = y_target . reshape ( - 1 ) return y_input , y_target def _polyval ( self , coeffs : np . ndarray , x : np . ndarray ) -> np . ndarray : \"\"\" Evaluate a polynomial at given values. This implementation is way faster than np.polyval for multiple polynomials. Args: coeffs (np.ndarray): Coefficients of the polynomial. Coefficients at increasing order. Should have shape [deg+1, N]. x (np.ndarray): Values at which to evaluate the polynomial. Should have shape [N]. Returns: np.ndarray: The result of evaluating the polynomial at the given values. Shape is [N]. \"\"\" coeffs = coeffs . swapaxes ( 0 , 1 ) van = np . vander ( x , N = coeffs . shape [ 1 ] , increasing = True ) return np . sum ( van * coeffs , axis =- 1 ) def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/polyfit_controller/#eval","text":"def eval ( self , weights : numpy . ndarray , deg : int = 2 ) -> float Evaluate the mean absolute error (MAE) of the polynomial fit. Parameters: Name Type Description Default weights np.ndarray The weights used for the polynomial fit. Should have shape [N]. None deg int The degree of the polynomial fit. None Returns: Type Description float The mean absolute error (MAE) of the polynomial fit. View Source def eval ( self , weights : np . ndarray , deg : int = 2 ) -> float : \"\"\" Evaluate the mean absolute error (MAE) of the polynomial fit. Args: weights (np.ndarray): The weights used for the polynomial fit. Should have shape [N]. deg (int, optional): The degree of the polynomial fit. Returns: float: The mean absolute error (MAE) of the polynomial fit. \"\"\" coeffs = poly . polyfit ( self . x_input , self . y_input , deg = deg , w = weights ) y_pred = self . _polyval ( coeffs , self . x_target ) mae = np . mean ( np . abs ( self . y_target - y_pred )) return mae","title":"eval"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/","text":"Module wtracker.sim.sim_controllers.yolo_controller View Source from typing import Collection , Any from dataclasses import dataclass , field import numpy as np import cv2 as cv from collections import deque from ultralytics import YOLO from wtracker.sim.simulator import Simulator , SimController from wtracker.sim.config import TimingConfig from wtracker.utils.config_base import ConfigBase from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat @dataclass class YoloConfig ( ConfigBase ): model_path : str \"\"\"The path to the pretrained YOLO weights file.\"\"\" device : str = \"cpu\" \"\"\"Inference device for YOLO. Can be either 'cpu' or 'cuda'.\"\"\" verbose : bool = False \"\"\"Whether to print verbose output during YOLO inference.\"\"\" pred_kwargs : dict = field ( default_factory = lambda : { \"imgsz\" : 384 , \"conf\" : 0.1 , } ) \"\"\"Additional keyword arguments for the YOLO prediction method.\"\"\" model : YOLO = field ( default = None , init = False , repr = False ) \"\"\"The YOLO model object.\"\"\" def __getstate__ ( self ) -> dict [ str , Any ]: state = self . __dict__ . copy () del state [ \"model\" ] # we dont want to serialize the model return state def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model class YoloController ( SimController ): def __init__ ( self , timing_config : TimingConfig , yolo_config : YoloConfig ): super () . __init__ ( timing_config ) self . yolo_config = yolo_config self . _camera_frames = deque ( maxlen = timing_config . cycle_frame_num ) self . _model = yolo_config . load_model () def on_sim_start ( self , sim : Simulator ): self . _camera_frames . clear () def on_camera_frame ( self , sim : Simulator ): self . _camera_frames . append ( sim . camera_view ()) def on_cycle_end ( self , sim : Simulator ): self . _camera_frames . clear () def predict ( self , frames : Collection [ np . ndarray ]) -> np . ndarray : assert len ( frames ) > 0 # convert grayscale images to BGR because YOLO expects 3-channel images if frames [ 0 ] . ndim == 2 : frames = [ cv . cvtColor ( frame , cv . COLOR_GRAY2BGR ) for frame in frames ] # predict bounding boxes and format results results = self . _model . predict ( source = frames , device = self . yolo_config . device , max_det = 1 , verbose = self . yolo_config . verbose , ** self . yolo_config . pred_kwargs , ) results = [ res . numpy () for res in results ] bboxes = [] for res in results : if len ( res . boxes . xyxy ) == 0 : bboxes . append ( np . full ([ 4 ], np . nan )) else : bbox = BoxConverter . to_xywh ( res . boxes . xyxy [ 0 ], BoxFormat . XYXY ) bboxes . append ( bbox ) return np . stack ( bboxes , axis = 0 ) def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: frame = self . _camera_frames [ - self . timing_config . pred_frame_num ] bbox = self . predict ([ frame ]) bbox = bbox [ 0 ] if not np . isfinite ( bbox ) . all (): return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ]), round ( bbox_mid [ 1 ] - camera_mid [ 1 ]) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . predict ( self . _camera_frames ) Classes YoloConfig class YoloConfig ( model_path : str , device : str = 'cpu' , verbose : bool = False , pred_kwargs : dict = < factory > ) YoloConfig(model_path: str, device: str = 'cpu', verbose: bool = False, pred_kwargs: dict = ) View Source @ dataclass class YoloConfig ( ConfigBase ): model_path : str \"\"\"The path to the pretrained YOLO weights file.\"\"\" device : str = \"cpu\" \"\"\"Inference device for YOLO. Can be either 'cpu' or 'cuda'.\"\"\" verbose : bool = False \"\"\"Whether to print verbose output during YOLO inference.\"\"\" pred_kwargs : dict = field ( default_factory = lambda : { \"imgsz\" : 384 , \"conf\" : 0.1 , } ) \"\"\"Additional keyword arguments for the YOLO prediction method.\"\"\" model : YOLO = field ( default = None , init = False , repr = False ) \"\"\"The YOLO model object.\"\"\" def __getstate__ ( self ) -> dict [ str , Any ]: state = self . __dict__ . copy () del state [ \"model\" ] # we dont want to serialize the model return state def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model Ancestors (in MRO) wtracker.utils.config_base.ConfigBase Class variables device model verbose Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods load_model def load_model ( self ) -> ultralytics . models . yolo . model . YOLO View Source def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) YoloController class YoloController ( timing_config : wtracker . sim . config . TimingConfig , yolo_config : wtracker . sim . sim_controllers . yolo_controller . YoloConfig ) Abstract base class for simulator controllers. Attributes Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class YoloController ( SimController ) : def __init__ ( self , timing_config : TimingConfig , yolo_config : YoloConfig ) : super (). __init__ ( timing_config ) self . yolo_config = yolo_config self . _camera_frames = deque ( maxlen = timing_config . cycle_frame_num ) self . _model = yolo_config . load_model () def on_sim_start ( self , sim : Simulator ) : self . _camera_frames . clear () def on_camera_frame ( self , sim : Simulator ) : self . _camera_frames . append ( sim . camera_view ()) def on_cycle_end ( self , sim : Simulator ) : self . _camera_frames . clear () def predict ( self , frames : Collection [ np.ndarray ] ) -> np . ndarray : assert len ( frames ) > 0 # convert grayscale images to BGR because YOLO expects 3 - channel images if frames [ 0 ] . ndim == 2 : frames = [ cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames ] # predict bounding boxes and format results results = self . _model . predict ( source = frames , device = self . yolo_config . device , max_det = 1 , verbose = self . yolo_config . verbose , ** self . yolo_config . pred_kwargs , ) results = [ res.numpy() for res in results ] bboxes = [] for res in results : if len ( res . boxes . xyxy ) == 0 : bboxes . append ( np . full ( [ 4 ] , np . nan )) else : bbox = BoxConverter . to_xywh ( res . boxes . xyxy [ 0 ] , BoxFormat . XYXY ) bboxes . append ( bbox ) return np . stack ( bboxes , axis = 0 ) def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : frame = self . _camera_frames [ -self.timing_config.pred_frame_num ] bbox = self . predict ( [ frame ] ) bbox = bbox [ 0 ] if not np . isfinite ( bbox ). all () : return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ] ), round ( bbox_mid [ 1 ] - camera_mid [ 1 ] ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . predict ( self . _camera_frames ) Ancestors (in MRO) wtracker.sim.simulator.SimController abc.ABC Methods begin_movement_prediction def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass on_camera_frame def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_frames . append ( sim . camera_view ()) on_cycle_end def on_cycle_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when a cycle ends. View Source def on_cycle_end ( self , sim : Simulator ) : self . _camera_frames . clear () on_cycle_start def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass on_imaging_end def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass on_imaging_start def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass on_micro_frame def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass on_movement_end def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass on_movement_start def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass on_sim_end def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass on_sim_start def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_frames . clear () predict def predict ( self , frames : Collection [ numpy . ndarray ] ) -> numpy . ndarray View Source def predict(self, frames: Collection[np.ndarray]) -> np.ndarray: assert len(frames) > 0 # convert grayscale images to BGR because YOLO expects 3-channel images if frames[0].ndim == 2: frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames] # predict bounding boxes and format results results = self._model.predict( source=frames, device=self.yolo_config.device, max_det=1, verbose=self.yolo_config.verbose, **self.yolo_config.pred_kwargs, ) results = [res.numpy() for res in results] bboxes = [] for res in results: if len(res.boxes.xyxy) == 0: bboxes.append(np.full([4], np.nan)) else: bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY) bboxes.append(bbox) return np.stack(bboxes, axis=0) provide_movement_vector def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : frame = self . _camera_frames [ -self.timing_config.pred_frame_num ] bbox = self . predict ( [ frame ] ) bbox = bbox [ 0 ] if not np . isfinite ( bbox ). all () : return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ] ), round ( bbox_mid [ 1 ] - camera_mid [ 1 ] )","title":"Yolo Controller"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#module-wtrackersimsim_controllersyolo_controller","text":"View Source from typing import Collection , Any from dataclasses import dataclass , field import numpy as np import cv2 as cv from collections import deque from ultralytics import YOLO from wtracker.sim.simulator import Simulator , SimController from wtracker.sim.config import TimingConfig from wtracker.utils.config_base import ConfigBase from wtracker.utils.bbox_utils import BoxUtils , BoxConverter , BoxFormat @dataclass class YoloConfig ( ConfigBase ): model_path : str \"\"\"The path to the pretrained YOLO weights file.\"\"\" device : str = \"cpu\" \"\"\"Inference device for YOLO. Can be either 'cpu' or 'cuda'.\"\"\" verbose : bool = False \"\"\"Whether to print verbose output during YOLO inference.\"\"\" pred_kwargs : dict = field ( default_factory = lambda : { \"imgsz\" : 384 , \"conf\" : 0.1 , } ) \"\"\"Additional keyword arguments for the YOLO prediction method.\"\"\" model : YOLO = field ( default = None , init = False , repr = False ) \"\"\"The YOLO model object.\"\"\" def __getstate__ ( self ) -> dict [ str , Any ]: state = self . __dict__ . copy () del state [ \"model\" ] # we dont want to serialize the model return state def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model class YoloController ( SimController ): def __init__ ( self , timing_config : TimingConfig , yolo_config : YoloConfig ): super () . __init__ ( timing_config ) self . yolo_config = yolo_config self . _camera_frames = deque ( maxlen = timing_config . cycle_frame_num ) self . _model = yolo_config . load_model () def on_sim_start ( self , sim : Simulator ): self . _camera_frames . clear () def on_camera_frame ( self , sim : Simulator ): self . _camera_frames . append ( sim . camera_view ()) def on_cycle_end ( self , sim : Simulator ): self . _camera_frames . clear () def predict ( self , frames : Collection [ np . ndarray ]) -> np . ndarray : assert len ( frames ) > 0 # convert grayscale images to BGR because YOLO expects 3-channel images if frames [ 0 ] . ndim == 2 : frames = [ cv . cvtColor ( frame , cv . COLOR_GRAY2BGR ) for frame in frames ] # predict bounding boxes and format results results = self . _model . predict ( source = frames , device = self . yolo_config . device , max_det = 1 , verbose = self . yolo_config . verbose , ** self . yolo_config . pred_kwargs , ) results = [ res . numpy () for res in results ] bboxes = [] for res in results : if len ( res . boxes . xyxy ) == 0 : bboxes . append ( np . full ([ 4 ], np . nan )) else : bbox = BoxConverter . to_xywh ( res . boxes . xyxy [ 0 ], BoxFormat . XYXY ) bboxes . append ( bbox ) return np . stack ( bboxes , axis = 0 ) def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int , int ]: frame = self . _camera_frames [ - self . timing_config . pred_frame_num ] bbox = self . predict ([ frame ]) bbox = bbox [ 0 ] if not np . isfinite ( bbox ) . all (): return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ]), round ( bbox_mid [ 1 ] - camera_mid [ 1 ]) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . predict ( self . _camera_frames )","title":"Module wtracker.sim.sim_controllers.yolo_controller"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#yoloconfig","text":"class YoloConfig ( model_path : str , device : str = 'cpu' , verbose : bool = False , pred_kwargs : dict = < factory > ) YoloConfig(model_path: str, device: str = 'cpu', verbose: bool = False, pred_kwargs: dict = ) View Source @ dataclass class YoloConfig ( ConfigBase ): model_path : str \"\"\"The path to the pretrained YOLO weights file.\"\"\" device : str = \"cpu\" \"\"\"Inference device for YOLO. Can be either 'cpu' or 'cuda'.\"\"\" verbose : bool = False \"\"\"Whether to print verbose output during YOLO inference.\"\"\" pred_kwargs : dict = field ( default_factory = lambda : { \"imgsz\" : 384 , \"conf\" : 0.1 , } ) \"\"\"Additional keyword arguments for the YOLO prediction method.\"\"\" model : YOLO = field ( default = None , init = False , repr = False ) \"\"\"The YOLO model object.\"\"\" def __getstate__ ( self ) -> dict [ str , Any ]: state = self . __dict__ . copy () del state [ \"model\" ] # we dont want to serialize the model return state def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model","title":"YoloConfig"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#ancestors-in-mro","text":"wtracker.utils.config_base.ConfigBase","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#class-variables","text":"device model verbose","title":"Class variables"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#load_model","text":"def load_model ( self ) -> ultralytics . models . yolo . model . YOLO View Source def load_model ( self ) -> YOLO : if self . model is None : self . model = YOLO ( self . model_path , task = \"detect\" , verbose = self . verbose ) return self . model","title":"load_model"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#yolocontroller","text":"class YoloController ( timing_config : wtracker . sim . config . TimingConfig , yolo_config : wtracker . sim . sim_controllers . yolo_controller . YoloConfig ) Abstract base class for simulator controllers.","title":"YoloController"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#attributes","text":"Name Type Description Default timing_config TimingConfig The timing configuration for the simulator. None View Source class YoloController ( SimController ) : def __init__ ( self , timing_config : TimingConfig , yolo_config : YoloConfig ) : super (). __init__ ( timing_config ) self . yolo_config = yolo_config self . _camera_frames = deque ( maxlen = timing_config . cycle_frame_num ) self . _model = yolo_config . load_model () def on_sim_start ( self , sim : Simulator ) : self . _camera_frames . clear () def on_camera_frame ( self , sim : Simulator ) : self . _camera_frames . append ( sim . camera_view ()) def on_cycle_end ( self , sim : Simulator ) : self . _camera_frames . clear () def predict ( self , frames : Collection [ np.ndarray ] ) -> np . ndarray : assert len ( frames ) > 0 # convert grayscale images to BGR because YOLO expects 3 - channel images if frames [ 0 ] . ndim == 2 : frames = [ cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames ] # predict bounding boxes and format results results = self . _model . predict ( source = frames , device = self . yolo_config . device , max_det = 1 , verbose = self . yolo_config . verbose , ** self . yolo_config . pred_kwargs , ) results = [ res.numpy() for res in results ] bboxes = [] for res in results : if len ( res . boxes . xyxy ) == 0 : bboxes . append ( np . full ( [ 4 ] , np . nan )) else : bbox = BoxConverter . to_xywh ( res . boxes . xyxy [ 0 ] , BoxFormat . XYXY ) bboxes . append ( bbox ) return np . stack ( bboxes , axis = 0 ) def begin_movement_prediction ( self , sim : Simulator ) -> None : pass def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : frame = self . _camera_frames [ -self.timing_config.pred_frame_num ] bbox = self . predict ( [ frame ] ) bbox = bbox [ 0 ] if not np . isfinite ( bbox ). all () : return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ] ), round ( bbox_mid [ 1 ] - camera_mid [ 1 ] ) def _cycle_predict_all ( self , sim : Simulator ) -> np . ndarray : return self . predict ( self . _camera_frames )","title":"Attributes"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#ancestors-in-mro_1","text":"wtracker.sim.simulator.SimController abc.ABC","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#begin_movement_prediction","text":"def begin_movement_prediction ( self , sim : wtracker . sim . simulator . Simulator ) -> None Called when the movement prediction begins. View Source def begin_movement_prediction ( self , sim : Simulator ) -> None : pass","title":"begin_movement_prediction"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_camera_frame","text":"def on_camera_frame ( self , sim : wtracker . sim . simulator . Simulator ) Called when a camera frame is captured. Happens every frame. View Source def on_camera_frame ( self , sim : Simulator ) : self . _camera_frames . append ( sim . camera_view ())","title":"on_camera_frame"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_cycle_end","text":"def on_cycle_end ( self , sim : wtracker . sim . simulator . Simulator ) Called when a cycle ends. View Source def on_cycle_end ( self , sim : Simulator ) : self . _camera_frames . clear ()","title":"on_cycle_end"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_cycle_start","text":"def on_cycle_start ( self , sim : 'Simulator' ) Called when a new cycle starts. View Source def on_cycle_start(self, sim: Simulator): \"\"\" Called when a new cycle starts. \"\"\" pass","title":"on_cycle_start"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_imaging_end","text":"def on_imaging_end ( self , sim : 'Simulator' ) Called when imaging phase ends. View Source def on_imaging_end(self, sim: Simulator): \"\"\" Called when imaging phase ends. \"\"\" pass","title":"on_imaging_end"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_imaging_start","text":"def on_imaging_start ( self , sim : 'Simulator' ) Called when imaging phase starts. View Source def on_imaging_start(self, sim: Simulator): \"\"\" Called when imaging phase starts. \"\"\" pass","title":"on_imaging_start"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_micro_frame","text":"def on_micro_frame ( self , sim : 'Simulator' ) Called when a micro frame is captured. Happens for every during the imaging phase. View Source def on_micro_frame(self, sim: Simulator): \"\"\" Called when a micro frame is captured. Happens for every during the imaging phase. \"\"\" pass","title":"on_micro_frame"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_movement_end","text":"def on_movement_end ( self , sim : 'Simulator' ) Called when movement phase ends. View Source def on_movement_end(self, sim: Simulator): \"\"\" Called when movement phase ends. \"\"\" pass","title":"on_movement_end"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_movement_start","text":"def on_movement_start ( self , sim : 'Simulator' ) Called when movement phase starts. View Source def on_movement_start(self, sim: Simulator): \"\"\" Called when movement phase starts. \"\"\" pass","title":"on_movement_start"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_sim_end","text":"def on_sim_end ( self , sim : 'Simulator' ) Called when the simulation ends. View Source def on_sim_end(self, sim: Simulator): \"\"\" Called when the simulation ends. \"\"\" pass","title":"on_sim_end"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#on_sim_start","text":"def on_sim_start ( self , sim : wtracker . sim . simulator . Simulator ) Called when the simulation starts. View Source def on_sim_start ( self , sim : Simulator ) : self . _camera_frames . clear ()","title":"on_sim_start"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#predict","text":"def predict ( self , frames : Collection [ numpy . ndarray ] ) -> numpy . ndarray View Source def predict(self, frames: Collection[np.ndarray]) -> np.ndarray: assert len(frames) > 0 # convert grayscale images to BGR because YOLO expects 3-channel images if frames[0].ndim == 2: frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames] # predict bounding boxes and format results results = self._model.predict( source=frames, device=self.yolo_config.device, max_det=1, verbose=self.yolo_config.verbose, **self.yolo_config.pred_kwargs, ) results = [res.numpy() for res in results] bboxes = [] for res in results: if len(res.boxes.xyxy) == 0: bboxes.append(np.full([4], np.nan)) else: bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY) bboxes.append(bbox) return np.stack(bboxes, axis=0)","title":"predict"},{"location":"reference/wtracker/sim/sim_controllers/yolo_controller/#provide_movement_vector","text":"def provide_movement_vector ( self , sim : wtracker . sim . simulator . Simulator ) -> tuple [ int , int ] Provides the movement vector for the simulator. The platform is moved by the provided vector. Returns: Type Description tuple[int, int] The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. View Source def provide_movement_vector ( self , sim : Simulator ) -> tuple [ int, int ] : frame = self . _camera_frames [ -self.timing_config.pred_frame_num ] bbox = self . predict ( [ frame ] ) bbox = bbox [ 0 ] if not np . isfinite ( bbox ). all () : return 0 , 0 bbox_mid = bbox [ 0 ] + bbox [ 2 ] / 2 , bbox [ 1 ] + bbox [ 3 ] / 2 camera_mid = sim . view . camera_size [ 0 ] / 2 , sim . view . camera_size [ 1 ] / 2 return round ( bbox_mid [ 0 ] - camera_mid [ 0 ] ), round ( bbox_mid [ 1 ] - camera_mid [ 1 ] )","title":"provide_movement_vector"},{"location":"reference/wtracker/utils/","text":"Namespace wtracker.utils Sub-modules wtracker.utils.bbox_utils wtracker.utils.config_base wtracker.utils.frame_reader wtracker.utils.gui_utils wtracker.utils.io_utils wtracker.utils.log_utils wtracker.utils.path_utils wtracker.utils.threading_utils","title":"Index"},{"location":"reference/wtracker/utils/#namespace-wtrackerutils","text":"","title":"Namespace wtracker.utils"},{"location":"reference/wtracker/utils/#sub-modules","text":"wtracker.utils.bbox_utils wtracker.utils.config_base wtracker.utils.frame_reader wtracker.utils.gui_utils wtracker.utils.io_utils wtracker.utils.log_utils wtracker.utils.path_utils wtracker.utils.threading_utils","title":"Sub-modules"},{"location":"reference/wtracker/utils/bbox_utils/","text":"Module wtracker.utils.bbox_utils View Source import numpy as np from enum import Enum class BoxFormat ( Enum ): \"\"\" Enumeration representing different box formats. Attributes: XYWH (int): Represents the box format as (x, y, width, height). XYXY (int): Represents the box format as (x1, y1, x2, y2). YOLO (int): Represents the box format as (center_x, center_y, width, height). \"\"\" XYWH = 0 XYXY = 1 YOLO = 2 class BoxUtils : \"\"\" A utility class for working with bounding boxes. \"\"\" @staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ - 1 ] == 4 @staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray ]: \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ - 1 ], axis =- 1 ) c1 = np . squeeze ( c1 , axis =- 1 ) c2 = np . squeeze ( c2 , axis =- 1 ) c3 = np . squeeze ( c3 , axis =- 1 ) c4 = np . squeeze ( c4 , axis =- 1 ) return c1 , c2 , c3 , c4 @staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis =- 1 ) c2 = np . expand_dims ( c2 , axis =- 1 ) c3 = np . expand_dims ( c3 , axis =- 1 ) c4 = np . expand_dims ( c4 , axis =- 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis =- 1 ) @staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ([ center_x , center_y ]) . T @staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ) . astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ) . astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ) . astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ) . astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) @staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non-finite bounding boxes is_legal = np . isfinite ( bboxes ) . all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal class BoxConverter : \"\"\" Utility class for converting bounding box coordinates between different formats. \"\"\" @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) Classes BoxConverter class BoxConverter ( / , * args , ** kwargs ) Utility class for converting bounding box coordinates between different formats. View Source class BoxConverter : \"\"\" Utility class for converting bounding box coordinates between different formats. \"\"\" @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) Static methods change_format def change_format ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat , dst_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates from one format to another. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None dst_format BoxFormat The destination format of the bounding box coordinates. None Returns: Type Description np.ndarray The converted bounding box coordinates. Raises: Type Description Exception If the conversion between the specified formats is not supported. View Source @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) to_xywh def to_xywh ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the XYWH format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the XYWH format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) to_xyxy def to_xyxy ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the XYXY format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the XYXY format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) to_yolo def to_yolo ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the YOLO format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the YOLO format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) BoxFormat class BoxFormat ( / , * args , ** kwargs ) Enumeration representing different box formats. Attributes Name Type Description Default XYWH int Represents the box format as (x, y, width, height). None XYXY int Represents the box format as (x1, y1, x2, y2). None YOLO int Represents the box format as (center_x, center_y, width, height). None View Source class BoxFormat ( Enum ): \"\"\" Enumeration representing different box formats. Attributes: XYWH (int): Represents the box format as (x, y, width, height). XYXY (int): Represents the box format as (x1, y1, x2, y2). YOLO (int): Represents the box format as (center_x, center_y, width, height). \"\"\" XYWH = 0 XYXY = 1 YOLO = 2 Ancestors (in MRO) enum.Enum Class variables XYWH XYXY YOLO name value BoxUtils class BoxUtils ( / , * args , ** kwargs ) A utility class for working with bounding boxes. View Source class BoxUtils : \"\"\" A utility class for working with bounding boxes. \"\"\" @ staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ - 1 ] == 4 @ staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray ]: \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ - 1 ], axis = - 1 ) c1 = np . squeeze ( c1 , axis = - 1 ) c2 = np . squeeze ( c2 , axis = - 1 ) c3 = np . squeeze ( c3 , axis = - 1 ) c4 = np . squeeze ( c4 , axis = - 1 ) return c1 , c2 , c3 , c4 @ staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis = - 1 ) c2 = np . expand_dims ( c2 , axis = - 1 ) c3 = np . expand_dims ( c3 , axis = - 1 ) c4 = np . expand_dims ( c4 , axis = - 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis = - 1 ) @ staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ([ center_x , center_y ]). T @ staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ). astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ). astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ). astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ). astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) @ staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non - finite bounding boxes is_legal = np . isfinite ( bboxes ). all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal Static methods center def center ( bboxes : numpy . ndarray , box_format : wtracker . utils . bbox_utils . BoxFormat = < BoxFormat . XYWH : 0 > ) -> numpy . ndarray Calculate the center of the bounding boxes. Parameters: Name Type Description Default bboxes np.ndarray The input bounding boxes. None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description np.ndarray The center of the bounding boxes, in the format (center_x, center_y). View Source @staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ( [ center_x, center_y ] ). T discretize def discretize ( bboxes : numpy . ndarray , bounds : tuple [ int , int ], box_format : wtracker . utils . bbox_utils . BoxFormat ) -> tuple [ numpy . ndarray , numpy . ndarray ] Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Parameters: Name Type Description Default bboxes np.ndarray The bounding box coordinates to convert. None bounds tuple[int, int] The bounds to clamp the bounding boxes to, in the format (h, w). None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description tuple[np.ndarray, np.ndarray] The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. View Source @ staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non - finite bounding boxes is_legal = np . isfinite ( bboxes ). all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal is_bbox def is_bbox ( array : numpy . ndarray ) -> bool Check if the given array is a valid bounding box. Parameters: Name Type Description Default array np.ndarray The array to check. None Returns: Type Description bool True if the array is a valid bounding box, False otherwise. View Source @staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ -1 ] == 4 pack def pack ( c1 : numpy . ndarray , c2 : numpy . ndarray , c3 : numpy . ndarray , c4 : numpy . ndarray ) -> numpy . ndarray Pack the given components into a single bounding box. Parameters: Name Type Description Default c1 np.ndarray The first component of the bounding box. None c2 np.ndarray The second component of the bounding box. None c3 np.ndarray The third component of the bounding box. None c4 np.ndarray The fourth component of the bounding box. None Returns: Type Description np.ndarray The packed bounding box. View Source @staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis =- 1 ) c2 = np . expand_dims ( c2 , axis =- 1 ) c3 = np . expand_dims ( c3 , axis =- 1 ) c4 = np . expand_dims ( c4 , axis =- 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis =- 1 ) round def round ( bboxes : numpy . ndarray , box_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Rounds the bounding box coordinates to integers. Parameters: Name Type Description Default bboxes np.ndarray The bounding box coordinates to convert. None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description np.ndarray The bounding box coordinates as integers. View Source @ staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ). astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ). astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ). astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ). astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) unpack def unpack ( bbox : numpy . ndarray ) -> tuple [ numpy . ndarray , numpy . ndarray , numpy . ndarray , numpy . ndarray ] Unpack the given bounding box into its individual components. Parameters: Name Type Description Default bbox np.ndarray The bounding box to unpack. None Returns: Type Description tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] The unpacked components of the bounding box. View Source @staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np.ndarray, np.ndarray, np.ndarray, np.ndarray ] : \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ -1 ] , axis =- 1 ) c1 = np . squeeze ( c1 , axis =- 1 ) c2 = np . squeeze ( c2 , axis =- 1 ) c3 = np . squeeze ( c3 , axis =- 1 ) c4 = np . squeeze ( c4 , axis =- 1 ) return c1 , c2 , c3 , c4","title":"Bbox Utils"},{"location":"reference/wtracker/utils/bbox_utils/#module-wtrackerutilsbbox_utils","text":"View Source import numpy as np from enum import Enum class BoxFormat ( Enum ): \"\"\" Enumeration representing different box formats. Attributes: XYWH (int): Represents the box format as (x, y, width, height). XYXY (int): Represents the box format as (x1, y1, x2, y2). YOLO (int): Represents the box format as (center_x, center_y, width, height). \"\"\" XYWH = 0 XYXY = 1 YOLO = 2 class BoxUtils : \"\"\" A utility class for working with bounding boxes. \"\"\" @staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ - 1 ] == 4 @staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray ]: \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ - 1 ], axis =- 1 ) c1 = np . squeeze ( c1 , axis =- 1 ) c2 = np . squeeze ( c2 , axis =- 1 ) c3 = np . squeeze ( c3 , axis =- 1 ) c4 = np . squeeze ( c4 , axis =- 1 ) return c1 , c2 , c3 , c4 @staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis =- 1 ) c2 = np . expand_dims ( c2 , axis =- 1 ) c3 = np . expand_dims ( c3 , axis =- 1 ) c4 = np . expand_dims ( c4 , axis =- 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis =- 1 ) @staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ([ center_x , center_y ]) . T @staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ) . astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ) . astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ) . astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ) . astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) @staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non-finite bounding boxes is_legal = np . isfinite ( bboxes ) . all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal class BoxConverter : \"\"\" Utility class for converting bounding box coordinates between different formats. \"\"\" @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"Module wtracker.utils.bbox_utils"},{"location":"reference/wtracker/utils/bbox_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/bbox_utils/#boxconverter","text":"class BoxConverter ( / , * args , ** kwargs ) Utility class for converting bounding box coordinates between different formats. View Source class BoxConverter : \"\"\" Utility class for converting bounding box coordinates between different formats. \"\"\" @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" ) @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"BoxConverter"},{"location":"reference/wtracker/utils/bbox_utils/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/bbox_utils/#change_format","text":"def change_format ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat , dst_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates from one format to another. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None dst_format BoxFormat The destination format of the bounding box coordinates. None Returns: Type Description np.ndarray The converted bounding box coordinates. Raises: Type Description Exception If the conversion between the specified formats is not supported. View Source @staticmethod def change_format ( bbox : np . ndarray , src_format : BoxFormat , dst_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates from one format to another. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. dst_format (BoxFormat): The destination format of the bounding box coordinates. Returns: np.ndarray: The converted bounding box coordinates. Raises: Exception: If the conversion between the specified formats is not supported. \"\"\" if dst_format == BoxFormat . XYXY : return BoxConverter . to_xyxy ( bbox , src_format ) elif dst_format == BoxFormat . XYWH : return BoxConverter . to_xywh ( bbox , src_format ) elif dst_format == BoxFormat . YOLO : return BoxConverter . to_xywh ( bbox , src_format ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"change_format"},{"location":"reference/wtracker/utils/bbox_utils/#to_xywh","text":"def to_xywh ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the XYWH format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the XYWH format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_xywh ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYWH format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYWH format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYWH : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 return BoxUtils . pack ( x1 , y1 , w , h ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 return BoxUtils . pack ( x1 , y1 , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"to_xywh"},{"location":"reference/wtracker/utils/bbox_utils/#to_xyxy","text":"def to_xyxy ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the XYXY format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the XYXY format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_xyxy ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the XYXY format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the XYXY format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . XYXY : return bbox elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) elif src_format == BoxFormat . YOLO : xm , ym , w , h = BoxUtils . unpack ( bbox ) x1 = xm - w / 2 y1 = ym - h / 2 x2 = x1 + w y2 = y1 + h return BoxUtils . pack ( x1 , y1 , x2 , y2 ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"to_xyxy"},{"location":"reference/wtracker/utils/bbox_utils/#to_yolo","text":"def to_yolo ( bbox : numpy . ndarray , src_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Converts the bounding box coordinates to the YOLO format. Parameters: Name Type Description Default bbox np.ndarray The bounding box coordinates to be converted. None src_format BoxFormat The source format of the bounding box coordinates. None Returns: Type Description np.ndarray The bounding box coordinates in the YOLO format. Raises: Type Description Exception If the conversion from the specified source format is not supported. View Source @staticmethod def to_yolo ( bbox : np . ndarray , src_format : BoxFormat ) -> np . ndarray : \"\"\" Converts the bounding box coordinates to the YOLO format. Args: bbox (np.ndarray): The bounding box coordinates to be converted. src_format (BoxFormat): The source format of the bounding box coordinates. Returns: np.ndarray: The bounding box coordinates in the YOLO format. Raises: Exception: If the conversion from the specified source format is not supported. \"\"\" if src_format == BoxFormat . YOLO : return bbox elif src_format == BoxFormat . XYXY : x1 , y1 , x2 , y2 = BoxUtils . unpack ( bbox ) w = x2 - x1 h = y2 - y1 xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) elif src_format == BoxFormat . XYWH : x1 , y1 , w , h = BoxUtils . unpack ( bbox ) xm = x1 + w / 2 ym = y1 + h / 2 return BoxUtils . pack ( xm , ym , w , h ) else : raise Exception ( \"unsupported bbox format conversion.\" )","title":"to_yolo"},{"location":"reference/wtracker/utils/bbox_utils/#boxformat","text":"class BoxFormat ( / , * args , ** kwargs ) Enumeration representing different box formats.","title":"BoxFormat"},{"location":"reference/wtracker/utils/bbox_utils/#attributes","text":"Name Type Description Default XYWH int Represents the box format as (x, y, width, height). None XYXY int Represents the box format as (x1, y1, x2, y2). None YOLO int Represents the box format as (center_x, center_y, width, height). None View Source class BoxFormat ( Enum ): \"\"\" Enumeration representing different box formats. Attributes: XYWH (int): Represents the box format as (x, y, width, height). XYXY (int): Represents the box format as (x1, y1, x2, y2). YOLO (int): Represents the box format as (center_x, center_y, width, height). \"\"\" XYWH = 0 XYXY = 1 YOLO = 2","title":"Attributes"},{"location":"reference/wtracker/utils/bbox_utils/#ancestors-in-mro","text":"enum.Enum","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/bbox_utils/#class-variables","text":"XYWH XYXY YOLO name value","title":"Class variables"},{"location":"reference/wtracker/utils/bbox_utils/#boxutils","text":"class BoxUtils ( / , * args , ** kwargs ) A utility class for working with bounding boxes. View Source class BoxUtils : \"\"\" A utility class for working with bounding boxes. \"\"\" @ staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ - 1 ] == 4 @ staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray ]: \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ - 1 ], axis = - 1 ) c1 = np . squeeze ( c1 , axis = - 1 ) c2 = np . squeeze ( c2 , axis = - 1 ) c3 = np . squeeze ( c3 , axis = - 1 ) c4 = np . squeeze ( c4 , axis = - 1 ) return c1 , c2 , c3 , c4 @ staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis = - 1 ) c2 = np . expand_dims ( c2 , axis = - 1 ) c3 = np . expand_dims ( c3 , axis = - 1 ) c4 = np . expand_dims ( c4 , axis = - 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis = - 1 ) @ staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ([ center_x , center_y ]). T @ staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ). astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ). astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ). astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ). astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) @ staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non - finite bounding boxes is_legal = np . isfinite ( bboxes ). all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal","title":"BoxUtils"},{"location":"reference/wtracker/utils/bbox_utils/#static-methods_1","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/bbox_utils/#center","text":"def center ( bboxes : numpy . ndarray , box_format : wtracker . utils . bbox_utils . BoxFormat = < BoxFormat . XYWH : 0 > ) -> numpy . ndarray Calculate the center of the bounding boxes. Parameters: Name Type Description Default bboxes np.ndarray The input bounding boxes. None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description np.ndarray The center of the bounding boxes, in the format (center_x, center_y). View Source @staticmethod def center ( bboxes : np . ndarray , box_format : BoxFormat = BoxFormat . XYWH ) -> np . ndarray : \"\"\" Calculate the center of the bounding boxes. Args: bboxes (np.ndarray): The input bounding boxes. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The center of the bounding boxes, in the format (center_x, center_y). \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYWH ) x , y , w , h = BoxUtils . unpack ( bboxes ) center_x = x + w / 2 center_y = y + h / 2 return np . array ( [ center_x, center_y ] ). T","title":"center"},{"location":"reference/wtracker/utils/bbox_utils/#discretize","text":"def discretize ( bboxes : numpy . ndarray , bounds : tuple [ int , int ], box_format : wtracker . utils . bbox_utils . BoxFormat ) -> tuple [ numpy . ndarray , numpy . ndarray ] Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Parameters: Name Type Description Default bboxes np.ndarray The bounding box coordinates to convert. None bounds tuple[int, int] The bounds to clamp the bounding boxes to, in the format (h, w). None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description tuple[np.ndarray, np.ndarray] The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. View Source @ staticmethod def discretize ( bboxes : np . ndarray , bounds : tuple [ int , int ], box_format : BoxFormat , ) -> tuple [ np . ndarray , np . ndarray ]: \"\"\" Converts bounding boxes into integer format and clamps them to the specified bounds. All illegal bounding boxes are zeroed out. This function is especially useful for discretizing the bboxes for image slicing at bbox coordinates. Args: bboxes (np.ndarray): The bounding box coordinates to convert. bounds (tuple[int, int]): The bounds to clamp the bounding boxes to, in the format (h, w). box_format (BoxFormat): The format of the input bounding boxes. Returns: tuple[np.ndarray, np.ndarray]: The discretized bounding boxes and a boolean mask indicating which bounding boxes are legal. The first element are bounding boxes discretized to 'np.int32' format. All illegal bounding boxes are zeroed out. The second element is a boolean mask indicating which input bounding boxes are legal. \"\"\" # zero out all non - finite bounding boxes is_legal = np . isfinite ( bboxes ). all ( axis = 1 ) bboxes [ ~ is_legal ] = 0 bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) bboxes = BoxUtils . round ( bboxes , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) # clip worm bounding boxes to the size H , W = bounds x1 = np . clip ( x1 , a_min = 0 , a_max = W ) y1 = np . clip ( y1 , a_min = 0 , a_max = H ) x2 = np . clip ( x2 , a_min = 0 , a_max = W ) y2 = np . clip ( y2 , a_min = 0 , a_max = H ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) bboxes = BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format ) # zero out all bounding boxes with 0 dimension w = x2 - x1 h = y2 - y1 is_legal = ( w > 0.0 ) & ( h > 0.0 ) # zero out all illegal bounding boxes and make sure return types are correct bboxes [ ~ is_legal ] = 0 bboxes = bboxes . astype ( np . int32 , copy = False ) is_legal = is_legal . astype ( bool , copy = False ) return bboxes , is_legal","title":"discretize"},{"location":"reference/wtracker/utils/bbox_utils/#is_bbox","text":"def is_bbox ( array : numpy . ndarray ) -> bool Check if the given array is a valid bounding box. Parameters: Name Type Description Default array np.ndarray The array to check. None Returns: Type Description bool True if the array is a valid bounding box, False otherwise. View Source @staticmethod def is_bbox ( array : np . ndarray ) -> bool : \"\"\" Check if the given array is a valid bounding box. Args: array (np.ndarray): The array to check. Returns: bool: True if the array is a valid bounding box, False otherwise. \"\"\" return array . shape [ -1 ] == 4","title":"is_bbox"},{"location":"reference/wtracker/utils/bbox_utils/#pack","text":"def pack ( c1 : numpy . ndarray , c2 : numpy . ndarray , c3 : numpy . ndarray , c4 : numpy . ndarray ) -> numpy . ndarray Pack the given components into a single bounding box. Parameters: Name Type Description Default c1 np.ndarray The first component of the bounding box. None c2 np.ndarray The second component of the bounding box. None c3 np.ndarray The third component of the bounding box. None c4 np.ndarray The fourth component of the bounding box. None Returns: Type Description np.ndarray The packed bounding box. View Source @staticmethod def pack ( c1 : np . ndarray , c2 : np . ndarray , c3 : np . ndarray , c4 : np . ndarray ) -> np . ndarray : \"\"\" Pack the given components into a single bounding box. Args: c1 (np.ndarray): The first component of the bounding box. c2 (np.ndarray): The second component of the bounding box. c3 (np.ndarray): The third component of the bounding box. c4 (np.ndarray): The fourth component of the bounding box. Returns: np.ndarray: The packed bounding box. \"\"\" c1 = np . expand_dims ( c1 , axis =- 1 ) c2 = np . expand_dims ( c2 , axis =- 1 ) c3 = np . expand_dims ( c3 , axis =- 1 ) c4 = np . expand_dims ( c4 , axis =- 1 ) return np . concatenate (( c1 , c2 , c3 , c4 ), axis =- 1 )","title":"pack"},{"location":"reference/wtracker/utils/bbox_utils/#round","text":"def round ( bboxes : numpy . ndarray , box_format : wtracker . utils . bbox_utils . BoxFormat ) -> numpy . ndarray Rounds the bounding box coordinates to integers. Parameters: Name Type Description Default bboxes np.ndarray The bounding box coordinates to convert. None box_format BoxFormat The format of the input bounding boxes. None Returns: Type Description np.ndarray The bounding box coordinates as integers. View Source @ staticmethod def round ( bboxes : np . ndarray , box_format : BoxFormat ) -> np . ndarray : \"\"\" Rounds the bounding box coordinates to integers. Args: bboxes (np.ndarray): The bounding box coordinates to convert. box_format (BoxFormat): The format of the input bounding boxes. Returns: np.ndarray: The bounding box coordinates as integers. \"\"\" bboxes = BoxConverter . change_format ( bboxes , box_format , BoxFormat . XYXY ) x1 , y1 , x2 , y2 = BoxUtils . unpack ( bboxes ) x1 = np . floor ( x1 ). astype ( np . int32 , copy = False ) y1 = np . floor ( y1 ). astype ( np . int32 , copy = False ) x2 = np . ceil ( x2 ). astype ( np . int32 , copy = False ) y2 = np . ceil ( y2 ). astype ( np . int32 , copy = False ) bboxes = BoxUtils . pack ( x1 , y1 , x2 , y2 ) return BoxConverter . change_format ( bboxes , BoxFormat . XYXY , box_format )","title":"round"},{"location":"reference/wtracker/utils/bbox_utils/#unpack","text":"def unpack ( bbox : numpy . ndarray ) -> tuple [ numpy . ndarray , numpy . ndarray , numpy . ndarray , numpy . ndarray ] Unpack the given bounding box into its individual components. Parameters: Name Type Description Default bbox np.ndarray The bounding box to unpack. None Returns: Type Description tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] The unpacked components of the bounding box. View Source @staticmethod def unpack ( bbox : np . ndarray ) -> tuple [ np.ndarray, np.ndarray, np.ndarray, np.ndarray ] : \"\"\" Unpack the given bounding box into its individual components. Args: bbox (np.ndarray): The bounding box to unpack. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The unpacked components of the bounding box. \"\"\" c1 , c2 , c3 , c4 = np . split ( bbox , bbox . shape [ -1 ] , axis =- 1 ) c1 = np . squeeze ( c1 , axis =- 1 ) c2 = np . squeeze ( c2 , axis =- 1 ) c3 = np . squeeze ( c3 , axis =- 1 ) c4 = np . squeeze ( c4 , axis =- 1 ) return c1 , c2 , c3 , c4","title":"unpack"},{"location":"reference/wtracker/utils/config_base/","text":"Module wtracker.utils.config_base View Source from __future__ import annotations from typing import Type , TypeVar from dataclasses import dataclass , fields , MISSING , is_dataclass import json from wtracker.utils.gui_utils import UserPrompt from wtracker.utils.io_utils import pickle_load_object , pickle_save_object T = TypeVar ( \"T\" , bound = \"ConfigBase\" ) @dataclass class ConfigBase : @classmethod def load_json ( cls : type [ T ], path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open { cls . __name__ } File\" , file_types = [( \"json\" , \".json\" )], ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save { type ( self ) . __name__ } As\" , file_types = [( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) @classmethod def load_pickle ( cls : type [ T ], path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open { cls . __name__ } File\" , file_types = [( \"pickle\" , \".pkl\" )], ) return pickle_load_object ( path ) def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save { type ( self ) . __name__ } As\" , file_types = [( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) def print_initialization ( cls , include_default : bool = True , init_fields_only : bool = True ) -> str : \"\"\" Print the initialization of a dataclass as a string \"\"\" if not is_dataclass ( cls ): print ( f \"ERROR:: { cls . __name__ } is not a dataclass\" ) return \"\" print ( f \" { cls . __name__ } (\" ) for field in fields ( cls ): if init_fields_only and field . init is False : continue is_default = not isinstance ( field . default , type ( MISSING )) val = None if include_default and is_default : val = field . default if type ( val ) is str : val = f 'f\" { val } \"' print ( f \" { field . name } = { val } , # { field . type } \" ) print ( \")\" ) Variables T Functions print_initialization def print_initialization ( cls , include_default : 'bool' = True , init_fields_only : 'bool' = True ) -> 'str' Print the initialization of a dataclass as a string View Source def print_initialization ( cls , include_default : bool = True , init_fields_only : bool = True ) -> str : \"\"\" Print the initialization of a dataclass as a string \"\"\" if not is_dataclass ( cls ): print ( f \"ERROR::{cls.__name__} is not a dataclass\" ) return \"\" print ( f \"{cls.__name__}(\" ) for field in fields ( cls ): if init_fields_only and field . init is False : continue is_default = not isinstance ( field . default , type ( MISSING )) val = None if include_default and is_default : val = field . default if type ( val ) is str : val = f ' f \"{val}\" ' print ( f \" {field.name} = {val}, # {field.type}\" ) print ( \")\" ) Classes ConfigBase class ConfigBase ( ) ConfigBase() View Source @dataclass class ConfigBase : @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj def save_json ( self , path : str = None ) : \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[ (\"json\", \".json\") ] , defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[ (\"pickle\", \".pkl\") ] , defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) Descendants wtracker.sim.config.TimingConfig wtracker.sim.config.ExperimentConfig wtracker.neural.config.DatasetConfig wtracker.neural.config.TrainConfig wtracker.neural.config.IOConfig wtracker.sim.sim_controllers.logging_controller.LogConfig wtracker.sim.sim_controllers.polyfit_controller.PolyfitConfig wtracker.sim.sim_controllers.yolo_controller.YoloConfig Static methods load_json def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj load_pickle def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) Methods save_json def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) save_pickle def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"Config Base"},{"location":"reference/wtracker/utils/config_base/#module-wtrackerutilsconfig_base","text":"View Source from __future__ import annotations from typing import Type , TypeVar from dataclasses import dataclass , fields , MISSING , is_dataclass import json from wtracker.utils.gui_utils import UserPrompt from wtracker.utils.io_utils import pickle_load_object , pickle_save_object T = TypeVar ( \"T\" , bound = \"ConfigBase\" ) @dataclass class ConfigBase : @classmethod def load_json ( cls : type [ T ], path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open { cls . __name__ } File\" , file_types = [( \"json\" , \".json\" )], ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save { type ( self ) . __name__ } As\" , file_types = [( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) @classmethod def load_pickle ( cls : type [ T ], path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open { cls . __name__ } File\" , file_types = [( \"pickle\" , \".pkl\" )], ) return pickle_load_object ( path ) def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save { type ( self ) . __name__ } As\" , file_types = [( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path ) def print_initialization ( cls , include_default : bool = True , init_fields_only : bool = True ) -> str : \"\"\" Print the initialization of a dataclass as a string \"\"\" if not is_dataclass ( cls ): print ( f \"ERROR:: { cls . __name__ } is not a dataclass\" ) return \"\" print ( f \" { cls . __name__ } (\" ) for field in fields ( cls ): if init_fields_only and field . init is False : continue is_default = not isinstance ( field . default , type ( MISSING )) val = None if include_default and is_default : val = field . default if type ( val ) is str : val = f 'f\" { val } \"' print ( f \" { field . name } = { val } , # { field . type } \" ) print ( \")\" )","title":"Module wtracker.utils.config_base"},{"location":"reference/wtracker/utils/config_base/#variables","text":"T","title":"Variables"},{"location":"reference/wtracker/utils/config_base/#functions","text":"","title":"Functions"},{"location":"reference/wtracker/utils/config_base/#print_initialization","text":"def print_initialization ( cls , include_default : 'bool' = True , init_fields_only : 'bool' = True ) -> 'str' Print the initialization of a dataclass as a string View Source def print_initialization ( cls , include_default : bool = True , init_fields_only : bool = True ) -> str : \"\"\" Print the initialization of a dataclass as a string \"\"\" if not is_dataclass ( cls ): print ( f \"ERROR::{cls.__name__} is not a dataclass\" ) return \"\" print ( f \"{cls.__name__}(\" ) for field in fields ( cls ): if init_fields_only and field . init is False : continue is_default = not isinstance ( field . default , type ( MISSING )) val = None if include_default and is_default : val = field . default if type ( val ) is str : val = f ' f \"{val}\" ' print ( f \" {field.name} = {val}, # {field.type}\" ) print ( \")\" )","title":"print_initialization"},{"location":"reference/wtracker/utils/config_base/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/config_base/#configbase","text":"class ConfigBase ( ) ConfigBase() View Source @dataclass class ConfigBase : @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj def save_json ( self , path : str = None ) : \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[ (\"json\", \".json\") ] , defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 ) @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path ) def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[ (\"pickle\", \".pkl\") ] , defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"ConfigBase"},{"location":"reference/wtracker/utils/config_base/#descendants","text":"wtracker.sim.config.TimingConfig wtracker.sim.config.ExperimentConfig wtracker.neural.config.DatasetConfig wtracker.neural.config.TrainConfig wtracker.neural.config.IOConfig wtracker.sim.sim_controllers.logging_controller.LogConfig wtracker.sim.sim_controllers.polyfit_controller.PolyfitConfig wtracker.sim.sim_controllers.yolo_controller.YoloConfig","title":"Descendants"},{"location":"reference/wtracker/utils/config_base/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/config_base/#load_json","text":"def load_json ( path : 'str' = None ) -> 'T' Load the class from a JSON file. Parameters: Name Type Description Default path str The path to the JSON file. None Returns: Type Description ConfigBase The class loaded from the JSON file. View Source @classmethod def load_json ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a JSON file. Args: path (str): The path to the JSON file. Returns: ConfigBase: The class loaded from the JSON file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"json\", \".json\") ] , ) with open ( path , \"r\" ) as f : data = json . load ( f ) obj = cls . __new__ ( cls ) obj . __dict__ . update ( data ) return obj","title":"load_json"},{"location":"reference/wtracker/utils/config_base/#load_pickle","text":"def load_pickle ( path : 'str' = None ) -> 'T' Load the class from a pickle file. Parameters: Name Type Description Default path str The path to the pickle file. None Returns: Type Description None The class loaded from the pickle file. View Source @classmethod def load_pickle ( cls : type [ T ] , path : str = None ) -> T : \"\"\" Load the class from a pickle file. Args: path (str): The path to the pickle file. Returns: The class loaded from the pickle file. \"\"\" if path is None : path = UserPrompt . open_file ( title = f \"Open {cls.__name__} File\" , file_types =[ (\"pickle\", \".pkl\") ] , ) return pickle_load_object ( path )","title":"load_pickle"},{"location":"reference/wtracker/utils/config_base/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/config_base/#save_json","text":"def save_json ( self , path : 'str' = None ) Saves the class as JSON file. Parameters: Name Type Description Default path str The path to the output JSON file. None View Source def save_json ( self , path : str = None ): \"\"\" Saves the class as JSON file. Args: path (str): The path to the output JSON file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"json\" , \".json\" )], defaultextension = \".json\" , ) with open ( path , \"w\" ) as f : json . dump ( self . __dict__ , f , indent = 4 )","title":"save_json"},{"location":"reference/wtracker/utils/config_base/#save_pickle","text":"def save_pickle ( self , path : 'str' = None ) -> 'None' Saves the class as a pickle file. Parameters: Name Type Description Default path str The path to the output pickle file. None View Source def save_pickle ( self , path : str = None ) -> None : \"\"\" Saves the class as a pickle file. Args: path (str): The path to the output pickle file. \"\"\" if path is None : path = UserPrompt . save_file ( title = f \"Save {type(self).__name__} As\" , file_types =[( \"pickle\" , \".pkl\" )], defaultextension = \".pkl\" , ) pickle_save_object ( self , path )","title":"save_pickle"},{"location":"reference/wtracker/utils/frame_reader/","text":"Module wtracker.utils.frame_reader View Source from __future__ import annotations import os import glob import numpy as np import cv2 as cv from wtracker.utils.path_utils import join_paths class FrameReader : \"\"\" An class for reading frames from a directory of frame files. Args: root_folder (str): The root folder path where the frame files are located. frame_files (list[str]): A list of frame file names. read_format (int, optional): The format in which the frames should be read. Attributes: root_folder (str): The root folder path where the frame files are located. frame_shape (tuple[int, ...]): The shape of the frame. frame_size (tuple[int, int]): The size of the frame. files (list[str]): The list of file paths. read_format (int): The read format of the frame reader. \"\"\" def __init__ ( self , root_folder : str , frame_files : list [ str ], read_format : int = cv . IMREAD_GRAYSCALE , ): assert os . path . exists ( root_folder ) assert len ( frame_files ) > 0 self . _root_folder = root_folder self . _files = frame_files self . _read_format = read_format self . _frame_shape = self . _extract_frame_shape () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: path = join_paths ( self . root_folder , self . files [ 0 ]) frame = cv . imread ( path , self . _read_format ) return frame . shape @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os . path . isfile ( join_paths ( root_folder , f ))] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os . path . isfile ( join_paths ( root_folder , f ))] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @property def root_folder ( self ) -> str : \"\"\" Returns the root folder path. Returns: str: The root folder path. \"\"\" return self . _root_folder @property def frame_shape ( self ) -> tuple [ int , ... ]: \"\"\" Returns the shape of the frame. Returns: tuple[int, ...]: The shape of the frame, in format (h, w, ...). \"\"\" return self . _frame_shape @property def frame_size ( self ) -> tuple [ int , int ]: \"\"\" Returns the size of the frame. Returns: tuple[int, int]: The shape of the frame, in format (h, w). \"\"\" return self . _frame_shape [: 2 ] @property def files ( self ) -> list [ str ]: \"\"\" Returns the list of files associated with the FrameReader object. Returns: list[str]: The list of file paths. \"\"\" return self . _files @property def read_format ( self ) -> int : \"\"\" Returns the read format of the frame reader. Returns: int: The read format. \"\"\" return self . _read_format def __len__ ( self ) -> int : return len ( self . _files ) def __getitem__ ( self , idx : int ) -> np . ndarray : if idx < 0 or idx >= len ( self . _files ): raise IndexError ( \"index out of bounds\" ) path = join_paths ( self . root_folder , self . files [ idx ]) frame = cv . imread ( path , self . _read_format ) return frame . astype ( np . uint8 , copy = False ) def __iter__ ( self ): return FrameStream ( self ) def make_stream ( self ): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream ( self ) class FrameStream : \"\"\" A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Args: frame_reader (FrameReader): The frame reader object. \"\"\" def __init__ ( self , frame_reader : FrameReader ): self . _frame_reader = frame_reader self . _idx = - 1 self . frame = None @property def index ( self ) -> int : \"\"\" The index of the current frame. \"\"\" return self . _idx def __len__ ( self ): return len ( self . _frame_reader ) def __iter__ ( self ): return self def __next__ ( self ) -> np . ndarray : self . progress () if not self . can_read (): raise StopIteration () frame = self . read () return frame def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read (): raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self . _idx ] return self . frame def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) def reset ( self ): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self . seek ( - 1 ) class DummyReader ( FrameReader ): \"\"\" A dummy frame reader that generates empty frames of a specified resolution. Args: num_frames (int): The number of frames to generate. resolution (tuple[int, int]): The resolution of the frames, in format (h, w). colored (bool, optional): Whether the frames are colored or grayscale. \"\"\" def __init__ ( self , num_frames : int , resolution : tuple [ int , int ], colored : bool = True ): self . colored = colored self . _resolution = resolution shape = ( * resolution , 3 ) if colored else resolution self . _frame = np . full ( shape , fill_value = 255 , dtype = np . uint8 ) frames = [ str ( i ) for i in range ( num_frames )] super () . __init__ ( \".\" , frame_files = frames ) def __getitem__ ( self , idx : int ) -> np . ndarray : return self . _frame . copy () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: if self . colored : return ( * self . _resolution , 3 ) return self . _resolution Classes DummyReader class DummyReader ( num_frames : 'int' , resolution : 'tuple[int, int]' , colored : 'bool' = True ) A dummy frame reader that generates empty frames of a specified resolution. Attributes Name Type Description Default num_frames int The number of frames to generate. None resolution tuple[int, int] The resolution of the frames, in format (h, w). None colored bool Whether the frames are colored or grayscale. None View Source class DummyReader ( FrameReader ): \"\"\" A dummy frame reader that generates empty frames of a specified resolution. Args: num_frames (int): The number of frames to generate. resolution (tuple[int, int]): The resolution of the frames, in format (h, w). colored (bool, optional): Whether the frames are colored or grayscale. \"\"\" def __init__ ( self , num_frames : int , resolution : tuple [ int , int ], colored : bool = True ): self . colored = colored self . _resolution = resolution shape = ( * resolution , 3 ) if colored else resolution self . _frame = np . full ( shape , fill_value = 255 , dtype = np . uint8 ) frames = [ str ( i ) for i in range ( num_frames )] super (). __init__ ( \".\" , frame_files = frames ) def __getitem__ ( self , idx : int ) -> np . ndarray : return self . _frame . copy () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: if self . colored : return ( * self . _resolution , 3 ) return self . _resolution Ancestors (in MRO) wtracker.utils.frame_reader.FrameReader Static methods create_from_directory def create_from_directory ( root_folder : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a directory. Parameters: Name Type Description Default root_folder str The root folder containing the frame files. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) create_from_template def create_from_template ( root_folder : 'str' , name_format : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a file name template. Parameters: Name Type Description Default root_folder str The root folder where the frame files are located. None name_format str The format of the frame file names. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) Instance variables files Returns the list of files associated with the FrameReader object. frame_shape Returns the shape of the frame. frame_size Returns the size of the frame. read_format Returns the read format of the frame reader. root_folder Returns the root folder path. Methods make_stream def make_stream ( self ) Creates and returns a FrameStream object using the current instance of FrameReader. Returns: Type Description FrameStream A FrameStream object. View Source def make_stream(self): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream(self) FrameReader class FrameReader ( root_folder : 'str' , frame_files : 'list[str]' , read_format : 'int' = 0 ) An class for reading frames from a directory of frame files. Attributes Name Type Description Default root_folder str The root folder path where the frame files are located. None frame_files list[str] A list of frame file names. None read_format int The format in which the frames should be read. None root_folder str The root folder path where the frame files are located. None frame_shape tuple[int, ...] The shape of the frame. None frame_size tuple[int, int] The size of the frame. None files list[str] The list of file paths. None read_format int The read format of the frame reader. None View Source class FrameReader : \"\"\" An class for reading frames from a directory of frame files. Args: root_folder (str): The root folder path where the frame files are located. frame_files (list[str]): A list of frame file names. read_format (int, optional): The format in which the frames should be read. Attributes: root_folder (str): The root folder path where the frame files are located. frame_shape (tuple[int, ...]): The shape of the frame. frame_size (tuple[int, int]): The size of the frame. files (list[str]): The list of file paths. read_format (int): The read format of the frame reader. \"\"\" def __init__ ( self , root_folder : str , frame_files : list [ str ] , read_format : int = cv . IMREAD_GRAYSCALE , ) : assert os . path . exists ( root_folder ) assert len ( frame_files ) > 0 self . _root_folder = root_folder self . _files = frame_files self . _read_format = read_format self . _frame_shape = self . _extract_frame_shape () def _extract_frame_shape ( self ) -> tuple [ int, ... ] : path = join_paths ( self . root_folder , self . files [ 0 ] ) frame = cv . imread ( path , self . _read_format ) return frame . shape @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @property def root_folder ( self ) -> str : \"\"\" Returns the root folder path. Returns: str: The root folder path. \"\"\" return self . _root_folder @property def frame_shape ( self ) -> tuple [ int, ... ] : \"\"\" Returns the shape of the frame. Returns: tuple[int, ...]: The shape of the frame, in format (h, w, ...). \"\"\" return self . _frame_shape @property def frame_size ( self ) -> tuple [ int, int ] : \"\"\" Returns the size of the frame. Returns: tuple[int, int]: The shape of the frame, in format (h, w). \"\"\" return self . _frame_shape [ :2 ] @property def files ( self ) -> list [ str ] : \"\"\" Returns the list of files associated with the FrameReader object. Returns: list[str]: The list of file paths. \"\"\" return self . _files @property def read_format ( self ) -> int : \"\"\" Returns the read format of the frame reader. Returns: int: The read format. \"\"\" return self . _read_format def __len__ ( self ) -> int : return len ( self . _files ) def __getitem__ ( self , idx : int ) -> np . ndarray : if idx < 0 or idx >= len ( self . _files ) : raise IndexError ( \"index out of bounds\" ) path = join_paths ( self . root_folder , self . files [ idx ] ) frame = cv . imread ( path , self . _read_format ) return frame . astype ( np . uint8 , copy = False ) def __iter__ ( self ) : return FrameStream ( self ) def make_stream ( self ) : \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream ( self ) Descendants wtracker.utils.frame_reader.DummyReader Static methods create_from_directory def create_from_directory ( root_folder : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a directory. Parameters: Name Type Description Default root_folder str The root folder containing the frame files. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) create_from_template def create_from_template ( root_folder : 'str' , name_format : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a file name template. Parameters: Name Type Description Default root_folder str The root folder where the frame files are located. None name_format str The format of the frame file names. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) Instance variables files Returns the list of files associated with the FrameReader object. frame_shape Returns the shape of the frame. frame_size Returns the size of the frame. read_format Returns the read format of the frame reader. root_folder Returns the root folder path. Methods make_stream def make_stream ( self ) Creates and returns a FrameStream object using the current instance of FrameReader. Returns: Type Description FrameStream A FrameStream object. View Source def make_stream(self): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream(self) FrameStream class FrameStream ( frame_reader : 'FrameReader' ) A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Attributes Name Type Description Default frame_reader FrameReader The frame reader object. None View Source class FrameStream : \"\"\" A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Args: frame_reader (FrameReader): The frame reader object. \"\"\" def __init__ ( self , frame_reader : FrameReader ) : self . _frame_reader = frame_reader self . _idx = - 1 self . frame = None @property def index ( self ) -> int : \"\"\" The index of the current frame. \"\"\" return self . _idx def __len__ ( self ) : return len ( self . _frame_reader ) def __iter__ ( self ) : return self def __next__ ( self ) -> np . ndarray : self . progress () if not self . can_read () : raise StopIteration () frame = self . read () return frame def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read () : raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self._idx ] return self . frame def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) def reset ( self ) : \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self . seek ( - 1 ) Descendants wtracker.sim.view_controller.ViewController Instance variables index The index of the current frame. Methods can_read def can_read ( self ) -> 'bool' View Source def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) progress def progress ( self , n : 'int' = 1 ) -> 'bool' Moves the current index forward by the specified number of steps. Parameters: Name Type Description Default n int The number of steps to move forward. None Returns: Type Description bool True if the index was successfully moved forward, False otherwise. View Source def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) read def read ( self ) -> 'np.ndarray' Read and return the frame at the current index. Returns: Type Description np.ndarray The frame at the current index. Raises: Type Description IndexError If the index is out of bounds. View Source def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read () : raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self . _idx ] return self . frame reset def reset ( self ) Resets the frame reader to the beginning of the steam. View Source def reset(self): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self.seek(-1) seek def seek ( self , idx : 'int' ) -> 'bool' Move the index to the specified position. Parameters: Name Type Description Default idx int The index to seek to. None Returns: Type Description bool True if the index is within the valid range, False otherwise. View Source def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read ()","title":"Frame Reader"},{"location":"reference/wtracker/utils/frame_reader/#module-wtrackerutilsframe_reader","text":"View Source from __future__ import annotations import os import glob import numpy as np import cv2 as cv from wtracker.utils.path_utils import join_paths class FrameReader : \"\"\" An class for reading frames from a directory of frame files. Args: root_folder (str): The root folder path where the frame files are located. frame_files (list[str]): A list of frame file names. read_format (int, optional): The format in which the frames should be read. Attributes: root_folder (str): The root folder path where the frame files are located. frame_shape (tuple[int, ...]): The shape of the frame. frame_size (tuple[int, int]): The size of the frame. files (list[str]): The list of file paths. read_format (int): The read format of the frame reader. \"\"\" def __init__ ( self , root_folder : str , frame_files : list [ str ], read_format : int = cv . IMREAD_GRAYSCALE , ): assert os . path . exists ( root_folder ) assert len ( frame_files ) > 0 self . _root_folder = root_folder self . _files = frame_files self . _read_format = read_format self . _frame_shape = self . _extract_frame_shape () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: path = join_paths ( self . root_folder , self . files [ 0 ]) frame = cv . imread ( path , self . _read_format ) return frame . shape @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os . path . isfile ( join_paths ( root_folder , f ))] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os . path . isfile ( join_paths ( root_folder , f ))] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @property def root_folder ( self ) -> str : \"\"\" Returns the root folder path. Returns: str: The root folder path. \"\"\" return self . _root_folder @property def frame_shape ( self ) -> tuple [ int , ... ]: \"\"\" Returns the shape of the frame. Returns: tuple[int, ...]: The shape of the frame, in format (h, w, ...). \"\"\" return self . _frame_shape @property def frame_size ( self ) -> tuple [ int , int ]: \"\"\" Returns the size of the frame. Returns: tuple[int, int]: The shape of the frame, in format (h, w). \"\"\" return self . _frame_shape [: 2 ] @property def files ( self ) -> list [ str ]: \"\"\" Returns the list of files associated with the FrameReader object. Returns: list[str]: The list of file paths. \"\"\" return self . _files @property def read_format ( self ) -> int : \"\"\" Returns the read format of the frame reader. Returns: int: The read format. \"\"\" return self . _read_format def __len__ ( self ) -> int : return len ( self . _files ) def __getitem__ ( self , idx : int ) -> np . ndarray : if idx < 0 or idx >= len ( self . _files ): raise IndexError ( \"index out of bounds\" ) path = join_paths ( self . root_folder , self . files [ idx ]) frame = cv . imread ( path , self . _read_format ) return frame . astype ( np . uint8 , copy = False ) def __iter__ ( self ): return FrameStream ( self ) def make_stream ( self ): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream ( self ) class FrameStream : \"\"\" A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Args: frame_reader (FrameReader): The frame reader object. \"\"\" def __init__ ( self , frame_reader : FrameReader ): self . _frame_reader = frame_reader self . _idx = - 1 self . frame = None @property def index ( self ) -> int : \"\"\" The index of the current frame. \"\"\" return self . _idx def __len__ ( self ): return len ( self . _frame_reader ) def __iter__ ( self ): return self def __next__ ( self ) -> np . ndarray : self . progress () if not self . can_read (): raise StopIteration () frame = self . read () return frame def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read (): raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self . _idx ] return self . frame def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) def reset ( self ): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self . seek ( - 1 ) class DummyReader ( FrameReader ): \"\"\" A dummy frame reader that generates empty frames of a specified resolution. Args: num_frames (int): The number of frames to generate. resolution (tuple[int, int]): The resolution of the frames, in format (h, w). colored (bool, optional): Whether the frames are colored or grayscale. \"\"\" def __init__ ( self , num_frames : int , resolution : tuple [ int , int ], colored : bool = True ): self . colored = colored self . _resolution = resolution shape = ( * resolution , 3 ) if colored else resolution self . _frame = np . full ( shape , fill_value = 255 , dtype = np . uint8 ) frames = [ str ( i ) for i in range ( num_frames )] super () . __init__ ( \".\" , frame_files = frames ) def __getitem__ ( self , idx : int ) -> np . ndarray : return self . _frame . copy () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: if self . colored : return ( * self . _resolution , 3 ) return self . _resolution","title":"Module wtracker.utils.frame_reader"},{"location":"reference/wtracker/utils/frame_reader/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/frame_reader/#dummyreader","text":"class DummyReader ( num_frames : 'int' , resolution : 'tuple[int, int]' , colored : 'bool' = True ) A dummy frame reader that generates empty frames of a specified resolution.","title":"DummyReader"},{"location":"reference/wtracker/utils/frame_reader/#attributes","text":"Name Type Description Default num_frames int The number of frames to generate. None resolution tuple[int, int] The resolution of the frames, in format (h, w). None colored bool Whether the frames are colored or grayscale. None View Source class DummyReader ( FrameReader ): \"\"\" A dummy frame reader that generates empty frames of a specified resolution. Args: num_frames (int): The number of frames to generate. resolution (tuple[int, int]): The resolution of the frames, in format (h, w). colored (bool, optional): Whether the frames are colored or grayscale. \"\"\" def __init__ ( self , num_frames : int , resolution : tuple [ int , int ], colored : bool = True ): self . colored = colored self . _resolution = resolution shape = ( * resolution , 3 ) if colored else resolution self . _frame = np . full ( shape , fill_value = 255 , dtype = np . uint8 ) frames = [ str ( i ) for i in range ( num_frames )] super (). __init__ ( \".\" , frame_files = frames ) def __getitem__ ( self , idx : int ) -> np . ndarray : return self . _frame . copy () def _extract_frame_shape ( self ) -> tuple [ int , ... ]: if self . colored : return ( * self . _resolution , 3 ) return self . _resolution","title":"Attributes"},{"location":"reference/wtracker/utils/frame_reader/#ancestors-in-mro","text":"wtracker.utils.frame_reader.FrameReader","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/frame_reader/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/frame_reader/#create_from_directory","text":"def create_from_directory ( root_folder : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a directory. Parameters: Name Type Description Default root_folder str The root folder containing the frame files. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format )","title":"create_from_directory"},{"location":"reference/wtracker/utils/frame_reader/#create_from_template","text":"def create_from_template ( root_folder : 'str' , name_format : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a file name template. Parameters: Name Type Description Default root_folder str The root folder where the frame files are located. None name_format str The format of the frame file names. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format )","title":"create_from_template"},{"location":"reference/wtracker/utils/frame_reader/#instance-variables","text":"files Returns the list of files associated with the FrameReader object. frame_shape Returns the shape of the frame. frame_size Returns the size of the frame. read_format Returns the read format of the frame reader. root_folder Returns the root folder path.","title":"Instance variables"},{"location":"reference/wtracker/utils/frame_reader/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/frame_reader/#make_stream","text":"def make_stream ( self ) Creates and returns a FrameStream object using the current instance of FrameReader. Returns: Type Description FrameStream A FrameStream object. View Source def make_stream(self): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream(self)","title":"make_stream"},{"location":"reference/wtracker/utils/frame_reader/#framereader","text":"class FrameReader ( root_folder : 'str' , frame_files : 'list[str]' , read_format : 'int' = 0 ) An class for reading frames from a directory of frame files.","title":"FrameReader"},{"location":"reference/wtracker/utils/frame_reader/#attributes_1","text":"Name Type Description Default root_folder str The root folder path where the frame files are located. None frame_files list[str] A list of frame file names. None read_format int The format in which the frames should be read. None root_folder str The root folder path where the frame files are located. None frame_shape tuple[int, ...] The shape of the frame. None frame_size tuple[int, int] The size of the frame. None files list[str] The list of file paths. None read_format int The read format of the frame reader. None View Source class FrameReader : \"\"\" An class for reading frames from a directory of frame files. Args: root_folder (str): The root folder path where the frame files are located. frame_files (list[str]): A list of frame file names. read_format (int, optional): The format in which the frames should be read. Attributes: root_folder (str): The root folder path where the frame files are located. frame_shape (tuple[int, ...]): The shape of the frame. frame_size (tuple[int, int]): The size of the frame. files (list[str]): The list of file paths. read_format (int): The read format of the frame reader. \"\"\" def __init__ ( self , root_folder : str , frame_files : list [ str ] , read_format : int = cv . IMREAD_GRAYSCALE , ) : assert os . path . exists ( root_folder ) assert len ( frame_files ) > 0 self . _root_folder = root_folder self . _files = frame_files self . _read_format = read_format self . _frame_shape = self . _extract_frame_shape () def _extract_frame_shape ( self ) -> tuple [ int, ... ] : path = join_paths ( self . root_folder , self . files [ 0 ] ) frame = cv . imread ( path , self . _read_format ) return frame . shape @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format ) @property def root_folder ( self ) -> str : \"\"\" Returns the root folder path. Returns: str: The root folder path. \"\"\" return self . _root_folder @property def frame_shape ( self ) -> tuple [ int, ... ] : \"\"\" Returns the shape of the frame. Returns: tuple[int, ...]: The shape of the frame, in format (h, w, ...). \"\"\" return self . _frame_shape @property def frame_size ( self ) -> tuple [ int, int ] : \"\"\" Returns the size of the frame. Returns: tuple[int, int]: The shape of the frame, in format (h, w). \"\"\" return self . _frame_shape [ :2 ] @property def files ( self ) -> list [ str ] : \"\"\" Returns the list of files associated with the FrameReader object. Returns: list[str]: The list of file paths. \"\"\" return self . _files @property def read_format ( self ) -> int : \"\"\" Returns the read format of the frame reader. Returns: int: The read format. \"\"\" return self . _read_format def __len__ ( self ) -> int : return len ( self . _files ) def __getitem__ ( self , idx : int ) -> np . ndarray : if idx < 0 or idx >= len ( self . _files ) : raise IndexError ( \"index out of bounds\" ) path = join_paths ( self . root_folder , self . files [ idx ] ) frame = cv . imread ( path , self . _read_format ) return frame . astype ( np . uint8 , copy = False ) def __iter__ ( self ) : return FrameStream ( self ) def make_stream ( self ) : \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream ( self )","title":"Attributes"},{"location":"reference/wtracker/utils/frame_reader/#descendants","text":"wtracker.utils.frame_reader.DummyReader","title":"Descendants"},{"location":"reference/wtracker/utils/frame_reader/#static-methods_1","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/frame_reader/#create_from_directory_1","text":"def create_from_directory ( root_folder : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a directory. Parameters: Name Type Description Default root_folder str The root folder containing the frame files. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_directory ( root_folder : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a directory. Args: root_folder (str): The root folder containing the frame files. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files in root frame_paths = glob . glob ( \"*.*\" , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format )","title":"create_from_directory"},{"location":"reference/wtracker/utils/frame_reader/#create_from_template_1","text":"def create_from_template ( root_folder : 'str' , name_format : 'str' , read_format : 'int' = 0 ) -> 'FrameReader' Creates a FrameReader object from a file name template. Parameters: Name Type Description Default root_folder str The root folder where the frame files are located. None name_format str The format of the frame file names. None read_format int The format in which the frames should be read. None Returns: Type Description FrameReader The created FrameReader object. View Source @staticmethod def create_from_template ( root_folder : str , name_format : str , read_format : int = cv . IMREAD_GRAYSCALE ) -> FrameReader : \"\"\" Creates a FrameReader object from a file name template. Args: root_folder (str): The root folder where the frame files are located. name_format (str): The format of the frame file names. read_format (int, optional): The format in which the frames should be read. Returns: FrameReader: The created FrameReader object. \"\"\" # get all files matching name format fmt = name_format . format ( \"[0-9]*\" ) frame_paths = glob . glob ( fmt , root_dir = root_folder ) frame_paths = [ f for f in frame_paths if os.path.isfile(join_paths(root_folder, f)) ] frame_paths = sorted ( frame_paths ) return FrameReader ( root_folder , frame_paths , read_format )","title":"create_from_template"},{"location":"reference/wtracker/utils/frame_reader/#instance-variables_1","text":"files Returns the list of files associated with the FrameReader object. frame_shape Returns the shape of the frame. frame_size Returns the size of the frame. read_format Returns the read format of the frame reader. root_folder Returns the root folder path.","title":"Instance variables"},{"location":"reference/wtracker/utils/frame_reader/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/utils/frame_reader/#make_stream_1","text":"def make_stream ( self ) Creates and returns a FrameStream object using the current instance of FrameReader. Returns: Type Description FrameStream A FrameStream object. View Source def make_stream(self): \"\"\" Creates and returns a FrameStream object using the current instance of FrameReader. Returns: FrameStream: A FrameStream object. \"\"\" return FrameStream(self)","title":"make_stream"},{"location":"reference/wtracker/utils/frame_reader/#framestream","text":"class FrameStream ( frame_reader : 'FrameReader' ) A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object.","title":"FrameStream"},{"location":"reference/wtracker/utils/frame_reader/#attributes_2","text":"Name Type Description Default frame_reader FrameReader The frame reader object. None View Source class FrameStream : \"\"\" A class for streaming frames from a FrameReader object. This class serves as an iterator for the FrameReader object. Args: frame_reader (FrameReader): The frame reader object. \"\"\" def __init__ ( self , frame_reader : FrameReader ) : self . _frame_reader = frame_reader self . _idx = - 1 self . frame = None @property def index ( self ) -> int : \"\"\" The index of the current frame. \"\"\" return self . _idx def __len__ ( self ) : return len ( self . _frame_reader ) def __iter__ ( self ) : return self def __next__ ( self ) -> np . ndarray : self . progress () if not self . can_read () : raise StopIteration () frame = self . read () return frame def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader ) def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read () def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read () : raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self._idx ] return self . frame def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n ) def reset ( self ) : \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self . seek ( - 1 )","title":"Attributes"},{"location":"reference/wtracker/utils/frame_reader/#descendants_1","text":"wtracker.sim.view_controller.ViewController","title":"Descendants"},{"location":"reference/wtracker/utils/frame_reader/#instance-variables_2","text":"index The index of the current frame.","title":"Instance variables"},{"location":"reference/wtracker/utils/frame_reader/#methods_2","text":"","title":"Methods"},{"location":"reference/wtracker/utils/frame_reader/#can_read","text":"def can_read ( self ) -> 'bool' View Source def can_read ( self ) -> bool : return self . _idx >= 0 and self . _idx < len ( self . _frame_reader )","title":"can_read"},{"location":"reference/wtracker/utils/frame_reader/#progress","text":"def progress ( self , n : 'int' = 1 ) -> 'bool' Moves the current index forward by the specified number of steps. Parameters: Name Type Description Default n int The number of steps to move forward. None Returns: Type Description bool True if the index was successfully moved forward, False otherwise. View Source def progress ( self , n : int = 1 ) -> bool : \"\"\" Moves the current index forward by the specified number of steps. Args: n (int): The number of steps to move forward. Returns: bool: True if the index was successfully moved forward, False otherwise. \"\"\" return self . seek ( self . _idx + n )","title":"progress"},{"location":"reference/wtracker/utils/frame_reader/#read","text":"def read ( self ) -> 'np.ndarray' Read and return the frame at the current index. Returns: Type Description np.ndarray The frame at the current index. Raises: Type Description IndexError If the index is out of bounds. View Source def read ( self ) -> np . ndarray : \"\"\" Read and return the frame at the current index. Raises: IndexError: If the index is out of bounds. Returns: np.ndarray: The frame at the current index. \"\"\" if not self . can_read () : raise IndexError ( \"index out of bounds\" ) if self . frame is None : self . frame = self . _frame_reader [ self . _idx ] return self . frame","title":"read"},{"location":"reference/wtracker/utils/frame_reader/#reset","text":"def reset ( self ) Resets the frame reader to the beginning of the steam. View Source def reset(self): \"\"\" Resets the frame reader to the beginning of the steam. \"\"\" self.seek(-1)","title":"reset"},{"location":"reference/wtracker/utils/frame_reader/#seek","text":"def seek ( self , idx : 'int' ) -> 'bool' Move the index to the specified position. Parameters: Name Type Description Default idx int The index to seek to. None Returns: Type Description bool True if the index is within the valid range, False otherwise. View Source def seek ( self , idx : int ) -> bool : \"\"\" Move the index to the specified position. Args: idx (int): The index to seek to. Returns: bool: True if the index is within the valid range, False otherwise. \"\"\" self . _idx = idx self . frame = None return self . can_read ()","title":"seek"},{"location":"reference/wtracker/utils/gui_utils/","text":"Module wtracker.utils.gui_utils View Source import tkinter as tk from tkinter import filedialog class FocusedWindow : def __init__ ( self ): root = tk . Tk () self . root = root self . hide () def __enter__ ( self ) -> tk . Tk : return self . focus () def __exit__ ( self , exc_type , exc_val , exc_tb ): self . hide () def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root def close ( self ): self . root . destroy () def __del__ ( self ): self . close () class UserPrompt : \"\"\"Class for creating a user prompt dialogs.\"\"\" @staticmethod def open_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ]: \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) @staticmethod def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , ) @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , ) Classes FocusedWindow class FocusedWindow ( ) View Source class FocusedWindow : def __init__ ( self ): root = tk . Tk () self . root = root self . hide () def __enter__ ( self ) -> tk . Tk : return self . focus () def __exit__ ( self , exc_type , exc_val , exc_tb ): self . hide () def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root def close ( self ): self . root . destroy () def __del__ ( self ): self . close () Methods close def close ( self ) View Source def close(self): self.root.destroy() focus def focus ( self ) -> tkinter . Tk View Source def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root hide def hide ( self ) -> tkinter . Tk View Source def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root UserPrompt class UserPrompt ( / , * args , ** kwargs ) Class for creating a user prompt dialogs. View Source class UserPrompt : \"\"\"Class for creating a user prompt dialogs.\"\"\" @staticmethod def open_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ] : \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) @staticmethod def save_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , ) @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , ) Static methods open_directory def open_directory ( title : str = None , ** kwargs ) -> str Opens a dialog box to select a directory. Parameters: Name Type Description Default title str The title of the dialog box. None **kwargs None Additional keyword arguments to be passed to the filedialog.askdirectory function. None Returns: Type Description str The path of the selected directory. View Source @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , ) open_file def open_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , multiple : bool = False , ** kwargs ) -> str | list [ str ] Opens a file dialog to select one or multiple files. Parameters: Name Type Description Default title str The title of the file dialog window. None file_types list[tuple[str, str]] A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). None multiple bool Whether to allow multiple file selection. None **kwargs None Additional keyword arguments to be passed to the file dialog. None Returns: Type Description str list[str] View Source @staticmethod def open_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ] : \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) save_file def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str Opens a file dialog to save a file and returns the selected file path. Parameters: Name Type Description Default title str The title of the file dialog window. None file_types list[tuple[str, str]] A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). None **kwargs None Additional keyword arguments to be passed to the file dialog. None Returns: Type Description str The selected file path. View Source @ staticmethod def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , )","title":"Gui Utils"},{"location":"reference/wtracker/utils/gui_utils/#module-wtrackerutilsgui_utils","text":"View Source import tkinter as tk from tkinter import filedialog class FocusedWindow : def __init__ ( self ): root = tk . Tk () self . root = root self . hide () def __enter__ ( self ) -> tk . Tk : return self . focus () def __exit__ ( self , exc_type , exc_val , exc_tb ): self . hide () def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root def close ( self ): self . root . destroy () def __del__ ( self ): self . close () class UserPrompt : \"\"\"Class for creating a user prompt dialogs.\"\"\" @staticmethod def open_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ]: \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) @staticmethod def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , ) @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , )","title":"Module wtracker.utils.gui_utils"},{"location":"reference/wtracker/utils/gui_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/gui_utils/#focusedwindow","text":"class FocusedWindow ( ) View Source class FocusedWindow : def __init__ ( self ): root = tk . Tk () self . root = root self . hide () def __enter__ ( self ) -> tk . Tk : return self . focus () def __exit__ ( self , exc_type , exc_val , exc_tb ): self . hide () def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root def close ( self ): self . root . destroy () def __del__ ( self ): self . close ()","title":"FocusedWindow"},{"location":"reference/wtracker/utils/gui_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/gui_utils/#close","text":"def close ( self ) View Source def close(self): self.root.destroy()","title":"close"},{"location":"reference/wtracker/utils/gui_utils/#focus","text":"def focus ( self ) -> tkinter . Tk View Source def focus ( self ) -> tk . Tk : root = self . root root . eval ( \"tk::PlaceWindow %s center\" % root . winfo_pathname ( root . winfo_id ())) root . deiconify () root . lift () root . attributes ( \"-topmost\" , True ) root . focus_force () root . update () root . after_idle ( root . attributes , \"-topmost\" , False ) return root","title":"focus"},{"location":"reference/wtracker/utils/gui_utils/#hide","text":"def hide ( self ) -> tkinter . Tk View Source def hide ( self ) -> tk . Tk : root = self . root root . withdraw () root . overrideredirect ( True ) root . geometry ( \"0x0+0+0\" ) root . update () return root","title":"hide"},{"location":"reference/wtracker/utils/gui_utils/#userprompt","text":"class UserPrompt ( / , * args , ** kwargs ) Class for creating a user prompt dialogs. View Source class UserPrompt : \"\"\"Class for creating a user prompt dialogs.\"\"\" @staticmethod def open_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ] : \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , ) @staticmethod def save_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , ) @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , )","title":"UserPrompt"},{"location":"reference/wtracker/utils/gui_utils/#static-methods","text":"","title":"Static methods"},{"location":"reference/wtracker/utils/gui_utils/#open_directory","text":"def open_directory ( title : str = None , ** kwargs ) -> str Opens a dialog box to select a directory. Parameters: Name Type Description Default title str The title of the dialog box. None **kwargs None Additional keyword arguments to be passed to the filedialog.askdirectory function. None Returns: Type Description str The path of the selected directory. View Source @staticmethod def open_directory ( title : str = None , ** kwargs ) -> str : \"\"\" Opens a dialog box to select a directory. Args: title (str, optional): The title of the dialog box. **kwargs: Additional keyword arguments to be passed to the filedialog.askdirectory function. Returns: str: The path of the selected directory. \"\"\" with FocusedWindow () as root : return filedialog . askdirectory ( parent = root , title = title , mustexist = False , ** kwargs , )","title":"open_directory"},{"location":"reference/wtracker/utils/gui_utils/#open_file","text":"def open_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , multiple : bool = False , ** kwargs ) -> str | list [ str ] Opens a file dialog to select one or multiple files. Parameters: Name Type Description Default title str The title of the file dialog window. None file_types list[tuple[str, str]] A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). None multiple bool Whether to allow multiple file selection. None **kwargs None Additional keyword arguments to be passed to the file dialog. None Returns: Type Description str list[str] View Source @staticmethod def open_file ( title : str = None , file_types : list [ tuple[str, str ] ] = None , multiple : bool = False , ** kwargs , ) -> str | list [ str ] : \"\"\" Opens a file dialog to select one or multiple files. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). multiple (bool, optional): Whether to allow multiple file selection. **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str | list[str]: The path of the selected file(s). If multiple is True, a list of paths is returned. Otherwise, a single path is returned. \"\"\" if file_types is None : file_types = [] file_types += [ (\"all files\", \"*.*\") ] with FocusedWindow () as root : if multiple : path = filedialog . askopenfilenames ( parent = root , title = title , filetypes = file_types , ** kwargs , ) return list ( path ) else : return filedialog . askopenfilename ( parent = root , title = title , filetypes = file_types , ** kwargs , )","title":"open_file"},{"location":"reference/wtracker/utils/gui_utils/#save_file","text":"def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str Opens a file dialog to save a file and returns the selected file path. Parameters: Name Type Description Default title str The title of the file dialog window. None file_types list[tuple[str, str]] A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). None **kwargs None Additional keyword arguments to be passed to the file dialog. None Returns: Type Description str The selected file path. View Source @ staticmethod def save_file ( title : str = None , file_types : list [ tuple [ str , str ]] = None , ** kwargs ) -> str : \"\"\" Opens a file dialog to save a file and returns the selected file path. Args: title (str, optional): The title of the file dialog window. file_types (list[tuple[str, str]], optional): A list of file types to filter the displayed files. Each file type is represented as a tuple of the form (description, pattern). **kwargs: Additional keyword arguments to be passed to the file dialog. Returns: str: The selected file path. \"\"\" if file_types is None : file_types = [] file_types += [( \"all files\" , \"*.*\" )] with FocusedWindow () as root : return filedialog . asksaveasfilename ( parent = root , title = title , filetypes = file_types , confirmoverwrite = True , ** kwargs , )","title":"save_file"},{"location":"reference/wtracker/utils/io_utils/","text":"Module wtracker.utils.io_utils View Source import cv2 as cv import numpy as np import pickle import math from wtracker.utils.path_utils import join_paths , create_directory , create_parent_directory from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import TaskScheduler class FrameSaver ( TaskScheduler ): \"\"\" A class for saving images from a frame reader to a specified folder. This class utilizes a queue to save images in a separate thread, which allows for non-blocking image saving. Args: frame_reader (FrameReader): The frame reader object from which images will be saved. root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , frame_reader : FrameReader , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ): super () . __init__ ( self . _save_frame , maxsize , tqdm , ** tqdm_kwargs ) self . _frame_reader = frame_reader self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ): \"\"\" Adds an image to the queue for saving. Args: img_index (int): The index of the image in the frame reader. crop_dims (tuple[float, float, float, float]): The crop dimensions (x, y, w, h) for the image. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super () . schedule_save ( img_index , crop_dims , img_name ) def _save_frame ( self , params : tuple [ int , tuple [ float , float , float , float ], str ]): img_index , crop_dims , img_name = params save_path = join_paths ( self . _root_path , img_name ) img = self . _frame_reader [ img_index ] x , y , w , h = crop_dims img = img [ y : y + h , x : x + w ] success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image { save_path } \" ) class ImageSaver ( TaskScheduler ): \"\"\" A class for saving images asynchronously using a task scheduler. Args: root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ): super () . __init__ ( self . _save_image , maxsize , tqdm , ** tqdm_kwargs ) self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img : np . ndarray , img_path : str ): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super () . schedule_save ( img , img_path ) def _save_image ( self , params : tuple [ np . ndarray , str ]): img , img_name = params save_path = join_paths ( self . _root_path , img_name ) success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image { save_path } \" ) def pickle_load_object ( file_path : str ): \"\"\" Load an object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: The loaded object. Raises: FileNotFoundError: If the file does not exist. ValueError: If there is an error loading the object from the pickle file. \"\"\" try : with open ( file_path , \"rb\" ) as f : return pickle . load ( f ) except FileNotFoundError : raise FileNotFoundError ( f \"file does not exist: { file_path } \" ) except Exception as e : raise ValueError ( f \"error loading object from pickle file: { e } \" ) def pickle_save_object ( obj , file_path : str ): \"\"\" Save an object to a pickle file. Args: obj: The object to be saved. file_path (str): The path to the pickle file. Raises: ValueError: If there is an error saving the object to the pickle file. \"\"\" try : create_parent_directory ( file_path ) with open ( file_path , \"wb\" ) as f : pickle . dump ( obj , f , protocol = pickle . HIGHEST_PROTOCOL ) except Exception as e : raise ValueError ( f \"error saving object to pickle file: { e } \" ) Functions pickle_load_object def pickle_load_object ( file_path : str ) Load an object from a pickle file. Parameters: Name Type Description Default file_path str The path to the pickle file. None Returns: Type Description None The loaded object. Raises: Type Description FileNotFoundError If the file does not exist. ValueError If there is an error loading the object from the pickle file. View Source def pickle_load_object ( file_path : str ): \"\"\" Load an object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: The loaded object. Raises: FileNotFoundError: If the file does not exist. ValueError: If there is an error loading the object from the pickle file. \"\"\" try : with open ( file_path , \"rb\" ) as f : return pickle . load ( f ) except FileNotFoundError : raise FileNotFoundError ( f \"file does not exist: {file_path}\" ) except Exception as e : raise ValueError ( f \"error loading object from pickle file: {e}\" ) pickle_save_object def pickle_save_object ( obj , file_path : str ) Save an object to a pickle file. Parameters: Name Type Description Default obj None The object to be saved. None file_path str The path to the pickle file. None Raises: Type Description ValueError If there is an error saving the object to the pickle file. View Source def pickle_save_object(obj, file_path: str): \"\"\" Save an object to a pickle file. Args: obj: The object to be saved. file_path (str): The path to the pickle file. Raises: ValueError: If there is an error saving the object to the pickle file. \"\"\" try: create_parent_directory(file_path) with open(file_path, \"wb\") as f: pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) except Exception as e: raise ValueError(f\"error saving object to pickle file: {e}\") Classes FrameSaver class FrameSaver ( frame_reader : wtracker . utils . frame_reader . FrameReader , root_path : str = '' , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs ) A class for saving images from a frame reader to a specified folder. This class utilizes a queue to save images in a separate thread, which allows for non-blocking image saving. Attributes Name Type Description Default frame_reader FrameReader The frame reader object from which images will be saved. None root_path str The root folder path, relative to which all other paths are. None maxsize int The maximum size of the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments for tqdm. None View Source class FrameSaver ( TaskScheduler ) : \"\"\" A class for saving images from a frame reader to a specified folder . This class utilizes a queue to save images in a separate thread , which allows for non - blocking image saving . Args : frame_reader ( FrameReader ) : The frame reader object from which images will be saved . root_path ( str ) : The root folder path , relative to which all other paths are . maxsize ( int , optional ) : The maximum size of the queue . tqdm ( bool , optional ) : Whether to use tqdm for progress tracking . ** tqdm_kwargs : Additional keyword arguments for tqdm . \"\"\" def __init__ ( self , frame_reader : FrameReader , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ) : super (). __init__ ( self . _save_frame , maxsize , tqdm , ** tqdm_kwargs ) self . _frame_reader = frame_reader self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ) : \"\"\" Adds an image to the queue for saving . Args : img_index ( int ) : The index of the image in the frame reader . crop_dims ( tuple [ float , float , float , float ]) : The crop dimensions ( x , y , w , h ) for the image . img_name ( str ) : The name ( path ) of the image file relative to the root path . \"\"\" super (). schedule_save ( img_index , crop_dims , img_name ) def _save_frame ( self , params : tuple [ int , tuple [ float , float , float , float ], str ]) : img_index , crop_dims , img_name = params save_path = join_paths ( self . _root_path , img_name ) img = self . _frame_reader [ img_index ] x , y , w , h = crop_dims img = img [ y : y + h , x : x + w ] success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ) : raise ValueError ( f \"Failed to save image {save_path}\" ) Ancestors (in MRO) wtracker.utils.threading_utils.TaskScheduler Methods close def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join() schedule_save def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ) Adds an image to the queue for saving. Parameters: Name Type Description Default img_index int The index of the image in the frame reader. None crop_dims tuple[float, float, float, float] The crop dimensions (x, y, w, h) for the image. None img_name str The name (path) of the image file relative to the root path. None View Source def schedule_save(self, img_index: int, crop_dims: tuple[float, float, float, float], img_name: str): \"\"\" Adds an image to the queue for saving. Args: img_index (int): The index of the image in the frame reader. crop_dims (tuple[float, float, float, float]): The crop dimensions (x, y, w, h) for the image. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super().schedule_save(img_index, crop_dims, img_name) start def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start() ImageSaver class ImageSaver ( root_path : str = '' , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs ) A class for saving images asynchronously using a task scheduler. Attributes Name Type Description Default root_path str The root folder path, relative to which all other paths are. None maxsize int The maximum size of the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments for tqdm. None View Source class ImageSaver ( TaskScheduler ): \"\"\" A class for saving images asynchronously using a task scheduler. Args: root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , root_path: str = \"\" , maxsize: int = 100 , tqdm: bool = True , ** tqdm_kwargs , ): super (). __init__ ( self . _save_image , maxsize , tqdm , ** tqdm_kwargs ) self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img: np . ndarray , img_path: str ): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super (). schedule_save ( img , img_path ) def _save_image ( self , params: tuple [ np . ndarray , str ]): img , img_name = params save_path = join_paths ( self . _root_path , img_name ) success = cv . imwrite ( save_path , img ) if not success: create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image {save_path}\" ) Ancestors (in MRO) wtracker.utils.threading_utils.TaskScheduler Methods close def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join() schedule_save def schedule_save ( self , img : numpy . ndarray , img_path : str ) Adds an image to the queue for saving. Parameters: Name Type Description Default img np.ndarray The image to save. None img_name str The name (path) of the image file relative to the root path. None View Source def schedule_save(self, img: np.ndarray, img_path: str): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super().schedule_save(img, img_path) start def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start()","title":"Io Utils"},{"location":"reference/wtracker/utils/io_utils/#module-wtrackerutilsio_utils","text":"View Source import cv2 as cv import numpy as np import pickle import math from wtracker.utils.path_utils import join_paths , create_directory , create_parent_directory from wtracker.utils.frame_reader import FrameReader from wtracker.utils.threading_utils import TaskScheduler class FrameSaver ( TaskScheduler ): \"\"\" A class for saving images from a frame reader to a specified folder. This class utilizes a queue to save images in a separate thread, which allows for non-blocking image saving. Args: frame_reader (FrameReader): The frame reader object from which images will be saved. root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , frame_reader : FrameReader , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ): super () . __init__ ( self . _save_frame , maxsize , tqdm , ** tqdm_kwargs ) self . _frame_reader = frame_reader self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ): \"\"\" Adds an image to the queue for saving. Args: img_index (int): The index of the image in the frame reader. crop_dims (tuple[float, float, float, float]): The crop dimensions (x, y, w, h) for the image. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super () . schedule_save ( img_index , crop_dims , img_name ) def _save_frame ( self , params : tuple [ int , tuple [ float , float , float , float ], str ]): img_index , crop_dims , img_name = params save_path = join_paths ( self . _root_path , img_name ) img = self . _frame_reader [ img_index ] x , y , w , h = crop_dims img = img [ y : y + h , x : x + w ] success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image { save_path } \" ) class ImageSaver ( TaskScheduler ): \"\"\" A class for saving images asynchronously using a task scheduler. Args: root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ): super () . __init__ ( self . _save_image , maxsize , tqdm , ** tqdm_kwargs ) self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img : np . ndarray , img_path : str ): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super () . schedule_save ( img , img_path ) def _save_image ( self , params : tuple [ np . ndarray , str ]): img , img_name = params save_path = join_paths ( self . _root_path , img_name ) success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image { save_path } \" ) def pickle_load_object ( file_path : str ): \"\"\" Load an object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: The loaded object. Raises: FileNotFoundError: If the file does not exist. ValueError: If there is an error loading the object from the pickle file. \"\"\" try : with open ( file_path , \"rb\" ) as f : return pickle . load ( f ) except FileNotFoundError : raise FileNotFoundError ( f \"file does not exist: { file_path } \" ) except Exception as e : raise ValueError ( f \"error loading object from pickle file: { e } \" ) def pickle_save_object ( obj , file_path : str ): \"\"\" Save an object to a pickle file. Args: obj: The object to be saved. file_path (str): The path to the pickle file. Raises: ValueError: If there is an error saving the object to the pickle file. \"\"\" try : create_parent_directory ( file_path ) with open ( file_path , \"wb\" ) as f : pickle . dump ( obj , f , protocol = pickle . HIGHEST_PROTOCOL ) except Exception as e : raise ValueError ( f \"error saving object to pickle file: { e } \" )","title":"Module wtracker.utils.io_utils"},{"location":"reference/wtracker/utils/io_utils/#functions","text":"","title":"Functions"},{"location":"reference/wtracker/utils/io_utils/#pickle_load_object","text":"def pickle_load_object ( file_path : str ) Load an object from a pickle file. Parameters: Name Type Description Default file_path str The path to the pickle file. None Returns: Type Description None The loaded object. Raises: Type Description FileNotFoundError If the file does not exist. ValueError If there is an error loading the object from the pickle file. View Source def pickle_load_object ( file_path : str ): \"\"\" Load an object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: The loaded object. Raises: FileNotFoundError: If the file does not exist. ValueError: If there is an error loading the object from the pickle file. \"\"\" try : with open ( file_path , \"rb\" ) as f : return pickle . load ( f ) except FileNotFoundError : raise FileNotFoundError ( f \"file does not exist: {file_path}\" ) except Exception as e : raise ValueError ( f \"error loading object from pickle file: {e}\" )","title":"pickle_load_object"},{"location":"reference/wtracker/utils/io_utils/#pickle_save_object","text":"def pickle_save_object ( obj , file_path : str ) Save an object to a pickle file. Parameters: Name Type Description Default obj None The object to be saved. None file_path str The path to the pickle file. None Raises: Type Description ValueError If there is an error saving the object to the pickle file. View Source def pickle_save_object(obj, file_path: str): \"\"\" Save an object to a pickle file. Args: obj: The object to be saved. file_path (str): The path to the pickle file. Raises: ValueError: If there is an error saving the object to the pickle file. \"\"\" try: create_parent_directory(file_path) with open(file_path, \"wb\") as f: pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) except Exception as e: raise ValueError(f\"error saving object to pickle file: {e}\")","title":"pickle_save_object"},{"location":"reference/wtracker/utils/io_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/io_utils/#framesaver","text":"class FrameSaver ( frame_reader : wtracker . utils . frame_reader . FrameReader , root_path : str = '' , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs ) A class for saving images from a frame reader to a specified folder. This class utilizes a queue to save images in a separate thread, which allows for non-blocking image saving.","title":"FrameSaver"},{"location":"reference/wtracker/utils/io_utils/#attributes","text":"Name Type Description Default frame_reader FrameReader The frame reader object from which images will be saved. None root_path str The root folder path, relative to which all other paths are. None maxsize int The maximum size of the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments for tqdm. None View Source class FrameSaver ( TaskScheduler ) : \"\"\" A class for saving images from a frame reader to a specified folder . This class utilizes a queue to save images in a separate thread , which allows for non - blocking image saving . Args : frame_reader ( FrameReader ) : The frame reader object from which images will be saved . root_path ( str ) : The root folder path , relative to which all other paths are . maxsize ( int , optional ) : The maximum size of the queue . tqdm ( bool , optional ) : Whether to use tqdm for progress tracking . ** tqdm_kwargs : Additional keyword arguments for tqdm . \"\"\" def __init__ ( self , frame_reader : FrameReader , root_path : str = \"\" , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs , ) : super (). __init__ ( self . _save_frame , maxsize , tqdm , ** tqdm_kwargs ) self . _frame_reader = frame_reader self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ) : \"\"\" Adds an image to the queue for saving . Args : img_index ( int ) : The index of the image in the frame reader . crop_dims ( tuple [ float , float , float , float ]) : The crop dimensions ( x , y , w , h ) for the image . img_name ( str ) : The name ( path ) of the image file relative to the root path . \"\"\" super (). schedule_save ( img_index , crop_dims , img_name ) def _save_frame ( self , params : tuple [ int , tuple [ float , float , float , float ], str ]) : img_index , crop_dims , img_name = params save_path = join_paths ( self . _root_path , img_name ) img = self . _frame_reader [ img_index ] x , y , w , h = crop_dims img = img [ y : y + h , x : x + w ] success = cv . imwrite ( save_path , img ) if not success : create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ) : raise ValueError ( f \"Failed to save image {save_path}\" )","title":"Attributes"},{"location":"reference/wtracker/utils/io_utils/#ancestors-in-mro","text":"wtracker.utils.threading_utils.TaskScheduler","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/io_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/io_utils/#close","text":"def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join()","title":"close"},{"location":"reference/wtracker/utils/io_utils/#schedule_save","text":"def schedule_save ( self , img_index : int , crop_dims : tuple [ float , float , float , float ], img_name : str ) Adds an image to the queue for saving. Parameters: Name Type Description Default img_index int The index of the image in the frame reader. None crop_dims tuple[float, float, float, float] The crop dimensions (x, y, w, h) for the image. None img_name str The name (path) of the image file relative to the root path. None View Source def schedule_save(self, img_index: int, crop_dims: tuple[float, float, float, float], img_name: str): \"\"\" Adds an image to the queue for saving. Args: img_index (int): The index of the image in the frame reader. crop_dims (tuple[float, float, float, float]): The crop dimensions (x, y, w, h) for the image. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super().schedule_save(img_index, crop_dims, img_name)","title":"schedule_save"},{"location":"reference/wtracker/utils/io_utils/#start","text":"def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start()","title":"start"},{"location":"reference/wtracker/utils/io_utils/#imagesaver","text":"class ImageSaver ( root_path : str = '' , maxsize : int = 100 , tqdm : bool = True , ** tqdm_kwargs ) A class for saving images asynchronously using a task scheduler.","title":"ImageSaver"},{"location":"reference/wtracker/utils/io_utils/#attributes_1","text":"Name Type Description Default root_path str The root folder path, relative to which all other paths are. None maxsize int The maximum size of the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments for tqdm. None View Source class ImageSaver ( TaskScheduler ): \"\"\" A class for saving images asynchronously using a task scheduler. Args: root_path (str): The root folder path, relative to which all other paths are. maxsize (int, optional): The maximum size of the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments for tqdm. \"\"\" def __init__ ( self , root_path: str = \"\" , maxsize: int = 100 , tqdm: bool = True , ** tqdm_kwargs , ): super (). __init__ ( self . _save_image , maxsize , tqdm , ** tqdm_kwargs ) self . _root_path = root_path create_directory ( root_path ) def schedule_save ( self , img: np . ndarray , img_path: str ): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super (). schedule_save ( img , img_path ) def _save_image ( self , params: tuple [ np . ndarray , str ]): img , img_name = params save_path = join_paths ( self . _root_path , img_name ) success = cv . imwrite ( save_path , img ) if not success: create_parent_directory ( save_path ) if not cv . imwrite ( save_path , img ): raise ValueError ( f \"Failed to save image {save_path}\" )","title":"Attributes"},{"location":"reference/wtracker/utils/io_utils/#ancestors-in-mro_1","text":"wtracker.utils.threading_utils.TaskScheduler","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/io_utils/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/utils/io_utils/#close_1","text":"def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join()","title":"close"},{"location":"reference/wtracker/utils/io_utils/#schedule_save_1","text":"def schedule_save ( self , img : numpy . ndarray , img_path : str ) Adds an image to the queue for saving. Parameters: Name Type Description Default img np.ndarray The image to save. None img_name str The name (path) of the image file relative to the root path. None View Source def schedule_save(self, img: np.ndarray, img_path: str): \"\"\" Adds an image to the queue for saving. Args: img (np.ndarray): The image to save. img_name (str): The name (path) of the image file relative to the root path. \"\"\" super().schedule_save(img, img_path)","title":"schedule_save"},{"location":"reference/wtracker/utils/io_utils/#start_1","text":"def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start()","title":"start"},{"location":"reference/wtracker/utils/log_utils/","text":"Module wtracker.utils.log_utils View Source import csv from typing import Iterable class CSVLogger : \"\"\" A class for logging data to a CSV file. Args: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. mode (str, optional): The file mode to open the CSV file in. Attributes: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. \"\"\" def __init__ ( self , path : str , col_names : list [ str ], mode : str = \"w+\" ): self . path = path self . col_names = col_names self . _file = open ( self . path , mode , newline = \"\" ) self . _writer = csv . DictWriter ( self . _file , self . col_names , escapechar = \",\" ) self . _writer . writeheader () self . flush () def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def close ( self ): \"\"\" Closes the CSV file. \"\"\" if not self . _file . closed : self . _file . flush () self . _file . close () def _to_dict ( self , items : Iterable ) -> dict : \"\"\" Converts an iterable of items to a dictionary using the column names as keys. Args: items (Iterable): The items to convert to a dictionary. Returns: dict: The dictionary with column names as keys and items as values. \"\"\" return { k : v for k , v in zip ( self . col_names , items )} def write ( self , row : dict | Iterable ): \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () if not isinstance ( row , dict ): row = self . _to_dict ( row ) self . _writer . writerow ( row ) def writerows ( self , rows : list [ dict ] | list [ Iterable ]): \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ], dict ): rows = [ self . _to_dict ( row ) for row in rows ] self . _writer . writerows ( rows ) def flush ( self ): \"\"\" Flushes any buffered data to the CSV file. \"\"\" self . _file . flush () Classes CSVLogger class CSVLogger ( path : str , col_names : list [ str ], mode : str = 'w+' ) A class for logging data to a CSV file. Attributes Name Type Description Default path str The path to the CSV file. None col_names list[str] The column names for the CSV file. None mode str The file mode to open the CSV file in. None path str The path to the CSV file. None col_names list[str] The column names for the CSV file. None View Source class CSVLogger : \"\"\" A class for logging data to a CSV file. Args: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. mode (str, optional): The file mode to open the CSV file in. Attributes: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. \"\"\" def __init__ ( self , path : str , col_names : list [ str ] , mode : str = \"w+\" ) : self . path = path self . col_names = col_names self . _file = open ( self . path , mode , newline = \"\" ) self . _writer = csv . DictWriter ( self . _file , self . col_names , escapechar = \",\" ) self . _writer . writeheader () self . flush () def __enter__ ( self ) : return self def __exit__ ( self , exc_type , exc_value , traceback ) : self . close () def close ( self ) : \"\"\" Closes the CSV file. \"\"\" if not self . _file . closed : self . _file . flush () self . _file . close () def _to_dict ( self , items : Iterable ) -> dict : \"\"\" Converts an iterable of items to a dictionary using the column names as keys. Args: items (Iterable): The items to convert to a dictionary. Returns: dict: The dictionary with column names as keys and items as values. \"\"\" return { k : v for k , v in zip ( self . col_names , items ) } def write ( self , row : dict | Iterable ) : \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () if not isinstance ( row , dict ) : row = self . _to_dict ( row ) self . _writer . writerow ( row ) def writerows ( self , rows : list [ dict ] | list [ Iterable ] ) : \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ] , dict ) : rows = [ self._to_dict(row) for row in rows ] self . _writer . writerows ( rows ) def flush ( self ) : \"\"\" Flushes any buffered data to the CSV file. \"\"\" self . _file . flush () Methods close def close ( self ) Closes the CSV file. View Source def close(self): \"\"\" Closes the CSV file. \"\"\" if not self._file.closed: self._file.flush() self._file.close() flush def flush ( self ) Flushes any buffered data to the CSV file. View Source def flush(self): \"\"\" Flushes any buffered data to the CSV file. \"\"\" self._file.flush() write def write ( self , row : Union [ dict , Iterable ] ) Writes a single row of data to the CSV file. Parameters: Name Type Description Default row dict Iterable The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. View Source def write(self, row: dict | Iterable): \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self._file.writable() if not isinstance(row, dict): row = self._to_dict(row) self._writer.writerow(row) writerows def writerows ( self , rows : list [ dict ] | list [ typing . Iterable ] ) Writes multiple rows of data to the CSV file. Parameters: Name Type Description Default rows list[dict] list[Iterable] The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. View Source def writerows ( self , rows : list [ dict ] | list [ Iterable ] ) : \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ] , dict ) : rows = [ self._to_dict(row) for row in rows ] self . _writer . writerows ( rows )","title":"Log Utils"},{"location":"reference/wtracker/utils/log_utils/#module-wtrackerutilslog_utils","text":"View Source import csv from typing import Iterable class CSVLogger : \"\"\" A class for logging data to a CSV file. Args: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. mode (str, optional): The file mode to open the CSV file in. Attributes: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. \"\"\" def __init__ ( self , path : str , col_names : list [ str ], mode : str = \"w+\" ): self . path = path self . col_names = col_names self . _file = open ( self . path , mode , newline = \"\" ) self . _writer = csv . DictWriter ( self . _file , self . col_names , escapechar = \",\" ) self . _writer . writeheader () self . flush () def __enter__ ( self ): return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def close ( self ): \"\"\" Closes the CSV file. \"\"\" if not self . _file . closed : self . _file . flush () self . _file . close () def _to_dict ( self , items : Iterable ) -> dict : \"\"\" Converts an iterable of items to a dictionary using the column names as keys. Args: items (Iterable): The items to convert to a dictionary. Returns: dict: The dictionary with column names as keys and items as values. \"\"\" return { k : v for k , v in zip ( self . col_names , items )} def write ( self , row : dict | Iterable ): \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () if not isinstance ( row , dict ): row = self . _to_dict ( row ) self . _writer . writerow ( row ) def writerows ( self , rows : list [ dict ] | list [ Iterable ]): \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ], dict ): rows = [ self . _to_dict ( row ) for row in rows ] self . _writer . writerows ( rows ) def flush ( self ): \"\"\" Flushes any buffered data to the CSV file. \"\"\" self . _file . flush ()","title":"Module wtracker.utils.log_utils"},{"location":"reference/wtracker/utils/log_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/log_utils/#csvlogger","text":"class CSVLogger ( path : str , col_names : list [ str ], mode : str = 'w+' ) A class for logging data to a CSV file.","title":"CSVLogger"},{"location":"reference/wtracker/utils/log_utils/#attributes","text":"Name Type Description Default path str The path to the CSV file. None col_names list[str] The column names for the CSV file. None mode str The file mode to open the CSV file in. None path str The path to the CSV file. None col_names list[str] The column names for the CSV file. None View Source class CSVLogger : \"\"\" A class for logging data to a CSV file. Args: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. mode (str, optional): The file mode to open the CSV file in. Attributes: path (str): The path to the CSV file. col_names (list[str]): The column names for the CSV file. \"\"\" def __init__ ( self , path : str , col_names : list [ str ] , mode : str = \"w+\" ) : self . path = path self . col_names = col_names self . _file = open ( self . path , mode , newline = \"\" ) self . _writer = csv . DictWriter ( self . _file , self . col_names , escapechar = \",\" ) self . _writer . writeheader () self . flush () def __enter__ ( self ) : return self def __exit__ ( self , exc_type , exc_value , traceback ) : self . close () def close ( self ) : \"\"\" Closes the CSV file. \"\"\" if not self . _file . closed : self . _file . flush () self . _file . close () def _to_dict ( self , items : Iterable ) -> dict : \"\"\" Converts an iterable of items to a dictionary using the column names as keys. Args: items (Iterable): The items to convert to a dictionary. Returns: dict: The dictionary with column names as keys and items as values. \"\"\" return { k : v for k , v in zip ( self . col_names , items ) } def write ( self , row : dict | Iterable ) : \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () if not isinstance ( row , dict ) : row = self . _to_dict ( row ) self . _writer . writerow ( row ) def writerows ( self , rows : list [ dict ] | list [ Iterable ] ) : \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ] , dict ) : rows = [ self._to_dict(row) for row in rows ] self . _writer . writerows ( rows ) def flush ( self ) : \"\"\" Flushes any buffered data to the CSV file. \"\"\" self . _file . flush ()","title":"Attributes"},{"location":"reference/wtracker/utils/log_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/log_utils/#close","text":"def close ( self ) Closes the CSV file. View Source def close(self): \"\"\" Closes the CSV file. \"\"\" if not self._file.closed: self._file.flush() self._file.close()","title":"close"},{"location":"reference/wtracker/utils/log_utils/#flush","text":"def flush ( self ) Flushes any buffered data to the CSV file. View Source def flush(self): \"\"\" Flushes any buffered data to the CSV file. \"\"\" self._file.flush()","title":"flush"},{"location":"reference/wtracker/utils/log_utils/#write","text":"def write ( self , row : Union [ dict , Iterable ] ) Writes a single row of data to the CSV file. Parameters: Name Type Description Default row dict Iterable The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. View Source def write(self, row: dict | Iterable): \"\"\" Writes a single row of data to the CSV file. Args: row (dict | Iterable): The row of data to write to the CSV file. If a dictionary is provided, the keys should match the column names. If an iterable is provided, the items will be matched with the column names in order. \"\"\" assert self._file.writable() if not isinstance(row, dict): row = self._to_dict(row) self._writer.writerow(row)","title":"write"},{"location":"reference/wtracker/utils/log_utils/#writerows","text":"def writerows ( self , rows : list [ dict ] | list [ typing . Iterable ] ) Writes multiple rows of data to the CSV file. Parameters: Name Type Description Default rows list[dict] list[Iterable] The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. View Source def writerows ( self , rows : list [ dict ] | list [ Iterable ] ) : \"\"\" Writes multiple rows of data to the CSV file. Args: rows (list[dict] | list[Iterable]): The rows of data to write to the CSV file. If a list of dictionaries is provided, the keys should match the column names. If a list of iterables is provided, the items will be matched with the column names in order. \"\"\" assert self . _file . writable () assert len ( rows ) > 0 if not isinstance ( rows [ 0 ] , dict ) : rows = [ self._to_dict(row) for row in rows ] self . _writer . writerows ( rows )","title":"writerows"},{"location":"reference/wtracker/utils/path_utils/","text":"Module wtracker.utils.path_utils View Source from __future__ import annotations import os from pathlib import Path , PurePath from typing import Callable , Union import shutil def absolute_path ( file_path : str ) -> str : \"\"\" Get the absolute path of a file. Args: file_path (str): The path of the file. Returns: str: The absolute path of the file. \"\"\" return Path ( file_path ) . resolve () . as_posix () def join_paths ( * path_segments : str ): \"\"\" Join multiple path segments into a single path. Args: *path_segments: Variable number of path segments to be joined. Returns: str: The joined path as a string. Example: >>> join_paths('home', 'yashlat', 'source', 'Bio-Proj', 'data') 'home/yashlat/source/Bio-Proj/data' \"\"\" return PurePath ( * path_segments ) . as_posix () def create_parent_directory ( file_path : str ): \"\"\" Create the parent directory for the given file path if it doesn't exist. Args: file_path (str): The path of the file. Returns: None \"\"\" save_folder = Path ( file_path ) . parent save_folder . mkdir ( parents = True , exist_ok = True ) def create_directory ( dir_path : str ): \"\"\" Create a directory at the specified path if it doesn't already exist. Args: dir_path (str): The path of the directory to be created. Returns: None \"\"\" Path ( dir_path ) . mkdir ( parents = True , exist_ok = True ) def bulk_rename ( dir_path : str , rename_fn : Callable [[ str ], str ]): \"\"\" Rename all files in a directory using the provided renaming function. Args: dir_path (str): The path of the directory containing the files to be renamed. rename_fn (Callable[[str], str]): The function to be used for renaming the files. Returns: None \"\"\" path : Path = Path ( dir_path ) for file_name in path . iterdir (): if file_name . is_dir (): continue new_name = path / rename_fn ( file_name . name ) file_name . rename ( new_name ) class Files : \"\"\" A utility class for working with files in a directory. Args: directory (str): The directory path to scan for files. extension (str, optional): The file extension to filter the files. scan_dirs (bool, optional): Whether to include directories in the results. return_full_path (bool, optional): Whether to return the full path of the files. sorting_key (Callable[[str], Union[int, str]], optional): A function to determine the sorting order of the files. \"\"\" def __init__ ( self , directory : str , extension : str = \"\" , scan_dirs : bool = False , return_full_path : bool = True , sorting_key : Callable [[ str ], Union [ int , str ]] = lambda name : name , ) -> None : self . root = directory self . extension = extension . lower () self . scan_dirs : bool = scan_dirs self . return_full_path = return_full_path self . results : list [ os . DirEntry ] = [] self . sorting_func = sorting_key self . _pos = - 1 self . _scan () def _scan ( self ): self . results = [] self . _pos = - 1 for result in os . scandir ( self . root ): if self . scan_dirs and result . is_dir (): self . results . append ( result ) else : if result . name . lower () . endswith ( self . extension ): self . results . append ( result ) self . results = sorted ( self . results , key = lambda f : self . sorting_func ( f . name )) def __getitem__ ( self , index : int ) -> os . DirEntry : \"\"\" Returns the file at the specified index. Args: index (int): The index of the file. Returns: os.DirEntry: The file at the specified index. \"\"\" return self . results [ index ] def __iter__ ( self ) -> Files : \"\"\" Returns an iterator object. Returns: Files: The iterator object. \"\"\" self . _pos = - 1 return self def __next__ ( self ) -> str : \"\"\" Returns the next file name or path in the iteration. Returns: str: The next file name or path. Raises: StopIteration: If there are no more files in the iteration. \"\"\" self . _pos += 1 if self . _pos >= self . __len__ (): raise StopIteration result = self . results [ self . _pos ] if self . return_full_path : return result . path return result . name def __len__ ( self ) -> int : \"\"\" Returns the number of files in the results list. Returns: int: The number of files. \"\"\" return len ( self . results ) def __contains__ ( self , key : str ) -> bool : \"\"\" Checks if a file with the specified name exists in the results list. Args: key (str): The file name to check. Returns: bool: True if the file exists, False otherwise. \"\"\" for res in self . results : if key == res . name : return True return False def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self . _pos ] . name def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self . _pos ] . path def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ () def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root ) Functions absolute_path def absolute_path ( file_path : 'str' ) -> 'str' Get the absolute path of a file. Parameters: Name Type Description Default file_path str The path of the file. None Returns: Type Description str The absolute path of the file. View Source def absolute_path ( file_path : str ) -> str : \"\"\" Get the absolute path of a file. Args: file_path (str): The path of the file. Returns: str: The absolute path of the file. \"\"\" return Path ( file_path ). resolve (). as_posix () bulk_rename def bulk_rename ( dir_path : 'str' , rename_fn : 'Callable[[str], str]' ) Rename all files in a directory using the provided renaming function. Parameters: Name Type Description Default dir_path str The path of the directory containing the files to be renamed. None rename_fn Callable[[str], str] The function to be used for renaming the files. None Returns: Type Description None None View Source def bulk_rename ( dir_path : str , rename_fn : Callable [ [str ] , str ] ) : \"\"\" Rename all files in a directory using the provided renaming function. Args: dir_path (str): The path of the directory containing the files to be renamed. rename_fn (Callable[[str], str]): The function to be used for renaming the files. Returns: None \"\"\" path : Path = Path ( dir_path ) for file_name in path . iterdir () : if file_name . is_dir () : continue new_name = path / rename_fn ( file_name . name ) file_name . rename ( new_name ) create_directory def create_directory ( dir_path : 'str' ) Create a directory at the specified path if it doesn't already exist. Parameters: Name Type Description Default dir_path str The path of the directory to be created. None Returns: Type Description None None View Source def create_directory(dir_path: str): \"\"\" Create a directory at the specified path if it doesn't already exist. Args: dir_path (str): The path of the directory to be created. Returns: None \"\"\" Path(dir_path).mkdir(parents=True, exist_ok=True) create_parent_directory def create_parent_directory ( file_path : 'str' ) Create the parent directory for the given file path if it doesn't exist. Parameters: Name Type Description Default file_path str The path of the file. None Returns: Type Description None None View Source def create_parent_directory(file_path: str): \"\"\" Create the parent directory for the given file path if it doesn't exist. Args: file_path (str): The path of the file. Returns: None \"\"\" save_folder = Path(file_path).parent save_folder.mkdir(parents=True, exist_ok=True) join_paths def join_paths ( * path_segments : 'str' ) Join multiple path segments into a single path. Parameters: Name Type Description Default *path_segments None Variable number of path segments to be joined. None Returns: Type Description str The joined path as a string. View Source def join_paths(*path_segments: str): \"\"\" Join multiple path segments into a single path. Args: *path_segments: Variable number of path segments to be joined. Returns: str: The joined path as a string. Example: >>> join_paths('home', 'yashlat', 'source', 'Bio-Proj', 'data') 'home/yashlat/source/Bio-Proj/data' \"\"\" return PurePath(*path_segments).as_posix() Classes Files class Files ( directory : 'str' , extension : 'str' = '' , scan_dirs : 'bool' = False , return_full_path : 'bool' = True , sorting_key : 'Callable[[str], Union[int, str]]' = < function Files .< lambda > at 0x7f93ee7f2290 > ) A utility class for working with files in a directory. Attributes Name Type Description Default directory str The directory path to scan for files. None extension str The file extension to filter the files. None scan_dirs bool Whether to include directories in the results. None return_full_path bool Whether to return the full path of the files. None sorting_key Callable[[str], Union[int, str]] A function to determine the sorting order of the files. None View Source class Files : \"\"\" A utility class for working with files in a directory. Args: directory (str): The directory path to scan for files. extension (str, optional): The file extension to filter the files. scan_dirs (bool, optional): Whether to include directories in the results. return_full_path (bool, optional): Whether to return the full path of the files. sorting_key (Callable[[str], Union[int, str]], optional): A function to determine the sorting order of the files. \"\"\" def __init__ ( self , directory : str , extension : str = \"\" , scan_dirs : bool = False , return_full_path : bool = True , sorting_key : Callable [ [str ] , Union [ int, str ] ] = lambda name : name , ) -> None : self . root = directory self . extension = extension . lower () self . scan_dirs : bool = scan_dirs self . return_full_path = return_full_path self . results : list [ os.DirEntry ] = [] self . sorting_func = sorting_key self . _pos = - 1 self . _scan () def _scan ( self ) : self . results = [] self . _pos = - 1 for result in os . scandir ( self . root ) : if self . scan_dirs and result . is_dir () : self . results . append ( result ) else : if result . name . lower (). endswith ( self . extension ) : self . results . append ( result ) self . results = sorted ( self . results , key = lambda f : self . sorting_func ( f . name )) def __getitem__ ( self , index : int ) -> os . DirEntry : \"\"\" Returns the file at the specified index. Args: index (int): The index of the file. Returns: os.DirEntry: The file at the specified index. \"\"\" return self . results [ index ] def __iter__ ( self ) -> Files : \"\"\" Returns an iterator object. Returns: Files: The iterator object. \"\"\" self . _pos = - 1 return self def __next__ ( self ) -> str : \"\"\" Returns the next file name or path in the iteration. Returns: str: The next file name or path. Raises: StopIteration: If there are no more files in the iteration. \"\"\" self . _pos += 1 if self . _pos >= self . __len__ () : raise StopIteration result = self . results [ self._pos ] if self . return_full_path : return result . path return result . name def __len__ ( self ) -> int : \"\"\" Returns the number of files in the results list. Returns: int: The number of files. \"\"\" return len ( self . results ) def __contains__ ( self , key : str ) -> bool : \"\"\" Checks if a file with the specified name exists in the results list. Args: key (str): The file name to check. Returns: bool: True if the file exists, False otherwise. \"\"\" for res in self . results : if key == res . name : return True return False def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self._pos ] . name def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self._pos ] . path def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ () def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root ) Methods copy def copy ( self , dst_root : 'str' ) -> 'None' Copies the current file to the specified destination directory. Parameters: Name Type Description Default dst_root str The destination directory path. None View Source def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root ) get_filename def get_filename ( self ) -> 'str' Returns the name of the current file. Returns: Type Description str The name of the current file. View Source def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self . _pos ]. name get_path def get_path ( self ) -> 'str' Returns the path of the current file. Returns: Type Description str The path of the current file. View Source def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self . _pos ]. path seek def seek ( self , pos : 'int' ) -> 'str' Moves the iterator to the specified position and returns the file name or path. Parameters: Name Type Description Default pos int The position to seek to. None Returns: Type Description str The file name or path at the specified position. Raises: Type Description AssertionError If the specified position is invalid. View Source def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ ()","title":"Path Utils"},{"location":"reference/wtracker/utils/path_utils/#module-wtrackerutilspath_utils","text":"View Source from __future__ import annotations import os from pathlib import Path , PurePath from typing import Callable , Union import shutil def absolute_path ( file_path : str ) -> str : \"\"\" Get the absolute path of a file. Args: file_path (str): The path of the file. Returns: str: The absolute path of the file. \"\"\" return Path ( file_path ) . resolve () . as_posix () def join_paths ( * path_segments : str ): \"\"\" Join multiple path segments into a single path. Args: *path_segments: Variable number of path segments to be joined. Returns: str: The joined path as a string. Example: >>> join_paths('home', 'yashlat', 'source', 'Bio-Proj', 'data') 'home/yashlat/source/Bio-Proj/data' \"\"\" return PurePath ( * path_segments ) . as_posix () def create_parent_directory ( file_path : str ): \"\"\" Create the parent directory for the given file path if it doesn't exist. Args: file_path (str): The path of the file. Returns: None \"\"\" save_folder = Path ( file_path ) . parent save_folder . mkdir ( parents = True , exist_ok = True ) def create_directory ( dir_path : str ): \"\"\" Create a directory at the specified path if it doesn't already exist. Args: dir_path (str): The path of the directory to be created. Returns: None \"\"\" Path ( dir_path ) . mkdir ( parents = True , exist_ok = True ) def bulk_rename ( dir_path : str , rename_fn : Callable [[ str ], str ]): \"\"\" Rename all files in a directory using the provided renaming function. Args: dir_path (str): The path of the directory containing the files to be renamed. rename_fn (Callable[[str], str]): The function to be used for renaming the files. Returns: None \"\"\" path : Path = Path ( dir_path ) for file_name in path . iterdir (): if file_name . is_dir (): continue new_name = path / rename_fn ( file_name . name ) file_name . rename ( new_name ) class Files : \"\"\" A utility class for working with files in a directory. Args: directory (str): The directory path to scan for files. extension (str, optional): The file extension to filter the files. scan_dirs (bool, optional): Whether to include directories in the results. return_full_path (bool, optional): Whether to return the full path of the files. sorting_key (Callable[[str], Union[int, str]], optional): A function to determine the sorting order of the files. \"\"\" def __init__ ( self , directory : str , extension : str = \"\" , scan_dirs : bool = False , return_full_path : bool = True , sorting_key : Callable [[ str ], Union [ int , str ]] = lambda name : name , ) -> None : self . root = directory self . extension = extension . lower () self . scan_dirs : bool = scan_dirs self . return_full_path = return_full_path self . results : list [ os . DirEntry ] = [] self . sorting_func = sorting_key self . _pos = - 1 self . _scan () def _scan ( self ): self . results = [] self . _pos = - 1 for result in os . scandir ( self . root ): if self . scan_dirs and result . is_dir (): self . results . append ( result ) else : if result . name . lower () . endswith ( self . extension ): self . results . append ( result ) self . results = sorted ( self . results , key = lambda f : self . sorting_func ( f . name )) def __getitem__ ( self , index : int ) -> os . DirEntry : \"\"\" Returns the file at the specified index. Args: index (int): The index of the file. Returns: os.DirEntry: The file at the specified index. \"\"\" return self . results [ index ] def __iter__ ( self ) -> Files : \"\"\" Returns an iterator object. Returns: Files: The iterator object. \"\"\" self . _pos = - 1 return self def __next__ ( self ) -> str : \"\"\" Returns the next file name or path in the iteration. Returns: str: The next file name or path. Raises: StopIteration: If there are no more files in the iteration. \"\"\" self . _pos += 1 if self . _pos >= self . __len__ (): raise StopIteration result = self . results [ self . _pos ] if self . return_full_path : return result . path return result . name def __len__ ( self ) -> int : \"\"\" Returns the number of files in the results list. Returns: int: The number of files. \"\"\" return len ( self . results ) def __contains__ ( self , key : str ) -> bool : \"\"\" Checks if a file with the specified name exists in the results list. Args: key (str): The file name to check. Returns: bool: True if the file exists, False otherwise. \"\"\" for res in self . results : if key == res . name : return True return False def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self . _pos ] . name def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self . _pos ] . path def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ () def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root )","title":"Module wtracker.utils.path_utils"},{"location":"reference/wtracker/utils/path_utils/#functions","text":"","title":"Functions"},{"location":"reference/wtracker/utils/path_utils/#absolute_path","text":"def absolute_path ( file_path : 'str' ) -> 'str' Get the absolute path of a file. Parameters: Name Type Description Default file_path str The path of the file. None Returns: Type Description str The absolute path of the file. View Source def absolute_path ( file_path : str ) -> str : \"\"\" Get the absolute path of a file. Args: file_path (str): The path of the file. Returns: str: The absolute path of the file. \"\"\" return Path ( file_path ). resolve (). as_posix ()","title":"absolute_path"},{"location":"reference/wtracker/utils/path_utils/#bulk_rename","text":"def bulk_rename ( dir_path : 'str' , rename_fn : 'Callable[[str], str]' ) Rename all files in a directory using the provided renaming function. Parameters: Name Type Description Default dir_path str The path of the directory containing the files to be renamed. None rename_fn Callable[[str], str] The function to be used for renaming the files. None Returns: Type Description None None View Source def bulk_rename ( dir_path : str , rename_fn : Callable [ [str ] , str ] ) : \"\"\" Rename all files in a directory using the provided renaming function. Args: dir_path (str): The path of the directory containing the files to be renamed. rename_fn (Callable[[str], str]): The function to be used for renaming the files. Returns: None \"\"\" path : Path = Path ( dir_path ) for file_name in path . iterdir () : if file_name . is_dir () : continue new_name = path / rename_fn ( file_name . name ) file_name . rename ( new_name )","title":"bulk_rename"},{"location":"reference/wtracker/utils/path_utils/#create_directory","text":"def create_directory ( dir_path : 'str' ) Create a directory at the specified path if it doesn't already exist. Parameters: Name Type Description Default dir_path str The path of the directory to be created. None Returns: Type Description None None View Source def create_directory(dir_path: str): \"\"\" Create a directory at the specified path if it doesn't already exist. Args: dir_path (str): The path of the directory to be created. Returns: None \"\"\" Path(dir_path).mkdir(parents=True, exist_ok=True)","title":"create_directory"},{"location":"reference/wtracker/utils/path_utils/#create_parent_directory","text":"def create_parent_directory ( file_path : 'str' ) Create the parent directory for the given file path if it doesn't exist. Parameters: Name Type Description Default file_path str The path of the file. None Returns: Type Description None None View Source def create_parent_directory(file_path: str): \"\"\" Create the parent directory for the given file path if it doesn't exist. Args: file_path (str): The path of the file. Returns: None \"\"\" save_folder = Path(file_path).parent save_folder.mkdir(parents=True, exist_ok=True)","title":"create_parent_directory"},{"location":"reference/wtracker/utils/path_utils/#join_paths","text":"def join_paths ( * path_segments : 'str' ) Join multiple path segments into a single path. Parameters: Name Type Description Default *path_segments None Variable number of path segments to be joined. None Returns: Type Description str The joined path as a string. View Source def join_paths(*path_segments: str): \"\"\" Join multiple path segments into a single path. Args: *path_segments: Variable number of path segments to be joined. Returns: str: The joined path as a string. Example: >>> join_paths('home', 'yashlat', 'source', 'Bio-Proj', 'data') 'home/yashlat/source/Bio-Proj/data' \"\"\" return PurePath(*path_segments).as_posix()","title":"join_paths"},{"location":"reference/wtracker/utils/path_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/path_utils/#files","text":"class Files ( directory : 'str' , extension : 'str' = '' , scan_dirs : 'bool' = False , return_full_path : 'bool' = True , sorting_key : 'Callable[[str], Union[int, str]]' = < function Files .< lambda > at 0x7f93ee7f2290 > ) A utility class for working with files in a directory.","title":"Files"},{"location":"reference/wtracker/utils/path_utils/#attributes","text":"Name Type Description Default directory str The directory path to scan for files. None extension str The file extension to filter the files. None scan_dirs bool Whether to include directories in the results. None return_full_path bool Whether to return the full path of the files. None sorting_key Callable[[str], Union[int, str]] A function to determine the sorting order of the files. None View Source class Files : \"\"\" A utility class for working with files in a directory. Args: directory (str): The directory path to scan for files. extension (str, optional): The file extension to filter the files. scan_dirs (bool, optional): Whether to include directories in the results. return_full_path (bool, optional): Whether to return the full path of the files. sorting_key (Callable[[str], Union[int, str]], optional): A function to determine the sorting order of the files. \"\"\" def __init__ ( self , directory : str , extension : str = \"\" , scan_dirs : bool = False , return_full_path : bool = True , sorting_key : Callable [ [str ] , Union [ int, str ] ] = lambda name : name , ) -> None : self . root = directory self . extension = extension . lower () self . scan_dirs : bool = scan_dirs self . return_full_path = return_full_path self . results : list [ os.DirEntry ] = [] self . sorting_func = sorting_key self . _pos = - 1 self . _scan () def _scan ( self ) : self . results = [] self . _pos = - 1 for result in os . scandir ( self . root ) : if self . scan_dirs and result . is_dir () : self . results . append ( result ) else : if result . name . lower (). endswith ( self . extension ) : self . results . append ( result ) self . results = sorted ( self . results , key = lambda f : self . sorting_func ( f . name )) def __getitem__ ( self , index : int ) -> os . DirEntry : \"\"\" Returns the file at the specified index. Args: index (int): The index of the file. Returns: os.DirEntry: The file at the specified index. \"\"\" return self . results [ index ] def __iter__ ( self ) -> Files : \"\"\" Returns an iterator object. Returns: Files: The iterator object. \"\"\" self . _pos = - 1 return self def __next__ ( self ) -> str : \"\"\" Returns the next file name or path in the iteration. Returns: str: The next file name or path. Raises: StopIteration: If there are no more files in the iteration. \"\"\" self . _pos += 1 if self . _pos >= self . __len__ () : raise StopIteration result = self . results [ self._pos ] if self . return_full_path : return result . path return result . name def __len__ ( self ) -> int : \"\"\" Returns the number of files in the results list. Returns: int: The number of files. \"\"\" return len ( self . results ) def __contains__ ( self , key : str ) -> bool : \"\"\" Checks if a file with the specified name exists in the results list. Args: key (str): The file name to check. Returns: bool: True if the file exists, False otherwise. \"\"\" for res in self . results : if key == res . name : return True return False def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self._pos ] . name def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self._pos ] . path def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ () def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root )","title":"Attributes"},{"location":"reference/wtracker/utils/path_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/path_utils/#copy","text":"def copy ( self , dst_root : 'str' ) -> 'None' Copies the current file to the specified destination directory. Parameters: Name Type Description Default dst_root str The destination directory path. None View Source def copy ( self , dst_root : str ) -> None : \"\"\" Copies the current file to the specified destination directory. Args: dst_root (str): The destination directory path. \"\"\" shutil . copy2 ( self . get_path (), dst = dst_root )","title":"copy"},{"location":"reference/wtracker/utils/path_utils/#get_filename","text":"def get_filename ( self ) -> 'str' Returns the name of the current file. Returns: Type Description str The name of the current file. View Source def get_filename ( self ) -> str : \"\"\" Returns the name of the current file. Returns: str: The name of the current file. \"\"\" return self . results [ self . _pos ]. name","title":"get_filename"},{"location":"reference/wtracker/utils/path_utils/#get_path","text":"def get_path ( self ) -> 'str' Returns the path of the current file. Returns: Type Description str The path of the current file. View Source def get_path ( self ) -> str : \"\"\" Returns the path of the current file. Returns: str: The path of the current file. \"\"\" return self . results [ self . _pos ]. path","title":"get_path"},{"location":"reference/wtracker/utils/path_utils/#seek","text":"def seek ( self , pos : 'int' ) -> 'str' Moves the iterator to the specified position and returns the file name or path. Parameters: Name Type Description Default pos int The position to seek to. None Returns: Type Description str The file name or path at the specified position. Raises: Type Description AssertionError If the specified position is invalid. View Source def seek ( self , pos : int ) -> str : \"\"\" Moves the iterator to the specified position and returns the file name or path. Args: pos (int): The position to seek to. Returns: str: The file name or path at the specified position. Raises: AssertionError: If the specified position is invalid. \"\"\" assert 0 <= pos < self . __len__ (), \"Invalid position\" self . _pos = pos - 1 return self . __next__ ()","title":"seek"},{"location":"reference/wtracker/utils/threading_utils/","text":"Module wtracker.utils.threading_utils View Source import queue import threading import multiprocessing from typing import Callable from tqdm.auto import tqdm def adjust_num_workers ( num_tasks : int , chunk_size : int , num_workers : int = None ) -> int : \"\"\" Adjust the number of workers based on the number of tasks and chunk size. Args: num_tasks (int): The number of tasks to be processed. chunk_size (int): The size of each processing chunk. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. \"\"\" if num_workers is None : # if None then choose automatically num_workers = min ( multiprocessing . cpu_count () / 2 , num_tasks / ( 2 * chunk_size )) num_workers = round ( num_workers ) use_multiprocessing = num_workers > 0 num_workers = min ( num_workers , num_tasks // chunk_size ) # no point having workers without tasks num_workers = min ( num_workers , multiprocessing . cpu_count ()) # no point having more workers than cpus if num_workers < 0 : # make sure value is valid num_workers = 0 if use_multiprocessing : num_workers = max ( num_workers , 1 ) elif not use_multiprocessing and num_workers == 1 : num_workers = 0 return num_workers class TqdmQueue ( queue . Queue ): \"\"\" A subclass of `queue.Queue` that provides progress tracking using `tqdm`. Args: maxsize (int): The maximum size of the queue (default: 0). **kwargs: Additional keyword arguments to be passed to the tqdm progress bar. Attributes: pbar (tqdm.tqdm): The progress bar object. total (int): The total number of items processed. Example: queue = ProgressQueue(maxsize=10) queue.put(item) queue.task_done() queue.join() \"\"\" def __init__ ( self , maxsize : int = 0 , ** kwargs ): super () . __init__ ( maxsize = maxsize ) self . pbar = tqdm ( total = 1 , ** kwargs ) self . total = 0 # Keep our own total tracker so we can update the Progressbar def task_done ( self ): \"\"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\"\" super () . task_done () self . pbar . update () self . pbar . refresh () # Redraw the progressbar def _put ( self , item ): super () . _put ( item ) self . total += 1 processed = self . pbar . n # Get current progress to re-apply self . pbar . reset ( self . total ) # Reset and update total self . pbar . update ( processed ) # Re-apply progress self . pbar . refresh () # Redraw the progressbar def join ( self ): \"\"\" Blocks until all items in the Queue have been gotten and processed. \"\"\" super () . join () self . pbar . close () class TaskScheduler : \"\"\" This class is used to schedule tasks to be executed by a worker thread. Args: task_func (Callable): The function to be executed by the worker thread. maxsize (int, optional): The maximum number of items that can be in the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments to be passed to the TqdmQueue constructor. \"\"\" def __init__ ( self , task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs , ): self . _queue = TqdmQueue ( maxsize , ** tqdm_kwargs ) if tqdm else queue . Queue ( maxsize ) self . _worker_thread = threading . Thread ( target = self . _worker , args = ( self . _queue ,)) self . _task_func = task_func def start ( self ): \"\"\" Starts the worker thread. \"\"\" self . _worker_thread . start () def __enter__ ( self ): self . start () return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def schedule_save ( self , * params ): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self . _queue . put ( item = params , block = True ) def _worker ( self , q : queue . Queue ): while True : params = q . get ( block = True ) # exit if signaled if params is None : break self . _task_func ( params ) q . task_done () def close ( self ): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self . _queue . join () self . _queue . put ( None ) self . _worker_thread . join () Functions adjust_num_workers def adjust_num_workers ( num_tasks : int , chunk_size : int , num_workers : int = None ) -> int Adjust the number of workers based on the number of tasks and chunk size. Parameters: Name Type Description Default num_tasks int The number of tasks to be processed. None chunk_size int The size of each processing chunk. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None View Source def adjust_num_workers(num_tasks: int, chunk_size: int, num_workers: int = None) -> int: \"\"\" Adjust the number of workers based on the number of tasks and chunk size. Args: num_tasks (int): The number of tasks to be processed. chunk_size (int): The size of each processing chunk. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. \"\"\" if num_workers is None: # if None then choose automatically num_workers = min(multiprocessing.cpu_count() / 2, num_tasks / (2 * chunk_size)) num_workers = round(num_workers) use_multiprocessing = num_workers > 0 num_workers = min(num_workers, num_tasks // chunk_size) # no point having workers without tasks num_workers = min(num_workers, multiprocessing.cpu_count()) # no point having more workers than cpus if num_workers < 0: # make sure value is valid num_workers = 0 if use_multiprocessing: num_workers = max(num_workers, 1) elif not use_multiprocessing and num_workers == 1: num_workers = 0 return num_workers Classes TaskScheduler class TaskScheduler ( task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs ) This class is used to schedule tasks to be executed by a worker thread. Attributes Name Type Description Default task_func Callable The function to be executed by the worker thread. None maxsize int The maximum number of items that can be in the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments to be passed to the TqdmQueue constructor. None View Source class TaskScheduler : \"\"\" This class is used to schedule tasks to be executed by a worker thread. Args: task_func (Callable): The function to be executed by the worker thread. maxsize (int, optional): The maximum number of items that can be in the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments to be passed to the TqdmQueue constructor. \"\"\" def __init__ ( self , task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs , ): self . _queue = TqdmQueue ( maxsize , ** tqdm_kwargs ) if tqdm else queue . Queue ( maxsize ) self . _worker_thread = threading . Thread ( target = self . _worker , args = ( self . _queue ,)) self . _task_func = task_func def start ( self ): \"\"\" Starts the worker thread. \"\"\" self . _worker_thread . start () def __enter__ ( self ): self . start () return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def schedule_save ( self , * params ): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self . _queue . put ( item = params , block = True ) def _worker ( self , q : queue . Queue ): while True : params = q . get ( block = True ) # exit if signaled if params is None : break self . _task_func ( params ) q . task_done () def close ( self ): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self . _queue . join () self . _queue . put ( None ) self . _worker_thread . join () Descendants wtracker.utils.io_utils.FrameSaver wtracker.utils.io_utils.ImageSaver Methods close def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join() schedule_save def schedule_save ( self , * params ) Schedules a task by putting task parameters into the queue. Parameters: Name Type Description Default *params None The parameters to be passed to the task function. None View Source def schedule_save(self, *params): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self._queue.put(item=params, block=True) start def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start() TqdmQueue class TqdmQueue ( maxsize : int = 0 , ** kwargs ) A subclass of queue.Queue that provides progress tracking using tqdm . Attributes Name Type Description Default maxsize int The maximum size of the queue (default: 0). None **kwargs None Additional keyword arguments to be passed to the tqdm progress bar. None pbar tqdm.tqdm The progress bar object. None total int The total number of items processed. None View Source class TqdmQueue ( queue . Queue ) : \" \"\" A subclass of `queue.Queue` that provides progress tracking using `tqdm`. Args: maxsize (int): The maximum size of the queue (default: 0). **kwargs: Additional keyword arguments to be passed to the tqdm progress bar. Attributes: pbar (tqdm.tqdm): The progress bar object. total (int): The total number of items processed. Example: queue = ProgressQueue(maxsize=10) queue.put(item) queue.task_done() queue.join() \"\" \" def __init__ ( self , maxsize : int = 0 , ** kwargs ) : super (). __init__ ( maxsize = maxsize ) self . pbar = tqdm ( total = 1 , ** kwargs ) self . total = 0 # Keep our own total tracker so we can update the Progressbar def task_done ( self ) : \" \"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\" \" super (). task_done () self . pbar . update () self . pbar . refresh () # Redraw the progressbar def _put ( self , item ) : super (). _put ( item ) self . total += 1 processed = self . pbar . n # Get current progress to re-apply self . pbar . reset ( self . total ) # Reset and update total self . pbar . update ( processed ) # Re-apply progress self . pbar . refresh () # Redraw the progressbar def join ( self ) : \" \"\" Blocks until all items in the Queue have been gotten and processed. \"\" \" super (). join () self . pbar . close () Ancestors (in MRO) queue.Queue Methods empty def empty ( self ) Return True if the queue is empty, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() == 0 as a direct substitute, but be aware that either approach risks a race condition where a queue can grow before the result of empty() or qsize() can be used. To create code that needs to wait for all queued tasks to be completed, the preferred technique is to use the join() method. View Source def empty(self): '''Return True if the queue is empty, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() == 0 as a direct substitute, but be aware that either approach risks a race condition where a queue can grow before the result of empty() or qsize() can be used. To create code that needs to wait for all queued tasks to be completed, the preferred technique is to use the join() method. ''' with self.mutex: return not self._qsize() full def full ( self ) Return True if the queue is full, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() >= n as a direct substitute, but be aware that either approach risks a race condition where a queue can shrink before the result of full() or qsize() can be used. View Source def full(self): '''Return True if the queue is full, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() >= n as a direct substitute, but be aware that either approach risks a race condition where a queue can shrink before the result of full() or qsize() can be used. ''' with self.mutex: return 0 < self.maxsize <= self._qsize() get def get ( self , block = True , timeout = None ) Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). View Source def get(self, block=True, timeout=None): '''Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). ''' with self.not_empty: if not block: if not self._qsize(): raise Empty elif timeout is None: while not self._qsize(): self.not_empty.wait() elif timeout < 0: raise ValueError(\"'timeout' must be a non-negative number\") else: endtime = time() + timeout while not self._qsize(): remaining = endtime - time() if remaining <= 0.0: raise Empty self.not_empty.wait(remaining) item = self._get() self.not_full.notify() return item get_nowait def get_nowait ( self ) Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise raise the Empty exception. View Source def get_nowait(self): '''Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise raise the Empty exception. ''' return self.get(block=False) join def join ( self ) Blocks until all items in the Queue have been gotten and processed. View Source def join ( self ) : \"\" \" Blocks until all items in the Queue have been gotten and processed. \"\" \" super().join() self.pbar.close() put def put ( self , item , block = True , timeout = None ) Put an item into the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until a free slot is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Full exception if no free slot was available within that time. Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). View Source def put(self, item, block=True, timeout=None): '''Put an item into the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until a free slot is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Full exception if no free slot was available within that time. Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). ''' with self.not_full: if self.maxsize > 0: if not block: if self._qsize() >= self.maxsize: raise Full elif timeout is None: while self._qsize() >= self.maxsize: self.not_full.wait() elif timeout < 0: raise ValueError(\"'timeout' must be a non-negative number\") else: endtime = time() + timeout while self._qsize() >= self.maxsize: remaining = endtime - time() if remaining <= 0.0: raise Full self.not_full.wait(remaining) self._put(item) self.unfinished_tasks += 1 self.not_empty.notify() put_nowait def put_nowait ( self , item ) Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. Otherwise raise the Full exception. View Source def put_nowait(self, item): '''Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. Otherwise raise the Full exception. ''' return self.put(item, block=False) qsize def qsize ( self ) Return the approximate size of the queue (not reliable!). View Source def qsize(self): '''Return the approximate size of the queue (not reliable!).''' with self.mutex: return self._qsize() task_done def task_done ( self ) Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. View Source def task_done(self): \"\"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\"\" super().task_done() self.pbar.update() self.pbar.refresh() # Redraw the progressbar","title":"Threading Utils"},{"location":"reference/wtracker/utils/threading_utils/#module-wtrackerutilsthreading_utils","text":"View Source import queue import threading import multiprocessing from typing import Callable from tqdm.auto import tqdm def adjust_num_workers ( num_tasks : int , chunk_size : int , num_workers : int = None ) -> int : \"\"\" Adjust the number of workers based on the number of tasks and chunk size. Args: num_tasks (int): The number of tasks to be processed. chunk_size (int): The size of each processing chunk. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. \"\"\" if num_workers is None : # if None then choose automatically num_workers = min ( multiprocessing . cpu_count () / 2 , num_tasks / ( 2 * chunk_size )) num_workers = round ( num_workers ) use_multiprocessing = num_workers > 0 num_workers = min ( num_workers , num_tasks // chunk_size ) # no point having workers without tasks num_workers = min ( num_workers , multiprocessing . cpu_count ()) # no point having more workers than cpus if num_workers < 0 : # make sure value is valid num_workers = 0 if use_multiprocessing : num_workers = max ( num_workers , 1 ) elif not use_multiprocessing and num_workers == 1 : num_workers = 0 return num_workers class TqdmQueue ( queue . Queue ): \"\"\" A subclass of `queue.Queue` that provides progress tracking using `tqdm`. Args: maxsize (int): The maximum size of the queue (default: 0). **kwargs: Additional keyword arguments to be passed to the tqdm progress bar. Attributes: pbar (tqdm.tqdm): The progress bar object. total (int): The total number of items processed. Example: queue = ProgressQueue(maxsize=10) queue.put(item) queue.task_done() queue.join() \"\"\" def __init__ ( self , maxsize : int = 0 , ** kwargs ): super () . __init__ ( maxsize = maxsize ) self . pbar = tqdm ( total = 1 , ** kwargs ) self . total = 0 # Keep our own total tracker so we can update the Progressbar def task_done ( self ): \"\"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\"\" super () . task_done () self . pbar . update () self . pbar . refresh () # Redraw the progressbar def _put ( self , item ): super () . _put ( item ) self . total += 1 processed = self . pbar . n # Get current progress to re-apply self . pbar . reset ( self . total ) # Reset and update total self . pbar . update ( processed ) # Re-apply progress self . pbar . refresh () # Redraw the progressbar def join ( self ): \"\"\" Blocks until all items in the Queue have been gotten and processed. \"\"\" super () . join () self . pbar . close () class TaskScheduler : \"\"\" This class is used to schedule tasks to be executed by a worker thread. Args: task_func (Callable): The function to be executed by the worker thread. maxsize (int, optional): The maximum number of items that can be in the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments to be passed to the TqdmQueue constructor. \"\"\" def __init__ ( self , task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs , ): self . _queue = TqdmQueue ( maxsize , ** tqdm_kwargs ) if tqdm else queue . Queue ( maxsize ) self . _worker_thread = threading . Thread ( target = self . _worker , args = ( self . _queue ,)) self . _task_func = task_func def start ( self ): \"\"\" Starts the worker thread. \"\"\" self . _worker_thread . start () def __enter__ ( self ): self . start () return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def schedule_save ( self , * params ): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self . _queue . put ( item = params , block = True ) def _worker ( self , q : queue . Queue ): while True : params = q . get ( block = True ) # exit if signaled if params is None : break self . _task_func ( params ) q . task_done () def close ( self ): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self . _queue . join () self . _queue . put ( None ) self . _worker_thread . join ()","title":"Module wtracker.utils.threading_utils"},{"location":"reference/wtracker/utils/threading_utils/#functions","text":"","title":"Functions"},{"location":"reference/wtracker/utils/threading_utils/#adjust_num_workers","text":"def adjust_num_workers ( num_tasks : int , chunk_size : int , num_workers : int = None ) -> int Adjust the number of workers based on the number of tasks and chunk size. Parameters: Name Type Description Default num_tasks int The number of tasks to be processed. None chunk_size int The size of each processing chunk. None num_workers int The number of workers to use for parallel processing. If None, the number of workers is determined automatically. None View Source def adjust_num_workers(num_tasks: int, chunk_size: int, num_workers: int = None) -> int: \"\"\" Adjust the number of workers based on the number of tasks and chunk size. Args: num_tasks (int): The number of tasks to be processed. chunk_size (int): The size of each processing chunk. num_workers (int, optional): The number of workers to use for parallel processing. If None, the number of workers is determined automatically. \"\"\" if num_workers is None: # if None then choose automatically num_workers = min(multiprocessing.cpu_count() / 2, num_tasks / (2 * chunk_size)) num_workers = round(num_workers) use_multiprocessing = num_workers > 0 num_workers = min(num_workers, num_tasks // chunk_size) # no point having workers without tasks num_workers = min(num_workers, multiprocessing.cpu_count()) # no point having more workers than cpus if num_workers < 0: # make sure value is valid num_workers = 0 if use_multiprocessing: num_workers = max(num_workers, 1) elif not use_multiprocessing and num_workers == 1: num_workers = 0 return num_workers","title":"adjust_num_workers"},{"location":"reference/wtracker/utils/threading_utils/#classes","text":"","title":"Classes"},{"location":"reference/wtracker/utils/threading_utils/#taskscheduler","text":"class TaskScheduler ( task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs ) This class is used to schedule tasks to be executed by a worker thread.","title":"TaskScheduler"},{"location":"reference/wtracker/utils/threading_utils/#attributes","text":"Name Type Description Default task_func Callable The function to be executed by the worker thread. None maxsize int The maximum number of items that can be in the queue. None tqdm bool Whether to use tqdm for progress tracking. None **tqdm_kwargs None Additional keyword arguments to be passed to the TqdmQueue constructor. None View Source class TaskScheduler : \"\"\" This class is used to schedule tasks to be executed by a worker thread. Args: task_func (Callable): The function to be executed by the worker thread. maxsize (int, optional): The maximum number of items that can be in the queue. tqdm (bool, optional): Whether to use tqdm for progress tracking. **tqdm_kwargs: Additional keyword arguments to be passed to the TqdmQueue constructor. \"\"\" def __init__ ( self , task_func : Callable , maxsize : int = 0 , tqdm : bool = True , ** tqdm_kwargs , ): self . _queue = TqdmQueue ( maxsize , ** tqdm_kwargs ) if tqdm else queue . Queue ( maxsize ) self . _worker_thread = threading . Thread ( target = self . _worker , args = ( self . _queue ,)) self . _task_func = task_func def start ( self ): \"\"\" Starts the worker thread. \"\"\" self . _worker_thread . start () def __enter__ ( self ): self . start () return self def __exit__ ( self , exc_type , exc_value , traceback ): self . close () def schedule_save ( self , * params ): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self . _queue . put ( item = params , block = True ) def _worker ( self , q : queue . Queue ): while True : params = q . get ( block = True ) # exit if signaled if params is None : break self . _task_func ( params ) q . task_done () def close ( self ): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self . _queue . join () self . _queue . put ( None ) self . _worker_thread . join ()","title":"Attributes"},{"location":"reference/wtracker/utils/threading_utils/#descendants","text":"wtracker.utils.io_utils.FrameSaver wtracker.utils.io_utils.ImageSaver","title":"Descendants"},{"location":"reference/wtracker/utils/threading_utils/#methods","text":"","title":"Methods"},{"location":"reference/wtracker/utils/threading_utils/#close","text":"def close ( self ) Waits for the queue to empty and then closes the worker thread. View Source def close(self): \"\"\" Waits for the queue to empty and then closes the worker thread. \"\"\" self._queue.join() self._queue.put(None) self._worker_thread.join()","title":"close"},{"location":"reference/wtracker/utils/threading_utils/#schedule_save","text":"def schedule_save ( self , * params ) Schedules a task by putting task parameters into the queue. Parameters: Name Type Description Default *params None The parameters to be passed to the task function. None View Source def schedule_save(self, *params): \"\"\" Schedules a task by putting task parameters into the queue. Args: *params: The parameters to be passed to the task function. \"\"\" self._queue.put(item=params, block=True)","title":"schedule_save"},{"location":"reference/wtracker/utils/threading_utils/#start","text":"def start ( self ) Starts the worker thread. View Source def start(self): \"\"\" Starts the worker thread. \"\"\" self._worker_thread.start()","title":"start"},{"location":"reference/wtracker/utils/threading_utils/#tqdmqueue","text":"class TqdmQueue ( maxsize : int = 0 , ** kwargs ) A subclass of queue.Queue that provides progress tracking using tqdm .","title":"TqdmQueue"},{"location":"reference/wtracker/utils/threading_utils/#attributes_1","text":"Name Type Description Default maxsize int The maximum size of the queue (default: 0). None **kwargs None Additional keyword arguments to be passed to the tqdm progress bar. None pbar tqdm.tqdm The progress bar object. None total int The total number of items processed. None View Source class TqdmQueue ( queue . Queue ) : \" \"\" A subclass of `queue.Queue` that provides progress tracking using `tqdm`. Args: maxsize (int): The maximum size of the queue (default: 0). **kwargs: Additional keyword arguments to be passed to the tqdm progress bar. Attributes: pbar (tqdm.tqdm): The progress bar object. total (int): The total number of items processed. Example: queue = ProgressQueue(maxsize=10) queue.put(item) queue.task_done() queue.join() \"\" \" def __init__ ( self , maxsize : int = 0 , ** kwargs ) : super (). __init__ ( maxsize = maxsize ) self . pbar = tqdm ( total = 1 , ** kwargs ) self . total = 0 # Keep our own total tracker so we can update the Progressbar def task_done ( self ) : \" \"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\" \" super (). task_done () self . pbar . update () self . pbar . refresh () # Redraw the progressbar def _put ( self , item ) : super (). _put ( item ) self . total += 1 processed = self . pbar . n # Get current progress to re-apply self . pbar . reset ( self . total ) # Reset and update total self . pbar . update ( processed ) # Re-apply progress self . pbar . refresh () # Redraw the progressbar def join ( self ) : \" \"\" Blocks until all items in the Queue have been gotten and processed. \"\" \" super (). join () self . pbar . close ()","title":"Attributes"},{"location":"reference/wtracker/utils/threading_utils/#ancestors-in-mro","text":"queue.Queue","title":"Ancestors (in MRO)"},{"location":"reference/wtracker/utils/threading_utils/#methods_1","text":"","title":"Methods"},{"location":"reference/wtracker/utils/threading_utils/#empty","text":"def empty ( self ) Return True if the queue is empty, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() == 0 as a direct substitute, but be aware that either approach risks a race condition where a queue can grow before the result of empty() or qsize() can be used. To create code that needs to wait for all queued tasks to be completed, the preferred technique is to use the join() method. View Source def empty(self): '''Return True if the queue is empty, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() == 0 as a direct substitute, but be aware that either approach risks a race condition where a queue can grow before the result of empty() or qsize() can be used. To create code that needs to wait for all queued tasks to be completed, the preferred technique is to use the join() method. ''' with self.mutex: return not self._qsize()","title":"empty"},{"location":"reference/wtracker/utils/threading_utils/#full","text":"def full ( self ) Return True if the queue is full, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() >= n as a direct substitute, but be aware that either approach risks a race condition where a queue can shrink before the result of full() or qsize() can be used. View Source def full(self): '''Return True if the queue is full, False otherwise (not reliable!). This method is likely to be removed at some point. Use qsize() >= n as a direct substitute, but be aware that either approach risks a race condition where a queue can shrink before the result of full() or qsize() can be used. ''' with self.mutex: return 0 < self.maxsize <= self._qsize()","title":"full"},{"location":"reference/wtracker/utils/threading_utils/#get","text":"def get ( self , block = True , timeout = None ) Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). View Source def get(self, block=True, timeout=None): '''Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). ''' with self.not_empty: if not block: if not self._qsize(): raise Empty elif timeout is None: while not self._qsize(): self.not_empty.wait() elif timeout < 0: raise ValueError(\"'timeout' must be a non-negative number\") else: endtime = time() + timeout while not self._qsize(): remaining = endtime - time() if remaining <= 0.0: raise Empty self.not_empty.wait(remaining) item = self._get() self.not_full.notify() return item","title":"get"},{"location":"reference/wtracker/utils/threading_utils/#get_nowait","text":"def get_nowait ( self ) Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise raise the Empty exception. View Source def get_nowait(self): '''Remove and return an item from the queue without blocking. Only get an item if one is immediately available. Otherwise raise the Empty exception. ''' return self.get(block=False)","title":"get_nowait"},{"location":"reference/wtracker/utils/threading_utils/#join","text":"def join ( self ) Blocks until all items in the Queue have been gotten and processed. View Source def join ( self ) : \"\" \" Blocks until all items in the Queue have been gotten and processed. \"\" \" super().join() self.pbar.close()","title":"join"},{"location":"reference/wtracker/utils/threading_utils/#put","text":"def put ( self , item , block = True , timeout = None ) Put an item into the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until a free slot is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Full exception if no free slot was available within that time. Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). View Source def put(self, item, block=True, timeout=None): '''Put an item into the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until a free slot is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Full exception if no free slot was available within that time. Otherwise ('block' is false), put an item on the queue if a free slot is immediately available, else raise the Full exception ('timeout' is ignored in that case). ''' with self.not_full: if self.maxsize > 0: if not block: if self._qsize() >= self.maxsize: raise Full elif timeout is None: while self._qsize() >= self.maxsize: self.not_full.wait() elif timeout < 0: raise ValueError(\"'timeout' must be a non-negative number\") else: endtime = time() + timeout while self._qsize() >= self.maxsize: remaining = endtime - time() if remaining <= 0.0: raise Full self.not_full.wait(remaining) self._put(item) self.unfinished_tasks += 1 self.not_empty.notify()","title":"put"},{"location":"reference/wtracker/utils/threading_utils/#put_nowait","text":"def put_nowait ( self , item ) Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. Otherwise raise the Full exception. View Source def put_nowait(self, item): '''Put an item into the queue without blocking. Only enqueue the item if a free slot is immediately available. Otherwise raise the Full exception. ''' return self.put(item, block=False)","title":"put_nowait"},{"location":"reference/wtracker/utils/threading_utils/#qsize","text":"def qsize ( self ) Return the approximate size of the queue (not reliable!). View Source def qsize(self): '''Return the approximate size of the queue (not reliable!).''' with self.mutex: return self._qsize()","title":"qsize"},{"location":"reference/wtracker/utils/threading_utils/#task_done","text":"def task_done ( self ) Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. View Source def task_done(self): \"\"\" Mark the task as done and update the progress bar. This method should be called when a task is completed. It updates the progress bar to reflect the completion of the task. \"\"\" super().task_done() self.pbar.update() self.pbar.refresh() # Redraw the progressbar","title":"task_done"}]} \ No newline at end of file diff --git a/sitemap.xml.gz b/sitemap.xml.gz index 33e5ba0..ed4d58e 100644 Binary files a/sitemap.xml.gz and b/sitemap.xml.gz differ