Skip to content

Commit

Permalink
Update dim_order type
Browse files Browse the repository at this point in the history
Differential Revision: D68041866

Pull Request resolved: #7610
  • Loading branch information
lucylq authored Jan 14, 2025
1 parent 83c0da5 commit b412ddc
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion exir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class Tensor:
scalar_type: ScalarType
storage_offset: int
sizes: List[int]
dim_order: List[bytes]
dim_order: List[int]
requires_grad: bool
layout: int
data_buffer_idx: int
Expand Down
2 changes: 1 addition & 1 deletion exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
return tuple(typing.cast(Tuple[bytes], sorted_dims))


def stride_from_dim_order(sizes: List[int], dim_order: List[bytes]) -> List[int]:
def stride_from_dim_order(sizes: List[int], dim_order: List[int]) -> List[int]:
"""
Converts dim order to stride using sizes
e.g. if sizes = (2, 3, 4) and dim_order = (0, 1, 2) then strides = (12, 4, 1)
Expand Down
4 changes: 1 addition & 3 deletions exir/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
import typing
from typing import List

import torch

Expand Down Expand Up @@ -46,7 +44,7 @@ def get_test_program() -> Program:
scalar_type=ScalarType.FLOAT,
storage_offset=0,
sizes=[2, 2],
dim_order=typing.cast(List[bytes], [0, 1]),
dim_order=[0, 1],
requires_grad=False,
layout=0,
data_buffer_idx=0,
Expand Down

0 comments on commit b412ddc

Please sign in to comment.