diff --git a/pkg/provider/post_test.go b/pkg/provider/post_test.go index 2f0dc8a..4cce5cc 100644 --- a/pkg/provider/post_test.go +++ b/pkg/provider/post_test.go @@ -251,7 +251,7 @@ func TestSSO_verifyPostSignature(t *testing.T) { t.Run(tt.name, func(t *testing.T) { spConfig := &serviceprovider.Config{Metadata: []byte(tt.args.spMetadata)} - sp, err := serviceprovider.NewServiceProvider("test", spConfig, "") + sp, err := serviceprovider.NewServiceProvider("test", spConfig, func(s string) string { return "" }) if err != nil { t.Errorf("verifyPostSignature() got = %v, wanted to create service provider instance", err) return diff --git a/pkg/provider/redirect_test.go b/pkg/provider/redirect_test.go index 87c689c..30f21b4 100644 --- a/pkg/provider/redirect_test.go +++ b/pkg/provider/redirect_test.go @@ -283,7 +283,7 @@ func TestRedirect_verifyRedirectSignature(t *testing.T) { t.Run(tt.name, func(t *testing.T) { spConfig := &serviceprovider.Config{Metadata: []byte(tt.args.spMetadata)} - sp, err := serviceprovider.NewServiceProvider("test", spConfig, "") + sp, err := serviceprovider.NewServiceProvider("test", spConfig, func(s string) string { return "" }) if err != nil { t.Errorf("verifyRedirectSignature() got = %v, wanted to create service provider instance", err) return diff --git a/pkg/provider/serviceprovider/serviceprovider.go b/pkg/provider/serviceprovider/serviceprovider.go index 83d9888..bd1697e 100644 --- a/pkg/provider/serviceprovider/serviceprovider.go +++ b/pkg/provider/serviceprovider/serviceprovider.go @@ -21,7 +21,7 @@ type ServiceProvider struct { ID string Metadata *md.EntityDescriptorType signerPublicKey interface{} - defaultLoginURL string + loginURL func(string) string } func (sp *ServiceProvider) GetEntityID() string { @@ -29,10 +29,10 @@ func (sp *ServiceProvider) GetEntityID() string { } func (sp *ServiceProvider) LoginURL(id string) string { - return sp.defaultLoginURL + id + return sp.loginURL(id) } -func NewServiceProvider(id string, config *Config, defaultLoginURL string) (*ServiceProvider, error) { +func NewServiceProvider(id string, config *Config, loginURL func(string) string) (*ServiceProvider, error) { metadata, err := xml.ParseMetadataXmlIntoStruct(config.Metadata) if err != nil { return nil, err @@ -54,7 +54,7 @@ func NewServiceProvider(id string, config *Config, defaultLoginURL string) (*Ser ID: id, Metadata: metadata, signerPublicKey: signerPublicKey, - defaultLoginURL: defaultLoginURL, + loginURL: loginURL, }, nil } diff --git a/pkg/provider/sso_test.go b/pkg/provider/sso_test.go index 2b3d599..8417857 100644 --- a/pkg/provider/sso_test.go +++ b/pkg/provider/sso_test.go @@ -628,7 +628,7 @@ func TestSSO_ssoHandleFunc(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { endpoint := NewEndpoint(tt.args.metadataEndpoint) - spInst, err := serviceprovider.NewServiceProvider(tt.args.sp.entityID, &serviceprovider.Config{Metadata: []byte(tt.args.sp.metadata)}, "") + spInst, err := serviceprovider.NewServiceProvider(tt.args.sp.entityID, &serviceprovider.Config{Metadata: []byte(tt.args.sp.metadata)}, func(s string) string { return "" }) if err != nil { t.Errorf("error while creating service provider") return