diff --git a/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs b/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs index 7a81fb9..ff8f48c 100644 --- a/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs +++ b/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs @@ -36,36 +36,35 @@ public void Initialize(IncrementalGeneratorInitializationContext context) private static SourceGenerationResult GenerateCode(SourceGenerationInput input, CancellationToken cancellationToken) { - var (file, fileNamespace, relativePath) = input; + var result = new SourceGenerationResult(input); - var fileContents = file.GetText(cancellationToken)?.ToString(); + var fileContents = input.AdditionalText.GetText(cancellationToken)?.ToString(); if (fileContents is null) - return SourceGenerationResult.Error(Diagnostic.Create(MessageDslDiagnostics.CouldNotReadFileContents, Location.Create(file.Path, default, default))); + { + result.AddDiagnostic(Diagnostic.Create(MessageDslDiagnostics.CouldNotReadFileContents, Location.Create(input.AdditionalText.Path, default, default))); + return result; + } try { - var contracts = ParsedContracts.Parse(fileContents, fileNamespace!.Trim()); + var contracts = ParsedContracts.Parse(fileContents, input.FileNamespace?.Trim()); - if (!contracts.IsValid) + foreach (var error in contracts.Errors) { - var diagnostics = new List(); - foreach (var error in contracts.Errors) - { - var location = Location.Create(file.Path, default, error.ToLinePositionSpan()); - diagnostics.Add(Diagnostic.Create(MessageDslDiagnostics.MessageDslError, location, error.Message)); - } - - return SourceGenerationResult.Error(diagnostics.ToArray()); + var location = Location.Create(input.AdditionalText.Path, default, error.ToLinePositionSpan()); + result.AddDiagnostic(Diagnostic.Create(MessageDslDiagnostics.MessageDslError, location, error.Message)); } var output = CSharpGenerator.Generate(contracts); + result.GeneratedSource = output; - return SourceGenerationResult.Success(file, output, relativePath); + return result; } catch (Exception ex) when (ex is not OperationCanceledException) { - return SourceGenerationResult.Error(Diagnostic.Create(MessageDslDiagnostics.UnexpectedError, Location.None, ex.ToString())); + result.AddDiagnostic(Diagnostic.Create(MessageDslDiagnostics.UnexpectedError, Location.None, ex.ToString())); + return result; } } @@ -76,36 +75,26 @@ private static void GenerateFiles(SourceProductionContext context, ImmutableArra foreach (var diagnostic in result.Diagnostics) context.ReportDiagnostic(diagnostic); - if (result.AdditionalText == null || result.GeneratedSource == null) + if (result.GeneratedSource == null) continue; var hintName = _sanitizePathRegex.Replace(result.RelativePath ?? result.AdditionalText.Path, "_") + ".g.cs"; - context.AddSource(hintName, result.GeneratedSource); } } private record SourceGenerationInput(AdditionalText AdditionalText, string? FileNamespace, string? RelativePath); - public class SourceGenerationResult + private class SourceGenerationResult(SourceGenerationInput input) { - public IList Diagnostics { get; } - public string? GeneratedSource { get; set; } - public AdditionalText? AdditionalText { get; } - public string? RelativePath { get; } + private readonly List _diagnostics = new(); - private SourceGenerationResult(IList diagnostics, string? generatedSource, AdditionalText? additionalText, string? relativePath) - { - Diagnostics = diagnostics; - GeneratedSource = generatedSource; - AdditionalText = additionalText; - RelativePath = relativePath; - } - - public static SourceGenerationResult Error(params Diagnostic[] diagnostics) - => new(diagnostics, null, null, null); + public AdditionalText AdditionalText { get; } = input.AdditionalText; + public IReadOnlyList Diagnostics => _diagnostics; + public string? GeneratedSource { get; set; } + public string? RelativePath { get; } = input.RelativePath; - public static SourceGenerationResult Success(AdditionalText? additionalText, string generatedSource, string? relativePath) - => new(Array.Empty(), generatedSource, additionalText, relativePath); + public void AddDiagnostic(Diagnostic diagnostic) + => _diagnostics.Add(diagnostic); } } diff --git a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs index e59a8e9..c8c51d5 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs @@ -692,15 +692,52 @@ public void should_detect_misplaced_optional_parameters() ParseValid("Foo(int a, int b = 42, int c = 10);"); } + [Test] + public void should_generate_ast_for_incomplete_code() + { + var contracts = ParseInvalid("Foo(int bar"); + var message = contracts.Messages.ExpectedSingle(); + var param = message.Parameters.ExpectedSingle(); + param.Type.NetType.ShouldEqual("int"); + param.Name.ShouldEqual("bar"); + } + + [Test] + [TestCase("Foo")] + [TestCase("Foo;")] + [TestCase("Foo(")] + [TestCase("Foo(;")] + [TestCase("Foo(int")] + [TestCase("Foo(int;")] + [TestCase("Foo(int first")] + [TestCase("Foo(int first;")] + [TestCase("Foo(int first,")] + [TestCase("Foo(int first,;")] + [TestCase("Foo(int first, string")] + [TestCase("Foo(int first, string;")] + [TestCase("Foo(int first, string second")] + [TestCase("Foo(int first, string second;")] + public void should_split_on_incomplete_message(string firstLine) + { + var contracts = ParseInvalid( + $""" + {firstLine} + Bar(int something, int somethingElse) + Baz(int something, int somethingElse) + """ + ); + + contracts.Messages.Count.ShouldBeGreaterThan(1); + contracts.Messages[0].Name.ShouldEqual("Foo"); + contracts.Messages[1].Name.ShouldEqualOneOf("Bar", "Baz"); + } + private static ParsedContracts ParseValid(string definitionText) { var contracts = Parse(definitionText); if (contracts.Errors.Any()) - { - SyntaxDebugHelper.DumpParseTree(contracts); Assert.Fail("There are unexpected errors"); - } return contracts; } @@ -710,10 +747,7 @@ private static ParsedContracts ParseInvalid(string definitionText) var contracts = Parse(definitionText); if (!contracts.Errors.Any()) - { - SyntaxDebugHelper.DumpParseTree(contracts); Assert.Fail("Errors were expected"); - } return contracts; } @@ -726,6 +760,9 @@ private static ParsedContracts Parse(string definitionText) foreach (var error in contracts.Errors) Console.WriteLine("ERROR: {0}", error); + foreach (var message in contracts.Messages) + Console.WriteLine($"MESSAGE: {message}"); + return contracts; } diff --git a/src/Abc.Zebus.MessageDsl.Tests/TestTools/AssertionExtensions.cs b/src/Abc.Zebus.MessageDsl.Tests/TestTools/AssertionExtensions.cs index eab5591..be1aa1f 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/TestTools/AssertionExtensions.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/TestTools/AssertionExtensions.cs @@ -47,6 +47,9 @@ public static void ShouldBeEmpty(this object actual) public static void ShouldEqual(this T? actual, T? expected) => Assert.That(actual, Is.EqualTo(expected)); + public static void ShouldEqualOneOf(this T? actual, params T[] expected) + => Assert.That(actual, Is.AnyOf(expected)); + public static void ShouldBeGreaterThan(this int actual, int value) => Assert.That(actual, Is.GreaterThan(value)); diff --git a/src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs b/src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs index 2e14e84..23c44fb 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs @@ -6,6 +6,7 @@ using Abc.Zebus.MessageDsl.Dsl; using Abc.Zebus.MessageDsl.Support; using Antlr4.Runtime; +using Antlr4.Runtime.Tree; using static Abc.Zebus.MessageDsl.Dsl.MessageContractsParser; namespace Abc.Zebus.MessageDsl.Analysis; @@ -436,4 +437,7 @@ private void ProcessAttributes(AttributeSet? attributeSet, AttributesContext? co private string GetId(IdContext? context) => context?.GetValidatedId(_contracts) ?? string.Empty; + + private new AstNode? Visit(IParseTree? tree) + => tree is not null ? base.Visit(tree) : null; } diff --git a/src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs b/src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs index ac048eb..e009136 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs @@ -12,7 +12,7 @@ public class ParsedContracts public IList Enums { get; } = new List(); public ContractOptions Options { get; } = new(); public ICollection Errors { get; } - public string Namespace { get; set; } = string.Empty; + public string? Namespace { get; set; } public bool ExplicitNamespace { get; internal set; } public ICollection ImportedNamespaces { get; } = new HashSet(); @@ -49,7 +49,11 @@ public static ParsedContracts CreateParseTree(string definitionText) var tokenStream = new CommonTokenStream(lexer); - var parser = new MessageContractsParser(tokenStream); + var parser = new MessageContractsParser(tokenStream) + { + ErrorHandler = new MessageContractsErrorStrategy() + }; + parser.RemoveErrorListeners(); parser.AddErrorListener(errorListener); @@ -58,18 +62,15 @@ public static ParsedContracts CreateParseTree(string definitionText) return new ParsedContracts(tokenStream, parseTree, errorListener.Errors); } - public static ParsedContracts Parse(string definitionText, string defaultNamespace) + public static ParsedContracts Parse(string definitionText, string? defaultNamespace) { var result = CreateParseTree(definitionText); if (!result.ExplicitNamespace) result.Namespace = defaultNamespace; - if (result.Errors.Count == 0) - { - new AstCreationVisitor(result).VisitCompileUnit(result.ParseTree); - result.Process(); - } + new AstCreationVisitor(result).VisitCompileUnit(result.ParseTree); + result.Process(); return result; } diff --git a/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs b/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs index 927d8c3..5bd1102 100644 --- a/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs +++ b/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs @@ -27,7 +27,7 @@ private bool IsAtStartOfPragma() return true; } - private bool IsAtImplicitSeparator() + internal bool IsAtImplicitSeparator() { var prevToken = _input.Lt(-1); var nextToken = _input.Lt(1); diff --git a/src/Abc.Zebus.MessageDsl/Dsl/MessageContractsErrorStrategy.cs b/src/Abc.Zebus.MessageDsl/Dsl/MessageContractsErrorStrategy.cs new file mode 100644 index 0000000..1cd6331 --- /dev/null +++ b/src/Abc.Zebus.MessageDsl/Dsl/MessageContractsErrorStrategy.cs @@ -0,0 +1,69 @@ +using Antlr4.Runtime; + +namespace Abc.Zebus.MessageDsl.Dsl; + +internal class MessageContractsErrorStrategy : DefaultErrorStrategy +{ + public override void Sync(Parser recognizer) + { + if (InErrorRecoveryMode(recognizer)) + return; + + base.Sync(recognizer); + + if (InErrorRecoveryMode(recognizer) && IsInDefinitionRule(recognizer)) + throw new GoToNextDefinitionException(recognizer); + } + + public override void ReportError(Parser recognizer, RecognitionException e) + { + if (e is GoToNextDefinitionException) + return; + + base.ReportError(recognizer, e); + + if (IsInDefinitionRule(recognizer)) + throw new GoToNextDefinitionException(recognizer); + } + + public override void Recover(Parser recognizer, RecognitionException e) + { + if (e is GoToNextDefinitionException) + { + if (recognizer.Context.RuleIndex != MessageContractsParser.RULE_definition) + throw e; + + while (true) + { + // Consume tokens until we reach a possible separator between definitions + if (recognizer.InputStream.La(1) is MessageContractsParser.SEP or MessageContractsParser.Eof + || ((MessageContractsParser)recognizer).IsAtImplicitSeparator()) + { + EndErrorCondition(recognizer); + return; + } + + recognizer.Consume(); + } + } + + base.Recover(recognizer, e); + } + + private static bool IsInDefinitionRule(Parser recognizer) + { + RuleContext ctx = recognizer.Context; + + while (ctx != null) + { + if (ctx.RuleIndex == MessageContractsParser.RULE_definition) + return true; + + ctx = ctx.Parent; + } + + return false; + } + + private class GoToNextDefinitionException(Parser recognizer) : RecognitionException(recognizer, recognizer.InputStream, recognizer.Context); +} diff --git a/src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs b/src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs index 24bb846..8254806 100644 --- a/src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs +++ b/src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs @@ -41,7 +41,7 @@ private string Generate() var hasNamespace = !string.IsNullOrEmpty(Contracts.Namespace); if (hasNamespace) - Writer.WriteLine("namespace {0}", Identifier(Contracts.Namespace)); + Writer.WriteLine("namespace {0}", Identifier(Contracts.Namespace!)); using (hasNamespace ? Block() : null) {