-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from denBruneBarone/decision_tree
Decision tree
- Loading branch information
Showing
13 changed files
with
338 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
notebook | ||
pandas | ||
pyarrow | ||
matplotlib | ||
notebook==7.1.1 | ||
pandas==2.2.1 | ||
pyarrow==15.0.1 | ||
matplotlib==3.8.3 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,14 @@ | ||
class ModelConfig: | ||
input_size = 1000 | ||
embedding_dim = 100 | ||
hidden_size = 128 | ||
num_classes = 3 | ||
class HPConfig: | ||
criterion = 'friedman_mse' # samme som paper | ||
max_depth = 10 # værdien fra paperet om modeller er 7 | ||
max_features = None # samme som paper | ||
max_leaf_nodes = 500 # værdien fra paperet om modeller er 10 | ||
|
||
class GridSearchConfig: | ||
param_grid = { | ||
'criterion': ['mse', 'friedman_mse', 'mae'], | ||
'max_depth': [2, 3, 4, 5, 6, 7, 8], | ||
'max_features': [None, 'sqrt', 'log2'], | ||
'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8, 9, 10] | ||
} | ||
|
||
class TrainingConfig: | ||
num_epochs = 10 | ||
batch_size = 32 | ||
learning_rate = 0.05 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pandas as pd | ||
from sklearn.model_selection import GridSearchCV, KFold | ||
from sklearn.tree import DecisionTreeRegressor | ||
from machine_learning.config import GridSearchConfig | ||
from machine_learning.prepare_for_training import organize_data | ||
|
||
def grid_search(array_of_df): | ||
# n_splits: number of subsets, | ||
# splits the train-val data into n_splits number of subsets for cross validation | ||
decisionTree = DecisionTreeRegressor() | ||
cv = KFold(n_splits = 5, shuffle = True, random_state = 42) | ||
|
||
grid_search = GridSearchCV(estimator = decisionTree, param_grid = GridSearchConfig.param_grid, | ||
cv = cv, scoring = 'friedman_mse') | ||
|
||
flight_dict_list = organize_data(array_of_df) | ||
|
||
# Extract features and target variable from flight_dict_list | ||
X_train_list = [flight['data'] for flight in flight_dict_list] | ||
y_train_list = [flight['power'] for flight in flight_dict_list] | ||
|
||
# Convert lists of DataFrames/Series into a single DataFrame and Series | ||
X_train = pd.concat(X_train_list, ignore_index=True) | ||
y_train = pd.concat(y_train_list, ignore_index=True) | ||
|
||
# Perform grid search | ||
grid_search.fit(X_train, y_train) | ||
|
||
best_params = grid_search.best_params_ | ||
best_score = grid_search.best_score_ | ||
|
||
best_regressor = DecisionTreeRegressor(**best_params) | ||
best_regressor.fit(X_train, y_train) | ||
|
||
return best_params, best_score, best_regressor | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
from machine_learning.pre_processing import pre_process_and_split_data | ||
from machine_learning.prepare_for_training import format_data | ||
from machine_learning.training import training_and_evaluating | ||
|
||
|
||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) | ||
flights_processed = os.path.join(PROJECT_ROOT, "data/datasets/rodrigues/flights_processed.csv") | ||
|
||
|
||
def train(): | ||
# pre_processing | ||
print("Pre-processing data...") | ||
input_file = os.path.join(PROJECT_ROOT, "data/datasets/rodrigues/flights_processed.csv") | ||
|
||
# organizing | ||
print("Splitting data...") | ||
train_data, test_data = pre_process_and_split_data(input_file) | ||
print("Formatting data...") | ||
train_data = format_data(train_data) | ||
test_data = format_data(test_data) | ||
|
||
# training | ||
print("Training...") | ||
training_and_evaluating(train_data, test_data) | ||
|
||
|
||
if __name__ == "__main__": | ||
train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,25 @@ | ||
def load_data () : | ||
pass | ||
import pandas as pd | ||
from sklearn.model_selection import train_test_split | ||
|
||
def extract_flights () : | ||
pass | ||
|
||
def load_data(file_path): | ||
df = pd.read_csv(file_path, sep=',', low_memory=False) | ||
return df | ||
|
||
def split_data () : | ||
pass | ||
|
||
def extract_flights(df): | ||
flights_list = [group for _, group in df.groupby('flight')] | ||
return flights_list | ||
|
||
|
||
def split_data(df, train_size=0.8, random_state=42): | ||
flights_list = extract_flights(df) | ||
train_data, test_data = train_test_split(flights_list, test_size=1-train_size, random_state=random_state) | ||
return train_data, test_data | ||
|
||
|
||
def pre_process_and_split_data(file_path): | ||
df = load_data(file_path) | ||
train_data, test_data = split_data(df) | ||
return train_data, test_data | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,69 @@ | ||
def target_variable_processing () : | ||
pass | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
from data_processing.energy_consumption.trapeziod_integration import add_power_to_df | ||
from sklearn.preprocessing import StandardScaler | ||
import pandas as pd | ||
|
||
def pre_process_flights () : | ||
pass | ||
|
||
def organize_data () : | ||
pass | ||
|
||
class TrainingDataset: | ||
pass | ||
def __init__(self, data): | ||
self.data = data | ||
self.scaler = StandardScaler() | ||
self.fit_scaler() # Laver og fitter en scaler ved initialisation. | ||
|
||
# Kigger på alt given data, ikke bare en enkelt dataframe. Konstruerer en passende scaler. | ||
def fit_scaler(self): | ||
# Concatenate all DataFrames in self.data into a single DataFrame | ||
df = pd.concat(self.data, ignore_index=True) | ||
|
||
features = df[ | ||
['time', 'wind_speed', 'wind_angle', | ||
'position_x', 'position_y', 'position_z', | ||
'orientation_x', 'orientation_y', 'orientation_z', 'orientation_w', | ||
'velocity_x', 'velocity_y', 'velocity_z', | ||
'angular_x', 'angular_y', 'angular_z', | ||
'linear_acceleration_x', 'linear_acceleration_y', 'linear_acceleration_z', | ||
'payload'] | ||
].values | ||
|
||
self.scaler.fit(features) # Fit the scaler on the entire training dataset | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
sample = self.data[index] | ||
|
||
# input features | ||
input_array = sample[ | ||
['time', 'wind_speed', 'wind_angle', | ||
'position_x', 'position_y', 'position_z', | ||
'orientation_x', 'orientation_y', 'orientation_z', 'orientation_w', | ||
'velocity_x', 'velocity_y', 'velocity_z', | ||
'angular_x', 'angular_y', 'angular_z', | ||
'linear_acceleration_x', 'linear_acceleration_y', 'linear_acceleration_z', | ||
'payload'] | ||
].values | ||
|
||
# Normalize input med scaleren fra initialization | ||
normalized_input = self.scaler.transform(input_array) | ||
|
||
# Output/target feature | ||
target_array = sample[ | ||
['battery_current', 'battery_voltage'] | ||
].values | ||
|
||
return normalized_input, target_array | ||
|
||
|
||
def format_data(array_of_df): | ||
formatted_array = [] | ||
for df in array_of_df: | ||
df['position_x'] = df['position_x'] - df['position_x'].iloc[0] | ||
df['position_y'] = df['position_y'] - df['position_y'].iloc[0] | ||
df['position_z'] = df['position_z'] - df['position_z'].iloc[0] | ||
|
||
df = df.drop(columns=['flight', 'speed', 'altitude', 'date', 'time_day', 'route']) | ||
df = add_power_to_df(df) | ||
formatted_array.append(df) | ||
return formatted_array |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
torch==2.2.1 | ||
pandas==2.2.1 | ||
scikit-learn==1.4.1.post1 | ||
numpy |
Oops, something went wrong.