-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcudnnStatus.go
78 lines (69 loc) · 2.75 KB
/
cudnnStatus.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package gocudnn
/*
#include <cudnn.h>
*/
import "C"
import (
"errors"
"strings"
)
//Status is the status of the cuda dnn
type Status C.cudnnStatus_t
//StatusSuccess is the zero error of Status. None of the other flags are visable for now,
// of the Status.error() method
const StatusSuccess Status = 0
//String is the function that makes a human readable message
func (status Status) String() string {
response := C.cudnnGetErrorString(C.cudnnStatus_t(status))
return "Cudnn Status: " + C.GoString(response)
}
//Error will return the error string if there was an error. If not it will return nil
func (status Status) error(comment string) error {
if C.cudnnStatus_t(status) == C.CUDNN_STATUS_SUCCESS {
return nil
}
x := comment + ":"
return errors.New(x + status.String())
}
func (status Status) c() C.cudnnStatus_t {
return C.cudnnStatus_t(status)
}
func (status Status) Error(comment string) error {
return status.error("::Exported Error Function:: " + comment)
}
//WrapErrorWithStatus if the error string contains a cudnnStatus_t string then it will return the Status and nil,
// if it doens't the Status will be the flag for CUDNN_STATUS_RUNTIME_FP_OVERFLOW but the error will not return a nil
func WrapErrorWithStatus(e error) (Status, error) {
if e == nil {
return Status(C.CUDNN_STATUS_SUCCESS), nil
}
x := e.Error()
switch {
case strings.Contains(x, "CUDNN_STATUS_NOT_INITIALIZED"):
return Status(C.CUDNN_STATUS_NOT_INITIALIZED), nil
case strings.Contains(x, "CUDNN_STATUS_ALLOC_FAILED"):
return Status(C.CUDNN_STATUS_ALLOC_FAILED), nil
case strings.Contains(x, "CUDNN_STATUS_BAD_PARAM"):
return Status(C.CUDNN_STATUS_BAD_PARAM), nil
case strings.Contains(x, "CUDNN_STATUS_ARCH_MISMATCH"):
return Status(C.CUDNN_STATUS_ARCH_MISMATCH), nil
case strings.Contains(x, "CUDNN_STATUS_MAPPING_ERROR"):
return Status(C.CUDNN_STATUS_MAPPING_ERROR), nil
case strings.Contains(x, "CUDNN_STATUS_EXECUTION_FAILED"):
return Status(C.CUDNN_STATUS_EXECUTION_FAILED), nil
case strings.Contains(x, "CUDNN_STATUS_INTERNAL_ERROR"):
return Status(C.CUDNN_STATUS_INTERNAL_ERROR), nil
case strings.Contains(x, "CUDNN_STATUS_NOT_SUPPORTED"):
return Status(C.CUDNN_STATUS_NOT_SUPPORTED), nil
case strings.Contains(x, "CUDNN_STATUS_LICENSE_ERROR"):
return Status(C.CUDNN_STATUS_LICENSE_ERROR), nil
case strings.Contains(x, "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING"):
return Status(C.CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING), nil
case strings.Contains(x, "CUDNN_STATUS_RUNTIME_IN_PROGRESS"):
return Status(C.CUDNN_STATUS_RUNTIME_IN_PROGRESS), nil
case strings.Contains(x, "CUDNN_STATUS_RUNTIME_FP_OVERFLOW"):
return Status(C.CUDNN_STATUS_RUNTIME_FP_OVERFLOW), nil
default:
return Status(C.CUDNN_STATUS_RUNTIME_FP_OVERFLOW), errors.New("Unsupported error")
}
}