[Enhance] Do not use deprecated `.type().scalarType()` (#3195)

`at::Tensor::type()` method is deprecated and `tensor.scalar_type()` is a recommended way to replace `tensor.type().scalarType()` invocation
pull/3122/merge
Nikita Shulga 2024-11-03 22:45:49 -08:00 committed by GitHub
parent 139325726b
commit 71437a361c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -35,7 +35,7 @@ struct TorchGPU : public tv::GPU {
template <typename scalar_t>
void check_torch_dtype(const torch::Tensor &tensor) {
switch (tensor.type().scalarType()) {
switch (tensor.scalar_type()) {
case at::ScalarType::Double: {
auto val = std::is_same<std::remove_const_t<scalar_t>, double>::value;
TV_ASSERT_RT_ERR(val, "error");