SDL3-CS/SDL3-CS.SourceGeneration/FriendlyOverloadGenerator.cs

174 lines
7.0 KiB
C#

// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace SDL3.SourceGeneration
{
[Generator]
public class FriendlyOverloadGenerator : ISourceGenerator
{
public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForSyntaxNotifications(() => new UnfriendlyMethodFinder());
}
private const string file_header = @"// <auto-generated/>
#nullable enable
using System;
";
public void Execute(GeneratorExecutionContext context)
{
var finder = (UnfriendlyMethodFinder)context.SyntaxReceiver!;
foreach (var kvp in finder.Methods)
{
string filename = kvp.Key;
var foundMethods = kvp.Value;
var result = new StringBuilder();
result.Append(file_header);
result.Append(
SyntaxFactory.NamespaceDeclaration(
SyntaxFactory.IdentifierName("SDL"))
.WithMembers(
SyntaxFactory.SingletonList<MemberDeclarationSyntax>(
SyntaxFactory.ClassDeclaration("SDL3")
.WithModifiers(
SyntaxFactory.TokenList(
SyntaxFactory.Token(SyntaxKind.UnsafeKeyword),
SyntaxFactory.Token(SyntaxKind.PartialKeyword)))
.WithMembers(SyntaxFactory.List(foundMethods.Select(makeFriendlyMethod)))))
.NormalizeWhitespace());
context.AddSource(filename, result.ToString());
}
}
private static MemberDeclarationSyntax makeFriendlyMethod(GeneratedMethod gm)
{
var returnType = gm.RequiredChanges.HasFlag(Changes.ChangeReturnTypeToString)
? SyntaxFactory.ParseTypeName("string?")
: gm.NativeMethod.ReturnType;
var identifier = gm.RequiredChanges.HasFlag(Changes.TrimUnsafeFromName)
? SyntaxFactory.Identifier(gm.NativeMethod.Identifier.ValueText.Replace(Helper.UnsafePrefix, string.Empty))
: gm.NativeMethod.Identifier;
return SyntaxFactory.MethodDeclaration(returnType, identifier)
.WithModifiers(
SyntaxFactory.TokenList(
SyntaxFactory.Token(SyntaxKind.PublicKeyword),
SyntaxFactory.Token(SyntaxKind.StaticKeyword)))
.WithParameterList(
SyntaxFactory.ParameterList(
SyntaxFactory.SeparatedList(transformParams(gm))))
.WithBody(
SyntaxFactory.Block(
makeMethodBody(gm)));
}
private static IEnumerable<ParameterSyntax> transformParams(GeneratedMethod gm)
{
foreach (var param in gm.NativeMethod.ParameterList.Parameters)
{
if (param.IsTypeConstCharPtr())
{
Debug.Assert(gm.RequiredChanges.HasFlag(Changes.ChangeParamsToReadOnlySpan));
yield return param.WithType(SyntaxFactory.ParseTypeName("ReadOnlySpan<byte>"))
.WithAttributeLists(SyntaxFactory.List<AttributeListSyntax>());
}
else
{
yield return param.WithAttributeLists(SyntaxFactory.List<AttributeListSyntax>()); // remove [NativeTypeName]
}
}
}
private const string pointer_suffix = "Ptr";
private static StatementSyntax makeMethodBody(GeneratedMethod gm)
{
var expr = makeReturn(gm);
foreach (var param in gm.NativeMethod.ParameterList.Parameters.Where(p => p.IsTypeConstCharPtr()).Reverse())
{
Debug.Assert(gm.RequiredChanges.HasFlag(Changes.ChangeParamsToReadOnlySpan));
expr = SyntaxFactory.FixedStatement(
SyntaxFactory.VariableDeclaration(
SyntaxFactory.ParseTypeName("byte*"),
SyntaxFactory.SingletonSeparatedList(
SyntaxFactory.VariableDeclarator(
param.Identifier.ValueText + pointer_suffix)
.WithInitializer(
SyntaxFactory.EqualsValueClause(
SyntaxFactory.IdentifierName(param.Identifier))))),
expr);
}
return expr;
}
private static StatementSyntax makeReturn(GeneratedMethod gm)
{
ExpressionSyntax expr;
if (gm.RequiredChanges.HasFlag(Changes.ChangeReturnTypeToString))
{
expr = SyntaxFactory.InvocationExpression(
SyntaxFactory.IdentifierName("PtrToStringUTF8"))
.WithArguments(new[]
{
SyntaxFactory.Argument(makeFunctionCall(gm)),
SyntaxFactory.Argument(SyntaxFactory.LiteralExpression(
gm.RequiredChanges.HasFlag(Changes.FreeReturnedPointer)
? SyntaxKind.TrueLiteralExpression
: SyntaxKind.FalseLiteralExpression))
}
);
}
else
{
expr = makeFunctionCall(gm);
}
if (gm.NativeMethod.ReturnType.IsVoid())
return SyntaxFactory.ExpressionStatement(expr);
return SyntaxFactory.ReturnStatement(expr);
}
private static InvocationExpressionSyntax makeFunctionCall(GeneratedMethod gm)
{
return SyntaxFactory.InvocationExpression(
SyntaxFactory.IdentifierName(gm.NativeMethod.Identifier))
.WithArguments(makeArguments(gm));
}
private static IEnumerable<ArgumentSyntax> makeArguments(GeneratedMethod gm)
{
foreach (var param in gm.NativeMethod.ParameterList.Parameters)
{
if (param.IsTypeConstCharPtr())
{
Debug.Assert(gm.RequiredChanges.HasFlag(Changes.ChangeParamsToReadOnlySpan));
yield return SyntaxFactory.Argument(SyntaxFactory.IdentifierName(param.Identifier.ValueText + pointer_suffix));
}
else
{
yield return SyntaxFactory.Argument(SyntaxFactory.IdentifierName(param.Identifier));
}
}
}
}
}