From d08fa219965196c62f5906373ba535ac5bd858de Mon Sep 17 00:00:00 2001 From: Emily Date: Tue, 31 Oct 2023 11:31:14 +0100 Subject: [PATCH] feat: add total count to search responses --- .changeset/tidy-readers-dance.md | 5 ++ packages/client/src/schema/repository.ts | 22 ++++-- packages/client/src/search/index.ts | 79 ++++++++++--------- test/integration/search.test.ts | 96 +++++++++++++++--------- test/integration/vectorSearch.test.ts | 26 +++++-- 5 files changed, 141 insertions(+), 87 deletions(-) create mode 100644 .changeset/tidy-readers-dance.md diff --git a/.changeset/tidy-readers-dance.md b/.changeset/tidy-readers-dance.md new file mode 100644 index 000000000..8a8cbc255 --- /dev/null +++ b/.changeset/tidy-readers-dance.md @@ -0,0 +1,5 @@ +--- +'@xata.io/client': minor +--- + +Add support for totalCount on search responses diff --git a/packages/client/src/schema/repository.ts b/packages/client/src/schema/repository.ts index be956d6c3..6788a9981 100644 --- a/packages/client/src/schema/repository.ts +++ b/packages/client/src/schema/repository.ts @@ -27,7 +27,7 @@ import { TransactionOperation } from '../api/schemas'; import { XataPluginOptions } from '../plugins'; -import { SearchXataRecord } from '../search'; +import { SearchXataRecord, TotalCount } from '../search'; import { Boosters } from '../search/boosters'; import { TargetColumn } from '../search/target'; import { chunk, compact, isDefined, isNumber, isObject, isString, promiseMap } from '../util/lang'; @@ -748,7 +748,7 @@ export abstract class Repository extends Query< page?: SearchPageConfig; target?: TargetColumn[]; } - ): Promise>[]>; + ): Promise<{ records: SearchXataRecord>[] } & TotalCount>; /** * Search for vectors in the table. @@ -777,7 +777,7 @@ export abstract class Repository extends Query< size?: number; filter?: Filter; } - ): Promise>[]>; + ): Promise<{ records: SearchXataRecord>[] } & TotalCount>; /** * Aggregates records in the table. @@ -1761,7 +1761,7 @@ export class RestRepository } = {} ) { return this.#trace('search', async () => { - const { records } = await searchTable({ + const { records, totalCount } = await searchTable({ pathParams: { workspace: '{workspaceId}', dbBranchName: '{dbBranch}', @@ -1784,7 +1784,10 @@ export class RestRepository const schemaTables = await this.#getSchemaTables(); // TODO - Column selection not supported by search endpoint yet - return records.map((item) => initObject(this.#db, schemaTables, this.#table, item, ['*'])) as any; + return { + records: records.map((item) => initObject(this.#db, schemaTables, this.#table, item, ['*'])) as any, + totalCount + }; }); } @@ -1798,9 +1801,9 @@ export class RestRepository filter?: Filter | undefined; } | undefined - ): Promise>[]> { + ): Promise<{ records: SearchXataRecord>[] } & TotalCount> { return this.#trace('vectorSearch', async () => { - const { records } = await vectorSearchTable({ + const { records, totalCount } = await vectorSearchTable({ pathParams: { workspace: '{workspaceId}', dbBranchName: '{dbBranch}', @@ -1820,7 +1823,10 @@ export class RestRepository const schemaTables = await this.#getSchemaTables(); // TODO - Column selection not supported by search endpoint yet - return records.map((item) => initObject(this.#db, schemaTables, this.#table, item, ['*'])) as any; + return { + records: records.map((item) => initObject(this.#db, schemaTables, this.#table, item, ['*'])), + totalCount + } as any; }); } diff --git a/packages/client/src/search/index.ts b/packages/client/src/search/index.ts index 4f05c030c..10a7dbd02 100644 --- a/packages/client/src/search/index.ts +++ b/packages/client/src/search/index.ts @@ -1,4 +1,4 @@ -import { getBranchDetails, searchBranch } from '../api'; +import { Responses, getBranchDetails, searchBranch } from '../api'; import { FuzzinessExpression, HighlightExpression, PrefixExpression, SearchPageConfig, Table } from '../api/schemas'; import { XataPlugin, XataPluginOptions } from '../plugins'; import { SchemaPluginResult } from '../schema'; @@ -28,32 +28,38 @@ export type SearchOptions, Tables exten page?: SearchPageConfig; }; +export type TotalCount = Pick; + export type SearchPluginResult> = { all: >( query: string, options?: SearchOptions ) => Promise< - Values<{ + TotalCount & { + records: Values<{ + [Model in ExtractTables< + Schemas, + Tables, + GetArrayInnerType['tables']>> + >]: { + table: Model; + record: Awaited>>; + }; + }>[]; + } + >; + byTable: >( + query: string, + options?: SearchOptions + ) => Promise< + TotalCount & { [Model in ExtractTables< Schemas, Tables, GetArrayInnerType['tables']>> - >]: { - table: Model; - record: Awaited>>; - }; - }>[] + >]?: Awaited>[]>; + } >; - byTable: >( - query: string, - options?: SearchOptions - ) => Promise<{ - [Model in ExtractTables< - Schemas, - Tables, - GetArrayInnerType['tables']>> - >]?: Awaited>[]>; - }>; }; export class SearchPlugin> extends XataPlugin { @@ -67,32 +73,37 @@ export class SearchPlugin> extends Xa build(pluginOptions: XataPluginOptions): SearchPluginResult { return { all: async >(query: string, options: SearchOptions = {}) => { - const records = await this.#search(query, options, pluginOptions); + const { records, totalCount } = await this.#search(query, options, pluginOptions); const schemaTables = await this.#getSchemaTables(pluginOptions); - return records.map((record) => { - const { table = 'orphan' } = record.xata; - - // TODO: Search endpoint doesn't support column selection - return { table, record: initObject(this.db, schemaTables, table, record, ['*']) } as any; - }); + return { + totalCount, + records: records.map((record) => { + const { table = 'orphan' } = record.xata; + // TODO: Search endpoint doesn't support column selection + return { table, record: initObject(this.db, schemaTables, table, record, ['*']) } as any; + }) + }; }, byTable: async >( query: string, options: SearchOptions = {} ) => { - const records = await this.#search(query, options, pluginOptions); + const { records, totalCount } = await this.#search(query, options, pluginOptions); const schemaTables = await this.#getSchemaTables(pluginOptions); - return records.reduce((acc, record) => { - const { table = 'orphan' } = record.xata; + return records.reduce( + (acc, record) => { + const { table = 'orphan' } = record.xata; - const items = acc[table] ?? []; - // TODO: Search endpoint doesn't support column selection - const item = initObject(this.db, schemaTables, table, record, ['*']); + const items = acc[table] ?? []; + // TODO: Search endpoint doesn't support column selection + const item = initObject(this.db, schemaTables, table, record, ['*']); - return { ...acc, [table]: [...items, item] }; - }, {} as any); + return { ...acc, [table]: [...items, item] }; + }, + { totalCount } as any + ); } }; } @@ -104,14 +115,14 @@ export class SearchPlugin> extends Xa ) { const { tables, fuzziness, highlight, prefix, page } = options ?? {}; - const { records } = await searchBranch({ + const { records, totalCount } = await searchBranch({ pathParams: { workspace: '{workspaceId}', dbBranchName: '{dbBranch}', region: '{region}' }, // @ts-ignore https://github.com/xataio/client-ts/issues/313 body: { tables, query, fuzziness, prefix, highlight, page }, ...pluginOptions }); - return records; + return { records, totalCount }; } async #getSchemaTables(pluginOptions: XataPluginOptions): Promise { diff --git a/test/integration/search.test.ts b/test/integration/search.test.ts index 35a73d6aa..d04cb0c6a 100644 --- a/test/integration/search.test.ts +++ b/test/integration/search.test.ts @@ -60,32 +60,35 @@ afterEach(async (ctx) => { describe('search', () => { test('search in table', async () => { - const owners = await xata.db.users.search('Owner'); - expect(owners.length).toBeGreaterThan(0); - - expect(owners.length).toBe(2); - expect(owners[0].id).toBeDefined(); - expect(owners[0].full_name?.includes('Owner')).toBeTruthy(); - expect(owners[0].read).toBeDefined(); - expect(owners[0].getMetadata().score).toBeDefined(); - expect(owners[0].getMetadata().table).toBe('users'); + const { records, totalCount } = await xata.db.users.search('Owner'); + expect(totalCount).toBe(2); + expect(records.length).toBeGreaterThan(0); + + expect(records.length).toBe(2); + expect(records[0].id).toBeDefined(); + expect(records[0].full_name?.includes('Owner')).toBeTruthy(); + expect(records[0].read).toBeDefined(); + expect(records[0].getMetadata().score).toBeDefined(); + expect(records[0].getMetadata().table).toBe('users'); }); test('search in table with filtering', async () => { - const owners = await xata.db.users.search('Owner', { + const { records, totalCount } = await xata.db.users.search('Owner', { filter: { full_name: 'Owner of team animals' } }); - expect(owners.length).toBe(1); - expect(owners[0].id).toBeDefined(); - expect(owners[0].full_name?.includes('Owner of team animals')).toBeTruthy(); - expect(owners[0].read).toBeDefined(); - expect(owners[0].getMetadata().score).toBeDefined(); + expect(totalCount).toBe(1); + expect(records.length).toBe(1); + expect(records[0].id).toBeDefined(); + expect(records[0].full_name?.includes('Owner of team animals')).toBeTruthy(); + expect(records[0].read).toBeDefined(); + expect(records[0].getMetadata().score).toBeDefined(); }); test('search by tables with multiple tables', async () => { - const { users = [], teams = [] } = await xata.search.byTable('fruits', { tables: ['teams', 'users'] }); + const { users = [], teams = [], totalCount } = await xata.search.byTable('fruits', { tables: ['teams', 'users'] }); + expect(totalCount).toBeGreaterThan(0); expect(users.length).toBeGreaterThan(0); expect(teams.length).toBeGreaterThan(0); @@ -101,8 +104,9 @@ describe('search', () => { }); test('search by table with all tables', async () => { - const { users = [], teams = [] } = await xata.search.byTable('fruits'); + const { users = [], teams = [], totalCount } = await xata.search.byTable('fruits'); + expect(totalCount).toBeGreaterThan(0); expect(users.length).toBeGreaterThan(0); expect(teams.length).toBeGreaterThan(0); @@ -118,9 +122,11 @@ describe('search', () => { }); test('search all with multiple tables', async () => { - const results = await xata.search.all('fruits', { tables: ['teams', 'users'] }); + const { records, totalCount } = await xata.search.all('fruits', { tables: ['teams', 'users'] }); + expect(records).toBeDefined(); - for (const result of results) { + expect(totalCount).toBeGreaterThan(0); + for (const result of records) { if (result.table === 'teams') { expect(result.record.id).toBeDefined(); expect(result.record.read).toBeDefined(); @@ -138,9 +144,11 @@ describe('search', () => { }); test('search all with one table', async () => { - const results = await xata.search.all('fruits', { tables: ['teams'] }); + const { records, totalCount } = await xata.search.all('fruits', { tables: ['teams'] }); + expect(records).toBeDefined(); - for (const result of results) { + expect(totalCount).toBeGreaterThan(0); + for (const result of records) { expect(result.record.id).toBeDefined(); expect(result.record.read).toBeDefined(); expect(result.record.name?.includes('fruits')).toBeTruthy(); @@ -152,9 +160,11 @@ describe('search', () => { }); test('search all with all tables', async () => { - const results = await xata.search.all('fruits'); + const { records, totalCount } = await xata.search.all('fruits'); + expect(records).toBeDefined(); - for (const result of results) { + expect(totalCount).toBeGreaterThan(0); + for (const result of records) { if (result.table === 'teams') { expect(result.record.id).toBeDefined(); expect(result.record.read).toBeDefined(); @@ -175,26 +185,33 @@ describe('search', () => { }); test('search all with filters', async () => { - const results = await xata.search.all('fruits', { + const { records, totalCount } = await xata.search.all('fruits', { tables: [{ table: 'teams', filter: { name: 'Team fruits' } }] }); + expect(records).toBeDefined(); - expect(results.length).toBe(1); - expect(results[0].table).toBe('teams'); + expect(totalCount).toBe(1); + expect(records.length).toBe(1); + expect(records[0].table).toBe('teams'); - if (results[0].table === 'teams') { - expect(results[0].record.id).toBeDefined(); - expect(results[0].record.read).toBeDefined(); - expect(results[0].record.name?.includes('fruits')).toBeTruthy(); - expect(results[0].record.getMetadata().score).toBeDefined(); + if (records[0].table === 'teams') { + expect(records[0].record.id).toBeDefined(); + expect(records[0].record.read).toBeDefined(); + expect(records[0].record.name?.includes('fruits')).toBeTruthy(); + expect(records[0].record.getMetadata().score).toBeDefined(); } }); test('search with page and offset', async () => { - const owners = await xata.db.users.search('Owner'); - const page1 = await xata.db.users.search('Owner', { page: { size: 1 } }); - const page2 = await xata.db.users.search('Owner', { page: { size: 1, offset: 1 } }); + const { records: owners, totalCount } = await xata.db.users.search('Owner'); + const { records: page1, totalCount: page1Count } = await xata.db.users.search('Owner', { page: { size: 1 } }); + const { records: page2, totalCount: page2Count } = await xata.db.users.search('Owner', { + page: { size: 1, offset: 1 } + }); + expect(totalCount).toBe(2); + expect(page1Count).toBe(2); + expect(page2Count).toBe(2); expect(page1.length).toBe(1); expect(page2.length).toBe(1); @@ -205,10 +222,15 @@ describe('search', () => { }); test('global search with page and offset', async () => { - const { users: owners = [] } = await xata.search.byTable('Owner'); - const { users: page1 = [] } = await xata.search.byTable('Owner', { page: { size: 1 } }); - const { users: page2 = [] } = await xata.search.byTable('Owner', { page: { size: 1, offset: 1 } }); + const { users: owners = [], totalCount } = await xata.search.byTable('Owner'); + const { users: page1 = [], totalCount: page1Count } = await xata.search.byTable('Owner', { page: { size: 1 } }); + const { users: page2 = [], totalCount: page2Count } = await xata.search.byTable('Owner', { + page: { size: 1, offset: 1 } + }); + expect(totalCount).toBe(2); + expect(page1Count).toBe(2); + expect(page2Count).toBe(2); expect(page1.length).toBe(1); expect(page2.length).toBe(1); diff --git a/test/integration/vectorSearch.test.ts b/test/integration/vectorSearch.test.ts index be337422c..45217146a 100644 --- a/test/integration/vectorSearch.test.ts +++ b/test/integration/vectorSearch.test.ts @@ -39,59 +39,69 @@ afterEach(async (ctx) => { describe('search', () => { test('search 1 2 3 4', async () => { - const results = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4]); + const { records: results, totalCount } = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4]); + expect(totalCount).toEqual(4); expect(results.map((r) => r.full_name)).toEqual(['r4', 'r1', 'r2', 'r3']); }); test('search 0.4 0.3 0.2 0.1', async () => { - const results = await xata.db.users.vectorSearch('vector', [0.4, 0.3, 0.2, 0.1]); + const { records: results, totalCount } = await xata.db.users.vectorSearch('vector', [0.4, 0.3, 0.2, 0.1]); + expect(totalCount).toEqual(4); expect(results.map((r) => r.full_name)).toEqual(['r2', 'r3', 'r4', 'r1']); }); test('with size', async () => { - const results = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { size: 2 }); + const { records: results, totalCount } = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { size: 2 }); + expect(totalCount).toEqual(4); expect(results.map((r) => r.full_name)).toEqual(['r4', 'r1']); }); test('with filter', async () => { - const results = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { + const { records: results, totalCount } = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { filter: { full_name: { $any: ['r3', 'r4'] } } }); + expect(totalCount).toEqual(2); expect(results.map((r) => r.full_name)).toEqual(['r4', 'r3']); }); test('euclidean', async () => { - const results = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { similarityFunction: 'l1' }); + const { records: results, totalCount } = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { + similarityFunction: 'l1' + }); + expect(totalCount).toEqual(4); expect(results.map((r) => r.full_name)).toEqual(['r4', 'r2', 'r1', 'r3']); }); test('larger size', async () => { - const results = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { size: 100 }); + const { records: results, totalCount } = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { size: 100 }); + expect(totalCount).toEqual(4); expect(results.map((r) => r.full_name)).toEqual(['r4', 'r1', 'r2', 'r3']); }); test('with filter and size', async () => { - const results = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { + const { records: results, totalCount } = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { filter: { full_name: { $any: ['r3', 'r4'] } }, size: 1 }); + expect(totalCount).toEqual(2); expect(results.map((r) => r.full_name)).toEqual(['r4']); }); test('with filter and size and spaceFunction', async () => { - const results = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { + const { records: results, totalCount } = await xata.db.users.vectorSearch('vector', [1, 2, 3, 4], { filter: { full_name: { $any: ['r3', 'r4'] } }, size: 1, similarityFunction: 'l1' }); + expect(totalCount).toEqual(2); expect(results.map((r) => r.full_name)).toEqual(['r4']); }); });