Skip to content

Commit

Permalink
Fix tensor shapes for elementwise binary operations with broadcasting (
Browse files Browse the repository at this point in the history
…#1234)

* Fix shapes in keras

* remove extra lines

* Add python<3.12 requirement to fix CI errors

* Add python<=3.11 requirement to fix CI errors

* Tweak around requirement to fix CI errors

* Restore flexflow.yml

* Restore pytorch-gpu.yml

---------

Co-authored-by: Zhihao Jia <[email protected]>
  • Loading branch information
soumyac1999 and jiazhihao authored Dec 1, 2023
1 parent 457b5f2 commit 5501cf8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
33 changes: 30 additions & 3 deletions python/flexflow/keras/layers/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,16 @@ def __init__(self, **kwargs):
def _calculate_inout_shape(self, input_tensors):
assert len(input_tensors) == 2, "check input_tensors"
self.input_shape = input_tensors[0].batch_shape
self.output_shape = input_tensors[0].batch_shape
self.output_shape = list(input_tensors[0].batch_shape)
for i, d in enumerate(input_tensors[1].batch_shape):
if self.output_shape[i] != d:
if self.output_shape[i] == 1 or d == 1:
self.output_shape[i] *= d
else:
raise AssertionError(
f"Tensor with shape {input_tensors[0].batch_shape} and "
f"{input_tensors[1].batch_shape} cannot be added")
self.output_shape = tuple(self.output_shape)
fflogger.debug("add output %s" %( str(self.output_shape)))

def subtract(input_tensors):
Expand All @@ -114,7 +123,16 @@ def __init__(self, **kwargs):
def _calculate_inout_shape(self, input_tensors):
assert len(input_tensors) == 2, "check input_tensors"
self.input_shape = input_tensors[0].batch_shape
self.output_shape = input_tensors[0].batch_shape
self.output_shape = list(input_tensors[0].batch_shape)
for i, d in enumerate(input_tensors[1].batch_shape):
if self.output_shape[i] != d:
if self.output_shape[i] == 1 or d == 1:
self.output_shape[i] *= d
else:
raise AssertionError(
f"Tensor with shape {input_tensors[0].batch_shape} and "
f"{input_tensors[1].batch_shape} cannot be subtracted")
self.output_shape = tuple(self.output_shape)
fflogger.debug("subtract output %s" %( str(self.output_shape)))

def multiply(input_tensors):
Expand All @@ -127,7 +145,16 @@ def __init__(self, **kwargs):
def _calculate_inout_shape(self, input_tensors):
assert len(input_tensors) == 2, "check input_tensors"
self.input_shape = input_tensors[0].batch_shape
self.output_shape = input_tensors[0].batch_shape
self.output_shape = list(input_tensors[0].batch_shape)
for i, d in enumerate(input_tensors[1].batch_shape):
if self.output_shape[i] != d:
if self.output_shape[i] == 1 or d == 1:
self.output_shape[i] *= d
else:
raise AssertionError(
f"Tensor with shape {input_tensors[0].batch_shape} and "
f"{input_tensors[1].batch_shape} cannot be multiplied")
self.output_shape = tuple(self.output_shape)
fflogger.debug("multiply output %s" %( str(self.output_shape)))

class Maximum(_Merge):
Expand Down
15 changes: 14 additions & 1 deletion src/ops/element_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,21 @@ Tensor FFModel::binary(OperatorType op,
}
// Assert type match after broadcast
assert(ele->inputs[0]->data_type == ele->inputs[1]->data_type);

int numdim = in1->num_dims;
int dims[MAX_TENSOR_DIM];
for (int i = 0; i < numdim; i++) {
if (in1->dims[i] == 1) {
dims[i] = in2->dims[i];
} else if (in2->dims[i] == 1) {
dims[i] = in1->dims[i];
} else {
dims[i] = in1->dims[i];
}
}

ele->outputs[0] = create_tensor_legion_ordering(
in1->num_dims, in1->dims, ele->data_type, ele, 0, true /*create_grad*/);
in1->num_dims, dims, ele->data_type, ele, 0, true /*create_grad*/);
ele->add_int_property("inplace_a", inplace_a);
layers.push_back(ele);
return ele->outputs[0];
Expand Down

0 comments on commit 5501cf8

Please sign in to comment.