diff --git a/Source/Machine.Specifications.Tests/ExampleSpecifications.cs b/Source/Machine.Specifications.Tests/ExampleSpecifications.cs index 3a93965ce..1a3ad8101 100644 --- a/Source/Machine.Specifications.Tests/ExampleSpecifications.cs +++ b/Source/Machine.Specifications.Tests/ExampleSpecifications.cs @@ -242,4 +242,32 @@ public void Reset() { } } + + public class OuterNonGenericContext + { + public class Nested + { } + } + + public class OuterGenericContext + { + public class NestedBase : OuterGenericContext + {} + + public class Nested : NestedBase + {} + } + + public class ContextInheritingFromNestedGeneric : OuterGenericContext.Nested, IFakeContext + { + public static bool ItInvoked; + + It should_be_invoked = () => + ItInvoked = true; + + public void Reset() + { + ItInvoked = false; + } + } } diff --git a/Source/Machine.Specifications.Tests/Model/ContextTests.cs b/Source/Machine.Specifications.Tests/Model/ContextTests.cs index d43ec9356..1f53a5d4e 100644 --- a/Source/Machine.Specifications.Tests/Model/ContextTests.cs +++ b/Source/Machine.Specifications.Tests/Model/ContextTests.cs @@ -129,4 +129,22 @@ public void ShouldCleanup() ContextWithSingleSpecification.CleanupInvoked.Should().BeTrue(); } } + + [TestFixture] + public class InheritingFromNestedGenericTests : With + { + IEnumerable results; + + public override void BeforeEachTest() + { + base.BeforeEachTest(); + results = Run(context); + } + + [Test] + public void ShouldInvokeIt() + { + ContextInheritingFromNestedGeneric.ItInvoked.Should().BeTrue(); + } + } } diff --git a/Source/Machine.Specifications/Factories/ContextFactory.cs b/Source/Machine.Specifications/Factories/ContextFactory.cs index 57cfa08d6..3f86493a9 100644 --- a/Source/Machine.Specifications/Factories/ContextFactory.cs +++ b/Source/Machine.Specifications/Factories/ContextFactory.cs @@ -193,7 +193,32 @@ static void CollectDetailsOf(Type target, Func instanceResolver, ICol CollectDetailsOf(target.BaseType, () => instance, items, ensureMaximumOfOne, attributeFullName); } - CollectDetailsOf(target.DeclaringType, () => Activator.CreateInstance(target.DeclaringType), items, ensureMaximumOfOne, attributeFullName); + CollectDetailsOf(target.DeclaringType, () => CreateDeclaringTypeInstance(target), + items, ensureMaximumOfOne, attributeFullName); + } + + static object CreateDeclaringTypeInstance(Type target) + { + var instantiatingType = target.DeclaringType; + object declaringTypeInstance = null; + if (instantiatingType != typeof(object)) + { + if (instantiatingType.IsGenericType) + { + var genericTypeDefinition = target.GetGenericTypeDefinition(); + if (genericTypeDefinition == null) + { + throw new InvalidOperationException(string.Format("Unable to get generic type definition for {0}", target)); + } + var genericArgs = target.GetGenericArguments().ToArray(); + if (genericArgs.Length != 0) + { + instantiatingType = genericTypeDefinition.MakeGenericType(genericArgs); + } + } + declaringTypeInstance = Activator.CreateInstance(instantiatingType); + } + return declaringTypeInstance; } static bool IsStatic(Type target)