Skip to content

Commit

Permalink
codemod tensor.type().is_cuda(), tensor.type().is_sparse() (pytorch#1…
Browse files Browse the repository at this point in the history
…3590)

Summary:
Followup to pytorch#12841

Changed these to not require type dispatch:
tensor.type().is_cuda() -> tensor.is_cuda()
tensor.type().is_sparse() -> tensor.is_sparse()
isVariable(tensor.type()) -> tensor.is_variable()

This probably does not affect performance
very much in most cases but it is nice to have.
Pull Request resolved: pytorch#13590

Reviewed By: ezyang

Differential Revision: D12929301

Pulled By: zou3519

fbshipit-source-id: 8ac5c6200c579dd7a44fb4ee58fc9bb170feb1d7
  • Loading branch information
zou3519 authored and facebook-github-bot committed Nov 7, 2018
1 parent e70321e commit e60a7c2
Show file tree
Hide file tree
Showing 23 changed files with 44 additions and 44 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
int64_t device_id = 0;
if (src.type().is_cuda()) {
if (src.is_cuda()) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def TypedDict(name, attrs, total=True): # type: ignore
""")

SPARSE_CHECK = CodeTemplate("""\
if(${check_name}.type().is_sparse()) {
if(${check_name}.is_sparse()) {
return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals});
}""")

Expand Down
22 changes: 11 additions & 11 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool {
if (!detail::getCUDAHooks().compiledWithCuDNN()) {
return false;
}
if (!input.type().is_cuda() || !cudnn_enabled) {
if (!input.is_cuda() || !cudnn_enabled) {
return false;
}
if (deterministic && is_dilated()) {
Expand All @@ -125,7 +125,7 @@ auto ConvParams::use_miopen(const at::Tensor& input) const -> bool {

return ((input.type().scalarType() == at::kFloat) || (input.type().scalarType() == at::kHalf))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& input.type().is_cuda()
&& input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
&& !transposed
Expand All @@ -150,7 +150,7 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {
// a depthwise multiplier)
auto ConvParams::is_depthwise(
const at::Tensor& input, const at::Tensor& weight) const -> bool {
return input.type().is_cuda() &&
return input.is_cuda() &&
!transposed &&
input.ndimension() == 4 &&
input.size(1) == groups &&
Expand Down Expand Up @@ -450,7 +450,7 @@ at::Tensor _convolution_nogroup(
input, weight, kernel_size, bias,
stride, padding);
}
} else if (dim == 5 && (input.type().is_cuda() || dilated)) {
} else if (dim == 5 && (input.is_cuda() || dilated)) {
return at::thnn_conv_dilated3d(
input, weight, kernel_size, bias,
stride, padding, dilation);
Expand Down Expand Up @@ -498,14 +498,14 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
// Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
Tensor ggO;
if (ggI.defined()) {
if (weight.type().is_cuda()) {
if (weight.is_cuda()) {
weight = weight.contiguous();
}
ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled);
}

if (ggW.defined()) {
if (ggW.type().is_cuda()) {
if (ggW.is_cuda()) {
ggW = ggW.contiguous();
}
auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled);
Expand Down Expand Up @@ -553,7 +553,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
Tensor gWt;
// Compute conv
if (groups == 1) {
if (gOt.type().is_cuda()) {
if (gOt.is_cuda()) {
gOt = gOt.contiguous();
}

Expand All @@ -569,7 +569,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
for (int g = 0; g < groups; ++g) {
auto ggIt_g = subvariable(ggIt, 0, groups, g);
auto gOt_g = subvariable(gOt, 0, groups, g);
if (gOt_g.type().is_cuda()) {
if (gOt_g.is_cuda()) {
gOt_g = gOt_g.contiguous();
}

Expand Down Expand Up @@ -609,7 +609,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
gi_conv_params.transposed = !params.transposed;

if (params.transposed) {
if (gO.type().is_cuda()) {
if (gO.is_cuda()) {
gO = gO.contiguous();
}
gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled);
Expand Down Expand Up @@ -662,7 +662,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(

Tensor gIt;
if (params.groups == 1) {
if (gOt.type().is_cuda()) {
if (gOt.is_cuda()) {
gOt = gOt.contiguous();
}

Expand All @@ -672,7 +672,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
for (int g = 0; g < groups; ++g) {
auto ggWt_g = subvariable(ggWt, 1, groups, g);
auto gOt_g = subvariable(gOt, 0, groups, g);
if (gOt_g.type().is_cuda()) {
if (gOt_g.is_cuda()) {
gOt_g = gOt_g.contiguous();
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ Tensor batch_norm(
}

bool use_cudnn = false;
use_cudnn = (input.type().is_cuda()
use_cudnn = (input.is_cuda()
&& (input.type().scalarType() != at::kHalf
|| weight.type().scalarType() == at::kFloat)
&& weight.defined() && bias.defined()
Expand All @@ -262,7 +262,7 @@ Tensor batch_norm(
training, momentum, eps));
}

bool use_miopen = (input.type().is_cuda()
bool use_miopen = (input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.type().scalarType() != at::kDouble
&& (input.type().scalarType() == weight.type().scalarType())
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ Tensor& norm_out(Tensor &result, const Tensor &self, Scalar p, int64_t dim, bool
}

Tensor _norm(const Tensor &self, Scalar p) {
if (self.type().is_sparse()) {
if (self.is_sparse()) {
return at::native_norm(self, p);
} else {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ Tensor empty_like(const Tensor& self) {
}

Tensor empty_like(const Tensor& self, const TensorOptions& options) {
if (options.layout() == kSparse && self.type().is_sparse()) {
if (options.layout() == kSparse && self.is_sparse()) {
auto res = at::empty({0}, options); // to be resized
res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
return res;
Expand Down Expand Up @@ -523,7 +523,7 @@ Tensor zeros_like(const Tensor& self) {
}

Tensor zeros_like(const Tensor& self, const TensorOptions& options) {
if (options.layout() == kSparse && self.type().is_sparse()) {
if (options.layout() == kSparse && self.is_sparse()) {
auto res = at::empty({0}, options); // to be resized
res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
return res;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ Tensor repeat(const Tensor& self, IntList repeats) {
}

Tensor reshape(const Tensor& self, IntList proposed_shape) {
if (self.type().is_sparse()) {
if (self.is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
}
auto shape = infer_size(proposed_shape, self.numel());
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/TypeProperties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace at { namespace native {

bool is_cuda(const Tensor& self) {
return self.type().is_cuda();
return self.is_cuda();
}

bool is_distributed(const Tensor& self) {
Expand All @@ -31,7 +31,7 @@ bool is_signed(const Tensor &self) {
}

bool is_sparse(const Tensor& self) {
return self.type().is_sparse();
return self.is_sparse();
}

Tensor type_as(const Tensor& self, const Tensor& other) {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/WeightNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Tensor _weight_norm
auto v = v_in.contiguous();
auto g = g_in.contiguous();

bool can_use_fused = v.type().is_cuda() && (dim == 0 || dim == v.dim() - 1);
bool can_use_fused = v.is_cuda() && (dim == 0 || dim == v.dim() - 1);

if (can_use_fused) {
// weight_norm does not have a derivative defined for it, so this will route back through
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/VariableTypeManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ const Variable & VariableType::checked_cast_variable(const Tensor & t, const cha
if (!t.defined()) {
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
}
if (!isVariableType(t.type())) {
if (!t.is_variable()) {
AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'");
}
return as_variable_ref(t);
Expand All @@ -192,7 +192,7 @@ Variable & VariableType::checked_cast_variable(Tensor & t, const char * name, in
if (!t.defined()) {
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
}
if (!isVariableType(t.type())) {
if (!t.is_variable()) {
AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'");
}
return as_variable_ref(t);
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/functions/accumulate_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
if (!grad.defined()) {
// under following condition, we can avoid clone()
if (!GradMode::is_enabled()
&& !new_grad.type().is_sparse()
&& !new_grad.is_sparse()
&& new_grad.is_contiguous()
&& new_grad.use_count() == 1) {
// first check it is in first-order grad only mode
Expand All @@ -60,7 +60,7 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
// the users. Thanks to this case we can avoid changing the grad tensor,
// a thing never promised and documented, but used in some hacks seen
// on the internet.
if (grad_variable.type().is_sparse() && !new_grad.type().is_sparse()) {
if (grad_variable.is_sparse() && !new_grad.is_sparse()) {
grad_variable.data() = new_grad.data() + grad_variable.data();
} else {
grad_variable.data() += new_grad.data();
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/input_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void InputBuffer::add(size_t pos, Variable var) {
} else {
at::DeviceGuard device_guard(var);
// ATen doesn't route sparse additions correctly...
if (old_var.type().is_sparse()) {
if (old_var.is_sparse()) {
buffer[pos] = var + old_var;
} else {
buffer[pos] = old_var + var;
Expand All @@ -32,7 +32,7 @@ void InputBuffer::add(size_t pos, Variable var) {

auto InputBuffer::device() const -> int {
for (auto& var : buffer) {
if (var.defined() && var.type().is_cuda()) {
if (var.defined() && var.is_cuda()) {
return var.get_device();
}
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/python_hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject
throw std::runtime_error(ss.str());
}

if (original.type().is_cuda() != result.type().is_cuda()) {
if (original.is_cuda() != result.is_cuda()) {
std::stringstream ss;
auto name = hook_name(hook);
ss << "hook '" << name << "' has changed the type of value";
if (original.type().is_cuda()) {
if (original.is_cuda()) {
ss << " (was CUDA tensor got CPU tensor)";
} else {
ss << " (was CPU tensor got CUDA tensor)";
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/python_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)

THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse,
"assigned grad has data of a different type");
if (var.type().is_cuda()) {
if (var.is_cuda()) {
THPUtils_assertRet(-1, grad.get_device() == var.get_device(),
"assigned grad has data located on a different device");
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/cuda/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ at::Tensor gather(
std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
for (const auto& tensor : tensors) {
AT_CHECK(
tensor.type().is_cuda(), "Gather expects all inputs to have CUDA type");
tensor.is_cuda(), "Gather expects all inputs to have CUDA type");
AT_ASSERT(tensor.ndimension() == static_cast<int64_t>(expected_size.size()));
expected_size[dim] = tensor.size(dim);
for (size_t dimension = 0; dimension < expected_size.size(); ++dimension) {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ void _check_inputs(
auto input = inputs[i];
auto output = outputs[i];

if (!(input.type().is_cuda() && !input.type().is_sparse() &&
output.type().is_cuda() && !output.type().is_sparse())) {
if (!(input.is_cuda() && !input.is_sparse() &&
output.is_cuda() && !output.is_sparse())) {
throw std::runtime_error(
"input and output elements have to be cuda dense Tensors");
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/argument_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct ArgumentSpec {
if ((arg.defined_ = t.defined())) {
arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad();
arg.dim_ = t.dim();
arg.device_ = t.type().is_cuda() ? t.get_device() : -1;
arg.device_ = t.is_cuda() ? t.get_device() : -1;
arg.type_ = static_cast<unsigned>(t.type().scalarType());
}

Expand Down Expand Up @@ -203,7 +203,7 @@ struct CompleteArgumentSpec {
pod.defined = t.defined();
if (pod.defined) {
pod.type = static_cast<int>(t.type().scalarType());
pod.device = (!t.type().is_cuda()) ? -1 : t.get_device();
pod.device = (!t.is_cuda()) ? -1 : t.get_device();
pod.requires_grad = with_grad && autograd::as_variable_ref(t).requires_grad();
total_dims += t.ndimension();
auto sizes = t.sizes();
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/python_arg_flatten.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct IODescriptor {
VariableMetadata(const autograd::Variable& var)
: sizes(var.sizes().vec())
, type(var.type().scalarType())
, device(var.type().is_cuda() ? var.get_device() : -1)
, device(var.is_cuda() ? var.get_device() : -1)
, requires_grad(var.requires_grad()) {}

bool operator==(const VariableMetadata& o) const {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ struct TORCH_API TensorType : public Type {
protected:
TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType)
: TensorType(tensor.type().scalarType(),
tensor.type().is_cuda() ? tensor.get_device() : -1,
tensor.is_cuda() ? tensor.get_device() : -1,
tensor.dim(),
tensor.is_variable() && tensor.requires_grad(),
kind) {}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/nn/type_checks.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static inline int get_device(PyObject* args) {
PyObject* arg = PyTuple_GET_ITEM(args, i);
if (THPVariable_Check(arg)) {
auto& tensor = THPVariable_UnpackData(arg);
if (tensor.type().is_cuda()) {
if (tensor.is_cuda()) {
return tensor.get_device();
}
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/tensor/python_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ at::Type& get_default_tensor_type() {
}

Device getDevice(const at::Tensor& tensor) {
if (tensor.type().is_cuda()) {
if (tensor.is_cuda()) {
return at::Device(at::DeviceType::CUDA, tensor.get_device());
}
return at::Device(at::DeviceType::CPU);
Expand Down
4 changes: 2 additions & 2 deletions torch/lib/THD/base/data_channels/DataChannelNccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ bool DataChannelNccl::_tensorCheckHelper(

for (size_t i = 0; i < input.size(); ++i) {
// Check to make sure it's a GPU dense tensor
if (!(input[i].type().is_cuda() && !input[i].type().is_sparse() &&
output[i].type().is_cuda() && !output[i].type().is_sparse())) {
if (!(input[i].is_cuda() && !input[i].is_sparse() &&
output[i].is_cuda() && !output[i].is_sparse())) {
throw std::runtime_error(
"Only CUDA dense tensor is supported for NCCL "
"collective operations");
Expand Down
4 changes: 2 additions & 2 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ void ProcessGroupNCCL::tensorCheckHelper(

for (size_t i = 0; i < input.size(); ++i) {
// Check to make sure it's a GPU dense tensor
if (!(input[i].type().is_cuda() && !input[i].type().is_sparse() &&
output[i].type().is_cuda() && !output[i].type().is_sparse())) {
if (!(input[i].is_cuda() && !input[i].is_sparse() &&
output[i].is_cuda() && !output[i].is_sparse())) {
throw std::runtime_error(
"Only CUDA dense tensor is supported for NCCL "
"collective operations");
Expand Down

0 comments on commit e60a7c2

Please sign in to comment.