diff --git a/packages/core/test/test_bed_spec.ts b/packages/core/test/test_bed_spec.ts index 61b809e2e3..8fc4d85295 100644 --- a/packages/core/test/test_bed_spec.ts +++ b/packages/core/test/test_bed_spec.ts @@ -6,7 +6,7 @@ * found in the LICENSE file at https://angular.io/license */ -import {Component, Directive, ErrorHandler, Inject, Injectable, InjectionToken, Input, NgModule, Optional, Pipe, ɵsetClassMetadata as setClassMetadata, ɵɵdefineComponent as defineComponent, ɵɵdefineNgModule as defineNgModule, ɵɵtext as text} from '@angular/core'; +import {Component, Directive, ErrorHandler, Inject, Injectable, InjectionToken, Input, ModuleWithProviders, NgModule, Optional, Pipe, ɵsetClassMetadata as setClassMetadata, ɵɵdefineComponent as defineComponent, ɵɵdefineNgModule as defineNgModule, ɵɵtext as text} from '@angular/core'; import {TestBed, getTestBed} from '@angular/core/testing/src/test_bed'; import {By} from '@angular/platform-browser'; import {expect} from '@angular/platform-browser/testing/src/matchers'; @@ -229,6 +229,64 @@ describe('TestBed', () => { expect(hello.nativeElement).toHaveText('Hello injected World !'); }); + it('should allow overriding a provider defined via ModuleWithProviders (using TestBed.overrideProvider)', + () => { + const serviceOverride = { + get() { return 'override'; }, + }; + + @Injectable({providedIn: 'root'}) + class MyService { + get() { return 'original'; } + } + + @NgModule({}) + class MyModule { + static forRoot(): ModuleWithProviders { + return { + ngModule: MyModule, + providers: [MyService], + }; + } + } + TestBed.overrideProvider(MyService, {useValue: serviceOverride}); + TestBed.configureTestingModule({ + imports: [MyModule.forRoot()], + }); + + const service = TestBed.get(MyService); + expect(service.get()).toEqual('override'); + }); + + it('should allow overriding a provider defined via ModuleWithProviders (using TestBed.configureTestingModule)', + () => { + const serviceOverride = { + get() { return 'override'; }, + }; + + @Injectable({providedIn: 'root'}) + class MyService { + get() { return 'original'; } + } + + @NgModule({}) + class MyModule { + static forRoot(): ModuleWithProviders { + return { + ngModule: MyModule, + providers: [MyService], + }; + } + } + TestBed.configureTestingModule({ + imports: [MyModule.forRoot()], + providers: [{provide: MyService, useValue: serviceOverride}], + }); + + const service = TestBed.get(MyService); + expect(service.get()).toEqual('override'); + }); + it('allow to override multi provider', () => { const MY_TOKEN = new InjectionToken('MyProvider'); class MyProvider {} diff --git a/packages/core/testing/src/r3_test_bed_compiler.ts b/packages/core/testing/src/r3_test_bed_compiler.ts index 5943e89717..7b710cecd5 100644 --- a/packages/core/testing/src/r3_test_bed_compiler.ts +++ b/packages/core/testing/src/r3_test_bed_compiler.ts @@ -7,9 +7,10 @@ */ import {ResourceLoader} from '@angular/compiler'; -import {ApplicationInitStatus, COMPILER_OPTIONS, Compiler, Component, Directive, Injector, LOCALE_ID, ModuleWithComponentFactories, NgModule, NgModuleFactory, NgZone, Pipe, PlatformRef, Provider, Type, ɵDEFAULT_LOCALE_ID as DEFAULT_LOCALE_ID, ɵDirectiveDef as DirectiveDef, ɵNG_COMPONENT_DEF as NG_COMPONENT_DEF, ɵNG_DIRECTIVE_DEF as NG_DIRECTIVE_DEF, ɵNG_INJECTOR_DEF as NG_INJECTOR_DEF, ɵNG_MODULE_DEF as NG_MODULE_DEF, ɵNG_PIPE_DEF as NG_PIPE_DEF, ɵNgModuleFactory as R3NgModuleFactory, ɵNgModuleTransitiveScopes as NgModuleTransitiveScopes, ɵNgModuleType as NgModuleType, ɵRender3ComponentFactory as ComponentFactory, ɵRender3NgModuleRef as NgModuleRef, ɵcompileComponent as compileComponent, ɵcompileDirective as compileDirective, ɵcompileNgModuleDefs as compileNgModuleDefs, ɵcompilePipe as compilePipe, ɵgetInjectableDef as getInjectableDef, ɵpatchComponentDefWithScope as patchComponentDefWithScope, ɵsetLocaleId as setLocaleId, ɵtransitiveScopesFor as transitiveScopesFor, ɵɵInjectableDef as InjectableDef} from '@angular/core'; +import {ApplicationInitStatus, COMPILER_OPTIONS, Compiler, Component, Directive, Injector, LOCALE_ID, ModuleWithComponentFactories, ModuleWithProviders, NgModule, NgModuleFactory, NgZone, Pipe, PlatformRef, Provider, Type, ɵDEFAULT_LOCALE_ID as DEFAULT_LOCALE_ID, ɵDirectiveDef as DirectiveDef, ɵNG_COMPONENT_DEF as NG_COMPONENT_DEF, ɵNG_DIRECTIVE_DEF as NG_DIRECTIVE_DEF, ɵNG_INJECTOR_DEF as NG_INJECTOR_DEF, ɵNG_MODULE_DEF as NG_MODULE_DEF, ɵNG_PIPE_DEF as NG_PIPE_DEF, ɵNgModuleFactory as R3NgModuleFactory, ɵNgModuleTransitiveScopes as NgModuleTransitiveScopes, ɵNgModuleType as NgModuleType, ɵRender3ComponentFactory as ComponentFactory, ɵRender3NgModuleRef as NgModuleRef, ɵcompileComponent as compileComponent, ɵcompileDirective as compileDirective, ɵcompileNgModuleDefs as compileNgModuleDefs, ɵcompilePipe as compilePipe, ɵgetInjectableDef as getInjectableDef, ɵpatchComponentDefWithScope as patchComponentDefWithScope, ɵsetLocaleId as setLocaleId, ɵtransitiveScopesFor as transitiveScopesFor, ɵɵInjectableDef as InjectableDef} from '@angular/core'; import {clearResolutionOfComponentResourcesQueue, isComponentDefPendingResolution, resolveComponentResources, restoreComponentResolutionQueue} from '../../src/metadata/resource_loading'; + import {MetadataOverride} from './metadata_override'; import {ComponentResolver, DirectiveResolver, NgModuleResolver, PipeResolver, Resolver} from './resolvers'; import {TestModuleMetadata} from './test_bed_common'; @@ -359,11 +360,19 @@ export class R3TestBedCompiler { const injectorDef: any = (moduleType as any)[NG_INJECTOR_DEF]; if (this.providerOverridesByToken.size > 0) { - if (this.hasProviderOverrides(injectorDef.providers)) { + // Extract the list of providers from ModuleWithProviders, so we can define the final list of + // providers that might have overrides. + // Note: second `flatten` operation is needed to convert an array of providers + // (e.g. `[[], []]`) into one flat list, also eliminating empty arrays. + const providersFromModules = flatten(flatten( + injectorDef.imports, (imported: NgModuleType| ModuleWithProviders) => + isModuleWithProviders(imported) ? imported.providers : [])); + const providers = [...providersFromModules, ...injectorDef.providers]; + if (this.hasProviderOverrides(providers)) { this.maybeStoreNgDef(NG_INJECTOR_DEF, moduleType); this.storeFieldOfDefOnType(moduleType, NG_INJECTOR_DEF, 'providers'); - injectorDef.providers = this.getOverriddenProviders(injectorDef.providers); + injectorDef.providers = this.getOverriddenProviders(providers); } // Apply provider overrides to imported modules recursively @@ -695,6 +704,10 @@ function isMultiProvider(provider: Provider) { return !!getProviderField(provider, 'multi'); } +function isModuleWithProviders(value: any): value is ModuleWithProviders { + return value.hasOwnProperty('ngModule'); +} + function forEachRight(values: T[], fn: (value: T, idx: number) => void): void { for (let idx = values.length - 1; idx >= 0; idx--) { fn(values[idx], idx);