Skip to content

Commit

Permalink
uses model_validate and improves signature and comments for cache module
Browse files Browse the repository at this point in the history
  • Loading branch information
WolfgangFahl committed Mar 17, 2024
1 parent 2463539 commit 37fb4ca
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 15 deletions.
69 changes: 58 additions & 11 deletions ceurws/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
@author: wf
"""
from lodstorage.sparql import SPARQL
from lodstorage.query import QueryManager
from sqlmodel import Session, create_engine, select
from ngwidgets.profiler import Profiler
from typing import Any,Dict,List,Type

class SqlDB:
"""
Expand All @@ -17,18 +19,36 @@ def __init__(self, sqlite_file_path: str,debug:bool=False):
connect_args = {"check_same_thread": False}
self.engine = create_engine(sqlite_url, echo=debug, connect_args=connect_args)

def get_session(self):
# Provide a session for database operations
def get_session(self) -> Session:
"""
Provide a session for database operations.
Returns:
Session: A SQLAlchemy Session object bound to the engine for database operations.
"""
return Session(bind=self.engine)

class Cached:
"""
Manage cached entities.
"""

def __init__(self, clazz, sparql, sql_db, query_name: str, debug:bool=False):
def __init__(self,
clazz: Type[Any],
sparql: SPARQL,
sql_db: str,
query_name: str,
debug: bool = False):
"""
Initializes the Manager with the given endpoint, cache name, and query name.
Initializes the Manager with class reference, SPARQL endpoint URL, SQL database connection string,
query name, and an optional debug flag.
Args:
clazz (Type[Any]): The class reference for the type of objects managed by this manager.
sparql (SPARQL): a SPARQL endpoint.
sql_db (str): The connection string for the SQL database.
query_name (str): The name of the query to be executed.
debug (bool, optional): Flag to enable debug mode. Defaults to False.
"""
self.clazz = clazz
self.sparql = sparql
Expand All @@ -38,19 +58,29 @@ def __init__(self, clazz, sparql, sql_db, query_name: str, debug:bool=False):
# Ensure the table for the class exists
clazz.metadata.create_all(self.sql_db.engine)

def fetch_or_query(self, qm: QueryManager):
def fetch_or_query(self, qm, force_query=False):
"""
Fetches data from the local cache if available; otherwise, queries via SPARQL and caches the results.
Fetches data from the local cache if available.
If the data is not in the cache or if force_query is True,
it queries via SPARQL and caches the results.
Args:
qm (QueryManager): The query manager object used for making SPARQL queries.
force_query (bool, optional): A flag to force querying via SPARQL even if the data exists in the local cache. Defaults to False.
"""
if self.check_local_cache():
if not force_query and self.check_local_cache():
self.fetch_from_local()
else:
self.get_lod(qm)
self.store()


def check_local_cache(self) -> bool:
"""
Checks if there is data in the local cache (SQL database).
Returns:
bool: True if there is at least one record in the local SQL cache table
"""
with self.sql_db.get_session() as session:
result = session.exec(select(self.clazz)).first()
Expand All @@ -68,24 +98,41 @@ def fetch_from_local(self):
print(f"Loaded {len(self.entities)} records from local cache")
profiler.time()

def get_lod(self, qm: QueryManager):
def get_lod(self, qm: QueryManager) -> List[Dict]:
"""
Fetches data using the SPARQL query.
Fetches data using the SPARQL query specified by my query_name.
Args:
qm (QueryManager): The query manager object used for making SPARQL queries.
Returns:
List[Dict]: A list of dictionaries representing the data fetched.
"""
profiler = Profiler(f"fetch {self.query_name} from SPARQL endpoint {self.sparql.url}", profile=self.debug)
query = qm.queriesByName[self.query_name]
self.lod = self.sparql.queryAsListOfDicts(query.query)
profiler.time()
if self.debug:
print(f"Found {len(self.lod)} records for {self.query_name}")
return self.lod

def store(self):
def store(self)->List[Any]:
"""
Stores the fetched data into the local SQL database.
Returns:
List[Any]: A list of entity instances that were stored in the database.
"""
profiler = Profiler(f"store {self.query_name}", profile=self.debug)
self.entities = [self.clazz.parse_obj(record) for record in self.lod]
self.entities = []
for record in self.lod:
entity=self.clazz.model_validate(record)
self.entities.append(entity)
with self.sql_db.get_session() as session:
session.add_all(self.entities)
session.commit()
if self.debug:
print(f"Stored {len(self.entities)} records in local cache")
profiler.time()
return self.entities
9 changes: 5 additions & 4 deletions tests/test_dblp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ def test_dblp_caches(self):
from ceurws.models.dblp2 import Paper,Scholar, Proceeding, Authorship, Editorship

caches=[
Cached(Editorship, self.sparql, sql_db=self.sql_db, query_name="CEUR-WS-Editorship", debug=self.debug), # Assuming this query exists
Cached(Authorship, self.sparql, sql_db=self.sql_db, query_name="CEUR-WS-Authorship", debug=self.debug), # Assuming this query exists
Cached(Proceeding,self.sparql,sql_db=self.sql_db,query_name="CEUR-WS all Volumes",debug=self.debug),
Cached(Scholar,self.sparql,sql_db=self.sql_db,query_name="CEUR-WS-Scholars",debug=self.debug),
Cached(Paper,self.sparql,sql_db=self.sql_db,query_name="CEUR-WS-Papers",debug=self.debug)
Cached(Paper,self.sparql,sql_db=self.sql_db,query_name="CEUR-WS-Papers",debug=self.debug),
Cached(Editorship, self.sparql, sql_db=self.sql_db, query_name="CEUR-WS-Editorship", debug=self.debug),
Cached(Authorship, self.sparql, sql_db=self.sql_db, query_name="CEUR-WS-Authorship", debug=self.debug)
]
force_query=True
for cache in caches:
cache.fetch_or_query(self.qm)
cache.fetch_or_query(self.qm,force_query=force_query)
#paper_cache.get_lod(self.qm)
#paper_cache.store()

0 comments on commit 37fb4ca

Please sign in to comment.