Skip to content

Commit

Permalink
Merge pull request #4 from piotrkowalczuk/feature/refresh-token
Browse files Browse the repository at this point in the history
refresh token
  • Loading branch information
piotrkowalczuk authored Mar 19, 2017
2 parents 5a06369 + ae5ff6c commit 8452a72
Show file tree
Hide file tree
Showing 17 changed files with 518 additions and 292 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ RUN apk --no-cache add curl

COPY ./bin /usr/local/bin/
COPY ./scripts/docker-entrypoint.sh /
COPY ./scripts/docker-healthcheck.sh /

EXPOSE 8080 8081

Expand Down
12 changes: 0 additions & 12 deletions internal/cache/cache.go

This file was deleted.

4 changes: 3 additions & 1 deletion mnemosyne.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ func (m *mnemosyne) Start(ctx context.Context, subjectID, subjectClient string,

// Abandon implements Mnemosyne interface.
func (m *mnemosyne) Abandon(ctx context.Context, token string) error {
_, err := m.client.Abandon(ctx, &mnemosynerpc.AbandonRequest{AccessToken: token})
_, err := m.client.Abandon(ctx, &mnemosynerpc.AbandonRequest{
AccessToken: token,
})

return err
}
Expand Down
1 change: 1 addition & 0 deletions mnemosyned/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ func (d *Daemon) Close() (err error) {
if d.debugListener != nil {
err = d.debugListener.Close()
}

return
}

Expand Down
19 changes: 12 additions & 7 deletions mnemosyned/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ func (h *handler) start(ctx context.Context, req *mnemosynerpc.StartRequest) (*m
}

h.logger = log.NewContext(h.logger).With("subject_id", req.Session.SubjectId)

ses, err := h.storage.Start(ctx, req.Session.AccessToken, req.Session.SubjectId, req.Session.SubjectClient, req.Session.Bag)
ses, err := h.storage.Start(ctx,
req.Session.AccessToken,
req.Session.RefreshToken,
req.Session.SubjectId,
req.Session.SubjectClient,
req.Session.Bag,
)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -149,12 +154,12 @@ func (h *handler) delete(ctx context.Context, req *mnemosynerpc.DeleteRequest) (
var (
expireAtFrom, expireAtTo *time.Time
)
if req.AccessToken == "" && req.ExpireAtFrom == nil && req.ExpireAtTo == nil {

if req.AccessToken == "" && req.RefreshToken == "" && req.ExpireAtFrom == nil && req.ExpireAtTo == nil {
return 0, grpc.Errorf(codes.InvalidArgument, "none of expected arguments was provided")
}
if req.AccessToken != "" {
h.logger = log.NewContext(h.logger).With("access_token", req.AccessToken)
}
h.logger = log.NewContext(h.logger).With("access_token", req.AccessToken, "refresh_token", req.RefreshToken)

if req.ExpireAtFrom != nil {
eaf, err := ptypes.Timestamp(req.ExpireAtFrom)
if err != nil {
Expand All @@ -172,7 +177,7 @@ func (h *handler) delete(ctx context.Context, req *mnemosynerpc.DeleteRequest) (
h.logger = log.NewContext(h.logger).With("expire_at_to", eat)
}

affected, err := h.storage.Delete(ctx, req.AccessToken, expireAtFrom, expireAtTo)
affected, err := h.storage.Delete(ctx, req.AccessToken, req.RefreshToken, expireAtFrom, expireAtTo)
if err != nil {
return 0, err
}
Expand Down
28 changes: 14 additions & 14 deletions mnemosyned/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,22 @@ func (_m *mockStorage) TearDown() error {
return r0
}

// Start provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4
func (_m *mockStorage) Start(_a0 context.Context, _a1 string, _a2 string, _a3 string, _a4 map[string]string) (*mnemosynerpc.Session, error) {
ret := _m.Called(_a0, _a1, _a2, _a3, _a4)
// Start provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5
func (_m *mockStorage) Start(_a0 context.Context, _a1 string, _a2 string, _a3 string, _a4 string, _a5 map[string]string) (*mnemosynerpc.Session, error) {
ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5)

var r0 *mnemosynerpc.Session
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, map[string]string) *mnemosynerpc.Session); ok {
r0 = rf(_a0, _a1, _a2, _a3, _a4)
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, map[string]string) *mnemosynerpc.Session); ok {
r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*mnemosynerpc.Session)
}
}

var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, map[string]string) error); ok {
r1 = rf(_a0, _a1, _a2, _a3, _a4)
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string, map[string]string) error); ok {
r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5)
} else {
r1 = ret.Error(1)
}
Expand Down Expand Up @@ -178,20 +178,20 @@ func (_m *mockStorage) Exists(_a0 context.Context, _a1 string) (bool, error) {
return r0, r1
}

// Delete provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *mockStorage) Delete(_a0 context.Context, _a1 string, _a2 *time.Time, _a3 *time.Time) (int64, error) {
ret := _m.Called(_a0, _a1, _a2, _a3)
// Delete provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4
func (_m *mockStorage) Delete(_a0 context.Context, _a1 string, _a2 string, _a3 *time.Time, _a4 *time.Time) (int64, error) {
ret := _m.Called(_a0, _a1, _a2, _a3, _a4)

var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, string, *time.Time, *time.Time) int64); ok {
r0 = rf(_a0, _a1, _a2, _a3)
if rf, ok := ret.Get(0).(func(context.Context, string, string, *time.Time, *time.Time) int64); ok {
r0 = rf(_a0, _a1, _a2, _a3, _a4)
} else {
r0 = ret.Get(0).(int64)
}

var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, *time.Time, *time.Time) error); ok {
r1 = rf(_a0, _a1, _a2, _a3)
if rf, ok := ret.Get(1).(func(context.Context, string, string, *time.Time, *time.Time) error); ok {
r1 = rf(_a0, _a1, _a2, _a3, _a4)
} else {
r1 = ret.Error(1)
}
Expand Down
117 changes: 69 additions & 48 deletions mnemosyned/postgres_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (

"context"

"bytes"

"github.com/golang/protobuf/ptypes"
"github.com/piotrkowalczuk/mnemosyne/mnemosynerpc"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -29,22 +31,23 @@ func newPostgresStorage(tb, schema string, db *sql.DB, m *monitoring, ttl time.D
schema: schema,
ttl: ttl,
monitor: m,
querySave: `INSERT INTO ` + schema + ` .` + tb + ` (access_token, subject_id, subject_client, bag)
VALUES ($1, $2, $3, $4)
querySave: `INSERT INTO ` + schema + ` .` + tb + ` (access_token, refresh_token, subject_id, subject_client, bag)
VALUES ($1, $2, $3, $4, $5)
RETURNING expire_at`,
queryGet: fmt.Sprintf(`UPDATE `+schema+` .`+tb+`
SET expire_at = (NOW() + '%d seconds')
WHERE access_token = $1
RETURNING subject_id, subject_client, bag, expire_at`, int64(ttl.Seconds())),
RETURNING refresh_token, subject_id, subject_client, bag, expire_at`, int64(ttl.Seconds())),
queryExists: `SELECT EXISTS(SELECT 1 FROM ` + schema + ` .` + tb + ` WHERE access_token = $1)`,
queryAbandon: `DELETE FROM ` + schema + ` .` + tb + ` WHERE access_token = $1`,
}
}

// Start implements storage interface.
func (ps *postgresStorage) Start(ctx context.Context, at, sid, sc string, b map[string]string) (*mnemosynerpc.Session, error) {
func (ps *postgresStorage) Start(ctx context.Context, accessToken, refreshToken, sid, sc string, b map[string]string) (*mnemosynerpc.Session, error) {
ent := &sessionEntity{
AccessToken: at,
AccessToken: accessToken,
RefreshToken: refreshToken,
SubjectID: sid,
SubjectClient: sc,
Bag: bag(b),
Expand All @@ -57,17 +60,18 @@ func (ps *postgresStorage) Start(ctx context.Context, at, sid, sc string, b map[
return ent.session()
}

func (ps *postgresStorage) save(ctx context.Context, entity *sessionEntity) (err error) {
func (ps *postgresStorage) save(ctx context.Context, ent *sessionEntity) (err error) {
labels := prometheus.Labels{"query": "save"}
err = ps.db.QueryRowContext(
ctx,
ps.querySave,
entity.AccessToken,
entity.SubjectID,
entity.SubjectClient,
entity.Bag,
ent.AccessToken,
ent.RefreshToken,
ent.SubjectID,
ent.SubjectClient,
ent.Bag,
).Scan(
&entity.ExpireAt,
&ent.ExpireAt,
)
ps.incQueries(labels)
if err != nil {
Expand All @@ -82,6 +86,7 @@ func (ps *postgresStorage) Get(ctx context.Context, accessToken string) (*mnemos
labels := prometheus.Labels{"query": "get"}

err := ps.db.QueryRowContext(ctx, ps.queryGet, accessToken).Scan(
&entity.RefreshToken,
&entity.SubjectID,
&entity.SubjectClient,
&entity.Bag,
Expand All @@ -101,7 +106,8 @@ func (ps *postgresStorage) Get(ctx context.Context, accessToken string) (*mnemos
return nil, err
}
return &mnemosynerpc.Session{
AccessToken: string(accessToken),
AccessToken: accessToken,
RefreshToken: entity.RefreshToken,
SubjectId: entity.SubjectID,
SubjectClient: entity.SubjectClient,
Bag: entity.Bag,
Expand All @@ -116,7 +122,7 @@ func (ps *postgresStorage) List(ctx context.Context, offset, limit int64, expire
}

args := []interface{}{offset, limit}
query := "SELECT access_token, subject_id, subject_client, bag, expire_at FROM " + ps.schema + "." + ps.table + " "
query := "SELECT access_token, refresh_token, subject_id, subject_client, bag, expire_at FROM " + ps.schema + "." + ps.table + " "
if expiredAtFrom != nil || expiredAtTo != nil {
query += " WHERE "
}
Expand Down Expand Up @@ -145,29 +151,31 @@ func (ps *postgresStorage) List(ctx context.Context, offset, limit int64, expire

sessions := make([]*mnemosynerpc.Session, 0, limit)
for rows.Next() {
var entity sessionEntity
var ent sessionEntity

err = rows.Scan(
&entity.AccessToken,
&entity.SubjectID,
&entity.SubjectClient,
&entity.Bag,
&entity.ExpireAt,
&ent.AccessToken,
&ent.RefreshToken,
&ent.SubjectID,
&ent.SubjectClient,
&ent.Bag,
&ent.ExpireAt,
)
if err != nil {
ps.incError(labels)
return nil, err
}

expireAt, err := ptypes.TimestampProto(entity.ExpireAt)
expireAt, err := ptypes.TimestampProto(ent.ExpireAt)
if err != nil {
return nil, err
}
sessions = append(sessions, &mnemosynerpc.Session{
AccessToken: entity.AccessToken,
SubjectId: entity.SubjectID,
SubjectClient: entity.SubjectClient,
Bag: entity.Bag,
AccessToken: ent.AccessToken,
RefreshToken: ent.RefreshToken,
SubjectId: ent.SubjectID,
SubjectClient: ent.SubjectClient,
Bag: ent.Bag,
ExpireAt: expireAt,
})
}
Expand Down Expand Up @@ -275,13 +283,12 @@ func (ps *postgresStorage) SetValue(ctx context.Context, accessToken string, key
}

// Delete implements storage interface.
func (ps *postgresStorage) Delete(ctx context.Context, accessToken string, expiredAtFrom, expiredAtTo *time.Time) (int64, error) {
if accessToken == "" && expiredAtFrom == nil && expiredAtTo == nil {
return 0, errors.New("session cannot be deleted, no where parameter provided")
func (ps *postgresStorage) Delete(ctx context.Context, accessToken, refreshToken string, expiredAtFrom, expiredAtTo *time.Time) (int64, error) {
where, args := ps.where(accessToken, refreshToken, expiredAtFrom, expiredAtTo)
if where.Len() == 0 {
return 0, fmt.Errorf("session cannot be deleted, no where parameter provided: %s", where.String())
}

where, args := ps.where(accessToken, expiredAtFrom, expiredAtTo)
query := "DELETE FROM " + ps.schema + "." + ps.table + " WHERE " + where
query := "DELETE FROM " + ps.schema + "." + ps.table + " WHERE " + where.String()
labels := prometheus.Labels{"query": "delete"}

result, err := ps.db.Exec(query, args...)
Expand All @@ -300,15 +307,21 @@ func (ps *postgresStorage) Setup() error {
CREATE SCHEMA IF NOT EXISTS %s;
CREATE TABLE IF NOT EXISTS %s.%s (
access_token BYTEA PRIMARY KEY,
refresh_token BYTEA,
subject_id TEXT NOT NULL,
subject_client TEXT,
bag bytea NOT NULL,
expire_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() + '%d seconds')
);
CREATE INDEX ON %s.%s (refresh_token);
CREATE INDEX ON %s.%s (subject_id);
CREATE INDEX ON %s.%s (expire_at DESC);
`, ps.schema, ps.schema, ps.table, int64(ps.ttl.Seconds()), ps.schema, ps.table, ps.schema, ps.table)
`, ps.schema, ps.schema, ps.table, int64(ps.ttl.Seconds()),
ps.schema, ps.table,
ps.schema, ps.table,
ps.schema, ps.table,
)
_, err := ps.db.Exec(query)

return err
Expand All @@ -333,29 +346,36 @@ func (ps *postgresStorage) incError(field prometheus.Labels) {
}
}

func (ps *postgresStorage) where(accessToken string, expiredAtFrom, expiredAtTo *time.Time) (string, []interface{}) {
func (ps *postgresStorage) where(accessToken, refreshToken string, expiredAtFrom, expiredAtTo *time.Time) (*bytes.Buffer, []interface{}) {
var count int
buf := bytes.NewBuffer(nil)
args := make([]interface{}, 0, 4)

switch {
case accessToken != "" && expiredAtFrom == nil && expiredAtTo == nil:
return " access_token = $1", []interface{}{accessToken}
case accessToken == "" && expiredAtFrom != nil && expiredAtTo == nil:
return " expire_at > $1", []interface{}{expiredAtFrom}
case accessToken == "" && expiredAtFrom == nil && expiredAtTo != nil:
return " expire_at < $1", []interface{}{expiredAtTo}
case accessToken != "" && expiredAtFrom != nil && expiredAtTo == nil:
return " access_token = $1 AND expire_at > $2", []interface{}{accessToken, expiredAtFrom}
case accessToken != "" && expiredAtFrom == nil && expiredAtTo != nil:
return " access_token = $1 AND expire_at < $2", []interface{}{accessToken, expiredAtTo}
case accessToken == "" && expiredAtFrom != nil && expiredAtTo != nil:
return " expire_at > $1 AND expire_at < $2", []interface{}{expiredAtFrom, expiredAtTo}
case accessToken != "" && expiredAtFrom != nil && expiredAtTo != nil:
return " access_token = $1 AND expire_at > $2 AND expire_at < $3", []interface{}{accessToken, expiredAtFrom, expiredAtTo}
default:
return " ", nil
case accessToken != "":
count++
fmt.Fprintf(buf, " access_token = $%d", count)
args = append(args, accessToken)
case refreshToken != "":
count++
fmt.Fprintf(buf, " refresh_token = $%d", count)
args = append(args, refreshToken)
case expiredAtFrom != nil:
count++
fmt.Fprintf(buf, " expire_at > $%d", count)
args = append(args, expiredAtFrom)
case expiredAtTo != nil:
count++
fmt.Fprintf(buf, " expire_at < $%d", count)
args = append(args, expiredAtTo)
}

return buf, args
}

type sessionEntity struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
SubjectID string `json:"subjectId"`
SubjectClient string `json:"subjectClient"`
Bag bag `json:"bag"`
Expand All @@ -369,6 +389,7 @@ func (se *sessionEntity) session() (*mnemosynerpc.Session, error) {
}
return &mnemosynerpc.Session{
AccessToken: se.AccessToken,
RefreshToken: se.RefreshToken,
SubjectId: se.SubjectID,
SubjectClient: se.SubjectClient,
Bag: se.Bag,
Expand Down
4 changes: 2 additions & 2 deletions mnemosyned/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
var (
errSessionNotFound = grpc.Errorf(codes.NotFound, "mnemosyned: session not found")
errMissingAccessToken = grpc.Errorf(codes.InvalidArgument, "mnemosyned: missing access token")
errMissingSubjectID = grpc.Errorf(codes.InvalidArgument, "mnemosyned: missing subject id")
errMissingSubjectID = grpc.Errorf(codes.InvalidArgument, "mnemosyned: missing subject accessToken")
errMissingSession = grpc.Errorf(codes.InvalidArgument, "mnemosyned: missing session")
)

Expand Down Expand Up @@ -290,7 +290,7 @@ InfLoop:
case <-time.After(sm.ttc):
t := time.Now()
sklog.Debug(logger, "session cleanup start", "start_at", t.Format(time.RFC3339))
affected, err := sm.storage.Delete(context.Background(), "", nil, &t)
affected, err := sm.storage.Delete(context.Background(), "", "", nil, &t)
if err != nil {
if sm.monitor.enabled {
sm.monitor.cleanup.errors.Inc()
Expand Down
Loading

0 comments on commit 8452a72

Please sign in to comment.