Skip to content

Commit

Permalink
Update to response summary and test
Browse files Browse the repository at this point in the history
  • Loading branch information
suprjinx committed Jul 30, 2024
1 parent 52608b5 commit 82bac3a
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 65 deletions.
26 changes: 13 additions & 13 deletions pkg/api/aim/api/response/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func NewStreamMetricsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64,
//
//nolint:gocyclo
func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, runs map[string]models.Run,
result repositories.ArtifactSearchSummary, req request.SearchArtifactsRequest,
summary repositories.ArtifactSearchSummary, req request.SearchArtifactsRequest,
) {
ctx.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
//nolint:errcheck
Expand Down Expand Up @@ -547,18 +547,18 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, runs map[string]
return w.Flush()
}
addImage := func(img models.Artifact, run models.Run) {
imagesPerStep := result.StepImageCount(img.RunID, img.Name, 0)
totalSteps := result.TotalSteps(img.RunID, img.Name)
maxIndex := summary.MaxIndex(img.RunID, img.Name)
maxStep := summary.MaxStep(img.RunID, img.Name)
if runData == nil {
runData = fiber.Map{
"ranges": fiber.Map{
"record_range_total": []int{0, totalSteps},
"record_range_total": []int{0, maxStep},
"record_range_used": []int{0, int(img.Step)},
"index_range_total": []int{0, imagesPerStep},
"index_range_total": []int{0, maxIndex},
"index_range_used": []int{0, int(img.Index)},
},
"params": fiber.Map{
"images_per_step": imagesPerStep,
"images_per_step": maxIndex,
},
"props": renderProps(run),
}
Expand All @@ -575,18 +575,13 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, runs map[string]
}
traceValues, ok := trace["values"].([][]fiber.Map)
if !ok {
stepsSlice := make([][]fiber.Map, totalSteps)
stepsSlice := make([][]fiber.Map, maxStep+1)
traceValues = stepsSlice
}

stepImages := traceValues[img.Step]
if stepImages == nil {
stepImages = []fiber.Map{}
}

iters, ok := trace["iters"].([]int64)
if !ok {
iters = make([]int64, totalSteps)
iters = make([]int64, maxStep+1)
}
value := fiber.Map{
"blob_uri": img.BlobURI,
Expand All @@ -598,6 +593,11 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, runs map[string]
"index": img.Index,
"step": img.Step,
}

stepImages := traceValues[img.Step]
if stepImages == nil {
stepImages = []fiber.Map{}
}
stepImages = append(stepImages, value)
traceValues[img.Step] = stepImages
iters[img.Step] = img.Iter // TODO maybe not correct
Expand Down
41 changes: 34 additions & 7 deletions pkg/api/aim/dao/repositories/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,34 @@ type ArtifactSearchStepInfo struct {
Name string `gorm:"column:name"`
Step int `gorm:"column:step"`
ImgCount int `gorm:"column:img_count"`
MaxIndex int `gorm:"column:max_index"`
}

// ArtifactSearchSummary is a search summary for run and name.
type ArtifactSearchSummary map[string]map[string][]ArtifactSearchStepInfo

// TotalSteps figures out how many steps belong to the runID.
func (r ArtifactSearchSummary) TotalSteps(runID, name string) int {
return len(r[runID][name])
// MaxStep figures out the max step belonging to the runID and sequence name.
func (r ArtifactSearchSummary) MaxStep(runID, name string) int {
runSequence := r[runID][name]
maxStep := 0
for _, step := range runSequence {
if step.Step > maxStep {
maxStep = step.Step
}
}
return maxStep
}

// MaxIndex figures out the maximum index for the runID and sequence name.
func (r ArtifactSearchSummary) MaxIndex(runID, name string) int {
runSequence := r[runID][name]
maxIndex := 0
for _, step := range runSequence {
if step.MaxIndex > maxIndex {
maxIndex = step.MaxIndex
}
}
return maxIndex
}

// StepImageCount figures out how many steps belong to the runID and step.
Expand Down Expand Up @@ -114,11 +134,11 @@ func (r ArtifactRepository) Search(
// collect some summary data for progress indicator
stepInfo := []ArtifactSearchStepInfo{}
if tx := r.GetDB().WithContext(ctx).
Raw(`SELECT run_uuid, name, step, count(id) as img_count
Raw(`SELECT run_uuid, name, step, count(id) as img_count, max("index") as max_index
FROM artifacts
WHERE run_uuid IN (?)
AND step BETWEEN ? AND ?
AND "index" BETWEEN ? AND ?
AND "index" BETWEEN ? AND ?
GROUP BY run_uuid, name, step;`,
runIDs,
req.RecordRangeMin(),
Expand Down Expand Up @@ -163,9 +183,16 @@ func (r ArtifactRepository) Search(
) rows USING (id)
WHERE run_uuid IN ?
AND row_num % ? = 0
AND step BETWEEN ? AND ?
AND "index" BETWEEN ? AND ?
ORDER BY run_uuid, name, step
`, runIDs, interval)

`,
runIDs,
interval,
req.RecordRangeMin(),
req.RecordRangeMax(),
req.IndexRangeMin(),
req.IndexRangeMax())
rows, err := tx.Rows()
if err != nil {
return nil, nil, nil, eris.Wrap(err, "error searching artifacts")
Expand Down
122 changes: 77 additions & 45 deletions tests/integration/golang/aim/run/search_artifacts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,22 @@ func (s *SearchArtifactsTestSuite) Test_Ok() {
})
s.Require().Nil(err)
for i := 0; i < 5; i++ {
_, err = s.ArtifactFixtures.CreateArtifact(context.Background(), &models.Artifact{
ID: uuid.New(),
Name: "some-name",
RunID: run1.ID,
BlobURI: "path/filename.png",
Step: int64(i),
Iter: 1,
Index: 1,
Caption: "caption1",
Format: "png",
Width: 100,
Height: 100,
})
s.Require().Nil(err)
for j := 0; j < 5; j++ {
_, err = s.ArtifactFixtures.CreateArtifact(context.Background(), &models.Artifact{
ID: uuid.New(),
Name: "some-name",
RunID: run1.ID,
BlobURI: "path/filename.png",
Step: int64(i),
Iter: 1,
Index: int64(j),
Caption: "caption1",
Format: "png",
Width: 100,
Height: 100,
})
s.Require().Nil(err)
}
}

run2, err := s.RunFixtures.CreateRun(context.Background(), &models.Run{
Expand All @@ -92,46 +94,76 @@ func (s *SearchArtifactsTestSuite) Test_Ok() {
})
s.Require().Nil(err)
for i := 0; i < 5; i++ {
_, err = s.ArtifactFixtures.CreateArtifact(context.Background(), &models.Artifact{
ID: uuid.New(),
Name: "other-name",
RunID: run2.ID,
BlobURI: "path/filename.png",
Step: int64(i),
Iter: 1,
Index: 1,
Caption: "caption2",
Format: "png",
Width: 100,
Height: 100,
})
s.Require().Nil(err)
for j := 0; j < 5; j++ {
_, err = s.ArtifactFixtures.CreateArtifact(context.Background(), &models.Artifact{
ID: uuid.New(),
Name: "other-name",
RunID: run2.ID,
BlobURI: "path/filename.png",
Step: int64(i),
Iter: 1,
Index: int64(j),
Caption: "caption2",
Format: "png",
Width: 100,
Height: 100,
})
s.Require().Nil(err)
}
}

tests := []struct {
name string
request request.SearchArtifactsRequest
includedRuns []*models.Run
excludedRuns []*models.Run
expectedRecordRange int64
expectedIndexRange int64
name string
request request.SearchArtifactsRequest
includedRuns []*models.Run
excludedRuns []*models.Run
expectedRecordRangeMax int64
expectedIndexRangeMax int64
expectedTraceLength int
}{
{
name: "SearchArtifact",
request: request.SearchArtifactsRequest{},
includedRuns: []*models.Run{run1, run2},
expectedRecordRange: 5,
expectedIndexRange: 1,
name: "SearchArtifact",
request: request.SearchArtifactsRequest{},
includedRuns: []*models.Run{run1, run2},
expectedRecordRangeMax: 4,
expectedIndexRangeMax: 4,
},
{
name: "SearchArtifactWithNameQuery",
request: request.SearchArtifactsRequest{
Query: `((images.name == "some-name"))`,
},
includedRuns: []*models.Run{run1},
excludedRuns: []*models.Run{run2},
expectedRecordRange: 5,
expectedIndexRange: 1,
includedRuns: []*models.Run{run1},
excludedRuns: []*models.Run{run2},
expectedRecordRangeMax: 4,
expectedIndexRangeMax: 4,
},
{
name: "SearchArtifactWithRecordRange",
request: request.SearchArtifactsRequest{
RecordRange: "0:2",
},
includedRuns: []*models.Run{run1, run2},
expectedRecordRangeMax: 2,
expectedIndexRangeMax: 4,
},
{
name: "SearchArtifactWithIndexRange",
request: request.SearchArtifactsRequest{
IndexRange: "0:2",
},
includedRuns: []*models.Run{run1, run2},
expectedRecordRangeMax: 4,
expectedIndexRangeMax: 2,
},
{
name: "SearchArtifactWithRecordDensity",
request: request.SearchArtifactsRequest{
RecordDensity: 2,
},
includedRuns: []*models.Run{run1, run2},
expectedRecordRangeMax: 4,
expectedIndexRangeMax: 4,
},
}
for _, tt := range tests {
Expand All @@ -158,12 +190,12 @@ func (s *SearchArtifactsTestSuite) Test_Ok() {
valuesIndex := 0
rangesPrefix := fmt.Sprintf("%v.ranges", run.ID)
recordRangeKey := rangesPrefix + ".record_range_total.1"
s.Equal(tt.expectedRecordRange, decodedData[recordRangeKey])
s.Equal(tt.expectedRecordRangeMax, decodedData[recordRangeKey])
propsPrefix := fmt.Sprintf("%v.props", run.ID)
artifactLocation := propsPrefix + ".experiment.artifact_location"
s.Equal(experiment.ArtifactLocation, decodedData[artifactLocation])
indexRangeKey := rangesPrefix + ".index_range_total.1"
s.Equal(tt.expectedIndexRange, decodedData[indexRangeKey])
s.Equal(tt.expectedIndexRangeMax, decodedData[indexRangeKey])
tracesPrefix := fmt.Sprintf("%v.traces.%d", run.ID, traceIndex)
valuesPrefix := fmt.Sprintf(".values.%d.%d", valuesIndex, imgIndex)
blobUriKey := tracesPrefix + valuesPrefix + ".blob_uri"
Expand Down

0 comments on commit 82bac3a

Please sign in to comment.