Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add selecting json function into find method #379

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions dataset/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
self.engine = create_engine(url, **engine_kwargs)
self.is_postgres = self.engine.dialect.name == "postgresql"
self.is_sqlite = self.engine.dialect.name == "sqlite"
self._server_version_info = None
self._is_support_json = None

def _enable_sqlite_wal_mode(dbapi_con, con_record):
# reference:
Expand Down Expand Up @@ -105,6 +107,27 @@ def in_transaction(self):
return False
return len(self.local.tx) > 0

@property
def server_version_info(self):
"""Return database version number"""
if self._server_version_info is None:
tables = self.tables # connect DB
self._server_version_info = self.engine.dialect.server_version_info
return self._server_version_info

@property
def is_support_json(self):
"""Check if this database version support JSON column"""
if self._is_support_json is None:
support_versions = {"mysql": 5.7, "postgresql": 9.2, "sqlite": 3.9, "mssql": 13.0, "oracle": 12.1}
db_name = self.engine.dialect.name
version = float(str(self.server_version_info[0]) + "." + str(self.server_version_info[1]))
if support_versions.get(db_name) and version >= support_versions[db_name]:
self._is_support_json = True
else:
self._is_support_json = False
return self._is_support_json

def _flush_tables(self):
"""Clear the table metadata after transaction rollbacks."""
for table in self._tables.values():
Expand Down
68 changes: 66 additions & 2 deletions dataset/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
import threading
from banal import ensure_list
from decimal import Decimal

from sqlalchemy import func, select, false
from sqlalchemy.sql import and_, expression
Expand Down Expand Up @@ -418,7 +419,37 @@ def _generate_clause(self, column, op, value):
return self.table.c[column].like(value + "%")
return false()

def _args_to_clause(self, args, clauses=()):
def _generate_json_clause(self, column, key, ops_value):
element_type = {
int: "as_integer",
float: "as_float",
bool: "as_boolean",
Decimal: "as_numeric",
str: "as_string",
list: "as_json",
dict: "as_json"
}
for op, value in ops_value.items():
col = getattr(self.table.c[column][key], element_type[type(value)])()
if op in (">", "gt"):
yield col > value
elif op in ("<", "lt"):
yield col < value
elif op in (">=", "gte"):
yield col >= value
elif op in ("<=", "lte"):
yield col <= value
elif op in ("=", "==", "is"):
yield col == value
elif op in ("!=", "<>", "not"):
yield col != value
elif op in ("between", ".."):
start, end = value
yield self._generate_json_clause(column, key, {">=": start, "<": end})
else:
yield false()

def _args_to_clause(self, args, clauses=(), json_=None):
clauses = list(clauses)
for column, value in args.items():
column = self._get_column_name(column)
Expand All @@ -431,6 +462,13 @@ def _args_to_clause(self, args, clauses=()):
clauses.append(self._generate_clause(column, op, op_value))
else:
clauses.append(self._generate_clause(column, "=", value))
if json_:
if not self.db.is_support_json:
raise NotImplementedError("Current database not support json column!")
for column, value in json_.items():
column = self._get_column_name(column)
for key, ops_value in value.items():
clauses.extend(self._generate_json_clause(column, key, ops_value))
return and_(*clauses)

def _args_to_order_by(self, order_by):
Expand Down Expand Up @@ -597,6 +635,31 @@ def find(self, *_clauses, **kwargs):
# return all rows sorted by multiple columns (descending by year)
results = table.find(order_by=['country', '-year'])

Using ``_json``::

# Notice: selected key type depends on giving value type,
# like if given integer but stored type is float will be automatically transformed to integer.
# Support operations: >(gt), <(lt), >=(gte), <=(lte), =(==,is), !=(<>, not), between("..")
# id json_column
# 0 {"key":-0.5}
# 1 {"key":0.5}
# 2 {"key":1.5}
results = table.find(_json={'json_column':{'key':{'>=': 0.0, '<':1.0}}}) # id = [1]
results = table.find(_json={'json_column':{'key':{'>=': 0, '<':1}}}) # int(-0.5)==0, id = [0,1]

# id json_column
# 0 [0,1,2]
# 1 [0,0.5,1]
# 2 [0]
# find rows by index
results = table.find(_json={'json_column':{1:{'>=': 0.0, '<':1.0}}}) # id = [1]

# id json_column
# 0 {"key1":{"key2":-1}}
# 1 {"key1":{"key2":0.5}}
# find rows by path
results = table.find(_json={'json_column':{('key1','key2'):{'between':[0.0,1.0]}}}) # id = [1]

You can also submit filters based on criteria other than equality,
see :ref:`advanced_filters` for details.

Expand All @@ -612,11 +675,12 @@ def find(self, *_clauses, **kwargs):
order_by = kwargs.pop("order_by", None)
_streamed = kwargs.pop("_streamed", False)
_step = kwargs.pop("_step", QUERY_STEP)
_json = kwargs.pop("_json", None)
if _step is False or _step == 0:
_step = None

order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, clauses=_clauses)
args = self._args_to_clause(kwargs, clauses=_clauses, json_=_json)
query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
if len(order_by):
query = query.order_by(*order_by)
Expand Down