[fine] Check the return type of functions

A function returns what it says it does, the check is that the body
returns the right value.
This commit is contained in:
John Doty 2024-01-15 07:46:20 -08:00
parent 257a7e64c2
commit d893002ec2
5 changed files with 70 additions and 14 deletions

View file

@ -657,6 +657,7 @@ impl<'a> Semantics<'a> {
TreeKind::Identifier => self.type_of_identifier(t, tree),
TreeKind::FunctionDecl => self.type_of_function_decl(tree),
TreeKind::ReturnType => self.type_of_return_type(tree),
_ => self.internal_compiler_error(Some(t), "asking for a nonsense type"),
};
@ -1021,6 +1022,11 @@ impl<'a> Semantics<'a> {
Some(Type::Function(parameter_types, return_type))
}
fn type_of_return_type(&self, tree: &Tree) -> Option<Type> {
assert_eq!(tree.kind, TreeKind::ReturnType);
Some(self.type_of(tree.nth_tree(1)?)) // type expression
}
fn type_of_if_statement(&self, tree: &Tree) -> Option<Type> {
Some(self.type_of(tree.nth_tree(0)?))
}
@ -1101,9 +1107,7 @@ pub fn check(s: &Semantics) {
match tree.kind {
TreeKind::Error => {} // already reported
TreeKind::File => {}
TreeKind::FunctionDecl => {
let _ = s.environment_of(t);
}
TreeKind::FunctionDecl => check_function_decl(s, t, tree),
TreeKind::ParamList => {}
TreeKind::Parameter => {}
TreeKind::TypeExpression => {
@ -1138,6 +1142,38 @@ pub fn check(s: &Semantics) {
}
}
fn check_function_decl(s: &Semantics, t: TreeRef, tree: &Tree) {
assert_eq!(tree.kind, TreeKind::FunctionDecl);
let _ = s.environment_of(t);
let return_type_tree = tree.child_of_kind(s.syntax_tree, TreeKind::ReturnType);
let return_type = return_type_tree
.map(|t| s.type_of(t))
.unwrap_or(Type::Nothing);
if let Some(body) = tree.child_of_kind(s.syntax_tree, TreeKind::Block) {
let body_type = s.type_of(body);
if !body_type.compatible_with(&return_type) {
// Just work very hard to get an appropriate error span.
let (start, end) = return_type_tree
.map(|t| {
let rtt = &s.syntax_tree[t];
(rtt.start_pos, rtt.end_pos)
})
.unwrap_or_else(|| {
let start = tree.start_pos;
let end_tok = tree
.nth_token(1)
.unwrap_or_else(|| tree.nth_token(0).unwrap());
let end_pos = end_tok.start + end_tok.as_str().len();
(start, end_pos)
});
s.report_error_span(start, end, format!("the body of this function yields a value of type `{body_type}`, but callers expect this function to produce a `{return_type}`"));
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -1,8 +1,8 @@
fun foo(x: f64) {
fun foo(x: f64) -> f64 {
x + 7
}
fun test() {
fun test() -> f64 {
foo(1)
}
@ -22,6 +22,10 @@ fun test() {
// | TypeExpression
// | Identifier:'"f64"'
// | RightParen:'")"'
// | ReturnType
// | Arrow:'"->"'
// | TypeExpression
// | Identifier:'"f64"'
// | Block
// | LeftBrace:'"{"'
// | ExpressionStatement
@ -38,6 +42,10 @@ fun test() {
// | ParamList
// | LeftParen:'"("'
// | RightParen:'")"'
// | ReturnType
// | Arrow:'"->"'
// | TypeExpression
// | Identifier:'"f64"'
// | Block
// | LeftBrace:'"{"'
// | ExpressionStatement

View file

@ -1,4 +1,4 @@
fun test() {
fun test() -> f64 {
1 * 2 + -3 * 4
}
@ -13,6 +13,10 @@ fun test() {
// | ParamList
// | LeftParen:'"("'
// | RightParen:'")"'
// | ReturnType
// | Arrow:'"->"'
// | TypeExpression
// | Identifier:'"f64"'
// | Block
// | LeftBrace:'"{"'
// | ExpressionStatement

View file

@ -1,4 +1,4 @@
fun test() {
fun test() -> bool {
true and false or false and !true
}
@ -38,6 +38,10 @@ fun test() {
// | ParamList
// | LeftParen:'"("'
// | RightParen:'")"'
// | ReturnType
// | Arrow:'"->"'
// | TypeExpression
// | Identifier:'"bool"'
// | Block
// | LeftBrace:'"{"'
// | ExpressionStatement

View file

@ -1,22 +1,22 @@
fun test() {
fun test() -> f64 {
if true { "discarded"; 23 } else { 45 }
}
// @no-errors
// Here come some type probes!
// (type of the condition)
// @type: 20 bool
// @type: 27 bool
//
// (the discarded expression)
// @type: 27 string
// @type: 34 string
//
// (the "then" clause)
// @type: 40 f64
// @type: 43 f64
// @type: 47 f64
// @type: 50 f64
//
// (the "else" clause)
// @type: 52 f64
// @type: 55 f64
// @type: 59 f64
// @type: 62 f64
//
// @concrete:
// | File
@ -26,6 +26,10 @@ fun test() {
// | ParamList
// | LeftParen:'"("'
// | RightParen:'")"'
// | ReturnType
// | Arrow:'"->"'
// | TypeExpression
// | Identifier:'"f64"'
// | Block
// | LeftBrace:'"{"'
// | IfStatement