Skip to content

Commit

Permalink
fix(schema): reconnect closed connection
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Jan 9, 2025
1 parent fc8f940 commit 8e9b78a
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 13 deletions.
9 changes: 9 additions & 0 deletions src/datachain/data_storage/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def execute(
conn: Optional[Any] = None,
) -> Iterator[tuple[Any, ...]]: ...

def get_table(self, name: str) -> "Table":
table = self.metadata.tables.get(name)
if table is None:
sa.Table(name, self.metadata, autoload_with=self.engine)
# ^^^ This table may not be correctly initialised on some dialects
# Grab it from metadata instead.
table = self.metadata.tables[name]
return table

@abstractmethod
def executemany(
self, query, params, cursor: Optional[Any] = None
Expand Down
14 changes: 4 additions & 10 deletions src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from datachain.sql.types import Int, SQLType, UInt64

if TYPE_CHECKING:
from sqlalchemy import Engine
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.base import (
ColumnCollection,
Expand All @@ -25,6 +24,8 @@
)
from sqlalchemy.sql.elements import ColumnElement

from datachain.data_storage.db_engine import DatabaseEngine


DEFAULT_DELIMITER = "__"

Expand Down Expand Up @@ -150,14 +151,12 @@ class DataTable:
def __init__(
self,
name: str,
engine: "Engine",
metadata: Optional["sa.MetaData"] = None,
engine: "DatabaseEngine",
column_types: Optional[dict[str, SQLType]] = None,
object_name: str = "file",
):
self.name: str = name
self.engine = engine
self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
self.column_types: dict[str, SQLType] = column_types or {}
self.object_name = object_name

Expand Down Expand Up @@ -211,12 +210,7 @@ def new_table(
return sa.Table(name, metadata, *columns)

def get_table(self) -> "sa.Table":
table = self.metadata.tables.get(self.name)
if table is None:
sa.Table(self.name, self.metadata, autoload_with=self.engine)
# ^^^ This table may not be correctly initialised on some dialects
# Grab it from metadata instead.
table = self.metadata.tables[self.name]
table = self.engine.get_table(self.name)

column_types = self.column_types | {c.name: c.type for c in self.sys_columns()}
# adjusting types for custom columns to be instances of SQLType if possible
Expand Down
6 changes: 6 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ def _reconnect(self) -> None:
self.db_file = db_file
self.is_closed = False

def get_table(self, name: str) -> Table:
if self.is_closed:
# Reconnect in case of being closed previously.
self._reconnect()
return super().get_table(name)

@retry_sqlite_locks
def execute(
self,
Expand Down
5 changes: 2 additions & 3 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ def dataset_rows(
table_name = self.dataset_table_name(dataset.name, version)
return self.schema.dataset_row_cls(
table_name,
self.db.engine,
self.db.metadata,
self.db,
dataset.get_schema(version),
object_name=object_name,
)
Expand Down Expand Up @@ -220,7 +219,7 @@ def dataset_select_paginated(
num_yielded = 0

# Ensure we're using a thread-local connection
with self.clone() as wh:
with self.clone(use_new_connection=True) as wh:
while True:
if limit is not None:
limit -= num_yielded
Expand Down
15 changes: 15 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,3 +1660,18 @@ def test_to_from_jsonl_remote(cloud_test_catalog_upload):
df1 = dc_from.select("jsonl.first_name", "jsonl.age", "jsonl.city").to_pandas()
df1 = df1["jsonl"]
assert df_equal(df1, df)


def test_datachain_functional_after_exceptions(test_session):
def func(key: str) -> str:
raise Exception("Test Error!")

keys = ["a", "b", "c"]
values = [3, 1, 2]
dc = DataChain.from_values(key=keys, val=values, session=test_session)
# Running a few times, since sessions closing and cleaning up
# DB connections on errors. We need to make sure that it reconnects
# if needed.
for _ in range(4):
with pytest.raises(Exception, match="Test Error!"):
dc.map(res=func).exec()

0 comments on commit 8e9b78a

Please sign in to comment.