{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Network.Simple.SerializedCommunication (server, getMessage, putMessage, getCnx, writeMessageTo, stateAction, connInfos, failWith, logMessage, getTime, ioAction, NetworkProtocol, Failure(..)) where

import Control.Monad.Operational
import Data.Binary
import Data.Binary.Get
import Control.Concurrent.STM
import Control.Monad.State.Strict
import Network.Simple.TCP
import qualified Data.IntMap.Strict as M
import Control.Applicative
import qualified Data.Foldable as F
import Control.Monad.Catch (catchIOError)
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString.Lazy as BSL
import Data.Monoid
import Data.Thyme.Clock

type NetworkProtocol s = ProgramT (ProtoInstructions s) (State s)

newtype CnxIdentifier = CnxIdentifier Int
                        deriving (Show, Eq, Num, Ord)

data ProtoInstructions s a where
    GetMessage     :: Binary a => ProtoInstructions s a
    PutMessage     :: Binary a => a -> ProtoInstructions s ()
    GetCnx         :: ProtoInstructions s CnxIdentifier
    WriteMessageTo :: Binary a => CnxIdentifier -> a -> ProtoInstructions s ()
    StateStuff     :: STM (NetworkProtocol s a) -> ProtoInstructions s a
    ConnInfos      :: ProtoInstructions s SockAddr
    FailWith       :: Failure -> ProtoInstructions s a
    Log            :: String -> ProtoInstructions s ()
    GetTime        :: ProtoInstructions s UTCTime
    IoAction       :: IO a -> ProtoInstructions s a

ioAction :: IO a -> NetworkProtocol s a
ioAction = singleton . IoAction

getMessage :: Binary a => NetworkProtocol s a
getMessage = singleton GetMessage

getTime :: NetworkProtocol s UTCTime
getTime = singleton GetTime

putMessage :: Binary a => a -> NetworkProtocol s ()
putMessage = singleton . PutMessage

getCnx :: NetworkProtocol s CnxIdentifier
getCnx = singleton GetCnx

writeMessageTo :: Binary a => CnxIdentifier -> a -> NetworkProtocol s ()
writeMessageTo i a = singleton (WriteMessageTo i a)

stateAction :: STM (NetworkProtocol s a) -> NetworkProtocol s a
stateAction = singleton . StateStuff

connInfos :: NetworkProtocol s SockAddr
connInfos = singleton ConnInfos

failWith :: Failure -> NetworkProtocol s a
failWith = singleton . FailWith

logMessage :: String -> NetworkProtocol s ()
logMessage = singleton . Log

data Failure = ConnectionAbort
             | IOFail IOError
             | ParseError String
             | ProtocolError String
             deriving Show

getmsg :: Decoder a -> Socket -> IO (Either Failure (ByteString, a))
getmsg lo sock = case lo of
                     Fail _ _ rr -> return (Left (ParseError rr))
                     Partial f -> recv sock 1024 >>= \cnt -> getmsg (f cnt) sock
                     Done rmng _ a -> return (Right (BSL.fromChunks [rmng], a))

interpretProgram :: ByteString -> s -> CnxIdentifier -> (CnxIdentifier -> IO (Maybe Socket)) -> SockAddr -> Socket -> NetworkProtocol s a -> IO (s, Either Failure a)
interpretProgram leftovers stt identifier getSock sockaddr sock program =
    case runState (viewT program) stt of
        (Return a, s) -> return (s, Right a)
        (a :>>= f, nstt) ->
            let next = runWithState nstt . f
                runWithState = runWithLeftovers leftovers
                runWithLeftovers lo s = interpretProgram lo s identifier getSock sockaddr sock
            in  case a of
                    IoAction x -> x >>= next
                    GetTime -> getCurrentTime >>= next
                    Log msg -> putStrLn msg >> next ()
                    GetMessage -> do
                        em <- getmsg (runGetIncremental Data.Binary.get) sock
                        case em of
                            Left rr -> return (nstt, Left rr)
                            Right (leftovers', x) -> runWithLeftovers leftovers' nstt (f x)
                    PutMessage msg -> catchIOError (sendLazy sock (encode msg) >>= next) (\rr -> return (nstt, Left $ IOFail rr))
                    WriteMessageTo i msg -> do
                        msock <- getSock i
                        catchIOError (F.forM_ msock (\ms -> sendLazy ms (encode msg)) >>= next) (\rr -> return (nstt, Left $ IOFail rr))
                    FailWith rr -> return (nstt, Left rr)
                    GetCnx -> next identifier
                    StateStuff s -> do
                        g <- atomically s
                        (nsst', x) <- interpretProgram leftovers nstt identifier getSock sockaddr sock g
                        case x of
                            Left rr -> return (nsst', Left rr)
                            Right v -> runWithState nsst' (f v)
                    ConnInfos -> next sockaddr

runProgram :: s -> NetworkProtocol s () -> TVar (M.IntMap Socket) -> (Socket, SockAddr) -> IO ()
runProgram i program trelationmap (sock, sockaddr) = do
    putStrLn ("Connection from " <> show sockaddr)
    identifier <- atomically $ do
        relationmap <- readTVar trelationmap
        let id' = case M.maxViewWithKey relationmap of
                      Nothing -> 0
                      Just ((k, _), _) -> k + 1
        modifyTVar' trelationmap (M.insert id' sock)
        return id'
    (_, r) <- interpretProgram BSL.empty i (CnxIdentifier identifier) (\(CnxIdentifier q) -> M.lookup q <$> readTVarIO trelationmap) sockaddr sock program
    closeSock sock
    atomically $ modifyTVar' trelationmap (M.delete identifier)
    print (sockaddr, r)

server :: s -> Int -> NetworkProtocol s () -> IO ()
server i port proto = do
    identifiers <- newTVarIO M.empty
    serve HostAny (show port) (runProgram i proto identifiers)

