Skip to content

Commit

Permalink
testing in for loop, check type
Browse files Browse the repository at this point in the history
  • Loading branch information
JanRiedelsheimer committed Nov 19, 2024
1 parent 2d18d29 commit caefa95
Showing 1 changed file with 37 additions and 33 deletions.
70 changes: 37 additions & 33 deletions http_submission/sample_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
class HTTPScanpathModel(MySimpleScanpathModel):
def __init__(self, url):
self.url = url
self.log_density_url = url + "/conditional_log_density"
self.type_url = url + "/type"




@property
def log_density_url(self):
return self.url + "/conditional_log_density"

@property
def type_url(self):
return self.url + "/type"

def conditional_log_density(self, stimulus, x_hist, y_hist, t_hist, attributes=None, out=None):

# build request
pil_image = Image.fromarray(stimulus)
image_bytes = BytesIO()
Expand All @@ -48,44 +50,46 @@ def _convert_attribute(attribute):

return np.array(response.json()['log_density'])

def type(self):
response = requests.get(f"{self.type_url}")
return np.array(response.json())
def check_type(self):
response = requests.get(f"{self.type_url}").json()
if not response['type'] == 'ScanpathModel':
raise ValueError(f"invalid Model type: {response['type']}. Expected 'ScanpathModel'")
if not response['version'] in ['v1.0.0']:
raise ValueError(f"invalid Model type: {response['version']}. Expected 'v1.0.0'")


if __name__ == "__main__":
http_model = HTTPScanpathModel("http://localhost:4000")
type = http_model.type()
http_model.check_type()

# for testing
model = MySimpleScanpathModel()

# get MIT1003 dataset
stimuli, fixations = pysaliency.get_mit1003(location='pysaliency_datasets')
# fixation_index = 32185
fixation_index = 2
# density_list = []
# version_list = []
# for fixation_index in range(1000):

# get server response for one stimulus
server_density = http_model.conditional_log_density(
stimulus=stimuli.stimuli[fixations.n[fixation_index]],
x_hist=fixations.x_hist[fixation_index],
y_hist=fixations.y_hist[fixation_index],
t_hist=fixations.t_hist[fixation_index]
)
model_density = model.conditional_log_density(
stimulus=stimuli.stimuli[fixations.n[fixation_index]],
x_hist=fixations.x_hist[fixation_index],
y_hist=fixations.y_hist[fixation_index],
t_hist=fixations.t_hist[fixation_index]

)
# get server type
print(type)
eval_fixations = fixations[fixations.scanpath_history_length > 0]
server_density_list = []
model_density_list = []
for fixation_index in range(10):
# get server response for one stimulus
server_density = http_model.conditional_log_density(
stimulus=stimuli.stimuli[eval_fixations.n[fixation_index]],
x_hist=eval_fixations.x_hist[fixation_index],
y_hist=eval_fixations.y_hist[fixation_index],
t_hist=eval_fixations.t_hist[fixation_index]
)
# get model response
model_density = model.conditional_log_density(
stimulus=stimuli.stimuli[eval_fixations.n[fixation_index]],
x_hist=eval_fixations.x_hist[fixation_index],
y_hist=eval_fixations.y_hist[fixation_index],
t_hist=eval_fixations.t_hist[fixation_index]
)

server_density_list.append(server_density)
model_density_list.append(model_density)

# Testing

test = np.testing.assert_allclose(server_density, model_density)
test = np.testing.assert_allclose(server_density_list, model_density_list)
print(test)

0 comments on commit caefa95

Please sign in to comment.