Skip to content

Commit

Permalink
Merge pull request #260 from topoteretes/COG-505-data-dataset-model-c…
Browse files Browse the repository at this point in the history
…hanges

Cog 505 data dataset model changes
  • Loading branch information
dexters1 authored Dec 6, 2024
2 parents 348610e + d7fa9f3 commit 8415279
Show file tree
Hide file tree
Showing 21 changed files with 344 additions and 34 deletions.
69 changes: 69 additions & 0 deletions .github/workflows/test_deduplication.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: test | deduplication

on:
workflow_dispatch:
pull_request:
branches:
- main
types: [labeled, synchronize]


concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

env:
RUNTIME__LOG_LEVEL: ERROR

jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml

run_deduplication_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:
shell: bash
services:
postgres:
image: pgvector/pgvector:pg17
env:
POSTGRES_USER: cognee
POSTGRES_PASSWORD: cognee
POSTGRES_DB: cognee_db
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432

steps:
- name: Check out
uses: actions/checkout@master

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'

- name: Install Poetry
uses: snok/[email protected]
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true

- name: Install dependencies
run: poetry install -E postgres --no-interaction

- name: Run deduplication test
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: poetry run python ./cognee/tests/test_deduplication.py
8 changes: 4 additions & 4 deletions cognee/api/v1/add/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,25 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam

# data is text
else:
file_path = save_data_to_file(data, dataset_name)
file_path = save_data_to_file(data)
return await add([file_path], dataset_name)

if hasattr(data, "file"):
file_path = save_data_to_file(data.file, dataset_name, filename = data.filename)
file_path = save_data_to_file(data.file, filename = data.filename)
return await add([file_path], dataset_name)

# data is a list of file paths or texts
file_paths = []

for data_item in data:
if hasattr(data_item, "file"):
file_paths.append(save_data_to_file(data_item, dataset_name, filename = data_item.filename))
file_paths.append(save_data_to_file(data_item, filename = data_item.filename))
elif isinstance(data_item, str) and (
data_item.startswith("/") or data_item.startswith("file://")
):
file_paths.append(data_item)
elif isinstance(data_item, str):
file_paths.append(save_data_to_file(data_item, dataset_name))
file_paths.append(save_data_to_file(data_item))

if len(file_paths) > 0:
return await add_files(file_paths, dataset_name, user)
Expand Down
7 changes: 7 additions & 0 deletions cognee/infrastructure/files/utils/get_file_metadata.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from typing import BinaryIO, TypedDict
import hashlib
from .guess_file_type import guess_file_type
from cognee.shared.utils import get_file_content_hash


class FileMetadata(TypedDict):
name: str
file_path: str
mime_type: str
extension: str
content_hash: str

def get_file_metadata(file: BinaryIO) -> FileMetadata:
"""Get metadata from a file"""
file.seek(0)
content_hash = get_file_content_hash(file)
file.seek(0)

file_type = guess_file_type(file)

file_path = file.name
Expand All @@ -21,4 +27,5 @@ def get_file_metadata(file: BinaryIO) -> FileMetadata:
file_path = file_path,
mime_type = file_type.mime,
extension = file_type.extension,
content_hash = content_hash,
)
3 changes: 2 additions & 1 deletion cognee/modules/data/models/Data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import datetime, timezone
from typing import List
from uuid import uuid4

from sqlalchemy import UUID, Column, DateTime, String
from sqlalchemy.orm import Mapped, relationship

Expand All @@ -19,6 +18,8 @@ class Data(Base):
extension = Column(String)
mime_type = Column(String)
raw_data_location = Column(String)
owner_id = Column(UUID, index=True)
content_hash = Column(String)
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
Expand Down
1 change: 0 additions & 1 deletion cognee/modules/data/operations/write_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import re
import warnings
from typing import Any
from uuid import UUID
from sqlalchemy import select
from typing import Any, BinaryIO, Union
Expand Down
2 changes: 1 addition & 1 deletion cognee/modules/ingestion/data_types/BinaryData.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, data: BinaryIO, name: str = None):
def get_identifier(self):
metadata = self.get_metadata()

return self.name + "." + metadata["extension"]
return metadata["content_hash"]

def get_metadata(self):
self.ensure_metadata()
Expand Down
10 changes: 7 additions & 3 deletions cognee/modules/ingestion/identify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from uuid import uuid5, NAMESPACE_OID
from .data_types import IngestionData

def identify(data: IngestionData) -> str:
data_id: str = data.get_identifier()
from cognee.modules.users.models import User

return uuid5(NAMESPACE_OID, data_id)

def identify(data: IngestionData, user: User) -> str:
data_content_hash: str = data.get_identifier()

# return UUID hash of file contents + owner id
return uuid5(NAMESPACE_OID, f"{data_content_hash}{user.id}")
19 changes: 11 additions & 8 deletions cognee/modules/ingestion/save_data_to_file.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import string
import random
import os.path
import hashlib
from typing import BinaryIO, Union
from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage import LocalStorage
from .classify import classify

def save_data_to_file(data: Union[str, BinaryIO], dataset_name: str, filename: str = None):
def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
base_config = get_base_config()
data_directory_path = base_config.data_root_directory

classified_data = classify(data, filename)

storage_path = data_directory_path + "/" + dataset_name.replace(".", "/")
storage_path = os.path.join(data_directory_path, "data")
LocalStorage.ensure_directory_exists(storage_path)

file_metadata = classified_data.get_metadata()
if "name" not in file_metadata or file_metadata["name"] is None:
letters = string.ascii_lowercase
random_string = "".join(random.choice(letters) for _ in range(32))
file_metadata["name"] = "text_" + random_string + ".txt"
data_contents = classified_data.get_data().encode('utf-8')
hash_contents = hashlib.md5(data_contents).hexdigest()
file_metadata["name"] = "text_" + hash_contents + ".txt"
file_name = file_metadata["name"]
LocalStorage(storage_path).store(file_name, classified_data.get_data())

# Don't save file if it already exists
if not os.path.isfile(os.path.join(storage_path, file_name)):
LocalStorage(storage_path).store(file_name, classified_data.get_data())

return "file://" + storage_path + "/" + file_name
9 changes: 9 additions & 0 deletions cognee/shared/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various shared utility errors
"""

from .exceptions import (
IngestionError,
)
11 changes: 11 additions & 0 deletions cognee/shared/exceptions/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from cognee.exceptions import CogneeApiError
from fastapi import status

class IngestionError(CogneeApiError):
def __init__(
self,
message: str = "Failed to load data.",
name: str = "IngestionError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)
28 changes: 28 additions & 0 deletions cognee/shared/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
""" This module contains utility functions for the cognee. """
import os
from typing import BinaryIO, Union

import requests
import hashlib
from datetime import datetime, timezone
import graphistry
import networkx as nx
Expand All @@ -16,6 +19,8 @@
from uuid import uuid4
import pathlib

from cognee.shared.exceptions import IngestionError

# Analytics Proxy Url, currently hosted by Vercel
proxy_url = "https://test.prometh.ai"

Expand Down Expand Up @@ -70,6 +75,29 @@ def num_tokens_from_string(string: str, encoding_name: str) -> int:
num_tokens = len(encoding.encode(string))
return num_tokens

def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
h = hashlib.md5()

try:
if isinstance(file_obj, str):
with open(file_obj, 'rb') as file:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file.read(h.block_size)
if not chunk:
break
h.update(chunk)
else:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file_obj.read(h.block_size)
if not chunk:
break
h.update(chunk)

return h.hexdigest()
except IOError as e:
raise IngestionError(message=f"Failed to load data from {file}: {e}")

def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str:
"""
Expand Down
37 changes: 27 additions & 10 deletions cognee/tasks/ingestion/ingest_data_with_metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any
from typing import Any, List

import dlt
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.methods import create_dataset
from cognee.modules.data.operations.delete_metadata import delete_metadata
from cognee.modules.data.models.DatasetData import DatasetData
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import give_permission_on_document
from cognee.shared.utils import send_telemetry
Expand All @@ -23,19 +23,21 @@ async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
destination = destination,
)

@dlt.resource(standalone=True, merge_key="id")
async def data_resources(file_paths: str):
@dlt.resource(standalone=True, primary_key="id", merge_key="id")
async def data_resources(file_paths: List[str], user: User):
for file_path in file_paths:
with open(file_path.replace("file://", ""), mode="rb") as file:
classified_data = ingestion.classify(file)
data_id = ingestion.identify(classified_data)
data_id = ingestion.identify(classified_data, user)
file_metadata = classified_data.get_metadata()
yield {
"id": data_id,
"name": file_metadata["name"],
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
"content_hash": file_metadata["content_hash"],
"owner_id": str(user.id),
}

async def data_storing(data: Any, dataset_name: str, user: User):
Expand All @@ -57,7 +59,8 @@ async def data_storing(data: Any, dataset_name: str, user: User):
with open(file_path.replace("file://", ""), mode = "rb") as file:
classified_data = ingestion.classify(file)

data_id = ingestion.identify(classified_data)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)

file_metadata = classified_data.get_metadata()

Expand All @@ -70,6 +73,7 @@ async def data_storing(data: Any, dataset_name: str, user: User):
async with db_engine.get_async_session() as session:
dataset = await create_dataset(dataset_name, user.id, session)

# Check to see if data should be updated
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
Expand All @@ -79,17 +83,29 @@ async def data_storing(data: Any, dataset_name: str, user: User):
data_point.raw_data_location = file_metadata["file_path"]
data_point.extension = file_metadata["extension"]
data_point.mime_type = file_metadata["mime_type"]
data_point.owner_id = user.id
data_point.content_hash = file_metadata["content_hash"]
await session.merge(data_point)
else:
data_point = Data(
id = data_id,
name = file_metadata["name"],
raw_data_location = file_metadata["file_path"],
extension = file_metadata["extension"],
mime_type = file_metadata["mime_type"]
mime_type = file_metadata["mime_type"],
owner_id = user.id,
content_hash = file_metadata["content_hash"],
)

# Check if data is already in dataset
dataset_data = (
await session.execute(select(DatasetData).filter(DatasetData.data_id == data_id,
DatasetData.dataset_id == dataset.id))
).scalar_one_or_none()
# If data is not present in dataset add it
if dataset_data is None:
dataset.data.append(data_point)

await session.commit()
await write_metadata(data_item, data_point.id, file_metadata)

Expand All @@ -109,16 +125,17 @@ async def data_storing(data: Any, dataset_name: str, user: User):
# To use sqlite with dlt dataset_name must be set to "main".
# Sqlite doesn't support schemas
run_info = pipeline.run(
data_resources(file_paths),
data_resources(file_paths, user),
table_name="file_metadata",
dataset_name="main",
write_disposition="merge",
)
else:
# Data should be stored in the same schema to allow deduplication
run_info = pipeline.run(
data_resources(file_paths),
data_resources(file_paths, user),
table_name="file_metadata",
dataset_name=dataset_name,
dataset_name="public",
write_disposition="merge",
)

Expand Down
Loading

0 comments on commit 8415279

Please sign in to comment.