Skip to content

Commit

Permalink
Merge pull request #35 from roblox-csharp/macros
Browse files Browse the repository at this point in the history
Add more List macros and fix Lambda expressions not returning
  • Loading branch information
R-unic authored Jan 12, 2025
2 parents 7aeabaa + 537fccc commit 83292e1
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 5 deletions.
177 changes: 174 additions & 3 deletions RobloxCS.Luau/Macros.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using RobloxCS.Luau;
using RobloxCS.Shared;
using System.Xml.Linq;

namespace RobloxCS.Macros;

Expand Down Expand Up @@ -129,7 +130,7 @@ public class Macro(SemanticModel semanticModel)

/// <summary>Takes a C# member access and expands the macro into a Luau expression</summary>
/// <returns>The expanded expression of the macro, or null if no macro was applied</returns>
public Expression? MemberAccess(Func<SyntaxNode, Node?> visit, MemberAccessExpressionSyntax memberAccess)
public Node? MemberAccess(Func<SyntaxNode, Node?> visit, MemberAccessExpressionSyntax memberAccess)
{
var expressionType = _semanticModel.GetTypeInfo(memberAccess.Expression).Type;
{
Expand Down Expand Up @@ -265,12 +266,12 @@ private bool ObjectMethod(Func<SyntaxNode, Node?> visit, MemberAccessExpressionS
return expanded != null;
}

// TODO: Replace (function() end)() calls
/// <summary>Macros <see cref="List"/> methods</summary>
private static bool ListMethod(Func<SyntaxNode, Node?> visit, MemberAccessExpressionSyntax memberAccess,
InvocationExpressionSyntax invocation, out Expression? expanded)
InvocationExpressionSyntax invocation, out Node? expanded)
{
expanded = null;
Console.WriteLine(memberAccess.Expression is IdentifierNameSyntax);
switch (memberAccess.Name.Identifier.Text) {
case "Add": {
var arguments = (ArgumentList)visit(invocation.ArgumentList)!;
Expand All @@ -294,6 +295,176 @@ private static bool ListMethod(Func<SyntaxNode, Node?> visit, MemberAccessExpres
expanded = new Call(new QualifiedName(new IdentifierName("table"), new IdentifierName("clear")), new ArgumentList([new Argument(self)]));
break;
}
case "Exists": {
var self = (Expression)visit(memberAccess.Expression)!;
var FilterFunc = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_");
var value = new IdentifierName("_v");

expanded = new Call(new Parenthesized(new AnonymousFunction(new([]), body: new Block([
new Variable(new IdentifierName("_FilterFunc"), true, FilterFunc.Arguments.First()),
new For([key, value], self, new Block([
new If(new Call(new IdentifierName("_FilterFunc"), new ArgumentList([new Argument(value)])), new Block([
new Return(AstUtility.True())
])),
])),
new Return(AstUtility.False())
]))), new([]));
break;
}
case "Find": {
var self = (Expression)visit(memberAccess.Expression)!;
var filterFunc = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_");
var value = new IdentifierName("_v");

expanded = new Call(new Parenthesized(new AnonymousFunction(new([]), body: new Block([
new Variable(new IdentifierName("_FilterFunc"), true, filterFunc.Arguments.First()),
new For([key, value], self, new Block([
new If(new Call(new IdentifierName("_FilterFunc"), new ArgumentList([new Argument(value)])), new Block([
new Return(value)
])),
])),
new Return(AstUtility.Nil())
]))), new([]));
break;
}
case "FindLast": {
var self = (Expression)visit(memberAccess.Expression)!;
var filterFunc = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_");
var value = new IdentifierName("_v");

expanded = new Call(new Parenthesized(new AnonymousFunction(new([]), body: new Block([
new Variable(new IdentifierName("_FilterFunc"), true, filterFunc.Arguments.First()),
new Variable(new IdentifierName("_Return"), true),
new For([key, value], self, new Block([
new If(new Call(new IdentifierName("_FilterFunc"), new ArgumentList([new Argument(value)])), new Block([
new ExpressionStatement(new Assignment(new IdentifierName("_Return"), value))
]))
])),
new Return(new IdentifierName("_Return"))
]))), new([]));
break;
}
case "FindAll": {
var self = (Expression)visit(memberAccess.Expression)!;
var filterFunc = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_");
var value = new IdentifierName("_v");

expanded = new Call(new Parenthesized(new AnonymousFunction(new([]), body: new Block([
new Variable(new IdentifierName("_FilterFunc"), true, filterFunc.Arguments.First()),
new Variable(new IdentifierName("_Filtered"), true, new TableInitializer()),
new For([key, value], self, new Block([
new If(new Call(new IdentifierName("_FilterFunc"), new ArgumentList([new Argument(value)])), new Block([
new ExpressionStatement(new Call(new MemberAccess(new IdentifierName("table"), new IdentifierName("insert")), new ArgumentList([new Argument(new IdentifierName("_Filtered")), new Argument(value)])))
])),
])),
new Return(new IdentifierName("_Filtered"))
]))), new([]));
break;
}
case "AddRange": {
var self = (Expression)visit(memberAccess.Expression)!;
var table = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_");
var value = new IdentifierName("_v");

expanded = new Block([
new For([key, value], table.Arguments.First(), new Block([
new ExpressionStatement(new Call(new MemberAccess(new IdentifierName("table"), new IdentifierName("insert")), new ArgumentList([new Argument(self), new Argument(value)])))
])),
]);
break;
}
case "ConvertAll": {
var self = (Expression)visit(memberAccess.Expression)!;
var convertFunc = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_");
var value = new IdentifierName("_v");

expanded = new Call(new Parenthesized(new AnonymousFunction(new([]), body: new Block([
new Variable(new IdentifierName("_ConvertFunc"), true, convertFunc.Arguments.First()),
new Variable(new IdentifierName("_Converted"), true, new Call(new MemberAccess(new IdentifierName("table"), new IdentifierName("insert")), new ArgumentList([new (new UnaryOperator("#", self))]))),
new For([key, value], self, new Block([
new ExpressionStatement(new Call(new MemberAccess(new IdentifierName("table"), new IdentifierName("insert")), new ArgumentList([new Argument(new IdentifierName("_Converted")), new Argument(new Call(new IdentifierName("_ConvertFunc"), new([new(value)])))])))
])),
new Return(new IdentifierName("_Converted"))
]))), new([]));
break;
}
case "FindIndex": {
var self = (Expression)visit(memberAccess.Expression)!;
var filterFunc = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_k");
var value = new IdentifierName("_v");

// TODO: Add the other 2 overloads
if (invocation.ArgumentList.Arguments.Count == 1)
expanded = new Call(new Parenthesized(new AnonymousFunction(new([]), body: new Block([
new Variable(new IdentifierName("_FilterFunc"), true, filterFunc.Arguments.First()),
new For([key, value], self, new Block([
new If(new Call(new IdentifierName("_FilterFunc"), new ArgumentList([new Argument(value)])), new Block([
new Return(key)
]))
])),
new Return(AstUtility.Nil())
]))), new([]));
break;
}
case "IndexOf": {
var self = (Expression)visit(memberAccess.Expression)!;
var arguments = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_k");
var value = new IdentifierName("_v");
var shouldCreateVariable = invocation.ArgumentList.Arguments.First().Expression is not LiteralExpressionSyntax;

List <Statement> block = [
new For([key, value], self, new Block([
new If(new BinaryOperator(shouldCreateVariable ? new IdentifierName("_val") : arguments.Arguments.First().Expression, "==", value), new Block([
new Return(key)
])),
])),
new Return(AstUtility.Nil())
];

if (shouldCreateVariable)
block.Insert(0, new Variable(new("_val"), true, arguments.Arguments.First()));

expanded = new Call(new Parenthesized(new AnonymousFunction(new([]), body: new Block(block))), new([]));
break;
}
case "Insert": {
var self = (Expression)visit(memberAccess.Expression)!;
var arguments = (ArgumentList)visit(invocation.ArgumentList)!;

expanded = new Call(new MemberAccess(new IdentifierName("table"), new IdentifierName("insert")), new ArgumentList([new(self), arguments.Arguments.First(), arguments.Arguments.Last()]));
break;
}
case "Remove": {
var self = (Expression)visit(memberAccess.Expression)!;
var arguments = (ArgumentList)visit(invocation.ArgumentList)!;
var key = new IdentifierName("_k");
var value = new IdentifierName("_v");
var shouldCreateVariable = invocation.ArgumentList.Arguments.First().Expression is not LiteralExpressionSyntax;

List<Statement> block = [
new For([key, value], self, new Block([
new If(new BinaryOperator(shouldCreateVariable ? new IdentifierName("_val") : arguments.Arguments.First().Expression, "==", value), new Block([
new ExpressionStatement(new Call(new MemberAccess(new IdentifierName("table"), new IdentifierName("remove")), new ArgumentList([new(self), new(key)]))),
new Break()
])),
])),
new Return(AstUtility.Nil())
];

if (shouldCreateVariable)
block.Insert(0, new Variable(new("_val"), true, arguments.Arguments.First()));

expanded = new Block(block);
break;
}
}

expanded?.MarkExpanded(MacroKind.ListMethod);
Expand Down
6 changes: 4 additions & 2 deletions RobloxCS/LuauGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ public override Luau.AnonymousFunction VisitParenthesizedLambdaExpression(Parent
var returnType = new Luau.TypeRef(returnTypeName?.ToString() ?? "nil");
var parameterList = Visit<Luau.ParameterList?>(node.ParameterList) ?? new Luau.ParameterList([]);
var body = node.ExpressionBody != null ?
new Luau.Block([new Luau.ExpressionStatement(Visit<Luau.Expression>(node.ExpressionBody))])
new Luau.Block([new Luau.Return(Visit<Luau.Expression>(node.ExpressionBody))])
: Visit<Luau.Block?>(node.Block);

return new Luau.AnonymousFunction(parameterList, returnType, body);
Expand Down Expand Up @@ -993,7 +993,9 @@ public override Luau.Function VisitLocalFunctionStatement(LocalFunctionStatement
public override Luau.Parameter VisitParameter(ParameterSyntax node)
{
var name = Luau.AstUtility.CreateSimpleName<Luau.IdentifierName>(node, registerIdentifier: true);
var returnType = Luau.AstUtility.CreateTypeRef(Visit<Luau.Name>(node.Type).ToString());
Luau.TypeRef returnType = null;

Check warning on line 996 in RobloxCS/LuauGenerator.cs

View workflow job for this annotation

GitHub Actions / publish

Converting null literal or possible null value to non-nullable type.

Check warning on line 996 in RobloxCS/LuauGenerator.cs

View workflow job for this annotation

GitHub Actions / publish

Converting null literal or possible null value to non-nullable type.

Check warning on line 996 in RobloxCS/LuauGenerator.cs

View workflow job for this annotation

GitHub Actions / test

Converting null literal or possible null value to non-nullable type.
if (node.Type != null)
Luau.AstUtility.CreateTypeRef(Visit<Luau.Name>(node.Type).ToString());
var initializer = Visit<Luau.Expression?>(node.Default);
var isParams = HasSyntax(node.Modifiers, SyntaxKind.ParamsKeyword);

Expand Down

0 comments on commit 83292e1

Please sign in to comment.