Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Implement msgspec encoding #2541

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
def mypy(session: nox.Session) -> None:
"""Check types with mypy."""
args = session.posargs or ["singer_sdk"]
session.install(".[faker,jwt,parquet,s3,testing]")
session.install(".[faker,jwt,msgspec,parquet,s3,testing]")
session.install(*typing_dependencies)
session.run("mypy", *args)
if not session.posargs:
Expand All @@ -63,6 +63,7 @@ def tests(session: nox.Session) -> None:
extras = [
"faker",
"jwt",
"msgspec",
"parquet",
"s3",
]
Expand Down Expand Up @@ -92,7 +93,7 @@ def tests(session: nox.Session) -> None:
@nox.session(python=main_python_version)
def benches(session: nox.Session) -> None:
"""Run benchmarks."""
session.install(".[jwt,s3]")
session.install(".[jwt,msgspec,s3]")
session.install(*test_dependencies)
session.run(
"pytest",
Expand All @@ -105,7 +106,7 @@ def benches(session: nox.Session) -> None:
@nox.session(name="deps", python=main_python_version)
def dependencies(session: nox.Session) -> None:
"""Check issues with dependencies."""
session.install(".[docs,faker,jwt,parquet,s3,ssh,testing]")
session.install(".[docs,faker,jwt,msgspec,parquet,s3,ssh,testing]")
session.install("deptry")
session.run("deptry", "singer_sdk", *session.posargs)

Expand All @@ -115,7 +116,7 @@ def update_snapshots(session: nox.Session) -> None:
"""Update pytest snapshots."""
args = session.posargs or ["-m", "snapshot"]

session.install(".[faker,jwt,parquet]")
session.install(".[faker,jwt,msgspec,parquet]")
session.install(*test_dependencies)
session.run("pytest", "--snapshot-update", *args)

Expand Down
57 changes: 56 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ ssh = [
"paramiko>=3.3.0",
]

# msgspec extras
msgspec = [
"msgspec>=0.19.0",
]

[project.urls]
Homepage = "https://sdk.meltano.com/en/latest/"
Repository = "https://github.com/meltano/sdk"
Expand Down
3 changes: 2 additions & 1 deletion samples/sample_tap_countries/countries_tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
CountriesStream,
)
from singer_sdk import Stream, Tap
from singer_sdk._singerlib.encoding._msgspec import MsgSpecWriter # noqa: PLC2701
from singer_sdk.typing import PropertiesList


class SampleTapCountries(Tap):
class SampleTapCountries(MsgSpecWriter, Tap):
"""Sample tap for Countries GraphQL API."""

name: str = "sample-tap-countries"
Expand Down
3 changes: 2 additions & 1 deletion samples/sample_tap_sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from singer_sdk import SQLConnector, SQLStream, SQLTap
from singer_sdk import typing as th
from singer_sdk._singerlib.encoding._msgspec import MsgSpecWriter # noqa: PLC2701

DB_PATH_CONFIG = "path_to_db"

Expand Down Expand Up @@ -39,7 +40,7 @@ class SQLiteStream(SQLStream):
STATE_MSG_FREQUENCY = 10


class SQLiteTap(SQLTap):
class SQLiteTap(MsgSpecWriter, SQLTap):
"""The Tap class for SQLite."""

name = "tap-sqlite-sample"
Expand Down
110 changes: 110 additions & 0 deletions singer_sdk/_singerlib/encoding/_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

import datetime
import decimal
import logging
import sys
import typing as t

import msgspec

from singer_sdk._singerlib.exceptions import InvalidInputLine

from ._base import GenericSingerReader, GenericSingerWriter
from ._simple import Message

logger = logging.getLogger(__name__)


def enc_hook(obj: t.Any) -> t.Any: # noqa: ANN401
"""Encoding type helper for non native types.

Args:
obj: the item to be encoded

Returns:
The object converted to the appropriate type, default is str
"""
return obj.isoformat(sep="T") if isinstance(obj, datetime.datetime) else str(obj)


def dec_hook(type: type, obj: t.Any) -> t.Any: # noqa: ARG001, A002, ANN401
"""Decoding type helper for non native types.

Args:
type: the type given
obj: the item to be decoded

Returns:
The object converted to the appropriate type, default is str.
"""
return str(obj)


encoder = msgspec.json.Encoder(enc_hook=enc_hook, decimal_format="number")
decoder = msgspec.json.Decoder(dec_hook=dec_hook, float_hook=decimal.Decimal)
_jsonl_msg_buffer = bytearray(64)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 issue (security): The global bytearray buffer could cause thread-safety issues and the fixed size might be insufficient for larger messages

Consider using a thread-local buffer or creating a new buffer for each message. Also, the buffer size should either be dynamic or much larger to handle varying message sizes safely.



def serialize_jsonl(obj: object, **kwargs: t.Any) -> bytes: # noqa: ARG001
"""Serialize a dictionary into a line of jsonl.

Args:
obj: A Python object usually a dict.
**kwargs: Optional key word arguments.

Returns:
A bytes of serialized json.
"""
encoder.encode_into(obj, _jsonl_msg_buffer)
_jsonl_msg_buffer.extend(b"\n")
return _jsonl_msg_buffer


class MsgSpecReader(GenericSingerReader[str]):
"""Base class for all plugins reading Singer messages as strings from stdin."""

default_input = sys.stdin

def deserialize_json(self, line: str) -> dict: # noqa: PLR6301
"""Deserialize a line of json.

Args:
line: A single line of json.

Returns:
A dictionary of the deserialized json.

Raises:
InvalidInputLine: If the line cannot be parsed
"""
try:
return decoder.decode(line) # type: ignore[no-any-return]
except msgspec.DecodeError as exc:
logger.exception("Unable to parse:\n%s", line)
msg = f"Unable to parse line as JSON: {line}"
raise InvalidInputLine(msg) from exc


class MsgSpecWriter(GenericSingerWriter[bytes, Message]):
"""Interface for all plugins writing Singer messages to stdout."""

def serialize_message(self, message: Message) -> bytes: # noqa: PLR6301
"""Serialize a dictionary into a line of json.

Args:
message: A Singer message object.

Returns:
A string of serialized json.
"""
return serialize_jsonl(message.to_dict())

def write_message(self, message: Message) -> None:
"""Write a message to stdout.

Args:
message: The message to write.
"""
sys.stdout.buffer.write(self.format_message(message))
sys.stdout.flush()
30 changes: 20 additions & 10 deletions singer_sdk/testing/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ def get_standard_target_tests(
return []


def tap_sync_test(tap: Tap) -> tuple[io.StringIO, io.StringIO]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider creating a type alias to simplify the function signatures.

The type annotations have become overly verbose while not adding proportional value. Create a type alias to simplify the signatures while maintaining type safety:

# At the top of the file
SingerIO = io.TextIOWrapper[io.BytesIO]

# Then simplify function signatures like:
def tap_sync_test(
    tap: Tap,
) -> tuple[SingerIO, SingerIO]:
    stdout_buf = SingerIO(io.BytesIO(), encoding="utf-8")
    ...

def tap_to_target_sync_test(
    tap: Tap,
    target: Target,
) -> tuple[SingerIO, SingerIO, SingerIO, SingerIO]:
    ...

This maintains the same type safety while improving readability and making future type changes easier to maintain.

def tap_sync_test(
tap: Tap,
) -> tuple[
io.TextIOWrapper[io.BytesIO],
io.TextIOWrapper[io.BytesIO],
]:
"""Invokes a Tap object and return STDOUT and STDERR results in StringIO buffers.

Args:
Expand All @@ -120,8 +125,8 @@ def tap_sync_test(tap: Tap) -> tuple[io.StringIO, io.StringIO]:
Returns:
A 2-item tuple with StringIO buffers from the Tap's output: (stdout, stderr)
"""
stdout_buf = io.StringIO()
stderr_buf = io.StringIO()
stdout_buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
stderr_buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
tap.sync_all()
stdout_buf.seek(0)
Expand Down Expand Up @@ -171,10 +176,10 @@ def _select_all(catalog_dict: dict) -> dict:

def target_sync_test(
target: Target,
input: io.StringIO | None, # noqa: A002
input: t.IO[str] | None, # noqa: A002
*,
finalize: bool = True,
) -> tuple[io.StringIO, io.StringIO]:
) -> tuple[io.TextIOWrapper[io.BytesIO], io.TextIOWrapper[io.BytesIO]]:
"""Invoke the target with the provided input.

Args:
Expand All @@ -186,8 +191,8 @@ def target_sync_test(
Returns:
A 2-item tuple with StringIO buffers from the Target's output: (stdout, stderr)
"""
stdout_buf = io.StringIO()
stderr_buf = io.StringIO()
stdout_buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
stderr_buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")

with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
if input is not None:
Expand All @@ -203,7 +208,12 @@ def target_sync_test(
def tap_to_target_sync_test(
tap: Tap,
target: Target,
) -> tuple[io.StringIO, io.StringIO, io.StringIO, io.StringIO]:
) -> tuple[
io.TextIOWrapper[io.BytesIO],
io.TextIOWrapper[io.BytesIO],
io.TextIOWrapper[io.BytesIO],
io.TextIOWrapper[io.BytesIO],
]:
"""Test and end-to-end sink from the tap to the target.

Note: This method buffers all output from the tap in memory and should not be
Expand Down Expand Up @@ -236,15 +246,15 @@ def sync_end_to_end(tap: Tap, target: Target, *mappers: InlineMapper) -> None:
mappers: Zero or more inline mapper to apply in between the tap and target, in
order.
"""
buf = io.StringIO()
buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
with redirect_stdout(buf):
tap.sync_all()

buf.seek(0)
mapper_output = buf

for mapper in mappers:
buf = io.StringIO()
buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
with redirect_stdout(buf):
mapper.listen(mapper_output)

Expand Down
10 changes: 5 additions & 5 deletions singer_sdk/testing/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def _execute_sync(self) -> tuple[str, str]:
Returns:
A 2-item tuple with StringIO buffers from the Tap's output: (stdout, stderr)
"""
stdout_buf = io.StringIO()
stderr_buf = io.StringIO()
stdout_buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
stderr_buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
self.run_sync_dry_run()
stdout_buf.seek(0)
Expand Down Expand Up @@ -281,7 +281,7 @@ def _execute_sync( # noqa: PLR6301
target_input: t.IO[str],
*,
finalize: bool = True,
) -> tuple[io.StringIO, io.StringIO]:
) -> tuple[io.TextIOWrapper[io.BytesIO], io.TextIOWrapper[io.BytesIO]]:
"""Invoke the target with the provided input.

Args:
Expand All @@ -294,8 +294,8 @@ def _execute_sync( # noqa: PLR6301
A 2-item tuple with StringIO buffers from the Target's output:
(stdout, stderr)
"""
stdout_buf = io.StringIO()
stderr_buf = io.StringIO()
stdout_buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
stderr_buf = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")

with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
if target_input is not None:
Expand Down
1 change: 1 addition & 0 deletions tests/_singerlib/encoding/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import annotations
Loading
Loading