Skip to content

Commit

Permalink
Remove unnecessary type dispatches from Variable::Impl ctor (pytorch#…
Browse files Browse the repository at this point in the history
…13630)

Summary:
This should improve the performance of wrapping a tensor in a Variable
Pull Request resolved: pytorch#13630

Reviewed By: ezyang

Differential Revision: D12944960

Pulled By: zou3519

fbshipit-source-id: 89fa78a563e46a747d851a90ffd1b5cf3cd2d0d7
  • Loading branch information
zou3519 authored and facebook-github-bot committed Nov 7, 2018
1 parent 2ae8e46 commit e70321e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torch/csrc/autograd/variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
namespace torch {
namespace autograd {
Variable::Impl::Impl(at::Tensor data, bool requires_grad, Edge gradient_edge)
: TensorImpl(data.type().type_id(), data.type().typeMeta(), data.type().allocator(), /* is variable */ true),
: TensorImpl(data.type_id(), data.dtype(), /*allocator=*/nullptr, /* is variable */ true),
data_(std::move(data)),
grad_fn_(std::move(gradient_edge.function)),
requires_grad_(false),
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
/// variables.
void set_requires_grad(bool requires_grad) override {
AT_CHECK(
!requires_grad || at::isFloatingType(type().scalarType()),
!requires_grad || at::isFloatingType(at::typeMetaToScalarType(dtype())),
"Only Tensors of floating point dtype can require gradients");
requires_grad_ = requires_grad;
}
Expand Down

0 comments on commit e70321e

Please sign in to comment.