From 4d37e4cd3d7d135b089873fdeb7c5fd0bd3402a4 Mon Sep 17 00:00:00 2001 From: Osama Adam Date: Thu, 4 Apr 2024 23:25:35 +0200 Subject: [PATCH] conditionally formatting the policies endpoint url based on the current keycloak server version --- client.go | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- go.mod | 1 + go.sum | 2 ++ 3 files changed, 87 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index ccff45fc..a93419d4 100644 --- a/client.go +++ b/client.go @@ -17,6 +17,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/pkg/errors" "github.com/segmentio/ksuid" + "golang.org/x/mod/semver" "github.com/Nerzal/gocloak/v13/pkg/jwx" ) @@ -36,6 +37,7 @@ type GoCloak struct { logoutEndpoint string openIDConnect string attackDetection string + version string } } @@ -48,6 +50,53 @@ func makeURL(path ...string) string { return strings.Join(path, urlSeparator) } +// Compares the provided version against the current version of the Keycloak server. +// Current version is fetched from the serverinfo if not already set. +// +// Returns: +// +// -1 if the provided version is lower than the server version +// +// 0 if the provided version is equal to the server version +// +// 1 if the provided version is higher than the server version +func (g *GoCloak) compareVersions(v, token string, ctx context.Context) (int, error) { + curVersion := g.Config.version + if curVersion == "" { + curV, err := g.getServerVersion(ctx, token); + if err != nil { + return 0, err + } + + curVersion = curV + } + + curVersion = "v" + g.Config.version + if (v[0] != 'v') { + v = "v" + v + } + + return semver.Compare(curVersion, v), nil +} + +// Get the server version from the serverinfo endpoint. +// If the version is already set, it will return the cached version. +// Otherwise, it will fetch the version from the serverinfo endpoint and cache it. +func (g *GoCloak) getServerVersion(ctx context.Context, token string) (string, error) { + if g.Config.version != "" { + return g.Config.version, nil + } + + serverInfo, err := g.GetServerInfo(ctx, token) + if err != nil { + return "", err + } + + g.Config.version = *(serverInfo.SystemInfo.Version) + + return g.Config.version, nil +} + // GetRequest returns a request for calling endpoints. func (g *GoCloak) GetRequest(ctx context.Context) *resty.Request { var err HTTPErrorResponse @@ -3485,8 +3534,14 @@ func (g *GoCloak) GetPolicies(ctx context.Context, token, realm, idOfClient stri return nil, errors.Wrap(err, errMessage) } + compResult, err := g.compareVersions("20.0.0", token, ctx) + if err != nil { + return nil, err + } + shouldAddType := compResult != 1 + path := []string{"clients", idOfClient, "authz", "resource-server", "policy"} - if !NilOrEmpty(params.Type) { + if !NilOrEmpty(params.Type) && shouldAddType { path = append(path, *params.Type) } @@ -3511,11 +3566,23 @@ func (g *GoCloak) CreatePolicy(ctx context.Context, token, realm, idOfClient str return nil, errors.New("type of a policy required") } + compResult, err := g.compareVersions("20.0.0", token, ctx) + if err != nil { + return nil, err + } + shouldAddType := compResult != 1 + + path := []string{"clients", idOfClient, "authz", "resource-server", "policy"} + + if shouldAddType { + path = append(path, *policy.Type) + } + var result PolicyRepresentation resp, err := g.GetRequestWithBearerAuth(ctx, token). SetResult(&result). SetBody(policy). - Post(g.getAdminRealmURL(realm, "clients", idOfClient, "authz", "resource-server", "policy", *(policy.Type))) + Post(g.getAdminRealmURL(realm, path...)) if err := checkForError(resp, err, errMessage); err != nil { return nil, err @@ -3532,9 +3599,23 @@ func (g *GoCloak) UpdatePolicy(ctx context.Context, token, realm, idOfClient str return errors.New("ID of a policy required") } + compResult, err := g.compareVersions("20.0.0", token, ctx) + if err != nil { + return err + } + shouldAddType := compResult != 1 + + path := []string{"clients", idOfClient, "authz", "resource-server", "policy"} + + if shouldAddType { + path = append(path, *policy.Type) + } + + path = append(path, *(policy.ID)) + resp, err := g.GetRequestWithBearerAuth(ctx, token). SetBody(policy). - Put(g.getAdminRealmURL(realm, "clients", idOfClient, "authz", "resource-server", "policy", *(policy.Type), *(policy.ID))) + Put(g.getAdminRealmURL(realm, path...)) return checkForError(resp, err, errMessage) } diff --git a/go.mod b/go.mod index cde01123..0a817032 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/segmentio/ksuid v1.0.4 github.com/stretchr/testify v1.8.2 golang.org/x/crypto v0.17.0 + golang.org/x/mod v0.16.0 ) require ( diff --git a/go.sum b/go.sum index 86853e21..bd55e2d6 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=