Skip to content

Commit

Permalink
fix: resolve issue with ordering of middleware being applied (#189)
Browse files Browse the repository at this point in the history
When setting up the nestjs modules with MikroORM there is a chance it can throw Using global EntityManager instance methods for context specific actions is disallowed. when interacting with the EM within another middleware.

This fix applies to both single and multi database set-ups.
  • Loading branch information
tsangste authored Jan 23, 2025
1 parent b734867 commit 98171c4
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 77 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ More information about [enableShutdownHooks](https://docs.nestjs.com/fundamental

## Multiple Database Connections

You can define multiple database connections by registering multiple `MikroOrmModule` and setting their `contextName`. If you want to use middleware request context you must disable automatic middleware and register `MikroOrmModule` with `forMiddleware()` or use NestJS `Injection Scope`
You can define multiple database connections by registering multiple `MikroOrmModule`'s, each with a unique `contextName`. You will need to disable the automatic request context middleware by setting `registerRequestContext` to `false`, as it wouldn't work with this approach - note that this needs to be part of all your `MikroOrmModule`s with non-default `contextName`. To have the same automatic request context behaviour, you must register `MikroOrmModule` with `forMiddleware()` instead:

```typescript
@Module({
Expand Down
12 changes: 6 additions & 6 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@
},
"peerDependencies": {
"@mikro-orm/core": "^6.0.0 || ^6.0.0-dev.0",
"@nestjs/common": "^10.0.0 || ^11.0.0",
"@nestjs/core": "^10.0.0 || ^11.0.0"
"@nestjs/common": "^10.0.0 || ^11.0.5",
"@nestjs/core": "^10.0.0 || ^11.0.5"
},
"devDependencies": {
"@mikro-orm/core": "^6.2.7",
"@mikro-orm/sqlite": "^6.2.7",
"@nestjs/common": "^11.0.0",
"@nestjs/core": "^11.0.0",
"@nestjs/platform-express": "^11.0.0",
"@nestjs/testing": "^11.0.0",
"@nestjs/common": "^11.0.5",
"@nestjs/core": "^11.0.5",
"@nestjs/platform-express": "^11.0.5",
"@nestjs/testing": "^11.0.5",
"@types/jest": "^29.5.12",
"@types/node": "^22.0.0",
"@types/supertest": "^6.0.2",
Expand Down
14 changes: 12 additions & 2 deletions src/mikro-orm-core.module.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import { Configuration, ConfigurationLoader, EntityManager, MikroORM, type Dictionary } from '@mikro-orm/core';
import { Global, Inject, Module, RequestMethod, type DynamicModule, type MiddlewareConsumer, type OnApplicationShutdown, type Type } from '@nestjs/common';
import {
Global,
Inject,
Module,
RequestMethod,
type DynamicModule,
type MiddlewareConsumer,
type NestModule,
type OnApplicationShutdown,
type Type,
} from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';

import { forRoutesPath } from './middleware.helper';
Expand Down Expand Up @@ -31,7 +41,7 @@ const PACKAGES = {

@Global()
@Module({})
export class MikroOrmCoreModule implements OnApplicationShutdown {
export class MikroOrmCoreModule implements NestModule, OnApplicationShutdown {

constructor(@Inject(MIKRO_ORM_MODULE_OPTIONS)
private readonly options: MikroOrmModuleOptions,
Expand Down
9 changes: 3 additions & 6 deletions src/mikro-orm-middleware.module.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Global, Inject, Module, RequestMethod, type MiddlewareConsumer } from '@nestjs/common';
import { Global, Inject, Module, RequestMethod, type MiddlewareConsumer, type NestModule } from '@nestjs/common';

import type { MikroORM } from '@mikro-orm/core';
import { forRoutesPath } from './middleware.helper';
Expand All @@ -8,15 +8,12 @@ import { MikroOrmMiddlewareModuleOptions } from './typings';

@Global()
@Module({})
export class MikroOrmMiddlewareModule {
export class MikroOrmMiddlewareModule implements NestModule {

constructor(@Inject(MIKRO_ORM_MODULE_OPTIONS)
private readonly options: MikroOrmMiddlewareModuleOptions) { }

static forMiddleware(options?: MikroOrmMiddlewareModuleOptions) {
// Work around due to nestjs not supporting the ability to register multiple types
// https://github.com/nestjs/nest/issues/770
// https://github.com/nestjs/nest/issues/4786#issuecomment-755032258 - workaround suggestion
static forRoot(options?: MikroOrmMiddlewareModuleOptions) {
const inject = CONTEXT_NAMES.map(name => getMikroORMToken(name));
return {
module: MikroOrmMiddlewareModule,
Expand Down
23 changes: 7 additions & 16 deletions src/mikro-orm.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import { MikroOrmCoreModule } from './mikro-orm-core.module';
import { MikroOrmMiddlewareModule } from './mikro-orm-middleware.module';
import { MikroOrmEntitiesStorage } from './mikro-orm.entities.storage';
import { createMikroOrmRepositoryProviders } from './mikro-orm.providers';
import type {
import {
EntityName,
MikroOrmMiddlewareModuleOptions,
MikroOrmModuleAsyncOptions,
MikroOrmModuleFeatureOptions,
MikroOrmModuleSyncOptions,
MikroOrmMiddlewareModuleOptions,
} from './typings';

@Module({})
Expand All @@ -23,18 +23,12 @@ export class MikroOrmModule {
MikroOrmEntitiesStorage.clear(contextName);
}

static forRoot(options?: MikroOrmModuleSyncOptions): DynamicModule {
return {
module: MikroOrmModule,
imports: [MikroOrmCoreModule.forRoot(options)],
};
static forRoot(options?: MikroOrmModuleSyncOptions): DynamicModule | Promise<DynamicModule> {
return MikroOrmCoreModule.forRoot(options);
}

static forRootAsync(options: MikroOrmModuleAsyncOptions): DynamicModule {
return {
module: MikroOrmModule,
imports: [MikroOrmCoreModule.forRootAsync(options)],
};
static forRootAsync(options: MikroOrmModuleAsyncOptions): DynamicModule | Promise<DynamicModule> {
return MikroOrmCoreModule.forRootAsync(options);
}

static forFeature(options: EntityName<AnyEntity>[] | MikroOrmModuleFeatureOptions, contextName?: string): DynamicModule {
Expand All @@ -56,10 +50,7 @@ export class MikroOrmModule {
}

static forMiddleware(options?: MikroOrmMiddlewareModuleOptions): DynamicModule {
return {
module: MikroOrmModule,
imports: [MikroOrmMiddlewareModule.forMiddleware(options)],
};
return MikroOrmMiddlewareModule.forRoot(options);
}

}
3 changes: 2 additions & 1 deletion tests/entities/foo.entity.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { PrimaryKey, Entity } from '@mikro-orm/core';
import { PrimaryKey, Entity, Filter } from '@mikro-orm/core';

@Entity()
@Filter({ name: 'id', cond: args => ({ id: args.id }) })
export class Foo {

@PrimaryKey()
Expand Down
84 changes: 61 additions & 23 deletions tests/mikro-orm.middleware.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import type { Options } from '@mikro-orm/core';
import { MikroORM } from '@mikro-orm/core';
import { EntityManager, MikroORM, type Options } from '@mikro-orm/core';
import { SqliteDriver } from '@mikro-orm/sqlite';
import type { INestApplication } from '@nestjs/common';
import {
Controller,
Get,
Module,
type INestApplication,
Injectable,
type MiddlewareConsumer,
type NestMiddleware,
type NestModule,
} from '@nestjs/common';
import type { TestingModule } from '@nestjs/testing';
import { Test } from '@nestjs/testing';
import { Test, type TestingModule } from '@nestjs/testing';
import request from 'supertest';
import { InjectMikroORM, MikroOrmModule } from '../src';
import { InjectEntityManager, InjectMikroORM, MikroOrmModule } from '../src';
import { Bar } from './entities/bar.entity';
import { Foo } from './entities/foo.entity';

Expand All @@ -21,54 +23,90 @@ const testOptions: Options = {
entities: ['entities'],
};

@Controller()
class TestController {
@Controller('/foo')
class FooController {

constructor(
@InjectMikroORM('database1') private database1: MikroORM,
@InjectMikroORM('database2') private database2: MikroORM,
) {}
constructor(@InjectMikroORM('database-multi-foo') private database1: MikroORM) {}

@Get('foo')
@Get()
foo() {
return this.database1.em !== this.database1.em.getContext();
}

@Get('bar')
}

@Controller('/bar')
class BarController {

constructor(@InjectMikroORM('database-multi-bar') private database2: MikroORM) {}

@Get()
bar() {
return this.database2.em !== this.database2.em.getContext();
}

}

@Injectable()
export class TestMiddleware implements NestMiddleware {

constructor(@InjectEntityManager('database-multi-foo') private readonly em: EntityManager) {}

use(req: unknown, res: unknown, next: (...args: any[]) => void) {
this.em.setFilterParams('id', { id: '1' });

return next();
}

}

@Module({
imports: [MikroOrmModule.forFeature([Foo], 'database-multi-foo')],
controllers: [FooController],
})
class FooModule implements NestModule {

configure(consumer: MiddlewareConsumer): void {
consumer
.apply(TestMiddleware)
.forRoutes('/');
}

}

@Module({
imports: [MikroOrmModule.forFeature([Bar], 'database-multi-bar')],
controllers: [BarController],
})
class BarModule {}

@Module({
imports: [
MikroOrmModule.forRootAsync({
contextName: 'database1',
contextName: 'database-multi-foo',
useFactory: () => ({
registerRequestContext: false,
...testOptions,
}),
}),
MikroOrmModule.forRoot({
contextName: 'database2',
contextName: 'database-multi-bar',
registerRequestContext: false,
...testOptions,
}),
MikroOrmModule.forMiddleware(),
MikroOrmModule.forFeature([Foo], 'database1'),
MikroOrmModule.forFeature([Bar], 'database2'),
FooModule,
BarModule,
],
controllers: [TestController],
})
class TestModule {}
class TestMultiModule {}

describe('Middleware executes request context for all MikroORM registered', () => {
describe('Multiple Middleware executes request context for all MikroORM registered', () => {
let app: INestApplication;

beforeAll(async () => {
const moduleFixture: TestingModule = await Test.createTestingModule({
imports: [TestModule],
imports: [TestMultiModule],
}).compile();

app = moduleFixture.createNestApplication();
Expand All @@ -81,7 +119,7 @@ describe('Middleware executes request context for all MikroORM registered', () =
});

it(`forRoutes(/bar) should return 'true'`, () => {
return request(app.getHttpServer()).get('/foo').expect(200, 'true');
return request(app.getHttpServer()).get('/bar').expect(200, 'true');
});

afterAll(async () => {
Expand Down
94 changes: 94 additions & 0 deletions tests/mikro-orm.module-middleware.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { EntityManager, MikroORM, type Options } from '@mikro-orm/core';
import { SqliteDriver } from '@mikro-orm/sqlite';
import {
Controller,
Get,
Module,
type INestApplication,
Injectable,
type MiddlewareConsumer,
type NestMiddleware,
type NestModule,
} from '@nestjs/common';
import { Test, type TestingModule } from '@nestjs/testing';
import request from 'supertest';
import { MikroOrmModule } from '../src';
import { Foo } from './entities/foo.entity';

const testOptions: Options = {
dbName: ':memory:',
driver: SqliteDriver,
baseDir: __dirname,
entities: ['entities'],
};

@Controller('/foo')
class FooController {

constructor(private database1: MikroORM) {}

@Get()
foo() {
return this.database1.em !== this.database1.em.getContext();
}

}

@Injectable()
export class TestMiddleware implements NestMiddleware {

constructor(private readonly em: EntityManager) {}

use(req: unknown, res: unknown, next: (...args: any[]) => void) {
this.em.setFilterParams('id', { id: '1' });

return next();
}

}

@Module({
imports: [MikroOrmModule.forFeature([Foo])],
controllers: [FooController],
})
class FooModule implements NestModule {

configure(consumer: MiddlewareConsumer): void {
consumer
.apply(TestMiddleware)
.forRoutes('/');
}

}

@Module({
imports: [
MikroOrmModule.forRootAsync({
useFactory: () => testOptions,
}),
FooModule,
],
})
class TestModule {}

describe('Middleware executes request context', () => {
let app: INestApplication;

beforeAll(async () => {
const moduleFixture: TestingModule = await Test.createTestingModule({
imports: [TestModule],
}).compile();

app = moduleFixture.createNestApplication();

await app.init();
});

it(`forRoutes(/foo) should return 'true'`, () => {
return request(app.getHttpServer()).get('/foo').expect(200, 'true');
});

afterAll(async () => {
await app.close();
});
});
Loading

0 comments on commit 98171c4

Please sign in to comment.