forked from rai-project/go-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrace.go
111 lines (100 loc) · 3.24 KB
/
trace.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package pytorch
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/c3sr/tracer"
opentracing "github.com/opentracing/opentracing-go"
)
type TraceEvent struct {
Name string `json:"name,omitempty"`
Phase string `json:"ph,omitempty"`
Timestamp float64 `json:"ts,omitempty"`
Duration float64 `json:"dur,omitempty"`
ProcessID string `json:"pid,omitempty"`
ThreadID int64 `json:"tid,omitempty"`
Shape string `json:"shape,omitempty"`
AllocatedMemory int64 `json:"allocated_memory,omitempty"`
PeakMemory int64 `json:"peak_memory,omitempty"`
Index int64 `json:"layer_sequence_index,omitempty"`
Start int64 `json:"-"`
End int64 `json:"-"`
StartTime time.Time `json:"-"`
EndTime time.Time `json:"-"`
Seq int64 `json:"-"`
}
func (t TraceEvent) ID() string {
return fmt.Sprintf("%s/%v", t.Name, t.ThreadID)
}
type TraceEvents []TraceEvent
func (t TraceEvents) Len() int { return len(t) }
func (t TraceEvents) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
func (t TraceEvents) Less(i, j int) bool {
if t[i].Start == t[j].Start {
if t[i].End == t[j].End {
return t[i].Seq > t[j].Seq
}
return t[i].End > t[j].End
}
return t[i].Start < t[j].Start
}
type Trace struct {
StartTime time.Time
TraceEvents TraceEvents
}
func (t Trace) Len() int { return t.TraceEvents.Len() }
func (t Trace) Swap(i, j int) { t.TraceEvents.Swap(i, j) }
func (t Trace) Less(i, j int) bool { return t.TraceEvents.Less(i, j) }
func NewTrace(data string, start_time int64) (*Trace, error) {
trace := new(Trace)
err := json.Unmarshal([]byte(data), &trace.TraceEvents)
if err != nil {
return nil, err
}
trace.StartTime = time.Unix(0, start_time)
for ii, event := range trace.TraceEvents {
trace.TraceEvents[ii].Start = start_time + int64(event.Timestamp*1000)
trace.TraceEvents[ii].StartTime = time.Unix(0, trace.TraceEvents[ii].Start)
trace.TraceEvents[ii].End = start_time + int64(event.Timestamp*1000+event.Duration*1000)
trace.TraceEvents[ii].EndTime = time.Unix(0, trace.TraceEvents[ii].End)
trace.TraceEvents[ii].Seq = int64(ii)
}
return trace, nil
}
func (event *TraceEvent) Publish(ctx context.Context, lvl tracer.Level, opts ...opentracing.StartSpanOption) error {
tags := opentracing.Tags{
"phase": event.Phase,
"process_id": event.ProcessID,
"thread_id": event.ThreadID,
"layer_sequence_index": event.Index,
"shape": event.Shape,
"allocated_memory": event.AllocatedMemory,
"peak_memory": event.PeakMemory,
}
s, _ := tracer.StartSpanFromContext(
ctx,
lvl,
event.Name,
opentracing.StartTime(event.StartTime),
tags,
)
if s == nil {
log.WithField("event_name", event.Name).
WithField("tags", tags).
Error("failed to create span from context")
return nil
}
s.FinishWithOptions(opentracing.FinishOptions{
FinishTime: event.EndTime,
})
return nil
}
func (t *Trace) Publish(ctx context.Context, lvl tracer.Level, opts ...opentracing.StartSpanOption) error {
for _, event := range t.TraceEvents {
if err := event.Publish(ctx, lvl, opts...); err != nil {
return err
}
}
return nil
}