feat: 初步支持函数重载

This commit is contained in:
zhangxun 2025-08-06 01:13:17 +08:00
parent 6b56e65bce
commit 6ea8f88b0a
8 changed files with 216 additions and 28 deletions

View File

@ -72,7 +72,8 @@ public class CallGenerator implements InstructionGenerator<CallInstruction> {
String fn = ins.getFunctionName(); String fn = ins.getFunctionName();
// 特殊处理 syscall 调用 // 特殊处理 syscall 调用
if ("syscall".equals(fn) || fn.endsWith(".syscall")) { String tname = fn.substring(0, fn.lastIndexOf(':'));
if ("syscall".equals(tname) || tname.endsWith(".syscall")) {
generateSyscall(ins, out, slotMap, fn); generateSyscall(ins, out, slotMap, fn);
return; return;
} }

View File

@ -14,6 +14,9 @@ import org.jcnc.snow.compiler.parser.ast.base.ExpressionNode;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.StringJoiner;
import static org.jcnc.snow.compiler.ir.utils.ExpressionUtils.looksLikeFloat;
/** /**
* {@code ExpressionBuilder} 表达式 IR 构建器 * {@code ExpressionBuilder} 表达式 IR 构建器
@ -377,17 +380,67 @@ public record ExpressionBuilder(IRContext ctx) {
// 1. 递归生成所有参数的寄存器 // 1. 递归生成所有参数的寄存器
List<IRVirtualRegister> argv = call.arguments().stream().map(this::build).toList(); List<IRVirtualRegister> argv = call.arguments().stream().map(this::build).toList();
// TODO: 注释
StringJoiner sj = new StringJoiner("_");
for (ExpressionNode param : call.arguments()) {
switch (param) {
case NumberLiteralNode n -> {
String value = n.value();
char suffix = value.isEmpty() ? '\0'
: Character.toLowerCase(value.charAt(value.length() - 1));
switch (suffix) {
case 'b' -> sj.add("byte");
case 's' -> sj.add("short");
case 'l' -> sj.add("long");
case 'f' -> sj.add("float");
case 'd' -> sj.add("double");
default -> {
if (looksLikeFloat(value)) {
sj.add("double");
} else {
sj.add("int");
}
}
}
;
}
case BoolLiteralNode _ -> sj.add("bool");
case StringLiteralNode _ -> sj.add("string");
case IdentifierNode id -> {
String type = ctx.getScope().lookupType(id.name());
sj.add(type);
}
// case CallExpressionNode ce -> {
// }
case null, default -> throw new IllegalArgumentException("(内部错误) 不支持的参数表达式: " + param);
}
}
// 2. 规范化被调用方法名区分成员方法与普通函数 // 2. 规范化被调用方法名区分成员方法与普通函数
String callee = switch (call.callee()) { String callee = switch (call.callee()) {
// 成员方法调用 obj.method() // 成员方法调用 obj.method()
case MemberExpressionNode m when m.object() instanceof IdentifierNode id -> id.name() + "." + m.member(); case MemberExpressionNode m when m.object() instanceof IdentifierNode id -> {
String qualifiedName = id.name() + "." + m.member();
if (sj.length() > 0) {
qualifiedName = qualifiedName + ":" + sj;
}
yield qualifiedName;
}
// 普通函数调用或处理命名空间前缀如当前方法名为 namespace.func // 普通函数调用或处理命名空间前缀如当前方法名为 namespace.func
case IdentifierNode id -> { case IdentifierNode id -> {
String current = ctx.getFunction().name(); String current = ctx.getFunction().name();
int dot = current.lastIndexOf('.'); int dot = current.lastIndexOf('.');
if (dot > 0)
yield current.substring(0, dot) + "." + id.name(); // 同命名空间内调用 String qualifiedName = dot > 0
yield id.name(); // 全局函数调用 ? current.substring(0, dot) + "." + id.name()
: id.name();
if (sj.length() > 0) {
qualifiedName = qualifiedName + ":" + sj;
}
yield qualifiedName; // 全局函数调用
} }
// 其它类型不支持 // 其它类型不支持
default -> throw new IllegalStateException( default -> throw new IllegalStateException(

View File

@ -4,12 +4,14 @@ import org.jcnc.snow.compiler.ir.core.IRFunction;
import org.jcnc.snow.compiler.ir.core.IRProgram; import org.jcnc.snow.compiler.ir.core.IRProgram;
import org.jcnc.snow.compiler.parser.ast.FunctionNode; import org.jcnc.snow.compiler.parser.ast.FunctionNode;
import org.jcnc.snow.compiler.parser.ast.ModuleNode; import org.jcnc.snow.compiler.parser.ast.ModuleNode;
import org.jcnc.snow.compiler.parser.ast.ParameterNode;
import org.jcnc.snow.compiler.parser.ast.base.Node; import org.jcnc.snow.compiler.parser.ast.base.Node;
import org.jcnc.snow.compiler.parser.ast.base.NodeContext; import org.jcnc.snow.compiler.parser.ast.base.NodeContext;
import org.jcnc.snow.compiler.parser.ast.base.StatementNode; import org.jcnc.snow.compiler.parser.ast.base.StatementNode;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.StringJoiner;
/** /**
* IRProgramBuilder 负责将 AST 根节点(如模块函数顶层语句)转换为可执行的 IRProgram 实例 * IRProgramBuilder 负责将 AST 根节点(如模块函数顶层语句)转换为可执行的 IRProgram 实例
@ -66,6 +68,15 @@ public final class IRProgramBuilder {
private IRFunction buildFunctionWithGlobals(ModuleNode moduleNode, FunctionNode functionNode) { private IRFunction buildFunctionWithGlobals(ModuleNode moduleNode, FunctionNode functionNode) {
// 拼接模块名和函数名生成全限定名 // 拼接模块名和函数名生成全限定名
String qualifiedName = moduleNode.name() + "." + functionNode.name(); String qualifiedName = moduleNode.name() + "." + functionNode.name();
StringJoiner sj = new StringJoiner("_");
for (ParameterNode param : functionNode.parameters()) {
sj.add(param.type());
}
if (sj.length() > 0) {
qualifiedName = qualifiedName + ":" + sj;
}
// 若无全局声明仅重命名后直接构建 // 若无全局声明仅重命名后直接构建
if (moduleNode.globals() == null || moduleNode.globals().isEmpty()) { if (moduleNode.globals() == null || moduleNode.globals().isEmpty()) {
return buildFunction(renameFunction(functionNode, qualifiedName)); return buildFunction(renameFunction(functionNode, qualifiedName));

View File

@ -77,8 +77,15 @@ public class CallExpressionAnalyzer implements ExpressionAnalyzer<CallExpression
return BuiltinType.INT; return BuiltinType.INT;
} }
// 分析所有实参并获取类型
List<Type> args = new ArrayList<>();
for (ExpressionNode arg : call.arguments()) {
args.add(ctx.getRegistry().getExpressionAnalyzer(arg)
.analyze(ctx, mi, fn, locals, arg));
}
// 查找目标函数签名先在当前模块/显式模块查找 // 查找目标函数签名先在当前模块/显式模块查找
FunctionType ft = target.getFunctions().get(functionName); FunctionType ft = target.retrieveFunction(functionName, args);
// 未找到则报错 // 未找到则报错
if (ft == null) { if (ft == null) {
@ -88,13 +95,6 @@ public class CallExpressionAnalyzer implements ExpressionAnalyzer<CallExpression
return BuiltinType.INT; return BuiltinType.INT;
} }
// 分析所有实参并获取类型
List<Type> args = new ArrayList<>();
for (ExpressionNode arg : call.arguments()) {
args.add(ctx.getRegistry().getExpressionAnalyzer(arg)
.analyze(ctx, mi, fn, locals, arg));
}
// 参数数量检查 // 参数数量检查
if (args.size() != ft.paramTypes().size()) { if (args.size() != ft.paramTypes().size()) {
ctx.getErrors().add(new SemanticError(call, ctx.getErrors().add(new SemanticError(call,

View File

@ -1,6 +1,7 @@
package org.jcnc.snow.compiler.semantic.analyzers.statement; package org.jcnc.snow.compiler.semantic.analyzers.statement;
import org.jcnc.snow.compiler.parser.ast.FunctionNode; import org.jcnc.snow.compiler.parser.ast.FunctionNode;
import org.jcnc.snow.compiler.parser.ast.ParameterNode;
import org.jcnc.snow.compiler.parser.ast.ReturnNode; import org.jcnc.snow.compiler.parser.ast.ReturnNode;
import org.jcnc.snow.compiler.semantic.analyzers.base.StatementAnalyzer; import org.jcnc.snow.compiler.semantic.analyzers.base.StatementAnalyzer;
import org.jcnc.snow.compiler.semantic.core.Context; import org.jcnc.snow.compiler.semantic.core.Context;
@ -11,6 +12,9 @@ import org.jcnc.snow.compiler.semantic.type.BuiltinType;
import org.jcnc.snow.compiler.semantic.type.FunctionType; import org.jcnc.snow.compiler.semantic.type.FunctionType;
import org.jcnc.snow.compiler.semantic.type.Type; import org.jcnc.snow.compiler.semantic.type.Type;
import java.util.ArrayList;
import java.util.List;
/** /**
* {@code ReturnAnalyzer} 是用于分析 {@link ReturnNode} 返回语句的语义分析器 * {@code ReturnAnalyzer} 是用于分析 {@link ReturnNode} 返回语句的语义分析器
* <p> * <p>
@ -42,12 +46,32 @@ public class ReturnAnalyzer implements StatementAnalyzer<ReturnNode> {
ctx.log("检查 return"); ctx.log("检查 return");
// 获取当前函数的定义信息 // 获取当前函数的重载列表
FunctionType expected = ctx.getModules() var overloading = ctx.getModules()
.get(mi.getName()) .get(mi.getName())
.getFunctions() .getFunctions()
.get(fn.name()); .get(fn.name());
List<Type> params = new ArrayList<>();
for (ParameterNode pn : fn.parameters()) {
params.add(locals.resolve(pn.name()).type());
}
// fn 对应的函数签名
FunctionType expected = overloading.stream()
.filter(ft -> ft.paramTypes().equals(params))
.findFirst()
.orElse(null);
if (expected == null) {
ctx.getErrors().add(new SemanticError(
fn,
"不存在的函数签名: " + params
));
ctx.log("不存在的函数签名: " + params);
return;
}
// 情况 1: 存在返回表达式需进行类型检查 // 情况 1: 存在返回表达式需进行类型检查
ret.getExpression().ifPresentOrElse(exp -> { ret.getExpression().ifPresentOrElse(exp -> {
var exprAnalyzer = ctx.getRegistry().getExpressionAnalyzer(exp); var exprAnalyzer = ctx.getRegistry().getExpressionAnalyzer(exp);

View File

@ -62,14 +62,54 @@ public final class BuiltinTypeRegistry {
/* ---------- 注册标准库 os ---------- */ /* ---------- 注册标准库 os ---------- */
ModuleInfo utils = new ModuleInfo("os"); ModuleInfo utils = new ModuleInfo("os");
// syscall(string, int): void 供标准库内部使用的调用接口 /* syscall(string, type): void 的一组重载 */
utils.getFunctions().put( List<FunctionType> overloading = new ArrayList<>();
"syscall", {
new FunctionType( // syscall(string, byte): void
Arrays.asList(BuiltinType.STRING, BuiltinType.INT), overloading.add(new FunctionType(
BuiltinType.VOID Arrays.asList(BuiltinType.STRING, BuiltinType.BYTE),
) BuiltinType.VOID
); ));
// syscall(string, short): void
overloading.add(new FunctionType(
Arrays.asList(BuiltinType.STRING, BuiltinType.SHORT),
BuiltinType.VOID
));
// syscall(string, int): void
overloading.add(new FunctionType(
Arrays.asList(BuiltinType.STRING, BuiltinType.INT),
BuiltinType.VOID
));
// syscall(string, long): void
overloading.add(new FunctionType(
Arrays.asList(BuiltinType.STRING, BuiltinType.LONG),
BuiltinType.VOID
));
// syscall(string, float): void
overloading.add(new FunctionType(
Arrays.asList(BuiltinType.STRING, BuiltinType.FLOAT),
BuiltinType.VOID
));
// syscall(string, double): void
overloading.add(new FunctionType(
Arrays.asList(BuiltinType.STRING, BuiltinType.DOUBLE),
BuiltinType.VOID
));
// syscall(string, boolean): void
overloading.add(new FunctionType(
Arrays.asList(BuiltinType.STRING, BuiltinType.BOOLEAN),
BuiltinType.VOID
));
}
// syscall(string, `type`): void 供标准库内部使用的调用接口
utils.getFunctions().put("syscall", overloading);
// 注册 BuiltinUtils 到上下文的模块表若已存在则不重复添加 // 注册 BuiltinUtils 到上下文的模块表若已存在则不重复添加
ctx.getModules().putIfAbsent("os", utils); ctx.getModules().putIfAbsent("os", utils);

View File

@ -1,6 +1,7 @@
package org.jcnc.snow.compiler.semantic.core; package org.jcnc.snow.compiler.semantic.core;
import org.jcnc.snow.compiler.semantic.type.FunctionType; import org.jcnc.snow.compiler.semantic.type.FunctionType;
import org.jcnc.snow.compiler.semantic.type.Type;
import java.util.*; import java.util.*;
@ -25,8 +26,8 @@ public class ModuleInfo {
/** 该模块显式导入的模块名集合(用于跨模块访问符号) */ /** 该模块显式导入的模块名集合(用于跨模块访问符号) */
private final Set<String> imports = new HashSet<>(); private final Set<String> imports = new HashSet<>();
/** 该模块中定义的函数名 → 函数类型映射 */ /** 该模块中定义的函数名 → 函数类型列表映射 */
private final Map<String, FunctionType> functions = new HashMap<>(); private final Map<String, List<FunctionType>> functions = new HashMap<>();
/** /**
* 构造模块信息对象 * 构造模块信息对象
@ -65,8 +66,57 @@ public class ModuleInfo {
* *
* @return 模块内函数定义映射表 * @return 模块内函数定义映射表
*/ */
public Map<String, FunctionType> getFunctions() { public Map<String, List<FunctionType>> getFunctions() {
return functions; return functions;
} }
/**
* 添加一个新的函数类型
*
* @param name 函数名
* @param ft 函数类型
*/
public void addFunction(String name, FunctionType ft) {
if (!functions.containsKey(name)) {
functions.put(name, new ArrayList<>());
}
functions.get(name).add(ft);
}
/**
* 判断函数类型是否存在
*
* @param name 函数名
* @param ft 函数类型
* @return 函数类型存在则返回 true否则 false
*/
public boolean existsFunction(String name, FunctionType ft) {
if (!functions.containsKey(name)) {
return false;
}
return functions.get(name).stream().anyMatch((t) -> t.equals(ft));
}
/**
* 检索函数类型
* 通过函数名和参数列表唯一确定一个函数类型
*
* @param name 函数名
* @param params 参数列表
* @return 唯一确定的函数类型不存在则返回 null
*/
public FunctionType retrieveFunction(String name, List<Type> params) {
List<FunctionType> fts = functions.get(name);
if (fts != null) {
for (FunctionType ft : fts) {
if (ft.paramTypes().equals(params)) {
return ft;
}
}
}
return null;
}
} }

View File

@ -78,8 +78,17 @@ public record SignatureRegistrar(Context ctx) {
Type ret = Optional.ofNullable(ctx.parseType(fn.returnType())) Type ret = Optional.ofNullable(ctx.parseType(fn.returnType()))
.orElse(BuiltinType.VOID); .orElse(BuiltinType.VOID);
// 注册函数签名 FunctionType ft = new FunctionType(params, ret);
mi.getFunctions().put(fn.name(), new FunctionType(params, ret)); if (mi.existsFunction(fn.name(), ft)) {
ctx.errors().add(new SemanticError(
fn,
"有歧义的函数: " + fn.name()
));
}
else {
// 注册函数签名
mi.addFunction(fn.name(), ft);
}
} }
} }
} }