Skip to content

Commit

Permalink
Generate code for inputs with errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ltrzesniewski committed Jan 10, 2024
1 parent 17ecd30 commit 66f2ad2
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 50 deletions.
57 changes: 23 additions & 34 deletions src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,35 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

private static SourceGenerationResult GenerateCode(SourceGenerationInput input, CancellationToken cancellationToken)
{
var (file, fileNamespace, relativePath) = input;
var result = new SourceGenerationResult(input);

var fileContents = file.GetText(cancellationToken)?.ToString();
var fileContents = input.AdditionalText.GetText(cancellationToken)?.ToString();

if (fileContents is null)
return SourceGenerationResult.Error(Diagnostic.Create(MessageDslDiagnostics.CouldNotReadFileContents, Location.Create(file.Path, default, default)));
{
result.AddDiagnostic(Diagnostic.Create(MessageDslDiagnostics.CouldNotReadFileContents, Location.Create(input.AdditionalText.Path, default, default)));
return result;
}

try
{
var contracts = ParsedContracts.Parse(fileContents, fileNamespace!.Trim());
var contracts = ParsedContracts.Parse(fileContents, input.FileNamespace?.Trim());

if (!contracts.IsValid)
foreach (var error in contracts.Errors)
{
var diagnostics = new List<Diagnostic>();
foreach (var error in contracts.Errors)
{
var location = Location.Create(file.Path, default, error.ToLinePositionSpan());
diagnostics.Add(Diagnostic.Create(MessageDslDiagnostics.MessageDslError, location, error.Message));
}

return SourceGenerationResult.Error(diagnostics.ToArray());
var location = Location.Create(input.AdditionalText.Path, default, error.ToLinePositionSpan());
result.AddDiagnostic(Diagnostic.Create(MessageDslDiagnostics.MessageDslError, location, error.Message));
}

var output = CSharpGenerator.Generate(contracts);
result.GeneratedSource = output;

return SourceGenerationResult.Success(file, output, relativePath);
return result;
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
return SourceGenerationResult.Error(Diagnostic.Create(MessageDslDiagnostics.UnexpectedError, Location.None, ex.ToString()));
result.AddDiagnostic(Diagnostic.Create(MessageDslDiagnostics.UnexpectedError, Location.None, ex.ToString()));
return result;
}
}

Expand All @@ -76,36 +75,26 @@ private static void GenerateFiles(SourceProductionContext context, ImmutableArra
foreach (var diagnostic in result.Diagnostics)
context.ReportDiagnostic(diagnostic);

if (result.AdditionalText == null || result.GeneratedSource == null)
if (result.GeneratedSource == null)
continue;

var hintName = _sanitizePathRegex.Replace(result.RelativePath ?? result.AdditionalText.Path, "_") + ".g.cs";

context.AddSource(hintName, result.GeneratedSource);
}
}

private record SourceGenerationInput(AdditionalText AdditionalText, string? FileNamespace, string? RelativePath);

public class SourceGenerationResult
private class SourceGenerationResult(SourceGenerationInput input)
{
public IList<Diagnostic> Diagnostics { get; }
public string? GeneratedSource { get; set; }
public AdditionalText? AdditionalText { get; }
public string? RelativePath { get; }
private readonly List<Diagnostic> _diagnostics = new();

private SourceGenerationResult(IList<Diagnostic> diagnostics, string? generatedSource, AdditionalText? additionalText, string? relativePath)
{
Diagnostics = diagnostics;
GeneratedSource = generatedSource;
AdditionalText = additionalText;
RelativePath = relativePath;
}

public static SourceGenerationResult Error(params Diagnostic[] diagnostics)
=> new(diagnostics, null, null, null);
public AdditionalText AdditionalText { get; } = input.AdditionalText;
public IReadOnlyList<Diagnostic> Diagnostics => _diagnostics;
public string? GeneratedSource { get; set; }
public string? RelativePath { get; } = input.RelativePath;

public static SourceGenerationResult Success(AdditionalText? additionalText, string generatedSource, string? relativePath)
=> new(Array.Empty<Diagnostic>(), generatedSource, additionalText, relativePath);
public void AddDiagnostic(Diagnostic diagnostic)
=> _diagnostics.Add(diagnostic);
}
}
49 changes: 43 additions & 6 deletions src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -692,15 +692,52 @@ public void should_detect_misplaced_optional_parameters()
ParseValid("Foo(int a, int b = 42, int c = 10);");
}

[Test]
public void should_generate_ast_for_incomplete_code()
{
var contracts = ParseInvalid("Foo(int bar");
var message = contracts.Messages.ExpectedSingle();
var param = message.Parameters.ExpectedSingle();
param.Type.NetType.ShouldEqual("int");
param.Name.ShouldEqual("bar");
}

[Test]
[TestCase("Foo")]
[TestCase("Foo;")]
[TestCase("Foo(")]
[TestCase("Foo(;")]
[TestCase("Foo(int")]
[TestCase("Foo(int;")]
[TestCase("Foo(int first")]
[TestCase("Foo(int first;")]
[TestCase("Foo(int first,")]
[TestCase("Foo(int first,;")]
[TestCase("Foo(int first, string")]
[TestCase("Foo(int first, string;")]
[TestCase("Foo(int first, string second")]
[TestCase("Foo(int first, string second;")]
public void should_split_on_incomplete_message(string firstLine)
{
var contracts = ParseInvalid(
$"""
{firstLine}
Bar(int something, int somethingElse)
Baz(int something, int somethingElse)
"""
);

contracts.Messages.Count.ShouldBeGreaterThan(1);
contracts.Messages[0].Name.ShouldEqual("Foo");
contracts.Messages[1].Name.ShouldEqualOneOf("Bar", "Baz");
}

private static ParsedContracts ParseValid(string definitionText)
{
var contracts = Parse(definitionText);

if (contracts.Errors.Any())
{
SyntaxDebugHelper.DumpParseTree(contracts);
Assert.Fail("There are unexpected errors");
}

return contracts;
}
Expand All @@ -710,10 +747,7 @@ private static ParsedContracts ParseInvalid(string definitionText)
var contracts = Parse(definitionText);

if (!contracts.Errors.Any())
{
SyntaxDebugHelper.DumpParseTree(contracts);
Assert.Fail("Errors were expected");
}

return contracts;
}
Expand All @@ -726,6 +760,9 @@ private static ParsedContracts Parse(string definitionText)
foreach (var error in contracts.Errors)
Console.WriteLine("ERROR: {0}", error);

foreach (var message in contracts.Messages)
Console.WriteLine($"MESSAGE: {message}");

return contracts;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ public static void ShouldBeEmpty(this object actual)
public static void ShouldEqual<T>(this T? actual, T? expected)
=> Assert.That(actual, Is.EqualTo(expected));

public static void ShouldEqualOneOf<T>(this T? actual, params T[] expected)
=> Assert.That(actual, Is.AnyOf(expected));

public static void ShouldBeGreaterThan(this int actual, int value)
=> Assert.That(actual, Is.GreaterThan(value));

Expand Down
4 changes: 4 additions & 0 deletions src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Abc.Zebus.MessageDsl.Dsl;
using Abc.Zebus.MessageDsl.Support;
using Antlr4.Runtime;
using Antlr4.Runtime.Tree;
using static Abc.Zebus.MessageDsl.Dsl.MessageContractsParser;

namespace Abc.Zebus.MessageDsl.Analysis;
Expand Down Expand Up @@ -436,4 +437,7 @@ private void ProcessAttributes(AttributeSet? attributeSet, AttributesContext? co

private string GetId(IdContext? context)
=> context?.GetValidatedId(_contracts) ?? string.Empty;

private new AstNode? Visit(IParseTree? tree)
=> tree is not null ? base.Visit(tree) : null;
}
17 changes: 9 additions & 8 deletions src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class ParsedContracts
public IList<EnumDefinition> Enums { get; } = new List<EnumDefinition>();
public ContractOptions Options { get; } = new();
public ICollection<SyntaxError> Errors { get; }
public string Namespace { get; set; } = string.Empty;
public string? Namespace { get; set; }
public bool ExplicitNamespace { get; internal set; }
public ICollection<string> ImportedNamespaces { get; } = new HashSet<string>();

Expand Down Expand Up @@ -49,7 +49,11 @@ public static ParsedContracts CreateParseTree(string definitionText)

var tokenStream = new CommonTokenStream(lexer);

var parser = new MessageContractsParser(tokenStream);
var parser = new MessageContractsParser(tokenStream)
{
ErrorHandler = new MessageContractsErrorStrategy()
};

parser.RemoveErrorListeners();
parser.AddErrorListener(errorListener);

Expand All @@ -58,18 +62,15 @@ public static ParsedContracts CreateParseTree(string definitionText)
return new ParsedContracts(tokenStream, parseTree, errorListener.Errors);
}

public static ParsedContracts Parse(string definitionText, string defaultNamespace)
public static ParsedContracts Parse(string definitionText, string? defaultNamespace)
{
var result = CreateParseTree(definitionText);

if (!result.ExplicitNamespace)
result.Namespace = defaultNamespace;

if (result.Errors.Count == 0)
{
new AstCreationVisitor(result).VisitCompileUnit(result.ParseTree);
result.Process();
}
new AstCreationVisitor(result).VisitCompileUnit(result.ParseTree);
result.Process();

return result;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private bool IsAtStartOfPragma()
return true;
}

private bool IsAtImplicitSeparator()
internal bool IsAtImplicitSeparator()
{
var prevToken = _input.Lt(-1);
var nextToken = _input.Lt(1);
Expand Down
69 changes: 69 additions & 0 deletions src/Abc.Zebus.MessageDsl/Dsl/MessageContractsErrorStrategy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using Antlr4.Runtime;

namespace Abc.Zebus.MessageDsl.Dsl;

internal class MessageContractsErrorStrategy : DefaultErrorStrategy
{
public override void Sync(Parser recognizer)
{
if (InErrorRecoveryMode(recognizer))
return;

base.Sync(recognizer);

if (InErrorRecoveryMode(recognizer) && IsInDefinitionRule(recognizer))
throw new GoToNextDefinitionException(recognizer);
}

public override void ReportError(Parser recognizer, RecognitionException e)
{
if (e is GoToNextDefinitionException)
return;

base.ReportError(recognizer, e);

if (IsInDefinitionRule(recognizer))
throw new GoToNextDefinitionException(recognizer);
}

public override void Recover(Parser recognizer, RecognitionException e)
{
if (e is GoToNextDefinitionException)
{
if (recognizer.Context.RuleIndex != MessageContractsParser.RULE_definition)
throw e;

while (true)
{
// Consume tokens until we reach a possible separator between definitions
if (recognizer.InputStream.La(1) is MessageContractsParser.SEP or MessageContractsParser.Eof
|| ((MessageContractsParser)recognizer).IsAtImplicitSeparator())
{
EndErrorCondition(recognizer);
return;
}

recognizer.Consume();
}
}

base.Recover(recognizer, e);
}

private static bool IsInDefinitionRule(Parser recognizer)
{
RuleContext ctx = recognizer.Context;

while (ctx != null)
{
if (ctx.RuleIndex == MessageContractsParser.RULE_definition)
return true;

ctx = ctx.Parent;
}

return false;
}

private class GoToNextDefinitionException(Parser recognizer) : RecognitionException(recognizer, recognizer.InputStream, recognizer.Context);
}
2 changes: 1 addition & 1 deletion src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private string Generate()

var hasNamespace = !string.IsNullOrEmpty(Contracts.Namespace);
if (hasNamespace)
Writer.WriteLine("namespace {0}", Identifier(Contracts.Namespace));
Writer.WriteLine("namespace {0}", Identifier(Contracts.Namespace!));

using (hasNamespace ? Block() : null)
{
Expand Down

0 comments on commit 66f2ad2

Please sign in to comment.