-
Notifications
You must be signed in to change notification settings - Fork 0
/
split_users.py
48 lines (38 loc) · 1.72 KB
/
split_users.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 7 12:38:44 2022
@author: savvina
"""
#%%
import pandas as pd
import numpy as np
import random as rd
my_seed = 0
rd.seed(my_seed)
np.random.seed(my_seed)
#%%
def sort_user_dist(user_dist,pop_count, user_hist,pop_fraq,pop_item_fraq, by = "pop_fraq"):
user_dist = user_dist.sort_index()
user_dist_sorted = pd.DataFrame(data = user_dist)
user_dist_sorted.columns = ["count"]
user_dist_sorted["pop_count"] = pop_count
user_dist_sorted["user_hist"] = user_hist
user_dist_sorted["pop_fraq"] = pop_fraq
user_dist_sorted["pop_item_fraq"] = pop_item_fraq
user_dist_sorted = user_dist_sorted.sort_values(by=[by])
return user_dist_sorted
def split(user_dist_sorted, top_fraction):
low, med, high = np.split(user_dist_sorted, [int(top_fraction*len(user_dist_sorted)), int((1-top_fraction)*len(user_dist_sorted))])
return low, med, high
def read(low_user_file, medium_user_file, high_user_file):
low_users = pd.read_csv(low_user_file, sep=',').set_index('user_id')
medium_users = pd.read_csv(medium_user_file, sep=',').set_index('user_id')
high_users = pd.read_csv(high_user_file, sep=',').set_index('user_id')
no_users = len(low_users) + len(medium_users) + len(high_users)
print('No. of users: ' + str(no_users))
mainstreaminess = "M_global_R_APC"
print('Average mainstreaminess per user for low: ' + str(low_users[mainstreaminess].mean()))
print('Average mainstreaminess per user for med: ' + str(medium_users[mainstreaminess].mean()))
print('Average mainstreaminess per user for high: ' + str(high_users[mainstreaminess].mean()))
return no_users, low_users, medium_users, high_users