diff --git a/Stylet/StyletIoC/ICreator.cs b/Stylet/StyletIoC/ICreator.cs index f32f203..8f94cb6 100644 --- a/Stylet/StyletIoC/ICreator.cs +++ b/Stylet/StyletIoC/ICreator.cs @@ -82,10 +82,7 @@ namespace StyletIoC private string KeyForParameter(ParameterInfo parameter) { - var attributes = parameter.GetCustomAttributes(typeof(InjectAttribute)); - if (attributes == null) - return null; - var attribute = (InjectAttribute)attributes.FirstOrDefault(); + var attribute = parameter.GetCustomAttributes(typeof(InjectAttribute)).FirstOrDefault() as InjectAttribute; return attribute == null ? null : attribute.Key; } diff --git a/Stylet/StyletIoC/IRegistration.cs b/Stylet/StyletIoC/IRegistration.cs index 6490bf6..281c96e 100644 --- a/Stylet/StyletIoC/IRegistration.cs +++ b/Stylet/StyletIoC/IRegistration.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -133,9 +134,10 @@ namespace StyletIoC } } - internal class PerContainerRegistrations : RegistrationBase + internal class PerContainerRegistration : RegistrationBase { private readonly IRegistrationContext parentContext; + private readonly string key; private readonly object instanceFactoryLock = new object(); private Func instanceFactory; private object instance; @@ -143,10 +145,11 @@ namespace StyletIoC private static readonly MethodInfo getMethod = typeof(IContainer).GetMethod("Get", new[] { typeof(Type), typeof(string) }); - public PerContainerRegistrations(IRegistrationContext parentContext, ICreator creator, Func instanceFactory = null) + public PerContainerRegistration(IRegistrationContext parentContext, ICreator creator, string key, Func instanceFactory = null) : base(creator) { this.parentContext = parentContext; + this.key = key; this.instanceFactory = instanceFactory; this.parentContext.Disposing += (o, e) => @@ -177,34 +180,28 @@ namespace StyletIoC protected override Func GetGeneratorInternal() { // If the context is our parent context, then everything's fine and we can return our instance - // If not, we need to call Get on the current context, and a different instance of us will be invoked again by that + // If not, well, this should never happen. When we're cloned to the new context, we set ourselves up with the new parent return ctx => { - if (ctx != this.parentContext) - { - return ctx.Get(this.Type); - } - else - { - if (this.disposed) - throw new ObjectDisposedException(String.Format("ChildContainer registration for type {0}", this.Type.Description())); + Debug.Assert(ctx == this.parentContext); + if (this.disposed) + throw new ObjectDisposedException(String.Format("ChildContainer registration for type {0}", this.Type.Description())); - if (this.instance != null) - return this.instance; - - this.EnsureInstanceFactoryCreated(); - - var instance = this.instanceFactory(ctx); - Interlocked.CompareExchange(ref this.instance, instance, null); + if (this.instance != null) return this.instance; - } + + this.EnsureInstanceFactoryCreated(); + + var instance = this.instanceFactory(ctx); + Interlocked.CompareExchange(ref this.instance, instance, null); + return this.instance; }; } public override Expression GetInstanceExpression(ParameterExpression registrationContext) { // Always synthesize into a method call onto the current context - var call = Expression.Call(registrationContext, getMethod, Expression.Constant(this.Type)); + var call = Expression.Call(registrationContext, getMethod, Expression.Constant(this.Type), Expression.Constant(this.key, typeof(string))); var cast = Expression.Convert(call, this.Type); return cast; } @@ -213,7 +210,7 @@ namespace StyletIoC { // Ensure the factory's created, and pass it down. This means the work of compiling the creation expression is done once, ever this.EnsureInstanceFactoryCreated(); - return new PerContainerRegistrations(context, this.creator, this.instanceFactory); + return new PerContainerRegistration(context, this.creator, this.key, this.instanceFactory); } } @@ -346,7 +343,7 @@ namespace StyletIoC public IRegistration CloneToContext(IRegistrationContext context) { - return this; + throw new InvalidOperationException("should not be cloned"); } } } diff --git a/Stylet/StyletIoC/StyletIoCBuilder.cs b/Stylet/StyletIoC/StyletIoCBuilder.cs index 8cf66c0..ed5841f 100644 --- a/Stylet/StyletIoC/StyletIoCBuilder.cs +++ b/Stylet/StyletIoC/StyletIoCBuilder.cs @@ -100,7 +100,7 @@ namespace StyletIoC public static void InPerContainerScope(this IInScope builder) { - builder.WithRegistrationFactory((ctx, creator, key) => new PerContainerRegistrations(ctx, creator)); + builder.WithRegistrationFactory((ctx, creator, key) => new PerContainerRegistration(ctx, creator, key)); } } diff --git a/StyletUnitTests/StyletIoC/StyletIoCChildContainerTests.cs b/StyletUnitTests/StyletIoC/StyletIoCChildContainerTests.cs index 8bda754..0255efa 100644 --- a/StyletUnitTests/StyletIoC/StyletIoCChildContainerTests.cs +++ b/StyletUnitTests/StyletIoC/StyletIoCChildContainerTests.cs @@ -134,7 +134,7 @@ namespace StyletUnitTests } [Test] - public void CreatingSameBindingOnParentAndChildCausesMultipleRegistrations() + public void CreatingSameBindingOnParentAndChildCausesMultipleRegistrations_1() { var builder = new StyletIoCBuilder(); builder.Bind().To(); @@ -146,8 +146,8 @@ namespace StyletUnitTests var r = child.GetAll(); - Assert.AreEqual(2, child.GetAll().Count()); Assert.AreEqual(1, parent.GetAll().Count()); + Assert.AreEqual(2, child.GetAll().Count()); } [Test] @@ -264,6 +264,20 @@ namespace StyletUnitTests Assert.AreNotEqual(parent.Get(), child.Get()); } + [Test] + public void KeyedChildContainerScopeHasOneInstancePerScope() + { + var builder = new StyletIoCBuilder(); + builder.Bind().ToSelf().WithKey("foo").InPerContainerScope(); + var parent = builder.BuildContainer(); + + var child = parent.CreateChildBuilder().BuildContainer(); + + Assert.AreEqual(parent.Get("foo"), parent.Get("foo")); + Assert.AreEqual(child.Get("foo"), child.Get("foo")); + Assert.AreNotEqual(parent.Get("foo"), child.Get("foo")); + } + [Test] public void ChildContainerScopeDisposalDisposesCorrectThing() { @@ -281,5 +295,48 @@ namespace StyletUnitTests Assert.True(childs.Disposed); Assert.False(parents.Disposed); } + + [Test] + public void UsingPerContainerRegistrationAfterDisposalPromptsException() + { + var builder = new StyletIoCBuilder(); + builder.Bind().ToSelf().InPerContainerScope(); + var ioc = builder.BuildContainer(); + + ioc.Dispose(); + Assert.Throws(() => ioc.Get()); + } + + [Test] + public void FuncFactoryFetchesInstanceFromCorrectChild() + { + var builder = new StyletIoCBuilder(); + builder.Bind().ToSelf().InPerContainerScope(); + var parent = builder.BuildContainer(); + + var child = parent.CreateChildBuilder().BuildContainer(); + + var funcFromParent = parent.Get>(); + var funcFromChild = child.Get>(); + + Assert.AreEqual(parent.Get(), funcFromParent()); + Assert.AreEqual(child.Get(), funcFromChild()); + } + + [Test] + public void FuncFactoryWithKeyFetchesInstanceFromCorrectChild() + { + var builder = new StyletIoCBuilder(); + builder.Bind().ToSelf().WithKey("foo").InPerContainerScope(); + var parent = builder.BuildContainer(); + + var child = parent.CreateChildBuilder().BuildContainer(); + + var funcFromParent = parent.Get>(); + var funcFromChild = child.Get>(); + + Assert.AreEqual(parent.Get("foo"), funcFromParent("foo")); + Assert.AreEqual(child.Get("foo"), funcFromChild("foo")); + } } }