Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.stella.typecheck
- import org.syntax.stella.Absyn.*
- object TypeCheck {
- @Throws(Exception::class)
- fun typecheckProgram(program: Program) {
- val context = TypeContext()
- when (program) {
- is AProgram -> program.listdecl_.map {
- when (it) {
- is DeclFun -> {
- println("Declared function ${it.stellaident_}")
- typecheckFunction(it, context)
- }
- is DeclTypeAlias -> {
- println("Declared type alias (not yet processed)")
- }
- }
- }
- }
- }
- fun typecheckFunction(decl: DeclFun, context: TypeContext) {
- // Add parameters to the context first
- for (param in decl.listparamdecl_) {
- when (param) {
- is AParamDecl -> {
- // Add the parameter to the context with its type
- context.addVariable(param.stellaident_, param.type_)
- }
- else -> throw Exception("Unknown parameter declaration")
- }
- }
- // Now typecheck the body of the function
- val paramTypesListType = ListType().apply {
- addAll(decl.listparamdecl_.map {
- when (it) {
- is AParamDecl -> it.type_
- else -> throw Exception("Unknown parameter declaration")
- }
- })
- }
- val returnType = when (decl.returntype_) {
- is SomeReturnType -> decl.returntype_.type_
- is NoReturnType -> throw Exception("Function must have a return type")
- else -> throw Exception("Unknown return type")
- }
- val bodyType = typecheckExpr(decl.expr_, context)
- if (bodyType != returnType) {
- throw Exception("Function ${decl.stellaident_} returns $bodyType, but expected $returnType")
- }
- val functionType = TypeFun(paramTypesListType, returnType) // Create the function type
- context.addFunction(decl.stellaident_, functionType)
- // Add the function to the context as a variable too
- context.addVariable(decl.stellaident_, functionType)
- println("Added function ${decl.stellaident_} with type $functionType")
- }
- fun typecheckExpr(expr: Expr, context: TypeContext): Type {
- println("Typechecking expression: ${expr.javaClass.simpleName}")
- return when (expr) {
- is ConstTrue -> TypeBool() // ConstTrue is of type TypeBool
- is ConstFalse -> TypeBool() // ConstFalse is of type TypeBool
- is TypeNat -> TypeNat() // TypeNat is of type TypeNat
- is ConstInt -> TypeNat()
- is Succ -> {
- val argType = typecheckExpr(expr.expr_, context)
- if (argType != TypeNat()) {
- throw Exception("Argument to Succ must be of type Nat")
- }
- TypeNat() // Succ returns a natural number
- }
- is IsZero -> {
- val argType = typecheckExpr(expr.expr_, context)
- if (argType != TypeNat()) {
- throw Exception("Argument to IsZero must be of type Nat")
- }
- TypeBool() // IsZero returns a boolean
- }
- is NatRec -> {
- // Check the argument type (expr_1)
- val argType = typecheckExpr(expr.expr_1, context)
- if (argType != TypeNat()) {
- throw Exception("Argument to NatRec must be of type Nat")
- }
- // Check the base case (expr_2) type
- val baseType = typecheckExpr(expr.expr_2, context)
- if (baseType != TypeNat()) {
- throw Exception("Base case of NatRec must be of type Nat")
- }
- // Check the recursive case (expr_3) type
- val recType = typecheckExpr(expr.expr_3, context)
- println("Type of recursive case: $recType")
- // Ensure the recursive case is a function of type Nat -> Nat
- if (recType !is TypeFun) {
- throw Exception("Recursive case of NatRec must be a function, but got ${recType.javaClass.simpleName}")
- }
- // Check that the function takes a Nat as an argument and returns a Nat
- if (recType.listtype_[0] != TypeNat() || recType.type_ != TypeNat()) {
- throw Exception("Recursive case of NatRec must be a function of type Nat -> Nat, but got ${recType.listtype_[0]} -> ${recType.type_}")
- }
- // If everything is correct, the type of the whole NatRec expression is Nat
- TypeNat()
- }
- is If -> {
- val condType = typecheckExpr(expr.expr_1, context)
- if (condType !is TypeBool) {
- throw Exception("Condition in if-expression must be of type Bool")
- }
- val thenType = typecheckExpr(expr.expr_2, context)
- val elseType = typecheckExpr(expr.expr_3, context)
- if (thenType != elseType) {
- throw Exception("Branches of if-expression must have the same type")
- }
- thenType
- }
- is Var -> {
- val varName = expr.stellaident_
- val varType = context.getVariableType(varName) ?: throw Exception("Variable '$varName' is not declared")
- varType
- }
- is Abstraction -> {
- if (expr.listparamdecl_.size != 1) {
- throw Exception("First-class functions must accept exactly one parameter")
- }
- // Add the parameter to the context
- val paramDecl = expr.listparamdecl_[0] as AParamDecl
- context.addVariable(paramDecl.stellaident_, paramDecl.type_)
- val bodyType = typecheckExpr(expr.expr_, context)
- // Construct a ListType and return TypeFun
- val paramTypeList = ListType()
- paramTypeList.add(paramDecl.type_)
- TypeFun(paramTypeList, bodyType) // Return type is a function type
- }
- is Application -> {
- val funcType =
- typecheckExpr(expr.expr_, context) as? TypeFun ?: throw Exception("Expression must be a function")
- if (expr.listexpr_.size != 1) {
- throw Exception("Application must provide exactly one argument")
- }
- val argType = typecheckExpr(expr.listexpr_[0], context)
- if (argType != funcType.listtype_[0]) {
- throw Exception("Argument type does not match function parameter type")
- }
- funcType.type_ // Return the return type of the function
- }
- else -> throw Exception("Unsupported expression type: ${expr.javaClass.simpleName}")
- }
- }
- class TypeContext {
- val variables = mutableMapOf<String, Type>() // Variable name -> its type
- val functions = mutableMapOf<String, TypeFun>() // Function name -> its signature
- fun addVariable(name: String, type: Type) {
- variables[name] = type
- }
- fun getVariableType(name: String): Type? = variables[name]
- fun addFunction(name: String, type: TypeFun) {
- functions[name] = type
- }
- fun getFunctionType(name: String): TypeFun? = functions[name]
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement