// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. // See the LICENCE file in the repository root for full licence text. using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace SDL.SourceGeneration { [Generator] public class FriendlyOverloadGenerator : ISourceGenerator { public void Initialize(GeneratorInitializationContext context) { context.RegisterForSyntaxNotifications(() => new UnfriendlyMethodFinder()); } private const string file_header = @"// #nullable enable using System; "; public void Execute(GeneratorExecutionContext context) { var finder = (UnfriendlyMethodFinder)context.SyntaxReceiver!; foreach (var kvp in finder.Methods) { if (kvp.Value.Count == 0) return; string filename = kvp.Key; var foundMethods = kvp.Value; string className = ClassNameFromMethod(foundMethods.First().NativeMethod); var result = new StringBuilder(); result.Append(file_header); result.Append( SyntaxFactory.NamespaceDeclaration( SyntaxFactory.IdentifierName("SDL")) .WithMembers( SyntaxFactory.SingletonList( SyntaxFactory.ClassDeclaration(className) .WithModifiers( SyntaxFactory.TokenList( SyntaxFactory.Token(SyntaxKind.UnsafeKeyword), SyntaxFactory.Token(SyntaxKind.PartialKeyword))) .WithMembers(SyntaxFactory.List(foundMethods.Select(makeFriendlyMethod))))) .NormalizeWhitespace()); context.AddSource(filename, result.ToString()); } } private static string ClassNameFromMethod(MethodDeclarationSyntax methodNode) { if (methodNode.Parent is ClassDeclarationSyntax classDeclaration) { return classDeclaration.Identifier.Text; } return "SDL3"; // fallback! } private static MemberDeclarationSyntax makeFriendlyMethod(GeneratedMethod gm) { var returnType = gm.RequiredChanges.HasFlag(Changes.ChangeReturnTypeToString) ? SyntaxFactory.ParseTypeName("string?") : gm.NativeMethod.ReturnType; var identifier = gm.RequiredChanges.HasFlag(Changes.TrimUnsafeFromName) ? SyntaxFactory.Identifier(gm.NativeMethod.Identifier.ValueText.Replace(Helper.UnsafePrefix, string.Empty)) : gm.NativeMethod.Identifier; return SyntaxFactory.MethodDeclaration(returnType, identifier) .WithModifiers( SyntaxFactory.TokenList( SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.StaticKeyword))) .WithParameterList( SyntaxFactory.ParameterList( SyntaxFactory.SeparatedList(transformParams(gm)))) .WithBody( SyntaxFactory.Block( makeMethodBody(gm))); } private static IEnumerable transformParams(GeneratedMethod gm) { foreach (var param in gm.NativeMethod.ParameterList.Parameters) { if (param.IsTypeConstCharPtr()) { Debug.Assert(gm.RequiredChanges.HasFlag(Changes.ChangeParamsToUtf8String)); yield return param.WithType(SyntaxFactory.ParseTypeName(Helper.Utf8StringStructName)) .WithAttributeLists(SyntaxFactory.List()); } else { yield return param.WithAttributeLists(SyntaxFactory.List()); // remove [NativeTypeName] } } } private const string pointer_suffix = "Ptr"; private static StatementSyntax makeMethodBody(GeneratedMethod gm) { var expr = makeReturn(gm); foreach (var param in gm.NativeMethod.ParameterList.Parameters.Where(p => p.IsTypeConstCharPtr()).Reverse()) { Debug.Assert(gm.RequiredChanges.HasFlag(Changes.ChangeParamsToUtf8String)); expr = SyntaxFactory.FixedStatement( SyntaxFactory.VariableDeclaration( SyntaxFactory.ParseTypeName("byte*"), SyntaxFactory.SingletonSeparatedList( SyntaxFactory.VariableDeclarator( param.Identifier.ValueText + pointer_suffix) .WithInitializer( SyntaxFactory.EqualsValueClause( SyntaxFactory.IdentifierName(param.Identifier))))), expr); } return expr; } private static StatementSyntax makeReturn(GeneratedMethod gm) { ExpressionSyntax expr; if (gm.RequiredChanges.HasFlag(Changes.ChangeReturnTypeToString)) { expr = SyntaxFactory.InvocationExpression( SyntaxFactory.MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.IdentifierName("SDL3"), SyntaxFactory.IdentifierName("PtrToStringUTF8"))) .WithArguments(new[] { SyntaxFactory.Argument(makeFunctionCall(gm)), SyntaxFactory.Argument(SyntaxFactory.LiteralExpression( gm.RequiredChanges.HasFlag(Changes.FreeReturnedPointer) ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression)) } ); } else { expr = makeFunctionCall(gm); } if (gm.NativeMethod.ReturnType.IsVoid()) return SyntaxFactory.ExpressionStatement(expr); return SyntaxFactory.ReturnStatement(expr); } private static InvocationExpressionSyntax makeFunctionCall(GeneratedMethod gm) { return SyntaxFactory.InvocationExpression( SyntaxFactory.IdentifierName(gm.NativeMethod.Identifier)) .WithArguments(makeArguments(gm)); } private static IEnumerable makeArguments(GeneratedMethod gm) { foreach (var param in gm.NativeMethod.ParameterList.Parameters) { if (param.IsTypeConstCharPtr()) { Debug.Assert(gm.RequiredChanges.HasFlag(Changes.ChangeParamsToUtf8String)); yield return SyntaxFactory.Argument(SyntaxFactory.IdentifierName(param.Identifier.ValueText + pointer_suffix)); } else { yield return SyntaxFactory.Argument(SyntaxFactory.IdentifierName(param.Identifier)); } } } } }