Skip to content

Commit

Permalink
add support for custom MTTKRP kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
solomonik committed Dec 8, 2019
1 parent 4feba59 commit e862f7c
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 61 deletions.
31 changes: 14 additions & 17 deletions als.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,30 @@ def __init__(self, f1, f2, omega, string, use_MTTKRP):
self.string = string
self.use_MTTKRP = use_MTTKRP

def mul(self, idx, sk):
def MTTKRP_TTTP(self, sk, out):
if self.use_MTTKRP:
if self.string=="U":
out = ctf.tensor([self.omega.shape[0], self.f1.shape[1]])
ctf.MTTKRP(ctf.TTTP(self.omega, [sk, self.f1, self.f2]),[out,self.f1,self.f2],0)
elif self.string=="V":
out = ctf.tensor([self.omega.shape[1], self.f1.shape[1]])
ctf.MTTKRP(ctf.TTTP(self.omega, [self.f1, sk, self.f2]),[self.f1,out,self.f2],1)
elif self.string=="W":
out = ctf.tensor([self.omega.shape[2], self.f1.shape[1]])
ctf.MTTKRP(ctf.TTTP(self.omega, [self.f1, self.f2, sk]),[self.f1,self.f2,out],2)
else:
return ValueError("Invalid string for MTTKRP_mul")
return out.i(idx)
print("Invalid string for implicit MTTKRP_TTTP")
else:
idx = "ir"
if self.string=="U":
return self.f1.i("J"+idx[1]) \
*self.f2.i("K"+idx[1]) \
*ctf.TTTP(self.omega, [sk, self.f1, self.f2]).i(idx[0]+"JK")
out.i(idx) << self.f1.i("J"+idx[1]) \
*self.f2.i("K"+idx[1]) \
*ctf.TTTP(self.omega, [sk, self.f1, self.f2]).i(idx[0]+"JK")
if self.string=="V":
return self.f1.i("I"+idx[1]) \
*self.f2.i("K"+idx[1]) \
*ctf.TTTP(self.omega, [self.f1, sk, self.f2]).i("I"+idx[0]+"K")
out.i(idx) << self.f1.i("I"+idx[1]) \
*self.f2.i("K"+idx[1]) \
*ctf.TTTP(self.omega, [self.f1, sk, self.f2]).i("I"+idx[0]+"K")
if self.string=="W":
return self.f1.i("I"+idx[1]) \
*self.f2.i("J"+idx[1]) \
*ctf.TTTP(self.omega, [self.f1, self.f2, sk]).i("IJ"+idx[0])
out.i(idx) << self.f1.i("I"+idx[1]) \
*self.f2.i("J"+idx[1]) \
*ctf.TTTP(self.omega, [self.f1, self.f2, sk]).i("IJ"+idx[0])

def CG(A,b,x0,r,regParam,I,is_implicit=False):

Expand All @@ -53,7 +50,7 @@ def CG(A,b,x0,r,regParam,I,is_implicit=False):

Ax0 = ctf.tensor((I,r))
if is_implicit:
Ax0.i("ir") << A.mul("ir",x0)
A.MTTKRP_TTTP(x0,Ax0)
else:
Ax0.i("ir") << A.i("irl")*x0.i("il")
Ax0 += regParam*x0
Expand All @@ -66,7 +63,7 @@ def CG(A,b,x0,r,regParam,I,is_implicit=False):
t_cg_bmvec.start()
t0 = time.time()
if is_implicit:
Ask.i("ir") << A.mul("ir",sk)
A.MTTKRP_TTTP(sk,Ask)
else:
Ask.i("ir") << A.i("irl")*sk.i("il")
t1 = time.time()
Expand Down
4 changes: 2 additions & 2 deletions arg_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def add_general_arguments(parser):
metavar='int',
help='whether to use function tensor as test problem (default: 0, i.e. False, use explicit low CP-rank sampled tensor)')
parser.add_argument(
'--use-CCD-TTTP',
'--use-MTTKRP',
type=int,
default=1,
metavar='int',
help='whether to use TTTP for CCD contractions (default: 1, i.e. Yes)')
help='whether to use special MTTKRP kernel (default: 1, i.e. Yes)')
parser.add_argument(
'--tensor-file',
type=str,
Expand Down
32 changes: 19 additions & 13 deletions ccd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_objective(T,U,V,W,omega,regParam):
return [objective, RMSE]


def run_CCD(T,U,V,W,omega,regParam,num_iter,time_limit,objective_frequency,use_TTTP=True):
def run_CCD(T,U,V,W,omega,regParam,num_iter,time_limit,objective_frequency,use_MTTKRP=True):
U_vec_list = []
V_vec_list = []
W_vec_list = []
Expand Down Expand Up @@ -90,17 +90,19 @@ def run_CCD(T,U,V,W,omega,regParam,num_iter,time_limit,objective_frequency,use_T
print('updating U[:,{}]'.format(f))

t0 = time.time()
if use_TTTP:
if use_MTTKRP:
alphas = ctf.tensor(R.shape[0])
ctf.einsum('ijk -> i', ctf.TTTP(R, [None, V_vec_list[f], W_vec_list[f]]),out=alphas)
#ctf.einsum('ijk -> i', ctf.TTTP(R, [None, V_vec_list[f], W_vec_list[f]]),out=alphas)
ctf.MTTKRP(R, [alphas, V_vec_list[f], W_vec_list[f]], 0)
else:
alphas = ctf.einsum('ijk, j, k -> i', R, V_vec_list[f], W_vec_list[f])

t1 = time.time()

if use_TTTP:
if use_MTTKRP:
betas = ctf.tensor(R.shape[0])
ctf.einsum('ijk -> i', ctf.TTTP(omega, [None, V_vec_list[f]*V_vec_list[f], W_vec_list[f]*W_vec_list[f]]),out=betas)
#ctf.einsum('ijk -> i', ctf.TTTP(omega, [None, V_vec_list[f]*V_vec_list[f], W_vec_list[f]*W_vec_list[f]]),out=betas)
ctf.MTTKRP(omega, [betas, V_vec_list[f]*V_vec_list[f], W_vec_list[f]*W_vec_list[f]], 0)
else:
betas = ctf.einsum('ijk, j, j, k, k -> i', omega, V_vec_list[f], V_vec_list[f], W_vec_list[f], W_vec_list[f])

Expand All @@ -117,15 +119,17 @@ def run_CCD(T,U,V,W,omega,regParam,num_iter,time_limit,objective_frequency,use_T
# update V[:,f]
if glob_comm.rank() == 0 and status_prints == True:
print('updating V[:,{}]'.format(f))
if use_TTTP:
if use_MTTKRP:
alphas = ctf.tensor(R.shape[1])
ctf.einsum('ijk -> j', ctf.TTTP(R, [U_vec_list[f], None, W_vec_list[f]]),out=alphas)
#ctf.einsum('ijk -> j', ctf.TTTP(R, [U_vec_list[f], None, W_vec_list[f]]),out=alphas)
ctf.MTTKRP(R, [U_vec_list[f], alphas, W_vec_list[f]], 1)
else:
alphas = ctf.einsum('ijk, i, k -> j', R, U_vec_list[f], W_vec_list[f])

if use_TTTP:
if use_MTTKRP:
betas = ctf.tensor(R.shape[1])
ctf.einsum('ijk -> j', ctf.TTTP(omega, [U_vec_list[f]*U_vec_list[f], None, W_vec_list[f]*W_vec_list[f]]),out=betas)
#ctf.einsum('ijk -> j', ctf.TTTP(omega, [U_vec_list[f]*U_vec_list[f], None, W_vec_list[f]*W_vec_list[f]]),out=betas)
ctf.MTTKRP(omega, [U_vec_list[f]*U_vec_list[f], betas, W_vec_list[f]*W_vec_list[f]], 1)
else:
betas = ctf.einsum('ijk, i, i, k, k -> j', omega, U_vec_list[f], U_vec_list[f], W_vec_list[f], W_vec_list[f])

Expand All @@ -135,15 +139,17 @@ def run_CCD(T,U,V,W,omega,regParam,num_iter,time_limit,objective_frequency,use_T

if glob_comm.rank() == 0 and status_prints == True:
print('updating W[:,{}]'.format(f))
if use_TTTP:
if use_MTTKRP:
alphas = ctf.tensor(R.shape[2])
ctf.einsum('ijk -> k', ctf.TTTP(R, [U_vec_list[f], V_vec_list[f], None]),out=alphas)
#ctf.einsum('ijk -> k', ctf.TTTP(R, [U_vec_list[f], V_vec_list[f], None]),out=alphas)
ctf.MTTKRP(R, [U_vec_list[f], W_vec_list[f], alphas], 2)
else:
alphas = ctf.einsum('ijk, i, j -> k', R, U_vec_list[f], V_vec_list[f])

if use_TTTP:
if use_MTTKRP:
betas = ctf.tensor(R.shape[2])
ctf.einsum('ijk -> k', ctf.TTTP(omega, [U_vec_list[f]*U_vec_list[f], V_vec_list[f]*V_vec_list[f], None]),out=betas)
#ctf.einsum('ijk -> k', ctf.TTTP(omega, [U_vec_list[f]*U_vec_list[f], V_vec_list[f]*V_vec_list[f], None]),out=betas)
ctf.MTTKRP(omega, [U_vec_list[f]*U_vec_list[f], V_vec_list[f]*V_vec_list[f], betas], 2)
else:
betas = ctf.einsum('ijk, i, i, j, j -> k', omega, U_vec_list[f], U_vec_list[f], V_vec_list[f], V_vec_list[f])

Expand Down
13 changes: 7 additions & 6 deletions combined_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def create_function_tensor(I, J, K, sp_frac, use_sp_rep):
sample_frac_SGD = args.sample_frac_SGD
use_func_tsr = args.function_tensor
tensor_file = args.tensor_file
use_CCD_TTTP = args.use_CCD_TTTP
use_MTTKRP = args.use_MTTKRP


if use_func_tsr == True:
Expand All @@ -134,6 +134,7 @@ def create_function_tensor(I, J, K, sp_frac, use_sp_rep):
print("Dense tensor shape is",T.shape)

print("Computing tensor completion with CP rank",R)
print("use_MTTKRP set to",use_MTTKRP)

omega = getOmega(T)
U = ctf.random.random((I, R))
Expand All @@ -152,7 +153,7 @@ def create_function_tensor(I, J, K, sp_frac, use_sp_rep):
V_copy = ctf.copy(V)
W_copy = ctf.copy(W)

getALS_CG(T,U_copy,V_copy,W_copy,reg_ALS,omega,I,J,K,R,block_size_ALS_imp,numiter_ALS_imp,err_thresh,time_limit,True)
getALS_CG(T,U_copy,V_copy,W_copy,reg_ALS,omega,I,J,K,R,block_size_ALS_imp,numiter_ALS_imp,err_thresh,time_limit,True,use_MTTKRP)

if numiter_ALS_exp > 0:
if ctf.comm().rank() == 0:
Expand All @@ -162,18 +163,18 @@ def create_function_tensor(I, J, K, sp_frac, use_sp_rep):
V_copy = ctf.copy(V)
W_copy = ctf.copy(W)

getALS_CG(T,U_copy,V_copy,W_copy,reg_ALS,omega,I,J,K,R,block_size_ALS_exp,numiter_ALS_exp,err_thresh,time_limit,False)
getALS_CG(T,U_copy,V_copy,W_copy,reg_ALS,omega,I,J,K,R,block_size_ALS_exp,numiter_ALS_exp,err_thresh,time_limit,False,use_MTTKRP)


if numiter_CCD > 0:
if ctf.comm().rank() == 0:
print("Performing up to",numiter_CCD,"iterations, or reaching time limit of",time_limit,"seconds of CCD with use of TTTP for contractions set to use_CCD_TTTP =",use_CCD_TTTP)
print("Performing up to",numiter_CCD,"iterations, or reaching time limit of",time_limit,"seconds of CCD")
print("CCD regularization parameter is",reg_CCD)
U_copy = ctf.copy(U)
V_copy = ctf.copy(V)
W_copy = ctf.copy(W)

run_CCD(T,U_copy,V_copy,W_copy,omega,reg_CCD,numiter_CCD,time_limit,objfreq_CCD,use_CCD_TTTP)
run_CCD(T,U_copy,V_copy,W_copy,omega,reg_CCD,numiter_CCD,time_limit,objfreq_CCD,use_MTTKRP)


if numiter_SGD > 0:
Expand All @@ -184,7 +185,7 @@ def create_function_tensor(I, J, K, sp_frac, use_sp_rep):
V_copy = ctf.copy(V)
W_copy = ctf.copy(W)

sparse_SGD(T, U_copy, V_copy, W_copy, reg_SGD, omega, I, J, K, R, learn_rate, sample_frac_SGD, numiter_SGD, err_thresh, time_limit, objfreq_SGD)
sparse_SGD(T, U_copy, V_copy, W_copy, reg_SGD, omega, I, J, K, R, learn_rate, sample_frac_SGD, numiter_SGD, err_thresh, time_limit, objfreq_SGD,use_MTTKRP)



35 changes: 12 additions & 23 deletions sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

INDEX_STRING = "ijklmnopq"

def sparse_update(T, factors, Lambda, sizes, rank, stepSize, sample_rate, times):
def sparse_update(T, factors, Lambda, sizes, rank, stepSize, sample_rate, times, use_MTTKRP):
starting_time = time.time()
t_go = ctf.timer("SGD_getOmega")
t_go.start()
Expand All @@ -22,29 +22,30 @@ def sparse_update(T, factors, Lambda, sizes, rank, stepSize, sample_rate, times)
R = ctf.tensor(copy=T) #ctf.tensor(tuple(sizes), sp = True)
times[2] += time.time() - starting_time
for i in range(dimension):
tup_list = [factors[i].i(indexes[i] + "r") for i in range(dimension)]
#R.i(indexes) << T.i(indexes) - omega.i(indexes) * reduce(lambda x, y: x * y, tup_list)
starting_time = time.time()
R.i(indexes) << -1.* ctf.TTTP(omega, factors).i(indexes)
times[3] += time.time() - starting_time
starting_time = time.time()
#H = ctf.tensor(tuple((sizes[:i] + sizes[i + 1:] + [rank])))
times[4] += time.time() - starting_time
starting_time = time.time()
#H.i(indexes[:i] + indexes[i + 1:] + "r") <<
Hterm = reduce(lambda x, y: x * y, tup_list[:i] + tup_list[i + 1:])
times[5] += time.time() - starting_time
starting_time = time.time()
t_ctr = ctf.timer("SGD_main_contraction")
t_ctr.start()
(1- stepSize * 2 * Lambda * sample_rate)*factors[i].i(indexes[i] + "r") << stepSize * Hterm * R.i(indexes)
if use_MTTKRP:
new_fi = (1- stepSize * 2 * Lambda * sample_rate)*factors[i]
ctf.MTTKRP(R, factors, i)
stepSize*factors[i].i("ir") << new_fi.i("ir")
else:
tup_list = [factors[i].i(indexes[i] + "r") for i in range(dimension)]
Hterm = reduce(lambda x, y: x * y, tup_list[:i] + tup_list[i + 1:])
(1- stepSize * 2 * Lambda * sample_rate)*factors[i].i(indexes[i] + "r") << stepSize * Hterm * R.i(indexes)
t_ctr.stop()
times[6] += time.time() - starting_time
if i < dimension - 1:
R = ctf.tensor(copy=T)
#return ctf.vecnorm(R) + (sum([ctf.vecnorm(f) for f in factors])) * Lambda

def sparse_SGD(T, U, V, W, Lambda, omega, I, J, K, r, stepSize, sample_rate, num_iter, errThresh, time_limit, work_cycle):
def sparse_SGD(T, U, V, W, Lambda, omega, I, J, K, r, stepSize, sample_rate, num_iter, errThresh, time_limit, work_cycle, use_MTTKRP):
times = [0 for i in range(7)]

iteration_count = 0
Expand All @@ -58,11 +59,9 @@ def sparse_SGD(T, U, V, W, Lambda, omega, I, J, K, r, stepSize, sample_rate, num
starting_time = time.time()
dtime = 0
R.i("ijk") << T.i("ijk") - ctf.TTTP(omega, [U, V, W]).i("ijk")
# R.i("ijk") << T.i("ijk") - U.i("iu") * V.i("ju") * W.i("ku") * omega.i("ijk")
curr_err_norm = ctf.vecnorm(R) + (ctf.vecnorm(U) + ctf.vecnorm(V) + ctf.vecnorm(W)) * Lambda
times[0] += time.time() - starting_time
norm = [curr_err_norm]
#work_cycle = 1 #int(1.0 / sample_rate)
step = stepSize * 0.5
t_obj_calc = 0.

Expand All @@ -73,16 +72,15 @@ def sparse_SGD(T, U, V, W, Lambda, omega, I, J, K, r, stepSize, sample_rate, num
sampled_T.sample(sample_rate)
times[1] += time.time() - starting_time

sparse_update(sampled_T, [U, V, W], Lambda, [I, J, K], r, stepSize * 0.5 + step, sample_rate, times)
step *= 0.99
sparse_update(sampled_T, [U, V, W], Lambda, [I, J, K], r, stepSize * 0.5 + step, sample_rate, times, use_MTTKRP)
#step *= 0.99
sampled_T.set_zero()

if iteration_count % work_cycle == 0:
duration = time.time() - start_time - t_obj_calc
t_b_obj = time.time()
total_count += 1
R.set_zero()
#R.i("ijk") << T.i("ijk") - U.i("iu") * V.i("ju") * W.i("ku") * omega.i("ijk")
R.i("ijk") << T.i("ijk") - ctf.TTTP(omega, [U, V, W]).i("ijk")
diff_norm = ctf.vecnorm(R)
RMSE = diff_norm/(nnz_tot**.5)
Expand All @@ -91,11 +89,6 @@ def sparse_SGD(T, U, V, W, Lambda, omega, I, J, K, r, stepSize, sample_rate, num
print('Objective after',duration,'seconds (',iteration_count,'iterations) is: {}'.format(next_err_norm))
print('RMSE after',duration,'seconds (',iteration_count,'iterations) is: {}'.format(RMSE))
t_obj_calc += time.time() - t_b_obj
#print(curr_err_norm, next_err_norm, diff_norm)
#if ctf.comm().rank() == 0:
# x = time.time() - start_time
# print(diff_norm, x, total_count, x/total_count)
# # print(times)

if abs(curr_err_norm - next_err_norm) < errThresh:
break
Expand All @@ -107,10 +100,6 @@ def sparse_SGD(T, U, V, W, Lambda, omega, I, J, K, r, stepSize, sample_rate, num
if ctf.comm().rank() == 0:
print('SGD amortized seconds per sweep: {}'.format(duration/(iteration_count*sample_rate)))
print("Time/SGD iteration: {}".format(duration/iteration_count))
#curr_err_norm = ctf.vecnorm(R) + (ctf.vecnorm(U) + ctf.vecnorm(V) + ctf.vecnorm(W)) * Lambda
#R.i("ijk") << T.i("ijk") - ctf.TTTP(omega, [U, V, W]).i("ijk")
#if ctf.comm().rank() == 0:
# print('Objective after',duration,'seconds is: {}'.format(curr_err_norm))
return norm

def getOmega(T):
Expand Down

0 comments on commit e862f7c

Please sign in to comment.