Skip to content

Commit

Permalink
[SVD] Return np.ndarray when output_type="np" (huggingface#6507)
Browse files Browse the repository at this point in the history
[SVD] Fix output_type="np"
  • Loading branch information
yondonfu authored Jan 16, 2024
1 parent 181280b commit 8842bca
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):

outputs.append(batch_output)

if output_type == "np":
return np.stack(outputs)

return outputs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,23 @@ def test_inference_batch_single_identical(
def test_inference_batch_consistent(self):
pass

def test_np_output_type(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()

pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
inputs["output_type"] = "np"
output = pipe(**inputs).frames
self.assertTrue(isinstance(output, np.ndarray))
self.assertEqual(len(output.shape), 5)

def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
Expand Down

0 comments on commit 8842bca

Please sign in to comment.