Skip to content

Commit

Permalink
Various fixes for multiple value columns
Browse files Browse the repository at this point in the history
  • Loading branch information
suprjinx committed Apr 2, 2024
1 parent 9f286b2 commit 5ca50e6
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 124 deletions.
8 changes: 4 additions & 4 deletions pkg/api/aim2/api/response/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func NewGetRunInfoResponse(run *models.Run) *GetRunInfoResponse {

params := make(GetRunInfoParamsPartial, len(run.Params)+1)
for _, p := range run.Params {
params[p.Key] = p.ValueTyped()
params[p.Key] = p.ValueAny()
}
tags := make(GetRunInfoParamsPartial, len(run.Tags))
for _, t := range run.Tags {
Expand Down Expand Up @@ -508,10 +508,10 @@ func NewRunsSearchCSVResponse(ctx *fiber.Ctx, runs []models.Run, excludeTraces,
if !excludeParams {
for _, param := range run.Params {
if _, ok := paramData[param.Key]; ok {
paramData[param.Key][run.ID] = param.Value
paramData[param.Key][run.ID] = param.ValueString()
} else {
paramKeys = append(paramKeys, param.Key)
paramData[param.Key] = map[string]string{run.ID: param.Value}
paramData[param.Key] = map[string]string{run.ID: param.ValueString()}
}
}
for _, tag := range run.Tags {
Expand Down Expand Up @@ -679,7 +679,7 @@ func NewRunsSearchStreamResponse(
if !excludeParams {
params := make(fiber.Map, len(r.Params)+1)
for _, p := range r.Params {
params[p.Key] = p.Value
params[p.Key] = p.ValueAny()
}
tags := make(map[string]string, len(r.Tags))
for _, t := range r.Tags {
Expand Down
32 changes: 20 additions & 12 deletions pkg/api/aim2/dao/models/param.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,42 @@
package models

import "fmt"
import (
"fmt"
)

// Param represents model to work with `params` table.
type Param struct {
Key string `gorm:"type:varchar(250);not null;primaryKey"`
Value string `gorm:"type:varchar(500);not null"`
ValueStr *string `gorm:"type:varchar(500)"`
ValueInt *int64 `gorm:"type:bigint"`
ValueFloat *float64 `gorm:"type:float"`
RunID string `gorm:"column:run_uuid;not null;primaryKey;index"`
}

// ValueString returns the value held by this Param as a string
// Value returns the value held by this Param as a string.
func (p Param) ValueString() string {
if p.ValueInt != nil {
switch {
case p.ValueInt != nil:
return fmt.Sprintf("%v", *p.ValueInt)
} else if p.ValueFloat != nil {
case p.ValueFloat != nil:
return fmt.Sprintf("%v", *p.ValueFloat)
} else {
return p.Value
case p.ValueStr != nil:
return *p.ValueStr
default:
return ""
}
}

// ValueAny returns the value held by this Param as any
// ValueAny returns the value held by this Param as any with underlying type.
func (p Param) ValueAny() any {
if p.ValueInt != nil {
switch {
case p.ValueInt != nil:
return *p.ValueInt
} else if p.ValueFloat != nil {
case p.ValueFloat != nil:
return *p.ValueFloat
} else {
return p.Value
case p.ValueStr != nil:
return *p.ValueStr
default:
return nil
}
}
83 changes: 0 additions & 83 deletions pkg/api/aim2/dao/models/param_benchmark.go

This file was deleted.

4 changes: 2 additions & 2 deletions pkg/api/aim2/dao/models/param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)

func TestValueTyped(t *testing.T) {
func TestValueAny(t *testing.T) {
tests := []struct {
name string
param Param
Expand Down Expand Up @@ -36,7 +36,7 @@ func TestValueTyped(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
valueTyped := tt.param.ValueTyped()
valueTyped := tt.param.ValueAny()
assert.Equal(t, tt.want, valueTyped)
assert.IsType(t, tt.want, valueTyped)
})
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/mlflow/api/response/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func NewRunPartialResponse(run *models.Run) *RunPartialResponse {
for n, p := range run.Params {
params[n] = RunParamPartialResponse{
Key: p.Key,
Value: p.ValueTyped(),
Value: p.ValueAny(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/api/mlflow/dao/convertors/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func convertParam(runID, key string, value any) *models.Param {
}
switch v := value.(type) {
case string:
param.Value = v
param.ValueStr = &v
case int64:
param.ValueInt = &v
case float64:
Expand Down
27 changes: 14 additions & 13 deletions pkg/api/mlflow/dao/convertors/log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/api/request"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/common"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

Expand All @@ -19,7 +20,7 @@ func TestConvertLogParamRequestToDBModel_Ok(t *testing.T) {
}
result := ConvertLogParamRequestToDBModel("run_id", &req)
assert.Equal(t, "key", result.Key)
assert.Equal(t, "value", result.Value)
assert.Equal(t, "value", *result.ValueStr)
assert.Equal(t, "run_id", result.RunID)
}

Expand Down Expand Up @@ -62,9 +63,9 @@ func TestConvertLogBatchRequestToDBModel_Ok(t *testing.T) {
},
expectedParams: []models.Param{
{
RunID: "run_id",
Key: "key",
Value: "value",
RunID: "run_id",
Key: "key",
ValueStr: common.GetPointer[string]("value"),
},
},
expectedMetrics: []models.Metric{
Expand Down Expand Up @@ -109,9 +110,9 @@ func TestConvertLogBatchRequestToDBModel_Ok(t *testing.T) {
},
expectedParams: []models.Param{
{
RunID: "run_id",
Key: "key",
Value: "value",
RunID: "run_id",
Key: "key",
ValueStr: common.GetPointer[string]("value"),
},
},
expectedMetrics: []models.Metric{
Expand Down Expand Up @@ -157,9 +158,9 @@ func TestConvertLogBatchRequestToDBModel_Ok(t *testing.T) {
},
expectedParams: []models.Param{
{
RunID: "run_id",
Key: "key",
Value: "value",
RunID: "run_id",
Key: "key",
ValueStr: common.GetPointer[string]("value"),
},
},
expectedMetrics: []models.Metric{
Expand Down Expand Up @@ -204,9 +205,9 @@ func TestConvertLogBatchRequestToDBModel_Ok(t *testing.T) {
},
expectedParams: []models.Param{
{
RunID: "run_id",
Key: "key",
Value: "value",
RunID: "run_id",
Key: "key",
ValueStr: common.GetPointer[string]("value"),
},
},
expectedMetrics: []models.Metric{
Expand Down
33 changes: 25 additions & 8 deletions pkg/api/mlflow/dao/models/param.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,36 @@ import "fmt"
// Param represents model to work with `params` table.
type Param struct {
Key string `gorm:"type:varchar(250);not null;primaryKey"`
Value string `gorm:"type:varchar(500);not null"`
ValueStr *string `gorm:"type:varchar(500)"`
ValueInt *int64 `gorm:"type:bigint"`
ValueFloat *float64 `gorm:"type:float"`
RunID string `gorm:"column:run_uuid;not null;primaryKey;index"`
}

// ValueString returns the value held by this Param as a string
// Value returns the value held by this Param as a string.
func (p Param) ValueString() string {
if p.ValueInt != nil {
return fmt.Sprintf("%v", p.ValueInt)
} else if p.ValueFloat != nil {
return fmt.Sprintf("%v", p.ValueFloat)
} else {
return p.Value
switch {
case p.ValueInt != nil:
return fmt.Sprintf("%v", *p.ValueInt)
case p.ValueFloat != nil:
return fmt.Sprintf("%v", *p.ValueFloat)
case p.ValueStr != nil:
return *p.ValueStr
default:
return ""
}
}

// ValueAny returns the value held by this Param as any with underlying type.
func (p Param) ValueAny() any {
switch {
case p.ValueInt != nil:
return *p.ValueInt
case p.ValueFloat != nil:
return *p.ValueFloat
case p.ValueStr != nil:
return *p.ValueStr
default:
return nil
}
}

0 comments on commit 5ca50e6

Please sign in to comment.