-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
115 lines (98 loc) · 3.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from computer_vision.main import poster_analysis
from gender.gender_feature import *
import pandas as pd
import numpy as np
import pickle
from analysis.NLP_PCA import add_NLP_cols
from sklearn.preprocessing import StandardScaler
df = pd.read_json("data/tmdb_data.json")
df = df.head(31)
with open("xgb_reg.pkl","rb") as f:
clf = pickle.load(f)
print(clf)
N_PCA_NLP = 30 # Number of vectors in the PCA
YEARLY_INFLATION = 1.036
genres_global = []
def preprocessing(df):
global genres_global
merged_df = poster_analysis(df)
merged_df = genderAnalysis(merged_df)
print(merged_df.columns)
merged_df = add_NLP_cols(merged_df, N_PCA_NLP)
def parse_genres(x):
global genres_global
if x is np.nan:
return []
genres = []
for d in x:
genres.append(d["name"])
if d["name"] not in genres_global:
genres_global.append(d["name"])
return genres
merged_df["Genres"] = merged_df["genres"].apply(parse_genres)
for genre in genres_global:
merged_df[f"Is_" + genre] = merged_df["Genres"].apply(lambda x: genre in x)
merged_df["release_month"] = pd.to_datetime(merged_df["release_date"]).dt.month
merged_df["collection"] = merged_df["belongs_to_collection"] is None
merged_df["revenue_is_available"] = merged_df["revenue"] != 0
merged_df["budget is available"] = merged_df["budget"] != 0
merged_df["year"] = pd.to_datetime(merged_df["release_date"]).dt.year.astype(int)
merged_df["budget"] = merged_df["budget"] * (YEARLY_INFLATION ** (2022 - merged_df["year"]))
merged_df["revenue"] = merged_df["revenue"] * (YEARLY_INFLATION ** (2022 - merged_df["year"]))
columns_to_remove = [
"title",
"adult",
"imdb_id",
"overview",
"backdrop_path",
"genres",
"Genres",
"belongs_to_collection",
"homepage",
"original_language",
"original_title",
"poster_path",
"status",
"video",
"spoken_languages",
"tagline",
"release_date",
"directors",
"writers",
"cast",
"id"
]
columns_to_maybe_add_back = ["production_companies", "production_countries"]
merged_df = merged_df.drop(columns=columns_to_remove + columns_to_maybe_add_back)
columns_to_scale = [
"year",
"budget",
"popularity",
"revenue",
"runtime",
"vote_average",
"vote_count",
"release_month",
"directors_male",
"directors_female",
"writers_male",
"writers_female",
"cast_male",
"cast_female",
"nb_women",
"nb_men",
"area_women",
"area_men",
]
scaler = StandardScaler()
merged_df[columns_to_scale] = scaler.fit_transform(merged_df[columns_to_scale])
return merged_df
# require : imdb_id
def predict_bechdel(df):
df = preprocessing(df)
print(df.columns)
print(df.isna().any())
print(df.dtypes)
return clf.predict(df)
if __name__ == "__main__":
predict_bechdel(df)