-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathLSTMUtils.hs
36 lines (31 loc) · 1.21 KB
/
LSTMUtils.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnicodeSyntax #-}
module LSTMUtils where
import TypedFlow
import TypedFlow.Python
onFST :: (Tensor s1 t -> Tensor s t) -> HTV t '[s1, s'] -> HTV t '[s, s']
onFST f (VecPair h c) = (VecPair (f h) c)
mkLSTM :: ∀ n x. KnownNat x => KnownNat n =>
String -> DropProb -> Gen (RnnCell Float32 '[ '[n], '[n]] (Tensor '[x] Float32) (Tensor '[n] Float32))
mkLSTM pName dropProb = do
params <- parameterDefault pName
drp1 <- mkDropout dropProb
rdrp1 <- mkDropout dropProb
return (timeDistribute drp1
.-.
onStates (onFST rdrp1) (lstm params))
mkGRU :: ∀ n x. KnownNat x => KnownNat n =>
String -> DropProb -> Gen (RnnCell Float32 '[ '[n] ] (Tensor '[x] Float32) (Tensor '[n] Float32))
mkGRU pName dropProb = do
params <- parameterDefault pName
drp1 <- mkDropout dropProb
rdrp1 <- mkDropouts dropProb
return (timeDistribute drp1 .-. onStates rdrp1 (gru params))