-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcudnnRNN_clip.go
67 lines (57 loc) · 1.84 KB
/
cudnnRNN_clip.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
package gocudnn
/*
#include <cudnn.h>
*/
import "C"
//RNNClipMode is a flag for the clipmode for an RNN
type RNNClipMode C.cudnnRNNClipMode_t
func (r RNNClipMode) c() C.cudnnRNNClipMode_t {
return C.cudnnRNNClipMode_t(r)
}
//None sets r to and returns RNNClipMode(C.CUDNN_RNN_CLIP_NONE)
func (r *RNNClipMode) None() RNNClipMode { *r = RNNClipMode(C.CUDNN_RNN_CLIP_NONE); return *r }
//MinMax sets r to and returns RNNClipMode(C.CUDNN_RNN_CLIP_MINMAX)
func (r *RNNClipMode) MinMax() RNNClipMode { *r = RNNClipMode(C.CUDNN_RNN_CLIP_MINMAX); return *r }
func (r RNNClipMode) String() string {
var x string
f := r
switch r {
case f.MinMax():
x = "MinMax"
case f.None():
x = "None"
default:
x = "Unsupported Flag"
}
return "RNNClipMode" + x
}
//SetClip sets the clip mode into descriptor
func (r *RNND) SetClip(h *Handle, mode RNNClipMode, nanprop NANProp, lclip, rclip float64) error {
if h.w != nil {
return h.w.Work(func() error {
return Status(C.cudnnRNNSetClip(h.x, r.descriptor, mode.c(), nanprop.c(), C.double(lclip), C.double(rclip))).error("(r *RNND) SetClip")
})
}
return Status(C.cudnnRNNSetClip(h.x, r.descriptor, mode.c(), nanprop.c(), C.double(lclip), C.double(rclip))).error("(r *RNND) SetClip")
}
//GetClip returns the clip settings for the descriptor
func (r *RNND) GetClip(h *Handle) (mode RNNClipMode, nanprop NANProp, lclip, rclip float64, err error) {
var (
m C.cudnnRNNClipMode_t
nan C.cudnnNanPropagation_t
lt C.double
rt C.double
)
if h.w != nil {
err = h.w.Work(func() error {
return Status(C.cudnnRNNGetClip(h.x, r.descriptor, &m, &nan, <, &rt)).error("SetClip")
})
} else {
err = Status(C.cudnnRNNGetClip(h.x, r.descriptor, &m, &nan, <, &rt)).error("SetClip")
}
mode = RNNClipMode(m)
nanprop = NANProp(nan)
lclip = float64(lt)
rclip = float64(rt)
return mode, nanprop, lclip, rclip, err
}