Skip to content

Commit

Permalink
Cleanup for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
cssedev committed Nov 11, 2024
1 parent 8582f09 commit 8dd8b4f
Showing 1 changed file with 56 additions and 50 deletions.
106 changes: 56 additions & 50 deletions can/io/mf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
the ASAM MDF standard (see https://www.asam.net/standards/detail/mdf/)
"""

import abc
import logging
from datetime import datetime
from hashlib import md5
from io import BufferedIOBase, BytesIO
from pathlib import Path
from typing import Any, BinaryIO, Generator, Iterable, Optional, Union, cast
from typing import Any, BinaryIO, Dict, Generator, Iterator, List, Optional, Union, cast

from ..message import Message
from ..typechecking import StringPathLike
Expand Down Expand Up @@ -70,6 +71,8 @@
)
except ImportError:
asammdf = None
MDF4 = None
Signal = None


CAN_MSG_EXT = 0x80000000
Expand Down Expand Up @@ -266,60 +269,63 @@ def on_message_received(self, msg: Message) -> None:
self._rtr_buffer = np.zeros(1, dtype=RTR_DTYPE)


class MF4Reader(BinaryIOMessageReader):
class FrameIterator(object, metaclass=abc.ABCMeta):
"""
Iterator of CAN messages from a MF4 logging file.
The MF4Reader only supports MF4 files with CAN bus logging.
Iterator helper class for common handling among CAN DataFrames, ErrorFrames and RemoteFrames.
"""

# NOTE: Readout based on the bus logging code from asammdf GUI
# Number of records to request for each asammdf call
_chunk_size = 1000

class FrameIterator(object):
"""
Iterator helper class for common handling among CAN DataFrames, ErrorFrames and RemoteFrames.
"""
def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float, name: str):
self._mdf = mdf
self._group_index = group_index
self._start_timestamp = start_timestamp
self._name = name

# Number of records to request for each asammdf call
_chunk_size = 1000
# Extract names
channel_group: ChannelGroup = self._mdf.groups[self._group_index]

def __init__(
self, mdf: MDF, group_index: int, start_timestamp: float, name: str
):
self._mdf = mdf
self._group_index = group_index
self._start_timestamp = start_timestamp
self._name = name
self._channel_names = []

# Extract names
channel_group: ChannelGroup = self._mdf.groups[self._group_index]
for channel in channel_group.channels:
if str(channel.name).startswith(f"{self._name}."):
self._channel_names.append(channel.name)

self._channel_names = []
return

for channel in channel_group.channels:
if str(channel.name).startswith(f"{self._name}."):
self._channel_names.append(channel.name)
def _get_data(self, current_offset: int) -> Signal:
# NOTE: asammdf suggests using select instead of get. Select seem to miss converting some channels which
# get does convert as expected.
data_raw = self._mdf.get(
self._name,
self._group_index,
record_offset=current_offset,
record_count=self._chunk_size,
raw=False,
)

return
return data_raw

def _get_data(self, current_offset: int) -> asammdf.Signal:
# NOTE: asammdf suggests using select instead of get. Select seem to miss converting some channels which
# get does convert as expected.
data_raw = self._mdf.get(
self._name,
self._group_index,
record_offset=current_offset,
record_count=self._chunk_size,
raw=False,
)
@abc.abstractmethod
def __iter__(self) -> Generator[Message, None, None]:
pass

return data_raw
pass

pass

class MF4Reader(BinaryIOMessageReader):
"""
Iterator of CAN messages from a MF4 logging file.
The MF4Reader only supports MF4 files with CAN bus logging.
"""

# NOTE: Readout based on the bus logging code from asammdf GUI

class CANDataFrameIterator(FrameIterator):

def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float):
super().__init__(mdf, group_index, start_timestamp, "CAN_DataFrame")

return
Expand All @@ -336,7 +342,7 @@ def __iter__(self) -> Generator[Message, None, None]:
for i in range(len(data)):
data_length = int(data["CAN_DataFrame.DataLength"][i])

kv = {
kv: Dict[str, Any] = {
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
"arbitration_id": int(data["CAN_DataFrame.ID"][i]) & 0x1FFFFFFF,
"data": data["CAN_DataFrame.DataBytes"][i][
Expand Down Expand Up @@ -365,7 +371,7 @@ def __iter__(self) -> Generator[Message, None, None]:

class CANErrorFrameIterator(FrameIterator):

def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float):
super().__init__(mdf, group_index, start_timestamp, "CAN_ErrorFrame")

return
Expand All @@ -380,7 +386,7 @@ def __iter__(self) -> Generator[Message, None, None]:
names = data.samples[0].dtype.names

for i in range(len(data)):
kv = {
kv: Dict[str, Any] = {
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
"is_error_frame": True,
}
Expand Down Expand Up @@ -422,7 +428,7 @@ def __iter__(self) -> Generator[Message, None, None]:

class CANRemoteFrameIterator(FrameIterator):

def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float):
super().__init__(mdf, group_index, start_timestamp, "CAN_RemoteFrame")

return
Expand All @@ -437,7 +443,7 @@ def __iter__(self) -> Generator[Message, None, None]:
names = data.samples[0].dtype.names

for i in range(len(data)):
kv = {
kv: Dict[str, Any] = {
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
"arbitration_id": int(data["CAN_RemoteFrame.ID"][i])
& 0x1FFFFFFF,
Expand Down Expand Up @@ -476,20 +482,20 @@ def __init__(

super().__init__(file, mode="rb")

self._mdf: MDF
self._mdf: MDF4
if isinstance(file, BufferedIOBase):
self._mdf = MDF(BytesIO(file.read()))
self._mdf = cast(MDF4, MDF(BytesIO(file.read())))
else:
self._mdf = MDF(file)
self._mdf = cast(MDF4, MDF(file))

self._start_timestamp = self._mdf.header.start_time.timestamp()

def __iter__(self) -> Iterable[Message]:
def __iter__(self) -> Iterator[Message]:
import heapq

# To handle messages split over multiple channel groups, create a single iterator per channel group and merge
# these iterators into a single iterator using heapq.
iterators = []
iterators: List[FrameIterator] = []
for group_index, group in enumerate(self._mdf.groups):
channel_group: ChannelGroup = group.channel_group

Expand Down Expand Up @@ -536,7 +542,7 @@ def __iter__(self) -> Iterable[Message]:
continue

# Create merged iterator over all the groups, using the timestamps as comparison key
return heapq.merge(*iterators, key=lambda x: x.timestamp)
return iter(heapq.merge(*iterators, key=lambda x: x.timestamp))

def stop(self) -> None:
self._mdf.close()
Expand Down

0 comments on commit 8dd8b4f

Please sign in to comment.