Advertisement
EBobkunov

compiler type checker

Jan 30th, 2025 (edited)
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Kotlin 7.71 KB | None | 0 0
  1. package org.stella.typecheck
  2.  
  3. import org.syntax.stella.Absyn.*
  4.  
  5. object TypeCheck {
  6.     @Throws(Exception::class)
  7.     fun typecheckProgram(program: Program) {
  8.         val context = TypeContext()
  9.         when (program) {
  10.             is AProgram -> program.listdecl_.map {
  11.                 when (it) {
  12.                     is DeclFun -> {
  13.                         println("Declared function ${it.stellaident_}")
  14.                         typecheckFunction(it, context)
  15.                     }
  16.  
  17.                     is DeclTypeAlias -> {
  18.                         println("Declared type alias (not yet processed)")
  19.                     }
  20.                 }
  21.             }
  22.  
  23.         }
  24.     }
  25.  
  26.     fun typecheckFunction(decl: DeclFun, context: TypeContext) {
  27.         // Add parameters to the context first
  28.         for (param in decl.listparamdecl_) {
  29.             when (param) {
  30.                 is AParamDecl -> {
  31.                     // Add the parameter to the context with its type
  32.                     context.addVariable(param.stellaident_, param.type_)
  33.                 }
  34.  
  35.                 else -> throw Exception("Unknown parameter declaration")
  36.             }
  37.         }
  38.  
  39.         // Now typecheck the body of the function
  40.         val paramTypesListType = ListType().apply {
  41.             addAll(decl.listparamdecl_.map {
  42.                 when (it) {
  43.                     is AParamDecl -> it.type_
  44.                     else -> throw Exception("Unknown parameter declaration")
  45.                 }
  46.             })
  47.         }
  48.  
  49.         val returnType = when (decl.returntype_) {
  50.             is SomeReturnType -> decl.returntype_.type_
  51.             is NoReturnType -> throw Exception("Function must have a return type")
  52.             else -> throw Exception("Unknown return type")
  53.         }
  54.  
  55.         val bodyType = typecheckExpr(decl.expr_, context)
  56.         if (bodyType != returnType) {
  57.             throw Exception("Function ${decl.stellaident_} returns $bodyType, but expected $returnType")
  58.         }
  59.  
  60.         val functionType = TypeFun(paramTypesListType, returnType) // Create the function type
  61.         context.addFunction(decl.stellaident_, functionType)
  62.  
  63.         // Add the function to the context as a variable too
  64.         context.addVariable(decl.stellaident_, functionType)
  65.  
  66.         println("Added function ${decl.stellaident_} with type $functionType")
  67.     }
  68.  
  69.  
  70.     fun typecheckExpr(expr: Expr, context: TypeContext): Type {
  71.         println("Typechecking expression: ${expr.javaClass.simpleName}")
  72.         return when (expr) {
  73.             is ConstTrue -> TypeBool() // ConstTrue is of type TypeBool
  74.             is ConstFalse -> TypeBool() // ConstFalse is of type TypeBool
  75.             is TypeNat -> TypeNat()     // TypeNat is of type TypeNat
  76.             is ConstInt -> TypeNat()
  77.  
  78.             is Succ -> {
  79.                 val argType = typecheckExpr(expr.expr_, context)
  80.                 if (argType != TypeNat()) {
  81.                     throw Exception("Argument to Succ must be of type Nat")
  82.                 }
  83.                 TypeNat() // Succ returns a natural number
  84.             }
  85.  
  86.             is IsZero -> {
  87.                 val argType = typecheckExpr(expr.expr_, context)
  88.                 if (argType != TypeNat()) {
  89.                     throw Exception("Argument to IsZero must be of type Nat")
  90.                 }
  91.                 TypeBool() // IsZero returns a boolean
  92.             }
  93.  
  94.             is NatRec -> {
  95.                 // Check the argument type (expr_1)
  96.                 val argType = typecheckExpr(expr.expr_1, context)
  97.                 if (argType != TypeNat()) {
  98.                     throw Exception("Argument to NatRec must be of type Nat")
  99.                 }
  100.  
  101.                 // Check the base case (expr_2) type
  102.                 val baseType = typecheckExpr(expr.expr_2, context)
  103.                 if (baseType != TypeNat()) {
  104.                     throw Exception("Base case of NatRec must be of type Nat")
  105.                 }
  106.  
  107.                 // Check the recursive case (expr_3) type
  108.                 val recType = typecheckExpr(expr.expr_3, context)
  109.                 println("Type of recursive case: $recType")
  110.  
  111.                 // Ensure the recursive case is a function of type Nat -> Nat
  112.                 if (recType !is TypeFun) {
  113.                     throw Exception("Recursive case of NatRec must be a function, but got ${recType.javaClass.simpleName}")
  114.                 }
  115.  
  116.                 // Check that the function takes a Nat as an argument and returns a Nat
  117.                 if (recType.listtype_[0] != TypeNat() || recType.type_ != TypeNat()) {
  118.                     throw Exception("Recursive case of NatRec must be a function of type Nat -> Nat, but got ${recType.listtype_[0]} -> ${recType.type_}")
  119.                 }
  120.  
  121.                 // If everything is correct, the type of the whole NatRec expression is Nat
  122.                 TypeNat()
  123.             }
  124.  
  125.  
  126.  
  127.  
  128.             is If -> {
  129.                 val condType = typecheckExpr(expr.expr_1, context)
  130.                 if (condType !is TypeBool) {
  131.                     throw Exception("Condition in if-expression must be of type Bool")
  132.                 }
  133.  
  134.                 val thenType = typecheckExpr(expr.expr_2, context)
  135.                 val elseType = typecheckExpr(expr.expr_3, context)
  136.  
  137.                 if (thenType != elseType) {
  138.                     throw Exception("Branches of if-expression must have the same type")
  139.                 }
  140.  
  141.                 thenType
  142.             }
  143.  
  144.             is Var -> {
  145.                 val varName = expr.stellaident_
  146.                 val varType = context.getVariableType(varName) ?: throw Exception("Variable '$varName' is not declared")
  147.                 varType
  148.             }
  149.  
  150.             is Abstraction -> {
  151.                 if (expr.listparamdecl_.size != 1) {
  152.                     throw Exception("First-class functions must accept exactly one parameter")
  153.                 }
  154.  
  155.                 // Add the parameter to the context
  156.                 val paramDecl = expr.listparamdecl_[0] as AParamDecl
  157.                 context.addVariable(paramDecl.stellaident_, paramDecl.type_)
  158.  
  159.                 val bodyType = typecheckExpr(expr.expr_, context)
  160.  
  161.                 // Construct a ListType and return TypeFun
  162.                 val paramTypeList = ListType()
  163.                 paramTypeList.add(paramDecl.type_)
  164.  
  165.                 TypeFun(paramTypeList, bodyType) // Return type is a function type
  166.             }
  167.  
  168.  
  169.             is Application -> {
  170.                 val funcType =
  171.                     typecheckExpr(expr.expr_, context) as? TypeFun ?: throw Exception("Expression must be a function")
  172.  
  173.                 if (expr.listexpr_.size != 1) {
  174.                     throw Exception("Application must provide exactly one argument")
  175.                 }
  176.  
  177.                 val argType = typecheckExpr(expr.listexpr_[0], context)
  178.  
  179.                 if (argType != funcType.listtype_[0]) {
  180.                     throw Exception("Argument type does not match function parameter type")
  181.                 }
  182.  
  183.                 funcType.type_ // Return the return type of the function
  184.             }
  185.  
  186.             else -> throw Exception("Unsupported expression type: ${expr.javaClass.simpleName}")
  187.         }
  188.     }
  189.  
  190.  
  191.     class TypeContext {
  192.         val variables = mutableMapOf<String, Type>()   // Variable name -> its type
  193.         val functions = mutableMapOf<String, TypeFun>() // Function name -> its signature
  194.  
  195.         fun addVariable(name: String, type: Type) {
  196.             variables[name] = type
  197.         }
  198.  
  199.         fun getVariableType(name: String): Type? = variables[name]
  200.         fun addFunction(name: String, type: TypeFun) {
  201.             functions[name] = type
  202.         }
  203.  
  204.         fun getFunctionType(name: String): TypeFun? = functions[name]
  205.     }
  206.  
  207.  
  208. }
  209.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement