Skip to content

Commit

Permalink
queries for summary and image cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
suprjinx committed Jul 5, 2024
1 parent 9ea0438 commit 76950c8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 60 deletions.
39 changes: 14 additions & 25 deletions pkg/api/aim/api/response/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ func NewStreamMetricsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64,
//
//nolint:gocyclo
func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64,
result repositories.SearchResultMap, req request.SearchArtifactsRequest,
result repositories.ImageSearchSummary, req request.SearchArtifactsRequest,
) {
ctx.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
//nolint:errcheck
Expand All @@ -487,38 +487,32 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64,

if err := func() error {
var (
runID string
runData fiber.Map
values []float64
iters []float64
epochs []float64
timestamps []float64
progress int
runID string
step int
index int
runData fiber.Map
)
reportProgress := func(cur int64) error {
if !req.ReportProgress {
return nil
}
err := encoding.EncodeTree(w, fiber.Map{
fmt.Sprintf("progress_%d", progress): []int64{cur, totalRuns},
fmt.Sprintf("progress_%d", cur): []int64{cur, totalRuns},
})
if err != nil {
return err
}
progress++
cur++
return w.Flush()
}
addImage := func(img models.Artifact) {
if runData == nil {
runData = fiber.Map{
"ranges": fiber.Map{
"record_range_total": []int{0, 0},
"record_range_used": []int{0, 0},
"index_range_total": []int{0, 0},
"index_range_used": []int{0, 0},
},
"params": fiber.Map{
"images_per_step": 0,
"record_range_total": []int{0, result.TotalSteps(img.RunID)},
"record_range_used": []int{0, step},
"index_range_total": []int{0, result.StepImageCount(img.RunID, step)},
"index_range_used": []int{0, index},
},
"traces": []fiber.Map{},
}
Expand All @@ -534,7 +528,7 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64,
}); err != nil {
return err
}
if err := reportProgress(totalRuns - 1); err != nil {
if err := reportProgress(); err != nil {

Check failure on line 531 in pkg/api/aim/api/response/run.go

View workflow job for this annotation

GitHub Actions / Build (darwin/amd64)

not enough arguments in call to reportProgress

Check failure on line 531 in pkg/api/aim/api/response/run.go

View workflow job for this annotation

GitHub Actions / Go Unit Tests

not enough arguments in call to reportProgress

Check failure on line 531 in pkg/api/aim/api/response/run.go

View workflow job for this annotation

GitHub Actions / Build (darwin/arm64)

not enough arguments in call to reportProgress

Check failure on line 531 in pkg/api/aim/api/response/run.go

View workflow job for this annotation

GitHub Actions / Build (linux/amd64)

not enough arguments in call to reportProgress

Check failure on line 531 in pkg/api/aim/api/response/run.go

View workflow job for this annotation

GitHub Actions / Python Integration Tests (aim)

not enough arguments in call to reportProgress

Check failure on line 531 in pkg/api/aim/api/response/run.go

View workflow job for this annotation

GitHub Actions / Python Integration Tests (mlflow)

not enough arguments in call to reportProgress

Check failure on line 531 in pkg/api/aim/api/response/run.go

View workflow job for this annotation

GitHub Actions / Build (windows/amd64)

not enough arguments in call to reportProgress

Check failure on line 531 in pkg/api/aim/api/response/run.go

View workflow job for this annotation

GitHub Actions / Python Integration Tests (fml_client)

not enough arguments in call to reportProgress
return err
}
return w.Flush()
Expand All @@ -545,17 +539,12 @@ func NewStreamArtifactsResponse(ctx *fiber.Ctx, rows *sql.Rows, totalRuns int64,
return err
}

// flush after each change in runID
// (assumes order by runID)
if image.RunID != runID {
if err := flushImages(); err != nil {
return err
}

if err := encoding.EncodeTree(w, fiber.Map{
image.RunID: result[image.RunID].Info,
}); err != nil {
return err
}

runID = image.RunID
runData = nil
}
Expand Down
80 changes: 49 additions & 31 deletions pkg/api/aim/dao/repositories/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@ import (
"github.com/rotisserie/eris"
)

// ImageSearchStepInfo is a search summary for a Run Step.
type ImageSearchStepInfo struct {
RunUUID string `gorm:"column:run_uuid"`
Step int `gorm:"column:step"`
ImgCount int `gorm:"column:count"`
}

// ImageSearchSummary is a search summary for whole run.
type ImageSearchSummary map[string][]ImageSearchStepInfo

// TotalSteps figures out how many steps belong to the runID.
func (r ImageSearchSummary) TotalSteps(runID string) int {
return len(r[runID])
}

// StepImageCount figures out how many steps belong to the runID and step.
func (r ImageSearchSummary) StepImageCount(runID string, step int) int {
runStepImages := r[runID]
return runStepImages[step].ImgCount
}

// ArtifactRepositoryProvider provides an interface to work with `artifact` entity.
type ArtifactRepositoryProvider interface {
repositories.BaseRepositoryProvider
Expand All @@ -22,7 +43,7 @@ type ArtifactRepositoryProvider interface {
namespaceID uint,
timeZoneOffset int,
req request.SearchArtifactsRequest,
) (*sql.Rows, int64, SearchResultMap, error)
) (*sql.Rows, int64, ImageSearchSummary, error)
}

// ArtifactRepository repository to work with `artifact` entity.
Expand All @@ -43,7 +64,7 @@ func (r ArtifactRepository) Search(
namespaceID uint,
timeZoneOffset int,
req request.SearchArtifactsRequest,
) (*sql.Rows, int64, SearchResultMap, error) {
) (*sql.Rows, int64, ImageSearchSummary, error) {
qp := query.QueryParser{
Default: query.DefaultExpression{
Contains: "run.archived",
Expand All @@ -66,42 +87,39 @@ func (r ArtifactRepository) Search(
return nil, 0, nil, eris.Wrap(err, "error counting metrics")
}

var runs []models.Run
if tx := r.GetDB().WithContext(ctx).
InnerJoins(
"Experiment",
r.GetDB().WithContext(ctx).Select(
"ID", "Name",
).Where(&models.Experiment{NamespaceID: namespaceID}),
).
Preload("Params").
Preload("Tags").
Where("run_uuid IN (?)", pq.Filter(r.GetDB().WithContext(ctx).
Select("runs.run_uuid").
Table("runs").
Joins(
"INNER JOIN experiments ON experiments.experiment_id = runs.experiment_id AND experiments.namespace_id = ?",
namespaceID,
),
runIDs := []string{}
if tx := pq.Filter(r.GetDB().WithContext(ctx).
Select("runs.run_uuid").
Table("runs").
Joins(
"INNER JOIN experiments ON experiments.experiment_id = runs.experiment_id AND experiments.namespace_id = ?",
namespaceID,
)).
Order("runs.row_num DESC").
Find(&runs); tx.Error != nil {
return nil, 0, nil, eris.Wrap(err, "error searching artifacts")
Find(&runIDs); tx.Error != nil {
return nil, 0, nil, eris.Wrap(err, "error finding runs for artifact search")
}

runIDs := []string{}
for _, run := range runs {
runIDs = append(runIDs, run.ID)
resultSummaries := []ImageSearchStepInfo{}
if tx := r.GetDB().WithContext(ctx).
Raw(`SELECT run_uuid, step, count(id)
FROM artifacts
WHERE run_uuid IN (?)
GROUP BY run_uuid, step;`, runIDs).
Find(&resultSummaries); tx.Error != nil {
return nil, 0, nil, eris.Wrap(err, "error find result summary for artifact search")
}

runImages := make(ImageSearchSummary, len(runIDs))
for _, rslt := range resultSummaries {
runImages[rslt.RunUUID] = append(runImages[rslt.RunUUID], rslt)
}
result := make(SearchResultMap, len(runs))

tx := r.GetDB().WithContext(ctx).
Select(`row_number() over (order by run_uuid, step, created_at) as row_num, *`).
Table("artifacts").
Where("run_uuid IN ?", runIDs).
Order("metrics.run_uuid").
Order("metrics.step").
Order("metrics.created_at")
Order("run_uuid").
Order("step").
Order("created_at")

rows, err := tx.Rows()
if err != nil {
Expand All @@ -111,5 +129,5 @@ func (r ArtifactRepository) Search(
return nil, 0, nil, eris.Wrap(err, "error getting artifacts rows cursor")
}

return rows, int64(len(runs)), result, nil
return rows, int64(len(runIDs)), runImages, nil
}
8 changes: 4 additions & 4 deletions pkg/api/aim/services/run/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ func (s Service) SearchMetrics(
// SearchArtifacts returns the list of artifacts (images) by provided search criteria.
func (s Service) SearchArtifacts(
ctx context.Context, namespaceID uint, timeZoneOffset int, req request.SearchArtifactsRequest,
) (*sql.Rows, int64, repositories.SearchResultMap, error) {
rows, total, searchResult, err := s.artifactRepository.Search(ctx, namespaceID, timeZoneOffset, req)
) (*sql.Rows, int64, repositories.ImageSearchSummary, error) {
rows, total, result, err := s.artifactRepository.Search(ctx, namespaceID, timeZoneOffset, req)
if err != nil {
return nil, 0, nil, api.NewInternalError("error searching runs: %s", err)
return nil, 0, nil, api.NewInternalError("error searching artifacts: %s", err)
}
return rows, total, searchResult, nil
return rows, total, result, nil
}

// SearchAlignedMetrics returns the list of aligned metrics.
Expand Down

0 comments on commit 76950c8

Please sign in to comment.