[Fix] Fix setup on non-linux-x64 (#811)

* fix setup

* replace long to int64_t
pull/830/head
q.yao 2022-07-27 20:33:52 +08:00 committed by GitHub
parent c1d673e804
commit b763611cad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 7 additions and 7 deletions

View File

@ -125,7 +125,7 @@ bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node*
"' did not match:\n", *n1, *n2);
return false;
}
std::vector<long int> n1is, n2is;
std::vector<int64_t> n1is, n2is;
std::vector<double> n1fs, n2fs;
switch (n1->kindOf(attr_name)) {
case AttributeKind::s:

View File

@ -67,7 +67,7 @@ static bool matchClsHead(const Match& match, const std::unordered_map<std::strin
if (!is_kind(const_node, "onnx::Constant")) return false;
auto ival = const_node->t(Symbol::attr("value"));
if (ival.dim() != 0) return false;
auto ival_dataptr = ival.data_ptr<long>();
auto ival_dataptr = ival.data_ptr<int64_t>();
if (ival_dataptr[0] != 0) return false;
}

View File

@ -65,7 +65,7 @@ bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& param
Node* shape = values_map[vmap["reshape_1_shape"]]->node();
auto shape_val = shape->t(Symbol::attr("value"));
if (shape_val.dim() != 1) return false;
if (shape_val.data_ptr<long>()[0] != -1) return false;
if (shape_val.data_ptr<int64_t>()[0] != -1) return false;
}
{
@ -83,7 +83,7 @@ bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& param
Node* gather_inds = values_map[vmap["gather_inds_2"]]->node();
auto inds_val = gather_inds->t(Symbol::attr("value"));
if (inds_val.dim() != 0) return false;
if (inds_val.data_ptr<long>()[0] != 0) return false;
if (inds_val.data_ptr<int64_t>()[0] != 0) return false;
}
{
@ -92,7 +92,7 @@ bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& param
auto start_name = slice->inputs()[1]->debugName();
auto start_val = params[start_name];
if (start_val.dim() != 1) return false;
if (start_val.data_ptr<long>()[0] != 0) return false;
if (start_val.data_ptr<int64_t>()[0] != 0) return false;
}
// create new node

View File

@ -18,7 +18,7 @@ using torch::jit::Value;
void MergeShapeConcate(Node* node) {
auto inputs = node->inputs();
std::vector<long> gather_value;
std::vector<int64_t> gather_value;
Value* shape_from = nullptr;
std::vector<Node*> node_to_remove{node};
@ -54,7 +54,7 @@ void MergeShapeConcate(Node* node) {
if (!is_kind(constant_node, "onnx::Constant")) return;
auto gather_indices_val = constant_node->t(Symbol::attr("value"));
long* data_ptr = gather_indices_val.data_ptr<long>();
int64_t* data_ptr = gather_indices_val.data_ptr<int64_t>();
if (gather_indices_val.dim() == 0) {
gather_value.push_back(data_ptr[0]);
} else {