From 44b1948d0bd7f446f0f2107cf3bd0d32fe5b109d Mon Sep 17 00:00:00 2001 From: Geoff Wilson Date: Fri, 12 Jul 2024 15:13:02 -0400 Subject: [PATCH] Add images names to project/params --- pkg/api/aim/api/response/project.go | 8 +++++- pkg/api/aim/dao/models/project.go | 1 + pkg/api/aim/dao/repositories/artifact.go | 31 ++++++++++++++++++++++++ pkg/api/aim/services/project/service.go | 12 +++++++++ pkg/server/server.go | 1 + 5 files changed, 52 insertions(+), 1 deletion(-) diff --git a/pkg/api/aim/api/response/project.go b/pkg/api/aim/api/response/project.go index 401181d78..02d3ff44f 100644 --- a/pkg/api/aim/api/response/project.go +++ b/pkg/api/aim/api/response/project.go @@ -101,6 +101,12 @@ func NewProjectParamsResponse(projectParams *models.ProjectParams, } } + // process images + images := make(fiber.Map, len(projectParams.Images)) + for _, imageName := range projectParams.Images { + images[imageName] = []fiber.Map{} + } + rsp := ProjectParamsResponse{} if !excludeParams { rsp.Params = ¶ms @@ -118,7 +124,7 @@ func NewProjectParamsResponse(projectParams *models.ProjectParams, for _, s := range sequences { switch s { case "images": - rsp.Images = &fiber.Map{"images": fiber.Map{}} + rsp.Images = &images case "texts": rsp.Texts = &fiber.Map{} case "figures": diff --git a/pkg/api/aim/dao/models/project.go b/pkg/api/aim/dao/models/project.go index c73d71053..9df291b09 100644 --- a/pkg/api/aim/dao/models/project.go +++ b/pkg/api/aim/dao/models/project.go @@ -14,4 +14,5 @@ type ProjectParams struct { Metrics []LatestMetric TagKeys []string ParamKeys []string + Images []string } diff --git a/pkg/api/aim/dao/repositories/artifact.go b/pkg/api/aim/dao/repositories/artifact.go index 09a7af26e..83a804881 100644 --- a/pkg/api/aim/dao/repositories/artifact.go +++ b/pkg/api/aim/dao/repositories/artifact.go @@ -44,6 +44,9 @@ type ArtifactRepositoryProvider interface { timeZoneOffset int, req request.SearchArtifactsRequest, ) (*sql.Rows, int64, ArtifactSearchSummary, error) + GetArtifactNamesByExperiments( + ctx context.Context, namespaceID uint, experiments []int, + ) ([]string, error) } // ArtifactRepository repository to work with `artifact` entity. @@ -129,3 +132,31 @@ func (r ArtifactRepository) Search( return rows, int64(len(runIDs)), resultSummary, nil } + +// GetArtifactNamesByExperiments will find image names in the selected experiments. +func (r ArtifactRepository) GetArtifactNamesByExperiments( + ctx context.Context, namespaceID uint, experiments []int, +) ([]string, error) { + runIDs := []string{} + if err := r.GetDB().WithContext(ctx). + Table("runs"). + Joins(`INNER JOIN experiments + ON experiments.experiment_id = runs.experiment_id + AND experiments.namespace_id = ? + AND experiments.id IN ?`, + namespaceID, experiments, + ). + Find(&runIDs).Error; err != nil { + return nil, eris.Wrap(err, "error finding runs for artifacts") + } + + imageNames := []string{} + if err := r.GetDB().WithContext(ctx). + Distinct("name"). + Table("artifacts"). + Where("run_uuid IN ?", runIDs). + Find(&imageNames).Error; err != nil { + return nil, eris.Wrap(err, "error finding runs for artifact search") + } + return imageNames, nil +} diff --git a/pkg/api/aim/services/project/service.go b/pkg/api/aim/services/project/service.go index 344303b1e..0397859dc 100644 --- a/pkg/api/aim/services/project/service.go +++ b/pkg/api/aim/services/project/service.go @@ -18,6 +18,7 @@ type Service struct { paramRepository repositories.ParamRepositoryProvider metricRepository repositories.MetricRepositoryProvider experimentRepository repositories.ExperimentRepositoryProvider + artifactRepository repositories.ArtifactRepositoryProvider liveUpdatesEnabled bool } @@ -28,6 +29,7 @@ func NewService( paramRepository repositories.ParamRepositoryProvider, metricRepository repositories.MetricRepositoryProvider, experimentRepository repositories.ExperimentRepositoryProvider, + artifactRepository repositories.ArtifactRepositoryProvider, liveUpdatesEnabled bool, ) *Service { return &Service{ @@ -113,5 +115,15 @@ func (s Service) GetProjectParams( } projectParams.Metrics = metrics } + if slices.Contains(req.Sequences, "images") { + // fetch images available for requested Experiments. + images, err := s.artifactRepository.GetArtifactNamesByExperiments( + ctx, namespaceID, req.Experiments, + ) + if err != nil { + return nil, api.NewInternalError("error getting images: %s", err) + } + projectParams.Images = images + } return &projectParams, nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index 21bad64c3..4317fc2cb 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -290,6 +290,7 @@ func createApp( aimRepositories.NewParamRepository(db.GormDB()), aimRepositories.NewMetricRepository(db.GormDB()), aimRepositories.NewExperimentRepository(db.GormDB()), + aimRepositories.NewArtifactRepository(db.GormDB()), config.LiveUpdatesEnabled, ), aimDashboardService.NewService(