Skip to content

Commit

Permalink
Fix pooling when the operator is None and the input is 1d (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
ermanok authored Jun 17, 2024
1 parent 0f3dd3a commit 792d51b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion izer/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class MakefileMapping(MutableMapping):
'SRCS +=', etc. is done "on the fly" through this mapping object.
The key:value rules are as follows:
* The key is lowercase, and any '.' is replaced is '_'. (revelant for JSON parsing)
* The key is lowercase, and any '.' is replaced is '_'. (relevant for JSON parsing)
* The value is a tuple with 2 items:
* index 0: The template string
Expand Down
6 changes: 4 additions & 2 deletions izer/izer.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ def main():
f'{pixelcount}) does not match the '
f'sum of all pixels ({pixels}) in the layer\'s `in_sequences`.')
input_dim[ll] = conf_input_dim[ll]
if operator[ll] != op.CONV1D:

if operator[ll] != op.CONV1D and (operator[ll] != op.NONE or input_dim[ll][1] != 1):
if pool_stride[ll][0] != pool_stride[ll][1]:
eprint(f'{layer_pfx(ll)}{op.string(operator[ll])} does not support '
f'non-square pooling stride (currently set to '
Expand All @@ -512,6 +513,7 @@ def main():
(input_dim[ll][1] + pool_stride[ll][1] - pool[ll][1]
- pool_dilation[ll][1] + 1) // pool_stride[ll][1]]
else:
pool[ll][1] = 1
pooled_size = [(input_dim[ll][0] + pool_stride[ll][0] - pool[ll][0]
- pool_dilation[ll][0] + 1) // pool_stride[ll][0],
1]
Expand All @@ -521,7 +523,7 @@ def main():
eprint(f'{layer_pfx(ll)}Pooling or zero-padding results in a zero data '
f'dimension (input {input_dim[ll]}, result {pooled_dim[ll]}).')

if operator[ll] != op.CONV1D:
if operator[ll] != op.CONV1D and (operator[ll] != op.NONE or input_dim[ll][1] != 1):
if stride[ll][0] != stride[ll][1]:
eprint(f'{layer_pfx(ll)}{op.string(operator[ll])} does not support '
f'non-square stride (currently set to {stride[ll][0]}x{stride[ll][1]}).')
Expand Down

0 comments on commit 792d51b

Please sign in to comment.