-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcudnnTensorStringer.go
151 lines (131 loc) · 3.94 KB
/
cudnnTensorStringer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package gocudnn
import (
"errors"
"fmt"
"github.com/dereklstinson/gocudnn/cudart"
"github.com/dereklstinson/gocudnn/gocu"
"github.com/dereklstinson/half"
"github.com/dereklstinson/cutil"
)
type memstringer struct {
td *TensorD
t cutil.Pointer
kind cudart.MemcpyKind
}
func (m *memstringer) String() string {
frmt, dtype, dims, stride, err := m.td.Get()
if err != nil {
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", "Err in getting hidden Tensor Descriptor")
}
sib, err := m.td.GetSizeInBytes()
if err != nil {
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", "Err in getting sib for hidden Tensor Descriptor")
}
length := findvolume(dims)
fflg := frmt
dflg := dtype
switch dtype {
case dflg.Float():
data := make([]float32, length)
hptr, err := gocu.MakeGoMem(data)
if err != nil {
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", "Err in wrapping data::"+err.Error())
}
err = cudart.Memcpy(hptr, m.t, sib, m.kind)
if err != nil {
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", "Err in copy to host ::"+err.Error())
}
if frmt == fflg.NCHW() {
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", nchwtensorstringformated(dims, stride, data))
}
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", nhwctensorstringformated(dims, stride, data))
case dflg.Half():
data := make([]half.Float16, length)
hptr, err := gocu.MakeGoMem(data)
if err != nil {
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", "Err in wrapping data::"+err.Error())
}
err = cudart.Memcpy(hptr, m.t, sib, m.kind)
if err != nil {
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", "Err in copy to host ::"+err.Error())
}
if frmt == fflg.NCHW() {
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", nchwtensorstringformated(dims, stride, half.ToFloat32(data)))
}
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", nhwctensorstringformated(dims, stride, half.ToFloat32(data)))
default:
return fmt.Sprintf("Tensor Data: {\n%v\n}\n", "Unsupported DataType")
}
}
//GetStringer returns a stringer that will pring cuda allocated memory formated in NHWC or NCHW.
//Only works for 4d tensors with float or half datatype. It will only print the data.
func GetStringer(tD *TensorD, t cutil.Pointer) (fmt.Stringer, error) {
frmt, dtype, _, _, err := tD.Get()
if err != nil {
return nil, err
}
fflg := frmt
if !(frmt == fflg.NCHW() || frmt == fflg.NHWC()) {
return nil, errors.New(" GetStringer(tD *TensorD, t cutil.Pointer): Unsuported Format")
}
dflg := dtype
if !(dtype == dflg.Float() || dtype == dflg.Half()) {
return nil, errors.New(" GetStringer(tD *TensorD, t cutil.Pointer): Unsupported Type")
}
var kind cudart.MemcpyKind
return &memstringer{
td: tD,
t: t,
kind: kind.Default(),
}, nil
}
func nhwctensorstringformated(dims, strides []int32, data []float32) string {
var s string
s = "\n"
for i := int32(0); i < dims[0]; i++ {
s = s + fmt.Sprintf("Batch[%v]{\n", i)
for j := int32(0); j < dims[1]; j++ {
for k := int32(0); k < dims[2]; k++ {
s = s + fmt.Sprintf("(%v,%v)[ ", j, k)
for l := int32(0); l < dims[3]; l++ {
val := data[i*strides[0]+j*strides[1]+k*strides[2]+l*strides[3]]
if val >= 0 {
s = s + fmt.Sprintf(" %.5f ", val)
} else {
s = s + fmt.Sprintf("%.5f ", val)
}
}
s = s + "], "
}
s = s + "\n"
}
s = s + "}\n"
}
return s
}
func nchwtensorstringformated(dims, strides []int32, data []float32) string {
//flg := t.Format()
var s string
s = "\n"
for i := int32(0); i < dims[0]; i++ {
s = s + fmt.Sprintf("Batch[%v]{\n", i)
for j := int32(0); j < dims[1]; j++ {
s = s + fmt.Sprintf("\tChannel[%v]{\n", j)
for k := int32(0); k < dims[2]; k++ {
s = s + "\t\t"
for l := int32(0); l < dims[3]; l++ {
val := data[i*strides[0]+j*strides[1]+k*strides[2]+l*strides[3]]
if val >= 0 {
s = s + fmt.Sprintf(" %.4f ", val)
} else {
s = s + fmt.Sprintf("%.4f ", val)
}
}
s = s + "\n"
}
s = s + "\t}\n"
}
s = s + "}\n"
}
return s
}