Advertisement
NLinker

Haskell implementation of python's yield

Nov 14th, 2016
215
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. module Lib
  2.     ( someFunc
  3.     ) where
  4.  
  5. import Control.Monad
  6. import Control.Monad.Trans
  7. import Control.Monad.Logic
  8. import Control.Monad.Logic.Class
  9.  
  10. someFunc :: IO ()
  11. someFunc = putStrLn "someFunc"
  12.  
  13. newtype EitherT e m a = EitherT{runEitherT :: m (Either e a)}
  14.  
  15. instance Monad m => Monad (EitherT e m) where
  16.     return  = EitherT . return . Right
  17.     EitherT m >>= f = EitherT $ m >>= check
  18.       where
  19.       check (Right a) = runEitherT $ f a
  20.       check (Left  e) = return $ Left e
  21.  
  22. instance MonadPlus m => MonadPlus (EitherT e m) where
  23.     mzero = EitherT mzero
  24.     mplus (EitherT m1) (EitherT m2) = EitherT $ m1 `mplus` m2
  25.  
  26. instance MonadTrans (EitherT e) where
  27.     lift m = EitherT $ m >>= return . Right
  28.  
  29. instance MonadIO m => MonadIO (EitherT e m) where
  30.     liftIO = lift . liftIO
  31.  
  32. raise :: Monad m => e -> EitherT e m a
  33. raise = EitherT . return . Left
  34.  
  35. yield :: MonadPlus m => e -> EitherT e m ()
  36. yield x = raise x `mplus` return ()
  37.  
  38. -- We start with the in-order traversal example
  39.  
  40. -- A variant of catchError when we don't care about the
  41. -- return type, and the normal return of an expression is mapped
  42. -- to mzero. This is common for the normal return from a generator
  43. catchError' :: MonadPlus m => EitherT e m () -> m e
  44. catchError' (EitherT m) = m >>= check
  45.   where
  46.   check (Left x)  = return x
  47.   check (Right x) = mzero
  48.  
  49. -- Lifting iter to the EitherT-transformed LogicT
  50. -- We propagate the exceptions
  51. iterE :: (Monad m, MonadLogic (t m), MonadPlus (t m)) =>
  52.   Maybe Int -> EitherT e (t m) () -> EitherT e (t m) ()
  53. iterE n (EitherT m) = EitherT $ msplit m >>= check n
  54.   where
  55.   check _ Nothing = return (Right ())
  56.   check (Just n) _ | n <= 1  = return (Right ())
  57.   check n (Just (Right _,t)) = next n t
  58.   check n (Just (Left e,t))  = return (Left e) `mplus` next n t
  59.   next n t = runEitherT $ iterE (liftM pred n) (EitherT t)
  60.  
  61. -- A version of bagofN that doesn't care about the result of
  62. -- the computation (which is unit). No need to accumulate it in a list
  63. -- iter n m = bagofN n m >> return ()
  64. -- the following is an optimized implementation of the above
  65. iter :: (Monad m, MonadLogic (t m), MonadPlus (t m)) => Maybe Int -> t m () -> t m ()
  66. iter n m = msplit m >>= check n
  67.   where
  68.   check _ Nothing = return ()
  69.   check (Just n) _ | n <= 1 = return ()
  70.   check n (Just (_,t)) = iter (liftM pred n) t
  71.  
  72.  
  73. type Label = Int
  74. data Tree = Leaf | Node Label Tree Tree deriving Show
  75.  
  76. make_full_tree :: Int -> Tree
  77. make_full_tree = loop 1
  78.  where
  79.  loop label 0 = Leaf
  80.  loop label n = Node label (loop (2*label) (pred n)) (loop (2*label+1) (pred n))
  81.  
  82. tree1 = make_full_tree 3
  83.  
  84. -- This time, we implement Python code idiomatically
  85. in_order2 :: (MonadIO m, MonadPlus m) => Tree -> EitherT Label m ()
  86. in_order2 Leaf = return ()
  87. in_order2 (Node label left right) = do
  88.     in_order2 left
  89.     liftIO . putStrLn $ "traversing: " ++ show label
  90.     yield label
  91.     in_order2 right
  92.  
  93. in_order2_r :: IO ()
  94. in_order2_r = observe $ iter Nothing $ do
  95.   i <- catchError' (in_order2 tree1)
  96.  liftIO . putStrLn $ "Generated: " ++ show i
  97.  
  98. -- Stopping the generator earlier: request only two generated values
  99. -- The trace shows that we stop the traversal after consuming
  100. -- the needed two values.
  101. -- We indeed traverse on-demand.
  102. in_order2_r' :: IO ()
  103. in_order2_r' = observe $ iter (Just 2) $ do
  104.  i <- catchError' (in_order2 tree1)
  105.   liftIO . putStrLn $ "Generated: " ++ show i
  106.  
  107.  
  108. -- The post-order traversal example:
  109. -- traverse a tree post-order and print out the sum of the current
  110. -- label and the labels in the left and the right branches.
  111. -- Now the generator has to return a useful value.
  112.  
  113. post_order :: MonadPlus m => Tree -> EitherT Label m Label
  114. post_order Leaf = return 0
  115. post_order (Node label left right) = do
  116.   sum_left  <- post_order left
  117.   sum_right <- post_order right
  118.   let sum = sum_left + sum_right + label
  119.   yield sum
  120.   return sum
  121.  
  122. post_order_r :: IO ()
  123. post_order_r = observe $ iter Nothing $
  124.            catchError (post_order tree1) >>= liftIO . print
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement