Skip to content

Commit

Permalink
RDBC-855 Check if pydantic objects break storing documents
Browse files Browse the repository at this point in the history
RDBC-856 Delete convert_to_snake_case occurances #1/?
  • Loading branch information
poissoncorp committed Jun 13, 2024
1 parent 40b1baf commit b1c1357
Show file tree
Hide file tree
Showing 15 changed files with 95 additions and 80 deletions.
2 changes: 1 addition & 1 deletion ravendb/documents/commands/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def set_response(self, response: str, from_cache: bool) -> None:
"Got None response from the server after doing a batch, something is very wrong."
" Probably a garbled response."
)
self.result = Utils.initialize_object(json.loads(response), self._result_class, True)
self.result = BatchCommandResult.from_json(json.loads(response))


class ClusterWideBatchCommand(SingleNodeBatchCommand):
Expand Down
4 changes: 1 addition & 3 deletions ravendb/documents/operations/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def wait_for_completion(self) -> None:
raise OperationCancelledException()
elif operation_status == "Faulted":
result = status.get("Result")
exception_result: OperationExceptionResult = Utils.initialize_object(
result, OperationExceptionResult, True
)
exception_result = OperationExceptionResult.from_json(result)
schema = ExceptionDispatcher.ExceptionSchema(
self.__request_executor.url, exception_result.type, exception_result.message, exception_result.error
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,10 @@ def before_query_invoke(self, before_query_event_args: BeforeQueryEventArgs):
def documents_by_id(self):
return self._documents_by_id

@property
def included_documents_by_id(self):
return self._included_documents_by_id

@property
def deleted_entities(self):
return self._deleted_entities
Expand Down
4 changes: 2 additions & 2 deletions ravendb/documents/session/operations/load_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def __get_document(self, object_type: Type[_T], key: str) -> _T:
if self._session.is_deleted(key):
return Utils.get_default_value(object_type)

doc = self._session._documents_by_id.get(key)
doc = self._session.documents_by_id.get(key)
if doc is not None:
return self._session.track_entity_document_info(object_type, doc)

doc = self._session._included_documents_by_id.get(key)
doc = self._session.included_documents_by_id.get(key)
if doc is not None:
return self._session.track_entity_document_info(object_type, doc)

Expand Down
6 changes: 1 addition & 5 deletions ravendb/http/request_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,7 @@ def __run(errors: list):

topology = Topology(
self._topology_etag,
(
self.topology_nodes
if self.topology_nodes
else list(map(lambda url_val: ServerNode(url_val, self._database_name, "!"), initial_urls))
),
(self.topology_nodes or [ServerNode(url, self._database_name, "!") for url in initial_urls]),
)

self._node_selector = NodeSelector(topology, self._thread_pool_executor)
Expand Down
29 changes: 19 additions & 10 deletions ravendb/http/server_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, Any, Dict

if TYPE_CHECKING:
from ravendb.http.topology import ClusterTopology
Expand All @@ -26,8 +26,17 @@ def __init__(
self.database = database
self.cluster_tag = cluster_tag
self.server_role = server_role
self.__last_server_version_check = 0
self.__last_server_version: str = None
self._last_server_version_check = 0
self._last_server_version: Optional[str] = None

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "ServerNode":
return cls(
json_dict["Url"],
json_dict["Database"],
json_dict["ClusterTag"],
ServerNode.Role(json_dict["ServerRole"]) if "ServerRole" in json_dict else None,
)

def __eq__(self, other) -> bool:
if self == other:
Expand All @@ -45,7 +54,7 @@ def __hash__(self) -> int:

@property
def last_server_version(self) -> str:
return self.__last_server_version
return self._last_server_version

@classmethod
def create_from(cls, topology: "ClusterTopology"):
Expand All @@ -64,16 +73,16 @@ def create_from(cls, topology: "ClusterTopology"):
return nodes

def should_update_server_version(self) -> bool:
if self.last_server_version is None or self.__last_server_version_check > 100:
if self.last_server_version is None or self._last_server_version_check > 100:
return True

self.__last_server_version_check += 1
self._last_server_version_check += 1
return False

def update_server_version(self, server_version: str):
self.__last_server_version = server_version
self.__last_server_version_check = 0
self._last_server_version = server_version
self._last_server_version_check = 0

def discard_server_version(self) -> None:
self.__last_server_version_check = None
self.__last_server_version_check = 0
self._last_server_version_check = None
self._last_server_version_check = 0
6 changes: 5 additions & 1 deletion ravendb/http/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Any
from typing import Union, List, Dict

from ravendb.exceptions.exceptions import (
Expand All @@ -25,6 +25,10 @@ def __init__(self, etag: int, nodes: List[ServerNode]):
self.etag = etag
self.nodes = nodes

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "Topology":
return cls(json_dict["Etag"], [ServerNode.from_json(node_json_dict) for node_json_dict in json_dict["Nodes"]])


class ClusterTopology:
def __init__(self):
Expand Down
12 changes: 8 additions & 4 deletions ravendb/json/result.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import List, Any, Dict
from typing import List, Any, Dict, Optional


class BatchCommandResult:
def __init__(self, results, transaction_index):
self.results: [None, list] = results
self.transaction_index: [None, int] = transaction_index
def __init__(self, results: Optional[List[Dict]], transaction_index: Optional[int]):
self.results = results
self.transaction_index = transaction_index

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "BatchCommandResult":
return cls(json_dict["Results"], json_dict["TransactionIndex"] if "TransactionIndex" in json_dict else None)


class JsonArrayResult:
Expand Down
10 changes: 2 additions & 8 deletions ravendb/serverwide/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ravendb.tools.utils import Utils


class GetDatabaseTopologyCommand(RavenCommand):
class GetDatabaseTopologyCommand(RavenCommand[Topology]):
def __init__(self, debug_tag: Optional[str] = None, application_identifier: Optional[uuid.UUID] = None):
super().__init__(Topology)
self.__debug_tag = debug_tag
Expand All @@ -33,13 +33,7 @@ def create_request(self, node: ServerNode) -> requests.Request:
def set_response(self, response: str, from_cache: bool) -> None:
if response is None:
return

# todo: that's pretty bad way to do that, replace with initialization function that take nested object types
self.result: Topology = Utils.initialize_object(json.loads(response), self._result_class, True)
node_list = []
for node in self.result.nodes:
node_list.append(Utils.initialize_object(node, ServerNode, True))
self.result.nodes = node_list
self.result = Topology.from_json(json.loads(response))


class GetClusterTopologyCommand(RavenCommand[ClusterTopologyResponse]):
Expand Down
22 changes: 22 additions & 0 deletions ravendb/tests/issue_tests/test_RDBC_855.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from datetime import datetime

from pydantic import BaseModel

from ravendb.tests.test_base import TestBase


class User(BaseModel):
name: str = None
birthday: datetime = None
Id: str = None


class TestRDBC855(TestBase):
def test_storing_pydantic_objects(self):
with self.store.open_session() as session:
session.store(User(name="Josh", birthday=datetime(1999, 1, 1), Id="users/51"))
session.save_changes()

with self.store.open_session() as session:
user = session.load("users/51", User)
self.assertEqual("Josh", user.name)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
super(MyCounterIndex, self).__init__()
self.map = (
"counters.Companies.HeartRate.Select(counter => new {\n"
" heartBeat = counter.Value,\n"
" heart_beat = counter.Value,\n"
" name = counter.Name,\n"
" user = counter.DocumentId\n"
"})"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ravendb import ServerNode
from ravendb.serverwide.commands import GetDatabaseTopologyCommand
from ravendb.tests.test_base import TestBase

Expand All @@ -18,4 +19,4 @@ def test_get_topology(self):
self.assertEqual(server_node.url, self.store.urls[0])
self.assertEqual(server_node.database, self.store.database)
self.assertEqual(server_node.cluster_tag, "A")
self.assertEqual(server_node.server_role, "Member")
self.assertEqual(server_node.server_role, ServerNode.Role.MEMBER)
4 changes: 3 additions & 1 deletion ravendb/tools/custom_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def parse_json(json_string, object_type, mappers, convert_to_snake_case=False):
try:
obj = object_type(**obj)
except TypeError:
initialize_dict, set_needed = Utils.make_initialize_dict(obj, object_type.__init__, convert_to_snake_case)
initialize_dict, set_needed = Utils.create_initialize_kwargs(
obj, object_type.__init__, convert_to_snake_case
)
o = object_type(**initialize_dict)
if set_needed:
for key, value in obj.items():
Expand Down
15 changes: 7 additions & 8 deletions ravendb/tools/projection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
def create_entity_with_mapper(dict_obj, mapper, object_type, convert_to_snake_case=None):
"""
This method will create an entity from dict_obj and mapper
In case convert_to_snake_case is empty will convert dict_obj keys to snake_case
convert_to_snake_case can be dictionary with special words you can change ex. From -> from_date
"""
from typing import Type, TypeVar

_T = TypeVar("_T")


def create_entity_with_mapper(dict_obj, mapper, object_type: Type[_T]):
from ravendb.tools.utils import Utils

def parse_dict_rec(data):
Expand All @@ -14,7 +14,6 @@ def parse_dict_rec(data):
data[i] = Utils.initialize_object(
parse_dict_rec(data[i])[0],
object_type,
convert_to_snake_case,
)
except TypeError:
return data, False
Expand All @@ -39,5 +38,5 @@ def parse_dict_rec(data):
first_parsed, need_to_parse = parse_dict_rec(dict_obj)
# After create a complete dict for our object we need to create the object with the object_type
if first_parsed is not None and need_to_parse:
return Utils.initialize_object(first_parsed, object_type, convert_to_snake_case)
return Utils.initialize_object(first_parsed, object_type)
return first_parsed
52 changes: 17 additions & 35 deletions ravendb/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import enum
import time
from typing import Optional, Dict, Generic, Tuple, TypeVar, Collection, List, Union, Type, TYPE_CHECKING
from typing import Optional, Dict, Generic, Tuple, TypeVar, Collection, List, Union, Type, TYPE_CHECKING, Any, Callable

from ravendb.primitives import constants
from ravendb.exceptions import exceptions
Expand Down Expand Up @@ -448,22 +448,22 @@ def fill_with_nested_object_types(entity: object, nested_object_types: Dict[str,
return entity

@staticmethod
def initialize_object(obj: dict, object_type: Type[_T], convert_to_snake_case: bool = None) -> _T:
initialize_dict, set_needed = Utils.make_initialize_dict(obj, object_type.__init__, convert_to_snake_case)
def initialize_object(json_dict: Dict[str, Any], object_type: Type[_T]) -> _T:
initialize_dict, should_set_object_fields = Utils.create_initialize_kwargs(json_dict, object_type.__init__)
try:
o = object_type(**initialize_dict)
entity = object_type(**initialize_dict)
except Exception as e:
if "Id" not in initialize_dict:
initialize_dict["Id"] = None
o = object_type(**initialize_dict)
entity = object_type(**initialize_dict)
else:
raise TypeError(
f"Couldn't initialize object of type '{object_type.__name__}' using dict '{obj}'"
f"Couldn't initialize object of type '{object_type.__name__}' using dict '{json_dict}'"
) from e
if set_needed:
for key, value in obj.items():
setattr(o, key, value)
return o
if should_set_object_fields:
for key, value in json_dict.items():
setattr(entity, key, value)
return entity

@staticmethod
def get_field_names(object_type: Type[_T]) -> List[str]:
Expand Down Expand Up @@ -492,11 +492,11 @@ def convert_json_dict_to_object(
return _DynamicStructure(**json_dict)

if nested_object_types is None:
return Utils.initialize_object(json_dict, object_type, True)
return Utils.initialize_object(json_dict, object_type)

entity = _DynamicStructure(**json_dict)
entity.__class__ = object_type
entity = Utils.initialize_object(json_dict, object_type, True)
entity = Utils.initialize_object(json_dict, object_type)
if nested_object_types:
Utils.fill_with_nested_object_types(entity, nested_object_types)
Utils.deep_convert_to_snake_case(entity)
Expand Down Expand Up @@ -599,31 +599,13 @@ def convert_to_entity(
return entity, metadata, original_document

@staticmethod
def make_initialize_dict(document, entity_init, convert_to_snake_case=None):
"""
This method will create an entity from document
In case convert_to_snake_case will convert document keys to snake_case
convert_to_snake_case can be dictionary with special words you can change ex. From -> from_date
"""
if convert_to_snake_case:
convert_to_snake_case = {} if convert_to_snake_case is True else convert_to_snake_case
try:
converted_document = {}
for key in document:
converted_key = convert_to_snake_case.get(key, key)
converted_document[converted_key if key == "Id" else Utils.convert_to_snake_case(converted_key)] = (
document[key]
)
document = converted_document
except:
pass

if entity_init is None:
return document

def create_initialize_kwargs(
document: Dict[str, Any],
object_init_method: Callable[[Dict[str, Any]], None],
) -> Dict[str, Any]:
set_needed = False
entity_initialize_dict = {}
args, __, keywords, defaults, _, _, _ = inspect.getfullargspec(entity_init)
args, __, keywords, defaults, _, _, _ = inspect.getfullargspec(object_init_method)
if (len(args) - 1) > len(document):
remainder = len(args)
if defaults:
Expand Down

0 comments on commit b1c1357

Please sign in to comment.