parent
c1d673e804
commit
b763611cad
|
@ -125,7 +125,7 @@ bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node*
|
|||
"' did not match:\n", *n1, *n2);
|
||||
return false;
|
||||
}
|
||||
std::vector<long int> n1is, n2is;
|
||||
std::vector<int64_t> n1is, n2is;
|
||||
std::vector<double> n1fs, n2fs;
|
||||
switch (n1->kindOf(attr_name)) {
|
||||
case AttributeKind::s:
|
||||
|
|
|
@ -67,7 +67,7 @@ static bool matchClsHead(const Match& match, const std::unordered_map<std::strin
|
|||
if (!is_kind(const_node, "onnx::Constant")) return false;
|
||||
auto ival = const_node->t(Symbol::attr("value"));
|
||||
if (ival.dim() != 0) return false;
|
||||
auto ival_dataptr = ival.data_ptr<long>();
|
||||
auto ival_dataptr = ival.data_ptr<int64_t>();
|
||||
if (ival_dataptr[0] != 0) return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& param
|
|||
Node* shape = values_map[vmap["reshape_1_shape"]]->node();
|
||||
auto shape_val = shape->t(Symbol::attr("value"));
|
||||
if (shape_val.dim() != 1) return false;
|
||||
if (shape_val.data_ptr<long>()[0] != -1) return false;
|
||||
if (shape_val.data_ptr<int64_t>()[0] != -1) return false;
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -83,7 +83,7 @@ bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& param
|
|||
Node* gather_inds = values_map[vmap["gather_inds_2"]]->node();
|
||||
auto inds_val = gather_inds->t(Symbol::attr("value"));
|
||||
if (inds_val.dim() != 0) return false;
|
||||
if (inds_val.data_ptr<long>()[0] != 0) return false;
|
||||
if (inds_val.data_ptr<int64_t>()[0] != 0) return false;
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -92,7 +92,7 @@ bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& param
|
|||
auto start_name = slice->inputs()[1]->debugName();
|
||||
auto start_val = params[start_name];
|
||||
if (start_val.dim() != 1) return false;
|
||||
if (start_val.data_ptr<long>()[0] != 0) return false;
|
||||
if (start_val.data_ptr<int64_t>()[0] != 0) return false;
|
||||
}
|
||||
|
||||
// create new node
|
||||
|
|
|
@ -18,7 +18,7 @@ using torch::jit::Value;
|
|||
void MergeShapeConcate(Node* node) {
|
||||
auto inputs = node->inputs();
|
||||
|
||||
std::vector<long> gather_value;
|
||||
std::vector<int64_t> gather_value;
|
||||
Value* shape_from = nullptr;
|
||||
|
||||
std::vector<Node*> node_to_remove{node};
|
||||
|
@ -54,7 +54,7 @@ void MergeShapeConcate(Node* node) {
|
|||
if (!is_kind(constant_node, "onnx::Constant")) return;
|
||||
|
||||
auto gather_indices_val = constant_node->t(Symbol::attr("value"));
|
||||
long* data_ptr = gather_indices_val.data_ptr<long>();
|
||||
int64_t* data_ptr = gather_indices_val.data_ptr<int64_t>();
|
||||
if (gather_indices_val.dim() == 0) {
|
||||
gather_value.push_back(data_ptr[0]);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue