Skip to content

Commit

Permalink
feat: add total count to search responses
Browse files Browse the repository at this point in the history
  • Loading branch information
eemmiillyy committed Oct 31, 2023
1 parent 48cb976 commit d08fa21
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 87 deletions.
5 changes: 5 additions & 0 deletions .changeset/tidy-readers-dance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@xata.io/client': minor
---

Add support for totalCount on search responses
22 changes: 14 additions & 8 deletions packages/client/src/schema/repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import {
TransactionOperation
} from '../api/schemas';
import { XataPluginOptions } from '../plugins';
import { SearchXataRecord } from '../search';
import { SearchXataRecord, TotalCount } from '../search';
import { Boosters } from '../search/boosters';
import { TargetColumn } from '../search/target';
import { chunk, compact, isDefined, isNumber, isObject, isString, promiseMap } from '../util/lang';
Expand Down Expand Up @@ -748,7 +748,7 @@ export abstract class Repository<Record extends XataRecord> extends Query<
page?: SearchPageConfig;
target?: TargetColumn<Record>[];
}
): Promise<SearchXataRecord<SelectedPick<Record, ['*']>>[]>;
): Promise<{ records: SearchXataRecord<SelectedPick<Record, ['*']>>[] } & TotalCount>;

/**
* Search for vectors in the table.
Expand Down Expand Up @@ -777,7 +777,7 @@ export abstract class Repository<Record extends XataRecord> extends Query<
size?: number;
filter?: Filter<Record>;
}
): Promise<SearchXataRecord<SelectedPick<Record, ['*']>>[]>;
): Promise<{ records: SearchXataRecord<SelectedPick<Record, ['*']>>[] } & TotalCount>;

/**
* Aggregates records in the table.
Expand Down Expand Up @@ -1761,7 +1761,7 @@ export class RestRepository<Record extends XataRecord>
} = {}
) {
return this.#trace('search', async () => {
const { records } = await searchTable({
const { records, totalCount } = await searchTable({
pathParams: {
workspace: '{workspaceId}',
dbBranchName: '{dbBranch}',
Expand All @@ -1784,7 +1784,10 @@ export class RestRepository<Record extends XataRecord>
const schemaTables = await this.#getSchemaTables();

// TODO - Column selection not supported by search endpoint yet
return records.map((item) => initObject(this.#db, schemaTables, this.#table, item, ['*'])) as any;
return {
records: records.map((item) => initObject(this.#db, schemaTables, this.#table, item, ['*'])) as any,
totalCount
};
});
}

Expand All @@ -1798,9 +1801,9 @@ export class RestRepository<Record extends XataRecord>
filter?: Filter<Record> | undefined;
}
| undefined
): Promise<SearchXataRecord<SelectedPick<Record, ['*']>>[]> {
): Promise<{ records: SearchXataRecord<SelectedPick<Record, ['*']>>[] } & TotalCount> {
return this.#trace('vectorSearch', async () => {
const { records } = await vectorSearchTable({
const { records, totalCount } = await vectorSearchTable({
pathParams: {
workspace: '{workspaceId}',
dbBranchName: '{dbBranch}',
Expand All @@ -1820,7 +1823,10 @@ export class RestRepository<Record extends XataRecord>
const schemaTables = await this.#getSchemaTables();

// TODO - Column selection not supported by search endpoint yet
return records.map((item) => initObject(this.#db, schemaTables, this.#table, item, ['*'])) as any;
return {
records: records.map((item) => initObject(this.#db, schemaTables, this.#table, item, ['*'])),
totalCount
} as any;
});
}

Expand Down
79 changes: 45 additions & 34 deletions packages/client/src/search/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getBranchDetails, searchBranch } from '../api';
import { Responses, getBranchDetails, searchBranch } from '../api';
import { FuzzinessExpression, HighlightExpression, PrefixExpression, SearchPageConfig, Table } from '../api/schemas';
import { XataPlugin, XataPluginOptions } from '../plugins';
import { SchemaPluginResult } from '../schema';
Expand Down Expand Up @@ -28,32 +28,38 @@ export type SearchOptions<Schemas extends Record<string, BaseData>, Tables exten
page?: SearchPageConfig;
};

export type TotalCount = Pick<Responses.SearchResponse, 'totalCount'>;

export type SearchPluginResult<Schemas extends Record<string, BaseData>> = {
all: <Tables extends StringKeys<Schemas>>(
query: string,
options?: SearchOptions<Schemas, Tables>
) => Promise<
Values<{
TotalCount & {
records: Values<{
[Model in ExtractTables<
Schemas,
Tables,
GetArrayInnerType<NonNullable<NonNullable<typeof options>['tables']>>
>]: {
table: Model;
record: Awaited<SearchXataRecord<SelectedPick<Schemas[Model] & XataRecord, ['*']>>>;
};
}>[];
}
>;
byTable: <Tables extends StringKeys<Schemas>>(
query: string,
options?: SearchOptions<Schemas, Tables>
) => Promise<
TotalCount & {
[Model in ExtractTables<
Schemas,
Tables,
GetArrayInnerType<NonNullable<NonNullable<typeof options>['tables']>>
>]: {
table: Model;
record: Awaited<SearchXataRecord<SelectedPick<Schemas[Model] & XataRecord, ['*']>>>;
};
}>[]
>]?: Awaited<SearchXataRecord<SelectedPick<Schemas[Model] & XataRecord, ['*']>>[]>;
}
>;
byTable: <Tables extends StringKeys<Schemas>>(
query: string,
options?: SearchOptions<Schemas, Tables>
) => Promise<{
[Model in ExtractTables<
Schemas,
Tables,
GetArrayInnerType<NonNullable<NonNullable<typeof options>['tables']>>
>]?: Awaited<SearchXataRecord<SelectedPick<Schemas[Model] & XataRecord, ['*']>>[]>;
}>;
};

export class SearchPlugin<Schemas extends Record<string, XataRecord>> extends XataPlugin {
Expand All @@ -67,32 +73,37 @@ export class SearchPlugin<Schemas extends Record<string, XataRecord>> extends Xa
build(pluginOptions: XataPluginOptions): SearchPluginResult<Schemas> {
return {
all: async <Tables extends StringKeys<Schemas>>(query: string, options: SearchOptions<Schemas, Tables> = {}) => {
const records = await this.#search(query, options, pluginOptions);
const { records, totalCount } = await this.#search(query, options, pluginOptions);
const schemaTables = await this.#getSchemaTables(pluginOptions);

return records.map((record) => {
const { table = 'orphan' } = record.xata;

// TODO: Search endpoint doesn't support column selection
return { table, record: initObject(this.db, schemaTables, table, record, ['*']) } as any;
});
return {
totalCount,
records: records.map((record) => {
const { table = 'orphan' } = record.xata;
// TODO: Search endpoint doesn't support column selection
return { table, record: initObject(this.db, schemaTables, table, record, ['*']) } as any;
})
};
},
byTable: async <Tables extends StringKeys<Schemas>>(
query: string,
options: SearchOptions<Schemas, Tables> = {}
) => {
const records = await this.#search(query, options, pluginOptions);
const { records, totalCount } = await this.#search(query, options, pluginOptions);
const schemaTables = await this.#getSchemaTables(pluginOptions);

return records.reduce((acc, record) => {
const { table = 'orphan' } = record.xata;
return records.reduce(
(acc, record) => {
const { table = 'orphan' } = record.xata;

const items = acc[table] ?? [];
// TODO: Search endpoint doesn't support column selection
const item = initObject(this.db, schemaTables, table, record, ['*']);
const items = acc[table] ?? [];
// TODO: Search endpoint doesn't support column selection
const item = initObject(this.db, schemaTables, table, record, ['*']);

return { ...acc, [table]: [...items, item] };
}, {} as any);
return { ...acc, [table]: [...items, item] };
},
{ totalCount } as any
);
}
};
}
Expand All @@ -104,14 +115,14 @@ export class SearchPlugin<Schemas extends Record<string, XataRecord>> extends Xa
) {
const { tables, fuzziness, highlight, prefix, page } = options ?? {};

const { records } = await searchBranch({
const { records, totalCount } = await searchBranch({
pathParams: { workspace: '{workspaceId}', dbBranchName: '{dbBranch}', region: '{region}' },
// @ts-ignore https://github.com/xataio/client-ts/issues/313
body: { tables, query, fuzziness, prefix, highlight, page },
...pluginOptions
});

return records;
return { records, totalCount };
}

async #getSchemaTables(pluginOptions: XataPluginOptions): Promise<Table[]> {
Expand Down
96 changes: 59 additions & 37 deletions test/integration/search.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,35 @@ afterEach(async (ctx) => {

describe('search', () => {
test('search in table', async () => {
const owners = await xata.db.users.search('Owner');
expect(owners.length).toBeGreaterThan(0);

expect(owners.length).toBe(2);
expect(owners[0].id).toBeDefined();
expect(owners[0].full_name?.includes('Owner')).toBeTruthy();
expect(owners[0].read).toBeDefined();
expect(owners[0].getMetadata().score).toBeDefined();
expect(owners[0].getMetadata().table).toBe('users');
const { records, totalCount } = await xata.db.users.search('Owner');
expect(totalCount).toBe(2);
expect(records.length).toBeGreaterThan(0);

expect(records.length).toBe(2);
expect(records[0].id).toBeDefined();
expect(records[0].full_name?.includes('Owner')).toBeTruthy();
expect(records[0].read).toBeDefined();
expect(records[0].getMetadata().score).toBeDefined();
expect(records[0].getMetadata().table).toBe('users');
});

test('search in table with filtering', async () => {
const owners = await xata.db.users.search('Owner', {
const { records, totalCount } = await xata.db.users.search('Owner', {
filter: { full_name: 'Owner of team animals' }
});

expect(owners.length).toBe(1);
expect(owners[0].id).toBeDefined();
expect(owners[0].full_name?.includes('Owner of team animals')).toBeTruthy();
expect(owners[0].read).toBeDefined();
expect(owners[0].getMetadata().score).toBeDefined();
expect(totalCount).toBe(1);
expect(records.length).toBe(1);
expect(records[0].id).toBeDefined();
expect(records[0].full_name?.includes('Owner of team animals')).toBeTruthy();
expect(records[0].read).toBeDefined();
expect(records[0].getMetadata().score).toBeDefined();
});

test('search by tables with multiple tables', async () => {
const { users = [], teams = [] } = await xata.search.byTable('fruits', { tables: ['teams', 'users'] });
const { users = [], teams = [], totalCount } = await xata.search.byTable('fruits', { tables: ['teams', 'users'] });

expect(totalCount).toBeGreaterThan(0);
expect(users.length).toBeGreaterThan(0);
expect(teams.length).toBeGreaterThan(0);

Expand All @@ -101,8 +104,9 @@ describe('search', () => {
});

test('search by table with all tables', async () => {
const { users = [], teams = [] } = await xata.search.byTable('fruits');
const { users = [], teams = [], totalCount } = await xata.search.byTable('fruits');

expect(totalCount).toBeGreaterThan(0);
expect(users.length).toBeGreaterThan(0);
expect(teams.length).toBeGreaterThan(0);

Expand All @@ -118,9 +122,11 @@ describe('search', () => {
});

test('search all with multiple tables', async () => {
const results = await xata.search.all('fruits', { tables: ['teams', 'users'] });
const { records, totalCount } = await xata.search.all('fruits', { tables: ['teams', 'users'] });
expect(records).toBeDefined();

for (const result of results) {
expect(totalCount).toBeGreaterThan(0);
for (const result of records) {
if (result.table === 'teams') {
expect(result.record.id).toBeDefined();
expect(result.record.read).toBeDefined();
Expand All @@ -138,9 +144,11 @@ describe('search', () => {
});

test('search all with one table', async () => {
const results = await xata.search.all('fruits', { tables: ['teams'] });
const { records, totalCount } = await xata.search.all('fruits', { tables: ['teams'] });
expect(records).toBeDefined();

for (const result of results) {
expect(totalCount).toBeGreaterThan(0);
for (const result of records) {
expect(result.record.id).toBeDefined();
expect(result.record.read).toBeDefined();
expect(result.record.name?.includes('fruits')).toBeTruthy();
Expand All @@ -152,9 +160,11 @@ describe('search', () => {
});

test('search all with all tables', async () => {
const results = await xata.search.all('fruits');
const { records, totalCount } = await xata.search.all('fruits');
expect(records).toBeDefined();

for (const result of results) {
expect(totalCount).toBeGreaterThan(0);
for (const result of records) {
if (result.table === 'teams') {
expect(result.record.id).toBeDefined();
expect(result.record.read).toBeDefined();
Expand All @@ -175,26 +185,33 @@ describe('search', () => {
});

test('search all with filters', async () => {
const results = await xata.search.all('fruits', {
const { records, totalCount } = await xata.search.all('fruits', {
tables: [{ table: 'teams', filter: { name: 'Team fruits' } }]
});
expect(records).toBeDefined();

expect(results.length).toBe(1);
expect(results[0].table).toBe('teams');
expect(totalCount).toBe(1);
expect(records.length).toBe(1);
expect(records[0].table).toBe('teams');

if (results[0].table === 'teams') {
expect(results[0].record.id).toBeDefined();
expect(results[0].record.read).toBeDefined();
expect(results[0].record.name?.includes('fruits')).toBeTruthy();
expect(results[0].record.getMetadata().score).toBeDefined();
if (records[0].table === 'teams') {
expect(records[0].record.id).toBeDefined();
expect(records[0].record.read).toBeDefined();
expect(records[0].record.name?.includes('fruits')).toBeTruthy();
expect(records[0].record.getMetadata().score).toBeDefined();
}
});

test('search with page and offset', async () => {
const owners = await xata.db.users.search('Owner');
const page1 = await xata.db.users.search('Owner', { page: { size: 1 } });
const page2 = await xata.db.users.search('Owner', { page: { size: 1, offset: 1 } });
const { records: owners, totalCount } = await xata.db.users.search('Owner');
const { records: page1, totalCount: page1Count } = await xata.db.users.search('Owner', { page: { size: 1 } });
const { records: page2, totalCount: page2Count } = await xata.db.users.search('Owner', {
page: { size: 1, offset: 1 }
});

expect(totalCount).toBe(2);
expect(page1Count).toBe(2);
expect(page2Count).toBe(2);
expect(page1.length).toBe(1);
expect(page2.length).toBe(1);

Expand All @@ -205,10 +222,15 @@ describe('search', () => {
});

test('global search with page and offset', async () => {
const { users: owners = [] } = await xata.search.byTable('Owner');
const { users: page1 = [] } = await xata.search.byTable('Owner', { page: { size: 1 } });
const { users: page2 = [] } = await xata.search.byTable('Owner', { page: { size: 1, offset: 1 } });
const { users: owners = [], totalCount } = await xata.search.byTable('Owner');
const { users: page1 = [], totalCount: page1Count } = await xata.search.byTable('Owner', { page: { size: 1 } });
const { users: page2 = [], totalCount: page2Count } = await xata.search.byTable('Owner', {
page: { size: 1, offset: 1 }
});

expect(totalCount).toBe(2);
expect(page1Count).toBe(2);
expect(page2Count).toBe(2);
expect(page1.length).toBe(1);
expect(page2.length).toBe(1);

Expand Down
Loading

0 comments on commit d08fa21

Please sign in to comment.