Skip to content

Commit

Permalink
Merge pull request #825 from neo4j-contrib/824-fix-test-and-documenta…
Browse files Browse the repository at this point in the history
…tion-for-vector-index

FIx tests and doc for vector index
  • Loading branch information
mariusconjeaud authored Aug 14, 2024
2 parents 0194f3a + d78da5e commit 60a0296
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 39 deletions.
5 changes: 4 additions & 1 deletion doc/source/schema_management.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ Full example: ::
name = StringProperty(
index=True,
fulltext_index=FulltextIndex(analyzer='english', eventually_consistent=True)
)
name_embedding = ArrayProperty(
FloatProperty(),
vector_index=VectorIndex(dimensions=512, similarity_function='euclidean')
)

Expand All @@ -83,7 +86,7 @@ The following constraints are supported:
- ``unique_index=True``: This will create a uniqueness constraint on the property. Available for both nodes and relationships (Neo4j version 5.7 or higher).

.. note::
The uniquess constraint of Neo4j is not supported as such, but using ``required=True`` on a property serves the same purpose.
The uniqueness constraint of Neo4j is not supported as such, but using ``required=True`` on a property serves the same purpose.


Extracting the schema from a database
Expand Down
36 changes: 20 additions & 16 deletions test/async_/test_label_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from neo4j.exceptions import ClientError

from neomodel import (
ArrayProperty,
AsyncRelationshipTo,
AsyncStructuredNode,
AsyncStructuredRel,
FloatProperty,
FulltextIndex,
StringProperty,
UniqueIdProperty,
Expand Down Expand Up @@ -317,16 +319,17 @@ async def test_vector_index():
pytest.skip("Not supported before 5.15")

class VectorIndexNode(AsyncStructuredNode):
name = StringProperty(
vector_index=VectorIndex(dimensions=256, similarity_function="euclidean")
embedding = ArrayProperty(
FloatProperty(),
vector_index=VectorIndex(dimensions=256, similarity_function="euclidean"),
)

await adb.install_labels(VectorIndexNode)
indexes = await adb.list_indexes()
index_names = [index["name"] for index in indexes]
assert "vector_index_VectorIndexNode_name" in index_names
assert "vector_index_VectorIndexNode_embedding" in index_names

await adb.cypher_query("DROP INDEX vector_index_VectorIndexNode_name")
await adb.cypher_query("DROP INDEX vector_index_VectorIndexNode_embedding")


@mark_async_test
Expand All @@ -338,11 +341,11 @@ async def test_vector_index_conflict():

with patch("sys.stdout", new=stream):
await adb.cypher_query(
"CREATE VECTOR INDEX FOR (n:VectorIndexNodeConflict) ON n.name OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}"
"CREATE VECTOR INDEX FOR (n:VectorIndexNodeConflict) ON n.embedding OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}"
)

class VectorIndexNodeConflict(AsyncStructuredNode):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

await adb.install_labels(VectorIndexNodeConflict, quiet=False)

Expand All @@ -361,7 +364,7 @@ async def test_vector_index_not_supported():
):

class VectorIndexNodeOld(AsyncStructuredNode):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

await adb.install_labels(VectorIndexNodeOld)

Expand All @@ -372,8 +375,9 @@ async def test_rel_vector_index():
pytest.skip("Not supported before 5.18")

class VectorIndexRel(AsyncStructuredRel):
name = StringProperty(
vector_index=VectorIndex(dimensions=256, similarity_function="euclidean")
embedding = ArrayProperty(
FloatProperty(),
vector_index=VectorIndex(dimensions=256, similarity_function="euclidean"),
)

class VectorIndexRelNode(AsyncStructuredNode):
Expand All @@ -384,9 +388,9 @@ class VectorIndexRelNode(AsyncStructuredNode):
await adb.install_labels(VectorIndexRelNode)
indexes = await adb.list_indexes()
index_names = [index["name"] for index in indexes]
assert "vector_index_VECTOR_INDEX_REL_name" in index_names
assert "vector_index_VECTOR_INDEX_REL_embedding" in index_names

await adb.cypher_query("DROP INDEX vector_index_VECTOR_INDEX_REL_name")
await adb.cypher_query("DROP INDEX vector_index_VECTOR_INDEX_REL_embedding")


@mark_async_test
Expand All @@ -398,11 +402,11 @@ async def test_rel_vector_index_conflict():

with patch("sys.stdout", new=stream):
await adb.cypher_query(
"CREATE VECTOR INDEX FOR ()-[r:VECTOR_INDEX_REL_CONFLICT]-() ON r.name OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}"
"CREATE VECTOR INDEX FOR ()-[r:VECTOR_INDEX_REL_CONFLICT]-() ON r.embedding OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}"
)

class VectorIndexRelConflict(AsyncStructuredRel):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

class VectorIndexRelConflictNode(AsyncStructuredNode):
has_rel = AsyncRelationshipTo(
Expand All @@ -428,7 +432,7 @@ async def test_rel_vector_index_not_supported():
):

class VectorIndexRelOld(AsyncStructuredRel):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

class VectorIndexRelOldNode(AsyncStructuredNode):
has_rel = AsyncRelationshipTo(
Expand Down Expand Up @@ -522,7 +526,7 @@ class UnauthorizedFulltextNode(AsyncStructuredNode):
with await adb.impersonate(unauthorized_user):

class UnauthorizedVectorNode(AsyncStructuredNode):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

await adb.install_labels(UnauthorizedVectorNode)

Expand Down Expand Up @@ -572,7 +576,7 @@ class UnauthorizedFulltextRelNode(AsyncStructuredNode):
with await adb.impersonate(unauthorized_user):

class UnauthorizedVectorRel(AsyncStructuredRel):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

class UnauthorizedVectorRelNode(AsyncStructuredNode):
has_rel = AsyncRelationshipTo(
Expand Down
5 changes: 3 additions & 2 deletions test/async_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ async def test_double_traverse():

results = [node async for node in qb._execute()]
assert len(results) == 2
assert results[0].name == "Decafe"
assert results[1].name == "Nescafe plus"
names = [n.name for n in results]
assert "Decafe" in names
assert "Nescafe plus" in names


@mark_async_test
Expand Down
39 changes: 21 additions & 18 deletions test/sync_/test_label_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from neo4j.exceptions import ClientError

from neomodel import (
ArrayProperty,
FloatProperty,
FulltextIndex,
RelationshipTo,
StringProperty,
Expand All @@ -26,8 +28,7 @@ class NodeWithConstraint(StructuredNode):
name = StringProperty(unique_index=True)


class NodeWithRelationship(StructuredNode):
...
class NodeWithRelationship(StructuredNode): ...


class IndexedRelationship(StructuredRel):
Expand Down Expand Up @@ -317,16 +318,17 @@ def test_vector_index():
pytest.skip("Not supported before 5.15")

class VectorIndexNode(StructuredNode):
name = StringProperty(
vector_index=VectorIndex(dimensions=256, similarity_function="euclidean")
embedding = ArrayProperty(
FloatProperty(),
vector_index=VectorIndex(dimensions=256, similarity_function="euclidean"),
)

db.install_labels(VectorIndexNode)
indexes = db.list_indexes()
index_names = [index["name"] for index in indexes]
assert "vector_index_VectorIndexNode_name" in index_names
assert "vector_index_VectorIndexNode_embedding" in index_names

db.cypher_query("DROP INDEX vector_index_VectorIndexNode_name")
db.cypher_query("DROP INDEX vector_index_VectorIndexNode_embedding")


@mark_sync_test
Expand All @@ -338,11 +340,11 @@ def test_vector_index_conflict():

with patch("sys.stdout", new=stream):
db.cypher_query(
"CREATE VECTOR INDEX FOR (n:VectorIndexNodeConflict) ON n.name OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}"
"CREATE VECTOR INDEX FOR (n:VectorIndexNodeConflict) ON n.embedding OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}"
)

class VectorIndexNodeConflict(StructuredNode):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

db.install_labels(VectorIndexNodeConflict, quiet=False)

Expand All @@ -361,7 +363,7 @@ def test_vector_index_not_supported():
):

class VectorIndexNodeOld(StructuredNode):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

db.install_labels(VectorIndexNodeOld)

Expand All @@ -372,8 +374,9 @@ def test_rel_vector_index():
pytest.skip("Not supported before 5.18")

class VectorIndexRel(StructuredRel):
name = StringProperty(
vector_index=VectorIndex(dimensions=256, similarity_function="euclidean")
embedding = ArrayProperty(
FloatProperty(),
vector_index=VectorIndex(dimensions=256, similarity_function="euclidean"),
)

class VectorIndexRelNode(StructuredNode):
Expand All @@ -384,9 +387,9 @@ class VectorIndexRelNode(StructuredNode):
db.install_labels(VectorIndexRelNode)
indexes = db.list_indexes()
index_names = [index["name"] for index in indexes]
assert "vector_index_VECTOR_INDEX_REL_name" in index_names
assert "vector_index_VECTOR_INDEX_REL_embedding" in index_names

db.cypher_query("DROP INDEX vector_index_VECTOR_INDEX_REL_name")
db.cypher_query("DROP INDEX vector_index_VECTOR_INDEX_REL_embedding")


@mark_sync_test
Expand All @@ -398,11 +401,11 @@ def test_rel_vector_index_conflict():

with patch("sys.stdout", new=stream):
db.cypher_query(
"CREATE VECTOR INDEX FOR ()-[r:VECTOR_INDEX_REL_CONFLICT]-() ON r.name OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}"
"CREATE VECTOR INDEX FOR ()-[r:VECTOR_INDEX_REL_CONFLICT]-() ON r.embedding OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}"
)

class VectorIndexRelConflict(StructuredRel):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

class VectorIndexRelConflictNode(StructuredNode):
has_rel = RelationshipTo(
Expand All @@ -428,7 +431,7 @@ def test_rel_vector_index_not_supported():
):

class VectorIndexRelOld(StructuredRel):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

class VectorIndexRelOldNode(StructuredNode):
has_rel = RelationshipTo(
Expand Down Expand Up @@ -520,7 +523,7 @@ class UnauthorizedFulltextNode(StructuredNode):
with db.impersonate(unauthorized_user):

class UnauthorizedVectorNode(StructuredNode):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

db.install_labels(UnauthorizedVectorNode)

Expand Down Expand Up @@ -570,7 +573,7 @@ class UnauthorizedFulltextRelNode(StructuredNode):
with db.impersonate(unauthorized_user):

class UnauthorizedVectorRel(StructuredRel):
name = StringProperty(vector_index=VectorIndex())
embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex())

class UnauthorizedVectorRelNode(StructuredNode):
has_rel = RelationshipTo(
Expand Down
5 changes: 3 additions & 2 deletions test/sync_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ def test_double_traverse():

results = [node for node in qb._execute()]
assert len(results) == 2
assert results[0].name == "Decafe"
assert results[1].name == "Nescafe plus"
names = [n.name for n in results]
assert "Decafe" in names
assert "Nescafe plus" in names


@mark_sync_test
Expand Down

0 comments on commit 60a0296

Please sign in to comment.