diff --git a/lfw_eval.py b/lfw_eval.py index 16d70bc..def0b4b 100755 --- a/lfw_eval.py +++ b/lfw_eval.py @@ -85,36 +85,38 @@ def find_best_threshold(thresholds, predicts): landmark[l[0]] = [int(k) for k in l[1:]] with open('data/pairs.txt') as f: - pairs_lines = f.readlines()[1:] - -for i in range(6000): - p = pairs_lines[i].replace('\n','').split('\t') - - if 3==len(p): - sameflag = 1 - name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1])) - name2 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[2])) - if 4==len(p): - sameflag = 0 - name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1])) - name2 = p[2]+'/'+p[2]+'_'+'{:04}.jpg'.format(int(p[3])) - - img1 = alignment(cv2.imdecode(np.frombuffer(zfile.read(name1),np.uint8),1),landmark[name1]) - img2 = alignment(cv2.imdecode(np.frombuffer(zfile.read(name2),np.uint8),1),landmark[name2]) - - imglist = [img1,cv2.flip(img1,1),img2,cv2.flip(img2,1)] - for i in range(len(imglist)): - imglist[i] = imglist[i].transpose(2, 0, 1).reshape((1,3,112,96)) - imglist[i] = (imglist[i]-127.5)/128.0 - - img = np.vstack(imglist) - img = Variable(torch.from_numpy(img).float(),volatile=True).cuda() - output = net(img) - f = output.data - f1,f2 = f[0],f[2] - cosdistance = f1.dot(f2)/(f1.norm()*f2.norm()+1e-5) - predicts.append('{}\t{}\t{}\t{}\n'.format(name1,name2,cosdistance,sameflag)) - + _ = next(f, None) # skip header + + for line_no, line in enumerate(f, start=1): + p = line.replace('\n','').split('\t') + + if 3==len(p): + sameflag = 1 + name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1])) + name2 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[2])) + if 4==len(p): + sameflag = 0 + name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1])) + name2 = p[2]+'/'+p[2]+'_'+'{:04}.jpg'.format(int(p[3])) + + img1 = alignment(cv2.imdecode(np.frombuffer(zfile.read(name1),np.uint8),1),landmark[name1]) + img2 = alignment(cv2.imdecode(np.frombuffer(zfile.read(name2),np.uint8),1),landmark[name2]) + + imglist = [img1,cv2.flip(img1,1),img2,cv2.flip(img2,1)] + for i, image in enumerate(imglist): + image = image.transpose(2, 0, 1).reshape((1,3,112,96)) + imglist[i] = (image - 127.5) / 128.0 + + img = np.vstack(imglist) + img = Variable(torch.from_numpy(img).float(),volatile=True).cuda() + output = net(img) + f = output.data + f1,f2 = f[0],f[2] + cosdistance = f1.dot(f2)/(f1.norm()*f2.norm()+1e-5) + predicts.append('{}\t{}\t{}\t{}\n'.format(name1,name2,cosdistance,sameflag)) + + if line_no >= 6000: # break as soon as 6000 lines have been processed. + break accuracy = [] thd = []