Skip to content

Commit

Permalink
More lua scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Joker666 committed Oct 11, 2024
1 parent af0192e commit b5ff9e9
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 136 deletions.
118 changes: 3 additions & 115 deletions internal/rdb/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,42 +177,6 @@ func (r *RDB) Dequeue(ctx context.Context, queueNames ...string) (msg *base.Task
return nil, time.Time{}, errors.E(op, errors.NotFound, errors.ErrNoProcessableTask)
}

// KEYS[1] -> asynq:{<queueName>}:active
// KEYS[2] -> asynq:{<queueName>}:lease
// KEYS[3] -> asynq:{<queueName>}:t:<task_id>
// KEYS[4] -> asynq:{<queueName>}:processed:<yyyy-mm-dd>
// KEYS[5] -> asynq:{<queueName>}:processed
// KEYS[6] -> unique key
// -------
// ARGV[1] -> task ID
// ARGV[2] -> stats expiration timestamp
// ARGV[3] -> max int64 value
var doneUniqueCmd = redis.NewScript(`
if redis.call("LREM", KEYS[1], 0, ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZREM", KEYS[2], ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("DEL", KEYS[3]) == 0 then
return redis.error_reply("NOT FOUND")
end
local n = redis.call("INCR", KEYS[4])
if tonumber(n) == 1 then
redis.call("EXPIREAT", KEYS[4], ARGV[2])
end
local total = redis.call("GET", KEYS[5])
if tonumber(total) == tonumber(ARGV[3]) then
redis.call("SET", KEYS[5], 1)
else
redis.call("INCR", KEYS[5])
end
if redis.call("GET", KEYS[6]) == ARGV[1] then
redis.call("DEL", KEYS[6])
end
return redis.status_reply("OK")
`)

// Done removes the task from active queue and deletes the task.
// It removes a uniqueness lock acquired by the task, if any.
func (r *RDB) Done(ctx context.Context, msg *base.TaskMessage) error {
Expand All @@ -234,87 +198,11 @@ func (r *RDB) Done(ctx context.Context, msg *base.TaskMessage) error {
// Note: We cannot pass empty unique key when running this script in redis-cluster.
if len(msg.UniqueKey) > 0 {
keys = append(keys, msg.UniqueKey)
return r.runScript(ctx, op, doneUniqueCmd, keys, argv...)
return r.runScript(ctx, op, script.DoneUniqueCmd, keys, argv...)
}
return r.runScript(ctx, op, script.DoneCmd, keys, argv...)
}

// KEYS[1] -> asynq:{<queueName>}:active
// KEYS[2] -> asynq:{<queueName>}:lease
// KEYS[3] -> asynq:{<queueName>}:completed
// KEYS[4] -> asynq:{<queueName>}:t:<task_id>
// KEYS[5] -> asynq:{<queueName>}:processed:<yyyy-mm-dd>
// KEYS[6] -> asynq:{<queueName>}:processed
//
// ARGV[1] -> task ID
// ARGV[2] -> stats expiration timestamp
// ARGV[3] -> task expiration time in unix time
// ARGV[4] -> task message data
// ARGV[5] -> max int64 value
var markAsCompleteCmd = redis.NewScript(`
if redis.call("LREM", KEYS[1], 0, ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZREM", KEYS[2], ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZADD", KEYS[3], ARGV[3], ARGV[1]) ~= 1 then
return redis.error_reply("INTERNAL")
end
redis.call("HSET", KEYS[4], "msg", ARGV[4], "state", "completed")
local n = redis.call("INCR", KEYS[5])
if tonumber(n) == 1 then
redis.call("EXPIREAT", KEYS[5], ARGV[2])
end
local total = redis.call("GET", KEYS[6])
if tonumber(total) == tonumber(ARGV[5]) then
redis.call("SET", KEYS[6], 1)
else
redis.call("INCR", KEYS[6])
end
return redis.status_reply("OK")
`)

// KEYS[1] -> asynq:{<queueName>}:active
// KEYS[2] -> asynq:{<queueName>}:lease
// KEYS[3] -> asynq:{<queueName>}:completed
// KEYS[4] -> asynq:{<queueName>}:t:<task_id>
// KEYS[5] -> asynq:{<queueName>}:processed:<yyyy-mm-dd>
// KEYS[6] -> asynq:{<queueName>}:processed
// KEYS[7] -> asynq:{<queueName>}:unique:{<checksum>}
//
// ARGV[1] -> task ID
// ARGV[2] -> stats expiration timestamp
// ARGV[3] -> task expiration time in unix time
// ARGV[4] -> task message data
// ARGV[5] -> max int64 value
var markAsCompleteUniqueCmd = redis.NewScript(`
if redis.call("LREM", KEYS[1], 0, ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZREM", KEYS[2], ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZADD", KEYS[3], ARGV[3], ARGV[1]) ~= 1 then
return redis.error_reply("INTERNAL")
end
redis.call("HSET", KEYS[4], "msg", ARGV[4], "state", "completed")
local n = redis.call("INCR", KEYS[5])
if tonumber(n) == 1 then
redis.call("EXPIREAT", KEYS[5], ARGV[2])
end
local total = redis.call("GET", KEYS[6])
if tonumber(total) == tonumber(ARGV[5]) then
redis.call("SET", KEYS[6], 1)
else
redis.call("INCR", KEYS[6])
end
if redis.call("GET", KEYS[7]) == ARGV[1] then
redis.call("DEL", KEYS[7])
end
return redis.status_reply("OK")
`)

// MarkAsComplete removes the task from active queue to mark the task as completed.
// It removes a uniqueness lock acquired by the task, if any.
func (r *RDB) MarkAsComplete(ctx context.Context, msg *base.TaskMessage) error {
Expand Down Expand Up @@ -344,9 +232,9 @@ func (r *RDB) MarkAsComplete(ctx context.Context, msg *base.TaskMessage) error {
// Note: We cannot pass empty unique key when running this script in redis-cluster.
if len(msg.UniqueKey) > 0 {
keys = append(keys, msg.UniqueKey)
return r.runScript(ctx, op, markAsCompleteUniqueCmd, keys, argv...)
return r.runScript(ctx, op, script.MarkAsCompleteUniqueCmd, keys, argv...)
}
return r.runScript(ctx, op, markAsCompleteCmd, keys, argv...)
return r.runScript(ctx, op, script.MarkAsCompleteCmd, keys, argv...)
}

// KEYS[1] -> asynq:{<queueName>}:active
Expand Down
33 changes: 33 additions & 0 deletions internal/script/done_unique.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
-- KEYS[1] -> asynq:{<queueName>}:active
-- KEYS[2] -> asynq:{<queueName>}:lease
-- KEYS[3] -> asynq:{<queueName>}:t:<task_id>
-- KEYS[4] -> asynq:{<queueName>}:processed:<yyyy-mm-dd>
-- KEYS[5] -> asynq:{<queueName>}:processed
-- KEYS[6] -> unique key
-- -------
-- ARGV[1] -> task ID
-- ARGV[2] -> stats expiration timestamp
-- ARGV[3] -> max int64 value
if redis.call("LREM", KEYS[1], 0, ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZREM", KEYS[2], ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("DEL", KEYS[3]) == 0 then
return redis.error_reply("NOT FOUND")
end
local n = redis.call("INCR", KEYS[4])
if tonumber(n) == 1 then
redis.call("EXPIREAT", KEYS[4], ARGV[2])
end
local total = redis.call("GET", KEYS[5])
if tonumber(total) == tonumber(ARGV[3]) then
redis.call("SET", KEYS[5], 1)
else
redis.call("INCR", KEYS[5])
end
if redis.call("GET", KEYS[6]) == ARGV[1] then
redis.call("DEL", KEYS[6])
end
return redis.status_reply("OK")
File renamed without changes.
33 changes: 33 additions & 0 deletions internal/script/mark_as_completed.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
-- KEYS[1] -> asynq:{<queueName>}:active
-- KEYS[2] -> asynq:{<queueName>}:lease
-- KEYS[3] -> asynq:{<queueName>}:completed
-- KEYS[4] -> asynq:{<queueName>}:t:<task_id>
-- KEYS[5] -> asynq:{<queueName>}:processed:<yyyy-mm-dd>
-- KEYS[6] -> asynq:{<queueName>}:processed
--
-- ARGV[1] -> task ID
-- ARGV[2] -> stats expiration timestamp
-- ARGV[3] -> task expiration time in unix time
-- ARGV[4] -> task message data
-- ARGV[5] -> max int64 value
if redis.call("LREM", KEYS[1], 0, ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZREM", KEYS[2], ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZADD", KEYS[3], ARGV[3], ARGV[1]) ~= 1 then
return redis.error_reply("INTERNAL")
end
redis.call("HSET", KEYS[4], "msg", ARGV[4], "state", "completed")
local n = redis.call("INCR", KEYS[5])
if tonumber(n) == 1 then
redis.call("EXPIREAT", KEYS[5], ARGV[2])
end
local total = redis.call("GET", KEYS[6])
if tonumber(total) == tonumber(ARGV[5]) then
redis.call("SET", KEYS[6], 1)
else
redis.call("INCR", KEYS[6])
end
return redis.status_reply("OK")
37 changes: 37 additions & 0 deletions internal/script/mark_as_completed_unique.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
-- KEYS[1] -> asynq:{<queueName>}:active
-- KEYS[2] -> asynq:{<queueName>}:lease
-- KEYS[3] -> asynq:{<queueName>}:completed
-- KEYS[4] -> asynq:{<queueName>}:t:<task_id>
-- KEYS[5] -> asynq:{<queueName>}:processed:<yyyy-mm-dd>
-- KEYS[6] -> asynq:{<queueName>}:processed
-- KEYS[7] -> asynq:{<queueName>}:unique:{<checksum>}
--
-- ARGV[1] -> task ID
-- ARGV[2] -> stats expiration timestamp
-- ARGV[3] -> task expiration time in unix time
-- ARGV[4] -> task message data
-- ARGV[5] -> max int64 value
if redis.call("LREM", KEYS[1], 0, ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZREM", KEYS[2], ARGV[1]) == 0 then
return redis.error_reply("NOT FOUND")
end
if redis.call("ZADD", KEYS[3], ARGV[3], ARGV[1]) ~= 1 then
return redis.error_reply("INTERNAL")
end
redis.call("HSET", KEYS[4], "msg", ARGV[4], "state", "completed")
local n = redis.call("INCR", KEYS[5])
if tonumber(n) == 1 then
redis.call("EXPIREAT", KEYS[5], ARGV[2])
end
local total = redis.call("GET", KEYS[6])
if tonumber(total) == tonumber(ARGV[5]) then
redis.call("SET", KEYS[6], 1)
else
redis.call("INCR", KEYS[6])
end
if redis.call("GET", KEYS[7]) == ARGV[1] then
redis.call("DEL", KEYS[7])
end
return redis.status_reply("OK")
87 changes: 66 additions & 21 deletions internal/script/script.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,105 @@ package script
import (
"embed"
"fmt"
"sync"

"github.com/redis/go-redis/v9"
)

//go:embed *.lua
var luaScripts embed.FS

func loadLuaScript(name string) (string, error) {
var (
scriptCache = make(map[string]*redis.Script)
scriptCacheLock sync.RWMutex
)

func loadLuaScript(name string) (*redis.Script, error) {
scriptCacheLock.RLock()
script, ok := scriptCache[name]
scriptCacheLock.RUnlock()

if ok {
return script, nil
}

scriptCacheLock.Lock()
defer scriptCacheLock.Unlock()

// Double-check in case another goroutine has loaded the script
if script, ok := scriptCache[name]; ok {
return script, nil
}

content, err := luaScripts.ReadFile(fmt.Sprintf("%s.lua", name))
if err != nil {
return "", fmt.Errorf("failed to read Lua script %s: %w", name, err)
return nil, fmt.Errorf("failed to read Lua script %s: %w", name, err)
}
return string(content), nil

script = redis.NewScript(string(content))
scriptCache[name] = script
return script, nil
}

var (
EnqueueCmd *redis.Script
EnqueueUniqueCmd *redis.Script
DequeueCmd *redis.Script
DoneCmd *redis.Script
EnqueueCmd *redis.Script
EnqueueUniqueCmd *redis.Script
DequeueCmd *redis.Script
DoneCmd *redis.Script
DoneUniqueCmd *redis.Script
MarkAsCompleteCmd *redis.Script
MarkAsCompleteUniqueCmd *redis.Script
)

const (
enqueueCmd = "enqueue"
enqueueUniqueCmd = "enqueueUnique"
dequeueCmd = "dequeue"
doneCmd = "done"
enqueueCmd = "enqueue"
enqueueUniqueCmd = "enqueueUnique"
dequeueCmd = "dequeue"
doneCmd = "done"
doneUniqueCmd = "doneUnique"
markAsCompleteCmd = "markAsComplete"
markAsCompleteUniqueCmd = "markAsCompleteUnique"
)

// Use this function to initialize your Redis scripts
func init() {
enqueueLua, err := loadLuaScript(enqueueCmd)
var err error
EnqueueCmd, err = loadLuaScript(enqueueCmd)
if err != nil {
panic(err)
}

enqueueUniqueLua, err := loadLuaScript(enqueueUniqueCmd)
EnqueueUniqueCmd, err = loadLuaScript(enqueueUniqueCmd)
if err != nil {
panic(err)
}

dequeueLua, err := loadLuaScript(dequeueCmd)
DequeueCmd, err = loadLuaScript(dequeueCmd)
if err != nil {
panic(err)
}

doneLua, err := loadLuaScript(doneCmd)
DoneCmd, err = loadLuaScript(doneCmd)
if err != nil {
panic(err)
}

// Initialize Redis scripts here
EnqueueCmd = redis.NewScript(enqueueLua)
EnqueueUniqueCmd = redis.NewScript(enqueueUniqueLua)
DequeueCmd = redis.NewScript(dequeueLua)
DoneCmd = redis.NewScript(doneLua)
DoneUniqueCmd, err = loadLuaScript(doneUniqueCmd)
if err != nil {
panic(err)
}

MarkAsCompleteCmd, err = loadLuaScript(markAsCompleteCmd)
if err != nil {
panic(err)
}

MarkAsCompleteUniqueCmd, err = loadLuaScript(markAsCompleteUniqueCmd)
if err != nil {
panic(err)
}

MarkAsCompleteUniqueCmd, err = loadLuaScript(markAsCompleteCmd)
if err != nil {
panic(err)
}
}

0 comments on commit b5ff9e9

Please sign in to comment.