diff --git a/.gitignore b/.gitignore index 71367bc..af8ef04 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ packages/ installer/Output *.gjq +*.tmp diff --git a/Stylet/StyletIoC.cs b/Stylet/StyletIoC.cs index 1642a4a..506ec03 100644 --- a/Stylet/StyletIoC.cs +++ b/Stylet/StyletIoC.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Reflection.Emit; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -33,6 +34,7 @@ namespace Stylet void To(string key = null) where TImplementation : class; void To(Type implementationType, string key = null); void ToFactory(Func factory, string key = null) where TImplementation : class; + void ToAbstractFactory(string key = null); void ToAllImplementations(string key = null, params Assembly[] assembly); } @@ -45,12 +47,23 @@ namespace Stylet { #region Main Class + public static readonly string FactoryAssemblyName = "StyletIoCFactory"; + private ConcurrentDictionary registrations = new ConcurrentDictionary(); private ConcurrentDictionary getAllRegistrations = new ConcurrentDictionary(); + // The list object is used for locking it private ConcurrentDictionary> unboundGenerics = new ConcurrentDictionary>(); + private ModuleBuilder factoryBuilder; + private ConcurrentDictionary factories = new ConcurrentDictionary(); + private bool compilationStarted; + public StyletIoC() + { + this.BindSingleton().ToFactory(c => this); + } + public void AutoBind(params Assembly[] assemblies) { if (assemblies == null || assemblies.Length == 0) @@ -318,6 +331,105 @@ namespace Stylet } } + private Type GetFactoryForType(Type serviceType) + { + if (!serviceType.IsInterface) + throw new StyletIoCCreateFactoryException(String.Format("Unable to create a factory implementing type {0}, as it isn't an interface", serviceType.Name)); + + // Have we built it already? + Type factoryType; + if (this.factories.TryGetValue(serviceType, out factoryType)) + return factoryType; + + if (this.factoryBuilder == null) + { + var assemblyName = new AssemblyName(FactoryAssemblyName); + var assemblyBuilder = AppDomain.CurrentDomain.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.Run); + var moduleBuilder = assemblyBuilder.DefineDynamicModule("StyletIoCFactoryModule"); + Interlocked.CompareExchange(ref this.factoryBuilder, moduleBuilder, null); + } + + // If the service is 'ISomethingFactory', call out new class 'SomethingFactory' + var typeBuilder = this.factoryBuilder.DefineType(serviceType.Name.Substring(1), TypeAttributes.Public); + typeBuilder.AddInterfaceImplementation(serviceType); + + // Define a field which holds a reference to this ioc container + var containerField = typeBuilder.DefineField("container", typeof(IKernel), FieldAttributes.Private); + + // Add a constructor which takes one argument - the container - and sets the field + // public Name(IKernel container) + // { + // this.container = container; + // } + var ctorBuilder = typeBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard, new Type[] { typeof(IKernel) }); + var ilGenerator = ctorBuilder.GetILGenerator(); + // Load 'this' and the IOC container onto the stack + ilGenerator.Emit(OpCodes.Ldarg_0); + ilGenerator.Emit(OpCodes.Ldarg_1); + // Store the IOC container in this.container + ilGenerator.Emit(OpCodes.Stfld, containerField); + ilGenerator.Emit(OpCodes.Ret); + + // These are needed by all methods, so get them now + // IKernel.Get(Type, string) + var containerGetMethod = typeof(IKernel).GetMethod("Get", new Type[] { typeof(Type), typeof(string) }); + // Type.GetTypeFromHandler(RuntimeTypeHandle) + var typeFromHandleMethod = typeof(Type).GetMethod("GetTypeFromHandle"); + + // Go through each method, emmitting an implementation for each + foreach (var methodInfo in serviceType.GetMethods()) + { + var parameters = methodInfo.GetParameters(); + if (!(parameters.Length == 0 || (parameters.Length == 1 && parameters[0].ParameterType == typeof(string)))) + throw new StyletIoCCreateFactoryException("Can only implement methods with zero arguments, or a single string argument"); + + if (methodInfo.ReturnType == typeof(void)) + throw new StyletIoCCreateFactoryException("Can only implement methods which return something"); + + var methodBuilder = typeBuilder.DefineMethod(methodInfo.Name, MethodAttributes.Public | MethodAttributes.Virtual, methodInfo.ReturnType, parameters.Select(x => x.ParameterType).ToArray()); + var methodIlGenerator = methodBuilder.GetILGenerator(); + // Load 'this' onto stack + // Stack: [this] + methodIlGenerator.Emit(OpCodes.Ldarg_0); + // Load value of 'container' field of 'this' onto stack + // Stack: [this.container] + methodIlGenerator.Emit(OpCodes.Ldfld, containerField); + // New local variable which represents type to load + LocalBuilder lb = methodIlGenerator.DeclareLocal(methodInfo.ReturnType); + // Load this onto the stack. This is a RuntimeTypeHandle + // Stack: [this.container, runtimeTypeHandleOfReturnType] + methodIlGenerator.Emit(OpCodes.Ldtoken, lb.LocalType); + // Invoke Type.GetTypeFromHandle with this + // This is equivalent to calling typeof(T) + // Stack: [this.container, typeof(returnType)] + methodIlGenerator.Emit(OpCodes.Call, typeFromHandleMethod); + // Load the given key (if it's a parameter), or null if it isn't, onto the stack + // Stack: [this.container, typeof(returnType), key] + if (parameters.Length == 0) + methodIlGenerator.Emit(OpCodes.Ldnull); // Load null as the key + else + methodIlGenerator.Emit(OpCodes.Ldarg_1); // Load the given string as the key + // Call container.Get(type, key) + // Stack: [returnedInstance] + methodIlGenerator.Emit(OpCodes.Callvirt, containerGetMethod); + methodIlGenerator.Emit(OpCodes.Ret); + + typeBuilder.DefineMethodOverride(methodBuilder, methodInfo); + } + + Type constructedType; + try + { + constructedType = typeBuilder.CreateType(); + } + catch (TypeLoadException e) + { + throw new StyletIoCCreateFactoryException(String.Format("Unable to create factory type for interface {0}. Ensure that the interface is public, or add [assembly: InternalsVisibleTo(StyletIoC.FactoryAssemblyName)] to your AssemblyInfo.cs", serviceType.Name), e); + } + var actualType = this.factories.GetOrAdd(serviceType, constructedType); + return actualType; + } + #endregion #region BindTo @@ -370,6 +482,12 @@ namespace Stylet this.AddRegistration(creator, implementationType, key); } + public void ToAbstractFactory(string key = null) + { + var factoryType = this.service.GetFactoryForType(this.serviceType); + this.To(factoryType, key); + } + public void ToAllImplementations(string key = null, params Assembly[] assemblies) { if (assemblies == null || assemblies.Length == 0) @@ -550,7 +668,7 @@ namespace Stylet var init = Expression.ListInit(list, container.GetRegistrations(new TypeKey(this.Type.GenericTypeArguments[0], this.Key), false).GetAll().Select(x => x.GetInstanceExpression(container))); this.expression = init; - return init; + return this.expression; } } @@ -660,7 +778,10 @@ namespace Stylet private string KeyForParameter(ParameterInfo parameter) { - var attribute = (InjectAttribute)parameter.GetCustomAttributes(typeof(InjectAttribute)).FirstOrDefault(); + var attributes = parameter.GetCustomAttributes(typeof(InjectAttribute)); + if (attributes == null) + return null; + var attribute = (InjectAttribute)attributes.FirstOrDefault(); return attribute == null ? null : attribute.Key; } @@ -846,6 +967,12 @@ namespace Stylet public StyletIoCFindConstructorException(string message, Exception innerException) : base(message, innerException) { } } + public class StyletIoCCreateFactoryException : StyletIoCException + { + public StyletIoCCreateFactoryException(string message) : base(message) { } + public StyletIoCCreateFactoryException(string message, Exception innerException) : base(message, innerException) { } + } + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Constructor | AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)] public sealed class InjectAttribute : Attribute { diff --git a/StyletUnitTests/StyletIoCFactoryTests.cs b/StyletUnitTests/StyletIoCFactoryTests.cs new file mode 100644 index 0000000..a840928 --- /dev/null +++ b/StyletUnitTests/StyletIoCFactoryTests.cs @@ -0,0 +1,93 @@ +using NUnit.Framework; +using Stylet; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace StyletUnitTests +{ + [TestFixture] + public class StyletIoCFactoryTests + { + public interface I1 { } + public class C1 : I1 { } + + public interface I1Factory + { + I1 GetI1(); + } + + public interface I1Factory2 + { + I1 GetI1(string key = null); + } + + interface IPrivateFactory + { + } + + public interface IFactoryWithBadMethod + { + C1 MethodWithArgs(bool arg); + } + + public interface IFactoryWithVoidMethod + { + void Method(); + } + + [Test] + public void CreatesImplementationWithoutKey() + { + var ioc = new StyletIoC(); + ioc.Bind().To(); + ioc.Bind().ToAbstractFactory(); + + var factory = ioc.Get(); + var result = factory.GetI1(); + Assert.IsInstanceOf(result); + } + + [Test] + public void CreatesImplementationWithKey() + { + var ioc = new StyletIoC(); + ioc.Bind().To("key"); + ioc.Bind().ToAbstractFactory(); + + var factory = ioc.Get(); + var result = factory.GetI1("key"); + Assert.IsInstanceOf(result); + } + + [Test] + public void ThrowsIfServiceTypeIsNotInterface() + { + var ioc = new StyletIoC(); + Assert.Throws(() => ioc.Bind().ToAbstractFactory()); + } + + [Test] + public void ThrowsIfInterfaceNotPublic() + { + var ioc = new StyletIoC(); + Assert.Throws(() => ioc.Bind().ToAbstractFactory()); + } + + [Test] + public void ThrowsIfMethodHasArgumentOtherThanString() + { + var ioc = new StyletIoC(); + Assert.Throws(() => ioc.Bind().ToAbstractFactory()); + } + + [Test] + public void ThrowsIfMethodReturningVoid() + { + var ioc = new StyletIoC(); + Assert.Throws(() => ioc.Bind().ToAbstractFactory()); + } + } +} diff --git a/StyletUnitTests/StyletUnitTests.csproj b/StyletUnitTests/StyletUnitTests.csproj index 334b2f5..18bb32e 100644 --- a/StyletUnitTests/StyletUnitTests.csproj +++ b/StyletUnitTests/StyletUnitTests.csproj @@ -49,6 +49,7 @@ +