diff --git a/packages/core/test/test_bed_spec.ts b/packages/core/test/test_bed_spec.ts index d9da302e95..2e14c67da8 100644 --- a/packages/core/test/test_bed_spec.ts +++ b/packages/core/test/test_bed_spec.ts @@ -229,6 +229,27 @@ describe('TestBed', () => { expect(hello.nativeElement).toHaveText('Hello injected World !'); }); + it('allow to override multi provider', () => { + const MY_TOKEN = new InjectionToken('MyProvider'); + class MyProvider {} + + @Component({selector: 'my-comp', template: ``}) + class MyComp { + constructor(@Inject(MY_TOKEN) public myProviders: MyProvider[]) {} + } + + TestBed.configureTestingModule({ + declarations: [MyComp], + providers: [{provide: MY_TOKEN, useValue: {value: 'old provider'}, multi: true}] + }); + + const multiOverride = {useValue: [{value: 'new provider'}], multi: true}; + TestBed.overrideProvider(MY_TOKEN, multiOverride as any); + + const fixture = TestBed.createComponent(MyComp); + expect(fixture.componentInstance.myProviders).toEqual([{value: 'new provider'}]); + }); + it('should resolve components that are extended by other components', () => { // SimpleApp uses SimpleCmp in its template, which is extended by InheritedCmp const simpleApp = TestBed.createComponent(SimpleApp); diff --git a/packages/core/testing/src/r3_test_bed_compiler.ts b/packages/core/testing/src/r3_test_bed_compiler.ts index aac9429371..3b518e2095 100644 --- a/packages/core/testing/src/r3_test_bed_compiler.ts +++ b/packages/core/testing/src/r3_test_bed_compiler.ts @@ -73,7 +73,7 @@ export class R3TestBedCompiler { private providerOverrides: Provider[] = []; private rootProviderOverrides: Provider[] = []; - private providerOverridesByToken = new Map(); + private providerOverridesByToken = new Map(); private moduleProvidersOverridden = new Set>(); private testModuleType: NgModuleType; @@ -142,11 +142,17 @@ export class R3TestBedCompiler { this.pendingPipes.add(pipe); } - overrideProvider(token: any, provider: {useFactory?: Function, useValue?: any, deps?: any[]}): - void { + overrideProvider( + token: any, + provider: {useFactory?: Function, useValue?: any, deps?: any[], multi?: boolean}): void { const providerDef = provider.useFactory ? - {provide: token, useFactory: provider.useFactory, deps: provider.deps || []} : - {provide: token, useValue: provider.useValue}; + { + provide: token, + useFactory: provider.useFactory, + deps: provider.deps || [], + multi: provider.multi, + } : + {provide: token, useValue: provider.useValue, multi: provider.multi}; let injectableDef: InjectableDef|null; const isRoot = @@ -155,10 +161,8 @@ export class R3TestBedCompiler { const overridesBucket = isRoot ? this.rootProviderOverrides : this.providerOverrides; overridesBucket.push(providerDef); - // Keep all overrides grouped by token as well for fast lookups using token - const overridesForToken = this.providerOverridesByToken.get(token) || []; - overridesForToken.push(providerDef); - this.providerOverridesByToken.set(token, overridesForToken); + // Keep overrides grouped by token as well for fast lookups using token + this.providerOverridesByToken.set(token, providerDef); } overrideTemplateUsingTestingModule(type: Type, template: string): void { @@ -349,10 +353,7 @@ export class R3TestBedCompiler { this.maybeStoreNgDef(NG_INJECTOR_DEF, moduleType); this.storeFieldOfDefOnType(moduleType, NG_INJECTOR_DEF, 'providers'); - injectorDef.providers = [ - ...injectorDef.providers, // - ...this.getProviderOverrides(injectorDef.providers) - ]; + injectorDef.providers = this.getOverriddenProviders(injectorDef.providers); } // Apply provider overrides to imported modules recursively @@ -561,11 +562,9 @@ export class R3TestBedCompiler { } // get overrides for a specific provider (if any) - private getSingleProviderOverrides(provider: Provider&{provide?: any}): Provider[] { - const token = provider && typeof provider === 'object' && provider.hasOwnProperty('provide') ? - provider.provide : - provider; - return this.providerOverridesByToken.get(token) || []; + private getSingleProviderOverrides(provider: Provider): Provider|null { + const token = getProviderToken(provider); + return this.providerOverridesByToken.get(token) || null; } private getProviderOverrides(providers?: Provider[]): Provider[] { @@ -575,8 +574,47 @@ export class R3TestBedCompiler { // provider. The outer flatten() then flattens the produced overrides array. If this is not // done, the array can contain other empty arrays (e.g. `[[], []]`) which leak into the // providers array and contaminate any error messages that might be generated. - return flatten( - flatten(providers, (provider: Provider) => this.getSingleProviderOverrides(provider))); + return flatten(flatten( + providers, (provider: Provider) => this.getSingleProviderOverrides(provider) || [])); + } + + private getOverriddenProviders(providers?: Provider[]): Provider[] { + if (!providers || !providers.length || this.providerOverridesByToken.size === 0) return []; + + const overrides = this.getProviderOverrides(providers); + const hasMultiProviderOverrides = overrides.some(isMultiProvider); + const overriddenProviders = [...providers, ...overrides]; + + // No additional processing is required in case we have no multi providers to override + if (!hasMultiProviderOverrides) { + return overriddenProviders; + } + + const final: Provider[] = []; + const seenMultiProviders = new Set(); + + // We iterate through the list of providers in reverse order to make sure multi provider + // overrides take precedence over the values defined in provider list. We also fiter out all + // multi providers that have overrides, keeping overridden values only. + forEachRight(overriddenProviders, (provider: any) => { + const token: any = getProviderToken(provider); + if (isMultiProvider(provider) && this.providerOverridesByToken.has(token)) { + if (!seenMultiProviders.has(token)) { + seenMultiProviders.add(token); + if (provider && provider.useValue && Array.isArray(provider.useValue)) { + forEachRight(provider.useValue, (value: any) => { + // Unwrap provider override array into individual providers in final set. + final.unshift({provide: token, useValue: value, multi: true}); + }); + } else { + final.unshift(provider); + } + } + } else { + final.unshift(provider); + } + }); + return final; } private hasProviderOverrides(providers?: Provider[]): boolean { @@ -589,10 +627,7 @@ export class R3TestBedCompiler { this.maybeStoreNgDef(field, declaration); const resolver = def.providersResolver; - const processProvidersFn = (providers: Provider[]) => { - const overrides = this.getProviderOverrides(providers); - return [...providers, ...overrides]; - }; + const processProvidersFn = (providers: Provider[]) => this.getOverriddenProviders(providers); this.storeFieldOfDefOnType(declaration, field, 'providersResolver'); def.providersResolver = (ngDef: DirectiveDef) => resolver(ngDef, processProvidersFn); } @@ -628,6 +663,24 @@ function flatten(values: any[], mapFn?: (value: T) => any): T[] { return out; } +function getProviderField(provider: Provider, field: string) { + return provider && typeof provider === 'object' && (provider as any)[field]; +} + +function getProviderToken(provider: Provider) { + return getProviderField(provider, 'provide') || provider; +} + +function isMultiProvider(provider: Provider) { + return !!getProviderField(provider, 'multi'); +} + +function forEachRight(values: T[], fn: (value: T, idx: number) => void): void { + for (let idx = values.length - 1; idx >= 0; idx--) { + fn(values[idx], idx); + } +} + class R3TestCompiler implements Compiler { constructor(private testBed: R3TestBedCompiler) {}