diff --git a/main.go b/main.go index 77aa25bc..ff246121 100644 --- a/main.go +++ b/main.go @@ -200,7 +200,7 @@ func infoers() map[string]productinfo.ProductInfoer { switch p { case recommender.Ec2: - infoer, err = ec2.NewEc2Infoer(ec2.NewPricing(ec2.NewConfig()), viper.GetString(prometheusAddressFlag), viper.GetString(prometheusQueryFlag)) + infoer, err = ec2.NewEc2Infoer(viper.GetString(prometheusAddressFlag), viper.GetString(prometheusQueryFlag)) case recommender.Gce: infoer, err = gce.NewGceInfoer(viper.GetString(gceApiKeyFlag), viper.GetString(gceProjectIdFlag)) case recommender.Azure: diff --git a/productinfo/ec2/productinfo_ec2.go b/productinfo/ec2/productinfo_ec2.go index de1b684f..3158067d 100644 --- a/productinfo/ec2/productinfo_ec2.go +++ b/productinfo/ec2/productinfo_ec2.go @@ -37,26 +37,34 @@ type PricingSource interface { // Ec2Infoer encapsulates the data and operations needed to access external resources type Ec2Infoer struct { - pricing PricingSource - session *session.Session - prometheus v1.API - promQuery string + pricingSvc PricingSource + session *session.Session + prometheus v1.API + promQuery string + ec2Describer func(region string) Ec2Describer +} + +// Ec2Describer interface for operations describing EC2 artifacts. (a subset of the Ec2 cli operations iused by this app) +type Ec2Describer interface { + DescribeAvailabilityZones(input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) + DescribeSpotPriceHistoryPages(input *ec2.DescribeSpotPriceHistoryInput, fn func(*ec2.DescribeSpotPriceHistoryOutput, bool) bool) error } // NewEc2Infoer creates a new instance of the infoer -func NewEc2Infoer(pricing PricingSource, prom string, pq string) (*Ec2Infoer, error) { +func NewEc2Infoer(promAddr string, pq string) (*Ec2Infoer, error) { s, err := session.NewSession() + if err != nil { log.WithError(err).Error("Error creating AWS session") return nil, err } var promApi v1.API - if prom == "" { + if promAddr == "" { log.Warn("Prometheus API address is not set, fallback to direct API access.") promApi = nil } else { promClient, err := api.NewClient(api.Config{ - Address: prom, + Address: promAddr, }) if err != nil { log.WithError(err).Warn("Error creating Prometheus client, fallback to direct API access.") @@ -66,29 +74,21 @@ func NewEc2Infoer(pricing PricingSource, prom string, pq string) (*Ec2Infoer, er } } + const defaultPricingRegion = "us-east-1" return &Ec2Infoer{ - pricing: pricing, + pricingSvc: pricing.New(s, aws.NewConfig().WithRegion(defaultPricingRegion)), session: s, prometheus: promApi, promQuery: pq, + ec2Describer: func(region string) Ec2Describer { + return ec2.New(s, aws.NewConfig().WithRegion(region)) + }, }, nil } -// NewPricing creates a new PricingSource with the given configuration -func NewPricing(cfg *aws.Config) PricingSource { - - s, err := session.NewSession(cfg) - if err != nil { - log.Fatalf("could not create session. error: [%s]", err.Error()) - } - - pr := pricing.New(s, cfg) - return pr -} - // NewConfig creates a new Config instance and returns a pointer to it func NewConfig() *aws.Config { - return &aws.Config{Region: aws.String("us-east-1")} + return aws.NewConfig() } // Initialize is not needed on EC2 because price info is changing frequently @@ -99,7 +99,7 @@ func (e *Ec2Infoer) Initialize() (map[string]map[string]productinfo.Price, error // GetAttributeValues gets the AttributeValues for the given attribute name // Delegates to the underlying PricingSource instance and unifies (transforms) the response func (e *Ec2Infoer) GetAttributeValues(attribute string) (productinfo.AttrValues, error) { - apiValues, err := e.pricing.GetAttributeValues(e.newAttributeValuesInput(attribute)) + apiValues, err := e.pricingSvc.GetAttributeValues(e.newAttributeValuesInput(attribute)) if err != nil { return nil, err } @@ -126,7 +126,7 @@ func (e *Ec2Infoer) GetProducts(regionId string, attrKey string, attrValue produ var vms []productinfo.VmInfo log.Debugf("Getting available instance types from AWS API. [region=%s, %s=%s]", regionId, attrKey, attrValue.StrValue) - products, err := e.pricing.GetProducts(e.newGetProductsInput(regionId, attrKey, attrValue)) + products, err := e.pricingSvc.GetProducts(e.newGetProductsInput(regionId, attrKey, attrValue)) if err != nil { return nil, err @@ -328,14 +328,14 @@ func (e *Ec2Infoer) GetRegions() (map[string]string, error) { // GetZones returns the availability zones in a region func (e *Ec2Infoer) GetZones(region string) ([]string, error) { + var zones []string - ec2Svc := ec2.New(e.session, &aws.Config{Region: aws.String(region)}) - azs, err := ec2Svc.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{}) + azs, err := e.ec2Describer(region).DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{}) if err != nil { return nil, err } for _, az := range azs.AvailabilityZones { - if *az.State == "available" { + if *az.State == ec2.AvailabilityZoneStateAvailable { zones = append(zones, *az.ZoneName) } } @@ -377,8 +377,7 @@ func (e *Ec2Infoer) getSpotPricesFromPrometheus(region string) (map[string]produ func (e *Ec2Infoer) getCurrentSpotPrices(region string) (map[string]productinfo.SpotPriceInfo, error) { priceInfo := make(map[string]productinfo.SpotPriceInfo) - ec2Svc := ec2.New(e.session, &aws.Config{Region: aws.String(region)}) - err := ec2Svc.DescribeSpotPriceHistoryPages(&ec2.DescribeSpotPriceHistoryInput{ + err := e.ec2Describer(region).DescribeSpotPriceHistoryPages(&ec2.DescribeSpotPriceHistoryInput{ StartTime: aws.Time(time.Now()), ProductDescriptions: []*string{aws.String("Linux/UNIX")}, }, func(history *ec2.DescribeSpotPriceHistoryOutput, lastPage bool) bool { diff --git a/productinfo/ec2/productinfo_ec2_test.go b/productinfo/ec2/productinfo_ec2_test.go index d0ea58de..e79cfe1f 100644 --- a/productinfo/ec2/productinfo_ec2_test.go +++ b/productinfo/ec2/productinfo_ec2_test.go @@ -6,16 +6,18 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/pricing" "github.com/banzaicloud/telescopes/productinfo" "github.com/stretchr/testify/assert" ) -type DummyPricingSource struct { +//testStruct helps to mock external calls +type testStruct struct { TcId int } -func (dps *DummyPricingSource) GetAttributeValues(input *pricing.GetAttributeValuesInput) (*pricing.GetAttributeValuesOutput, error) { +func (dps *testStruct) GetAttributeValues(input *pricing.GetAttributeValuesInput) (*pricing.GetAttributeValuesOutput, error) { // example json sequence //{ @@ -63,7 +65,7 @@ func (dps *DummyPricingSource) GetAttributeValues(input *pricing.GetAttributeVal return nil, nil } -func (dps *DummyPricingSource) GetProducts(input *pricing.GetProductsInput) (*pricing.GetProductsOutput, error) { +func (dps *testStruct) GetProducts(input *pricing.GetProductsInput) (*pricing.GetProductsOutput, error) { switch dps.TcId { case 4: return &pricing.GetProductsOutput{ @@ -71,7 +73,7 @@ func (dps *DummyPricingSource) GetProducts(input *pricing.GetProductsInput) (*pr { "product": map[string]interface{}{ "attributes": map[string]interface{}{ - "instanceType": "db.t2.small", + "instanceType": ec2.InstanceTypeT2Small, Cpu: "1", Memory: "2", "networkPerformance": "Low to Moderate", @@ -94,7 +96,7 @@ func (dps *DummyPricingSource) GetProducts(input *pricing.GetProductsInput) (*pr { "product": map[string]interface{}{ "attributes": map[string]interface{}{ - "instanceType": "db.t2.small", + "instanceType": ec2.InstanceTypeT2Small, Cpu: "1", Memory: "2", }}, @@ -113,7 +115,7 @@ func (dps *DummyPricingSource) GetProducts(input *pricing.GetProductsInput) (*pr { "product": map[string]interface{}{ "attributes": map[string]interface{}{ - "instanceType": "db.t2.small", + "instanceType": ec2.InstanceTypeT2Small, Cpu: "1", }}}, }, @@ -124,7 +126,7 @@ func (dps *DummyPricingSource) GetProducts(input *pricing.GetProductsInput) (*pr { "product": map[string]interface{}{ "attributes": map[string]interface{}{ - "instanceType": "db.t2.small", + "instanceType": ec2.InstanceTypeT2Small, }}}, }, }, nil @@ -143,10 +145,67 @@ func (dps *DummyPricingSource) GetProducts(input *pricing.GetProductsInput) (*pr } // strPointer gets the pointer to the passed string -func (dps *DummyPricingSource) strPointer(str string) *string { +func (dps *testStruct) strPointer(str string) *string { return &str } +func (dps *testStruct) DescribeAvailabilityZones(input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) { + if dps.TcId == 10 { + return nil, errors.New("could not get information about zones") + } + return &ec2.DescribeAvailabilityZonesOutput{ + AvailabilityZones: []*ec2.AvailabilityZone{ + { + State: aws.String(ec2.AvailabilityZoneStateAvailable), + RegionName: aws.String("eu-central-1"), + ZoneName: aws.String("eu-central-1a"), + }, + { + State: aws.String("available"), + RegionName: aws.String("eu-central-1"), + ZoneName: aws.String("eu-central-1b"), + }, + }, + }, nil +} + +func (dps *testStruct) DescribeSpotPriceHistoryPages(input *ec2.DescribeSpotPriceHistoryInput, fn func(*ec2.DescribeSpotPriceHistoryOutput, bool) bool) error { + if dps.TcId == 11 { + return errors.New("invalid") + } + return nil +} + +func TestNewEc2Infoer(t *testing.T) { + tests := []struct { + name string + prom string + check func(info *Ec2Infoer, err error) + }{ + { + name: "create Ec2Infoer - Prometheus API address is not set", + prom: "", + check: func(info *Ec2Infoer, err error) { + assert.Nil(t, err, "the error should be nil") + assert.NotNil(t, info, "the Ec2Infoer should not be nil") + }, + }, + { + name: "create Ec2Infoer - Prometheus API address is set", + prom: "PromAPIAddress", + check: func(info *Ec2Infoer, err error) { + assert.Nil(t, err, "the error should be nil") + assert.NotNil(t, info, "the Ec2Infoer should not be nil") + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.check(NewEc2Infoer(test.prom, "")) + }) + } +} + func TestEc2Infoer_GetAttributeValues(t *testing.T) { tests := []struct { name string @@ -156,7 +215,7 @@ func TestEc2Infoer_GetAttributeValues(t *testing.T) { }{ { name: "successfully retrieve attributes", - pricingService: &DummyPricingSource{TcId: 1}, + pricingService: &testStruct{TcId: 1}, check: func(values productinfo.AttrValues, err error) { assert.Equal(t, 3, len(values), "invalid number of values returned") assert.Nil(t, err, "should not get error") @@ -164,7 +223,7 @@ func TestEc2Infoer_GetAttributeValues(t *testing.T) { }, { name: "error - invalid values zeroed out", - pricingService: &DummyPricingSource{TcId: 2}, + pricingService: &testStruct{TcId: 2}, check: func(values productinfo.AttrValues, err error) { assert.Equal(t, values[0].StrValue, "invalid float 256 GiB", "the invalid value is not the first element") assert.Equal(t, values[0].Value, float64(0), "the invalid value is not zeroed out") @@ -173,7 +232,7 @@ func TestEc2Infoer_GetAttributeValues(t *testing.T) { }, { name: "error - error when retrieving values", - pricingService: &DummyPricingSource{TcId: 3}, + pricingService: &testStruct{TcId: 3}, check: func(values productinfo.AttrValues, err error) { assert.Equal(t, "failed to retrieve values", err.Error()) }, @@ -181,7 +240,9 @@ func TestEc2Infoer_GetAttributeValues(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - productInfoer, err := NewEc2Infoer(test.pricingService, "", "") + productInfoer, err := NewEc2Infoer("", "") + // override pricingSvc + productInfoer.pricingSvc = test.pricingService if err != nil { t.Fatalf("failed to create productinfoer; [%s]", err.Error()) } @@ -192,6 +253,31 @@ func TestEc2Infoer_GetAttributeValues(t *testing.T) { } } +func TestEc2Infoer_GetRegions(t *testing.T) { + tests := []struct { + name string + check func(regionId map[string]string, err error) + }{ + { + name: "receive all regions", + check: func(regionId map[string]string, err error) { + assert.Contains(t, regionId, "us-west-1") + assert.Nil(t, err, "the error should be nil") + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + productInfoer, err := NewEc2Infoer("", "") + if err != nil { + t.Fatalf("failed to create productinfoer; [%s]", err.Error()) + } + regions, err := productInfoer.GetRegions() + test.check(regions, err) + }) + } +} + func TestEc2Infoer_GetProducts(t *testing.T) { tests := []struct { name string @@ -206,10 +292,10 @@ func TestEc2Infoer_GetProducts(t *testing.T) { regionId: "eu-central-1", attrKey: Cpu, attrValue: productinfo.AttrValue{Value: float64(2), StrValue: productinfo.Cpu}, - pricingService: &DummyPricingSource{TcId: 4}, + pricingService: &testStruct{TcId: 4}, check: func(vm []productinfo.VmInfo, err error) { assert.Nil(t, err, "the error should be nil") - assert.Equal(t, []productinfo.VmInfo{{Type: "db.t2.small", OnDemandPrice: 5, SpotPrice: productinfo.SpotPriceInfo(nil), Cpus: 1, Mem: 2, Gpus: 0, NtwPerf: "Low to Moderate"}}, vm) + assert.Equal(t, []productinfo.VmInfo{{Type: "t2.small", OnDemandPrice: 5, SpotPrice: productinfo.SpotPriceInfo(nil), Cpus: 1, Mem: 2, Gpus: 0, NtwPerf: "Low to Moderate"}}, vm) }, }, { @@ -217,7 +303,7 @@ func TestEc2Infoer_GetProducts(t *testing.T) { regionId: "eu-central-1", attrKey: Cpu, attrValue: productinfo.AttrValue{Value: float64(2), StrValue: productinfo.Cpu}, - pricingService: &DummyPricingSource{TcId: 5}, + pricingService: &testStruct{TcId: 5}, check: func(vm []productinfo.VmInfo, err error) { assert.EqualError(t, err, "failed to retrieve values") assert.Nil(t, vm, "the vm should be nil") @@ -228,7 +314,7 @@ func TestEc2Infoer_GetProducts(t *testing.T) { regionId: "eu-central-1", attrKey: Cpu, attrValue: productinfo.AttrValue{Value: float64(2), StrValue: productinfo.Cpu}, - pricingService: &DummyPricingSource{TcId: 6}, + pricingService: &testStruct{TcId: 6}, check: func(vm []productinfo.VmInfo, err error) { assert.Nil(t, err, "the error should be nil") assert.Nil(t, vm, "the vm should be nil") @@ -239,7 +325,7 @@ func TestEc2Infoer_GetProducts(t *testing.T) { regionId: "eu-central-1", attrKey: Cpu, attrValue: productinfo.AttrValue{Value: float64(2), StrValue: productinfo.Cpu}, - pricingService: &DummyPricingSource{TcId: 7}, + pricingService: &testStruct{TcId: 7}, check: func(vm []productinfo.VmInfo, err error) { assert.Nil(t, err, "the error should be nil") assert.Nil(t, vm, "the vm should be nil") @@ -250,7 +336,7 @@ func TestEc2Infoer_GetProducts(t *testing.T) { regionId: "eu-central-1", attrKey: Cpu, attrValue: productinfo.AttrValue{Value: float64(2), StrValue: productinfo.Cpu}, - pricingService: &DummyPricingSource{TcId: 8}, + pricingService: &testStruct{TcId: 8}, check: func(vm []productinfo.VmInfo, err error) { assert.Nil(t, err, "the error should be nil") assert.Nil(t, vm, "the vm should be nil") @@ -261,7 +347,7 @@ func TestEc2Infoer_GetProducts(t *testing.T) { regionId: "eu-central-1", attrKey: Cpu, attrValue: productinfo.AttrValue{Value: float64(2), StrValue: productinfo.Cpu}, - pricingService: &DummyPricingSource{TcId: 9}, + pricingService: &testStruct{TcId: 9}, check: func(vm []productinfo.VmInfo, err error) { assert.Nil(t, err, "the error should be nil") assert.Nil(t, vm, "the vm should be nil") @@ -270,7 +356,9 @@ func TestEc2Infoer_GetProducts(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - productInfoer, err := NewEc2Infoer(test.pricingService, "", "") + productInfoer, err := NewEc2Infoer("", "") + // override pricingSvc + productInfoer.pricingSvc = test.pricingService if err != nil { t.Fatalf("failed to create productinfoer; [%s]", err.Error()) } @@ -290,7 +378,7 @@ func TestEc2Infoer_GetRegion(t *testing.T) { { name: "returns data of a known region", id: "eu-west-3", - pricingService: &DummyPricingSource{}, + pricingService: &testStruct{}, check: func(region *endpoints.Region) { assert.Equal(t, region.Description(), "EU (Paris)") assert.Equal(t, region.ID(), "eu-west-3") @@ -299,7 +387,7 @@ func TestEc2Infoer_GetRegion(t *testing.T) { { name: "get an unknown region", id: "unknownRegion", - pricingService: &DummyPricingSource{}, + pricingService: &testStruct{}, check: func(region *endpoints.Region) { assert.Nil(t, region, "the region should be nil") }, @@ -307,7 +395,7 @@ func TestEc2Infoer_GetRegion(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - productInfoer, err := NewEc2Infoer(test.pricingService, "", "") + productInfoer, err := NewEc2Infoer("", "") if err != nil { t.Fatalf("failed to create productinfoer; [%s]", err.Error()) } @@ -317,42 +405,134 @@ func TestEc2Infoer_GetRegion(t *testing.T) { } } -func TestNewConfig(t *testing.T) { +func TestEc2Infoer_getCurrentSpotPrices(t *testing.T) { tests := []struct { - name string - check func(config *aws.Config) + name string + region string + ec2CliMock func(region string) Ec2Describer + check func(data map[string]productinfo.SpotPriceInfo, err error) }{ { - name: "success - create a new config instance", - check: func(config *aws.Config) { - assert.NotNil(t, config, "the config should not be nil") + name: "successful - get current spot prices", + region: "dummyRegion", + ec2CliMock: func(region string) Ec2Describer { + return &testStruct{} + }, + check: func(data map[string]productinfo.SpotPriceInfo, err error) { + assert.Equal(t, map[string]productinfo.SpotPriceInfo{}, data) + assert.Nil(t, err, "the error should be nil") + }, + }, + { + name: "error - could not get spot price history pages", + region: "dummyRegion", + ec2CliMock: func(region string) Ec2Describer { + return &testStruct{11} + }, + check: func(data map[string]productinfo.SpotPriceInfo, err error) { + assert.Nil(t, data, "the data should be nil") + assert.EqualError(t, err, "invalid") }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - test.check(NewConfig()) + productInfoer, err := NewEc2Infoer("", "") + // override ec2cli + productInfoer.ec2Describer = test.ec2CliMock + if err != nil { + t.Fatalf("failed to create productinfoer; [%s]", err.Error()) + } + + test.check(productInfoer.getCurrentSpotPrices(test.region)) }) } } -func TestNewPricing(t *testing.T) { +func TestEc2Infoer_GetCurrentPrices(t *testing.T) { tests := []struct { - name string - cfg *aws.Config - check func(source PricingSource) + name string + region string + ec2CliMock func(region string) Ec2Describer + check func(price map[string]productinfo.Price, err error) + }{ + { + name: "success - known region", + region: "eu-central-1", + ec2CliMock: func(region string) Ec2Describer { + return &testStruct{} + }, + check: func(price map[string]productinfo.Price, err error) { + assert.Nil(t, err, "the error should be nil") + assert.Equal(t, 0, len(price)) + }, + }, + { + name: "error - unknown region", + region: "dummyRegion", + ec2CliMock: func(region string) Ec2Describer { + return &testStruct{11} + }, + check: func(price map[string]productinfo.Price, err error) { + assert.Nil(t, price, "the zones should be nil") + assert.EqualError(t, err, "invalid") + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + productInfoer, err := NewEc2Infoer("PromAPIAddress", "") + // override ec2cli + productInfoer.ec2Describer = test.ec2CliMock + if err != nil { + t.Fatalf("failed to create productinfoer; [%s]", err.Error()) + } + + test.check(productInfoer.GetCurrentPrices(test.region)) + }) + } +} + +func TestEc2Infoer_GetZones(t *testing.T) { + tests := []struct { + name string + region string + ec2CliMock func(region string) Ec2Describer + check func(zones []string, err error) }{ { - name: "success - create a new PricingSource", - cfg: NewConfig(), - check: func(source PricingSource) { - assert.NotNil(t, source, "the source should not be nil") + name: "success - known region", + region: "eu-central-1", + ec2CliMock: func(region string) Ec2Describer { + return &testStruct{} + }, + check: func(zones []string, err error) { + assert.Nil(t, err, "the error should be nil") + assert.Equal(t, []string{"eu-central-1a", "eu-central-1b"}, zones) + assert.Equal(t, 2, len(zones)) + }, + }, + { + name: "error - unknown region", + region: "dummyRegion", + ec2CliMock: func(region string) Ec2Describer { + return &testStruct{TcId: 10} + }, + check: func(zones []string, err error) { + assert.Nil(t, zones, "the zones should be nil") + assert.EqualError(t, err, "could not get information about zones") }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - test.check(NewPricing(test.cfg)) + productInfoer, err := NewEc2Infoer("PromAPIAddress", "") + // override ec2cli + productInfoer.ec2Describer = test.ec2CliMock + if err != nil { + t.Fatalf("failed to create productinfoer; [%s]", err.Error()) + } + test.check(productInfoer.GetZones(test.region)) }) } } @@ -377,7 +557,7 @@ func TestPriceData_GetDataForKey(t *testing.T) { awsData: aws.JSONValue{ "product": map[string]interface{}{ "attributes": map[string]interface{}{ - "instanceType": "db.t2.small", + "instanceType": ec2.InstanceTypeT2Small, Cpu: "1", Memory: "2", "gpu": "5", @@ -404,7 +584,7 @@ func TestPriceData_GetDataForKey(t *testing.T) { price: data, check: func(s string, err error) { assert.Nil(t, err, "the error should be nil") - assert.Equal(t, "db.t2.small", s) + assert.Equal(t, "t2.small", s) }, }, { @@ -520,7 +700,7 @@ func TestPriceData_GetOnDemandPrice(t *testing.T) { awsData: aws.JSONValue{ "product": map[string]interface{}{ "attributes": map[string]interface{}{ - "instanceType": "db.t2.small", + "instanceType": ec2.InstanceTypeT2Small, Cpu: "1", Memory: "2", "gpu": "5",