Advertisement
NLinker

Pattern matching vs visitor

Apr 6th, 2016
294
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 4.19 KB | None | 0 0
  1.  
  2. package object algebraic {
  3.  
  4.   // normal AST
  5.   trait Expr
  6.   case class Const(x: Int) extends Expr
  7.   case class Var(n: String) extends Expr
  8.   case class Sum(a: Expr, b: Expr) extends Expr
  9.  
  10.   def eval(e: Expr, vars: Map[String, Int]): Int = {
  11.     e match {
  12.       case Const(x) => x
  13.       case Sum(a, b) => eval(a, vars) + eval(b, vars)
  14.       case Var(n) => vars(n)
  15.     }
  16.   }
  17.  
  18.   def format(e: Expr): String = {
  19.     e match {
  20.       case Const(x) => x.toString
  21.       case Sum(a, b) => s"(${format(a)} + ${format(b)})"
  22.       case Var(n) => n
  23.     }
  24.   }
  25.  
  26.   // convert (((5 + a) + (6 + b)) + 0) to (5 + (a + (6 + (b + 0))))
  27.   def flatten(e: Expr): Expr = {
  28.     e match {
  29.       case v@Var(_) => v
  30.       case c@Const(_) => c
  31.       case Sum(a@Var(_), b) => Sum(a, flatten(b))
  32.       case Sum(a@Const(_), b) => Sum(a, flatten(b))
  33.       case Sum(Sum(a, b), c) => flatten(Sum(a, Sum(b, c)))
  34.     }
  35.   }
  36.  
  37.   def simplify(e: Expr): Expr = {
  38.     def inner(e: Expr): Expr = {
  39.       e match {
  40.         case Sum(Var(a), Sum(Var(b), c)) =>
  41.           if (a < b) Sum(Var(a), Sum(Var(b), inner(c)))
  42.           else Sum(Var(b), Sum(Var(a), inner(c)))
  43.         case Sum(Const(a), Sum(Var(b), c)) => Sum(Var(b), inner(Sum(Const(a), c)))
  44.         case Sum(Var(a), Sum(Const(b), c)) => Sum(Var(a), inner(Sum(Const(b), c)))
  45.         case Sum(Const(a), Sum(Const(b), c)) => inner(Sum(Const(a + b), c))
  46.         case Sum(Const(a), Const(b)) => Const(a + b)
  47.         case Sum(v@Var(_a), c@Const(_b)) => Sum(v, c)
  48.         case Sum(c@Const(_a), v@Var(_b)) => Sum(v, c)
  49.         case Sum(va@Var(a), vb@Var(b)) => if (a < b) Sum(va, vb) else Sum(vb, va)
  50.         case v@Var(_) => v
  51.         case c@Const(_) => c
  52.       }
  53.     }
  54.     inner(flatten(e))
  55.   }
  56. }
  57.  
  58. package object visitor {
  59.  
  60.   // AST with visitor enabled
  61.   trait Expr {
  62.     def accept[A](op: Op[A]): A
  63.   }
  64.   case class Const(x: Int) extends Expr {
  65.     def accept[A](op: Op[A]): A = op.apply(this)
  66.   }
  67.   case class Var(n: String) extends Expr {
  68.     def accept[A](op: Op[A]): A = op.apply(this)
  69.   }
  70.   case class Sum(a: Expr, b: Expr) extends Expr {
  71.     def accept[A](op: Op[A]): A = op.apply(this)
  72.   }
  73.  
  74.   trait Op[A] {
  75.     def apply(c: Const): A
  76.     def apply(v: Var): A
  77.     def apply(s: Sum): A
  78.   }
  79.  
  80.   class Eval(vars: Map[String, Int]) extends Op[Int] {
  81.     override def apply(c: Const): Int = c.x
  82.     override def apply(v: Var): Int = vars(v.n)
  83.     override def apply(s: Sum): Int = s.a.accept(this) + s.b.accept(this)
  84.   }
  85.  
  86.   class Format extends Op[String] {
  87.     override def apply(c: Const): String = c.x.toString
  88.     override def apply(v: Var): String = v.n
  89.     override def apply(s: Sum): String =
  90.       s"(${s.a.accept(this)} + ${s.b.accept(this)})"
  91.   }
  92.  
  93.   class Flatten extends Op[Expr] {
  94.     override def apply(c: Const): Expr = c
  95.     override def apply(v: Var): Expr = v
  96.     override def apply(s: Sum): Expr = {
  97.       s.a match {
  98.         case _: Const => Sum(s.a, s.b.accept(this))
  99.         case _: Var => Sum(s.a, s.b.accept(this))
  100.         case t: Sum => Sum(t.a, Sum(t.b, s.b)).accept(this)
  101.       }
  102.     }
  103.   }
  104.  
  105.   class Simplify extends Op[Expr] {
  106.     var ex: Expr = _
  107.     override def apply(c: Const): Expr = c
  108.     override def apply(v: Var): Expr = v
  109.     override def apply(s: Sum): Expr = {
  110.       if (ex == null) {
  111.         ex = s.accept(new Flatten)
  112.       }
  113.       ex match {
  114.         case a: Const => s.b match {
  115.           case b: Const => Const(a.x + b.x)
  116.           case b: Var => Sum(b, a)
  117.           case b: Sum => ???
  118.           }
  119.         case a: Var => s.b match {
  120.           case b: Const => Sum(a, b)
  121.           case b: Var => if (a.n < b.n) Sum(a, b) else Sum(b, a)
  122.           case b: Sum => ???
  123.         }
  124.         case a: Sum => s.b match {
  125.           case b: Const => ???
  126.           case b: Var => ???
  127.           case b: Sum => ???
  128.         }
  129.       }
  130.     }
  131.   }
  132.  
  133.   def eval(expr: Expr, vars: Map[String, Int]): Int = expr.accept(new Eval(vars))
  134.  
  135.   def format(expr: Expr): String = expr.accept(new Format)
  136.  
  137.   def flatten(expr: Expr): Expr = expr.accept(new Flatten)
  138.  
  139.   def simplify(expr: Expr): Expr = expr.accept(new Simplify)
  140. }
  141.  
  142. object Test {
  143.  
  144.   import algebraic.{Const => ConstA, Var => VarA, Sum => SumA, eval => evalA, format => formatA, flatten => flattenA, simplify => simplifyA}
  145.   import visitor.{Const => ConstV, Var => VarV, Sum => SumV, eval => evalV, format => formatV, flatten => flattenV, simplify => simplifyV}
  146.  
  147.   def main(args: Array[String]): Unit = {
  148.     val m = Map("a" -> 1, "b" -> 2)
  149.  
  150.     val exa = SumA(SumA(SumA(ConstA(5), VarA("a")), SumA(ConstA(6), VarA("b"))), ConstA(0))
  151.     println(s"algebraic eval on $m: ${evalA(exa, m)}")
  152.     println(s"algebraic format: ${formatA(exa)}")
  153.     println(s"algebraic flatten: ${formatA(flattenA(exa))}")
  154.     println(s"algebraic simplify: ${formatA(simplifyA(exa))}")
  155.     println()
  156.  
  157.     val exv = SumV(SumV(SumV(ConstV(5), VarV("a")), SumV(ConstV(6), VarV("b"))), ConstV(0))
  158.     println(s"visitor eval on $m: ${evalV(exv, m)}")
  159.     println(s"visitor format: ${formatV(exv)}")
  160.     println(s"visitor flatten: ${formatV(flattenV(exv))}")
  161.     println(s"visitor simplify: ${formatV(simplifyV(exv))}")
  162.     println()
  163.  
  164.     //    val ex = Sum(
  165.     //      Sum(Sum(Const(1),Const(2)),Sum(Const(3),Const(4))),
  166.     //      Sum(Sum(Const(5),Const(6)),Sum(Const(7),Const(8)))
  167.     //    )
  168.   }
  169. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement