diff --git a/gopool.go b/gopool.go index 977c214..254e8b4 100644 --- a/gopool.go +++ b/gopool.go @@ -1,8 +1,8 @@ package gopool import ( - "sync" - "time" + "sync" + "time" ) // task represents a function that will be executed by a worker. @@ -11,151 +11,151 @@ type task func() (interface{}, error) // goPool represents a pool of workers. type goPool struct { - workers []*worker - workerStack []int - maxWorkers int - // Set by WithMinWorkers(), used to adjust the number of workers. Default equals to maxWorkers. - minWorkers int - // tasks are added to this channel first, then dispatched to workers. Default buffer size is 1 million. - taskQueue chan task - // Set by WithRetryCount(), used to retry a task when it fails. Default is 0. - retryCount int - lock sync.Locker - cond *sync.Cond - // Set by WithTimeout(), used to set a timeout for a task. Default is 0, which means no timeout. - timeout time.Duration - // Set by WithResultCallback(), used to handle the result of a task. Default is nil. - resultCallback func(interface{}) - // Set by WithErrorCallback(), used to handle the error of a task. Default is nil. - errorCallback func(error) - // adjustInterval is the interval to adjust the number of workers. Default is 1 second. - adjustInterval time.Duration + workers []*worker + workerStack []int + maxWorkers int + // Set by WithMinWorkers(), used to adjust the number of workers. Default equals to maxWorkers. + minWorkers int + // tasks are added to this channel first, then dispatched to workers. Default buffer size is 1 million. + taskQueue chan task + // Set by WithRetryCount(), used to retry a task when it fails. Default is 0. + retryCount int + lock sync.Locker + cond *sync.Cond + // Set by WithTimeout(), used to set a timeout for a task. Default is 0, which means no timeout. + timeout time.Duration + // Set by WithResultCallback(), used to handle the result of a task. Default is nil. + resultCallback func(interface{}) + // Set by WithErrorCallback(), used to handle the error of a task. Default is nil. + errorCallback func(error) + // adjustInterval is the interval to adjust the number of workers. Default is 1 second. + adjustInterval time.Duration } // NewGoPool creates a new pool of workers. func NewGoPool(maxWorkers int, opts ...Option) *goPool { - pool := &goPool{ - maxWorkers: maxWorkers, - // Set minWorkers to maxWorkers by default - minWorkers: maxWorkers, - workers: make([]*worker, maxWorkers), - workerStack: make([]int, maxWorkers), - taskQueue: make(chan task, 1e6), - retryCount: 0, - lock: new(sync.Mutex), - timeout: 0, - adjustInterval: 1 * time.Second, - } - // Apply options - for _, opt := range opts { - opt(pool) - } - if pool.cond == nil { - pool.cond = sync.NewCond(pool.lock) - } - // Create workers with the minimum number. Don't use pushWorker() here. - for i := 0; i < pool.minWorkers; i++ { - worker := newWorker() - pool.workers[i] = worker - pool.workerStack[i] = i - worker.start(pool, i) - } - go pool.adjustWorkers() - go pool.dispatch() - return pool + pool := &goPool{ + maxWorkers: maxWorkers, + // Set minWorkers to maxWorkers by default + minWorkers: maxWorkers, + workers: make([]*worker, maxWorkers), + workerStack: make([]int, maxWorkers), + taskQueue: make(chan task, 1e6), + retryCount: 0, + lock: new(sync.Mutex), + timeout: 0, + adjustInterval: 1 * time.Second, + } + // Apply options + for _, opt := range opts { + opt(pool) + } + if pool.cond == nil { + pool.cond = sync.NewCond(pool.lock) + } + // Create workers with the minimum number. Don't use pushWorker() here. + for i := 0; i < pool.minWorkers; i++ { + worker := newWorker() + pool.workers[i] = worker + pool.workerStack[i] = i + worker.start(pool, i) + } + go pool.adjustWorkers() + go pool.dispatch() + return pool } // AddTask adds a task to the pool. func (p *goPool) AddTask(t task) { - p.taskQueue <- t + p.taskQueue <- t } // Wait waits for all tasks to be dispatched. func (p *goPool) Wait() { - for len(p.taskQueue) > 0 { - time.Sleep(100 * time.Millisecond) - } + for len(p.taskQueue) > 0 { + time.Sleep(100 * time.Millisecond) + } } // Release stops all workers and releases resources. -func (p *goPool) Release() { - close(p.taskQueue) - p.cond.L.Lock() - for len(p.workerStack) != p.minWorkers { - p.cond.Wait() - } - p.cond.L.Unlock() - for _, worker := range p.workers { - close(worker.taskQueue) - } - p.workers = nil - p.workerStack = nil +func (p *goPool) Release() { + close(p.taskQueue) + p.cond.L.Lock() + for len(p.workerStack) != p.minWorkers { + p.cond.Wait() + } + p.cond.L.Unlock() + for _, worker := range p.workers { + close(worker.taskQueue) + } + p.workers = nil + p.workerStack = nil } func (p *goPool) popWorker() int { - p.lock.Lock() - workerIndex := p.workerStack[len(p.workerStack)-1] - p.workerStack = p.workerStack[:len(p.workerStack)-1] - p.lock.Unlock() - return workerIndex + p.lock.Lock() + workerIndex := p.workerStack[len(p.workerStack)-1] + p.workerStack = p.workerStack[:len(p.workerStack)-1] + p.lock.Unlock() + return workerIndex } func (p *goPool) pushWorker(workerIndex int) { - p.lock.Lock() - p.workerStack = append(p.workerStack, workerIndex) - p.lock.Unlock() - p.cond.Signal() + p.lock.Lock() + p.workerStack = append(p.workerStack, workerIndex) + p.lock.Unlock() + p.cond.Signal() } // adjustWorkers adjusts the number of workers according to the number of tasks in the queue. func (p *goPool) adjustWorkers() { - ticker := time.NewTicker(p.adjustInterval) - defer ticker.Stop() + ticker := time.NewTicker(p.adjustInterval) + defer ticker.Stop() - for range ticker.C { - p.cond.L.Lock() - if len(p.taskQueue) > len(p.workerStack)*3/4 && len(p.workerStack) < p.maxWorkers { - // Double the number of workers until it reaches the maximum - newWorkers := min(len(p.workerStack)*2, p.maxWorkers) - len(p.workerStack) - for i := 0; i < newWorkers; i++ { - worker := newWorker() - p.workers = append(p.workers, worker) - p.workerStack = append(p.workerStack, len(p.workers)-1) - worker.start(p, len(p.workers)-1) - } - } else if len(p.taskQueue) == 0 && len(p.workerStack) > p.minWorkers { - // Halve the number of workers until it reaches the minimum - removeWorkers := max((len(p.workerStack)-p.minWorkers)/2, p.minWorkers) - p.workers = p.workers[:len(p.workers)-removeWorkers] - p.workerStack = p.workerStack[:len(p.workerStack)-removeWorkers] - } - p.cond.L.Unlock() - } + for range ticker.C { + p.cond.L.Lock() + if len(p.taskQueue) > len(p.workerStack)*3/4 && len(p.workerStack) < p.maxWorkers { + // Double the number of workers until it reaches the maximum + newWorkers := min(len(p.workerStack)*2, p.maxWorkers) - len(p.workerStack) + for i := 0; i < newWorkers; i++ { + worker := newWorker() + p.workers = append(p.workers, worker) + p.workerStack = append(p.workerStack, len(p.workers)-1) + worker.start(p, len(p.workers)-1) + } + } else if len(p.taskQueue) == 0 && len(p.workerStack) > p.minWorkers { + // Halve the number of workers until it reaches the minimum + removeWorkers := max((len(p.workerStack)-p.minWorkers)/2, p.minWorkers) + p.workers = p.workers[:len(p.workers)-removeWorkers] + p.workerStack = p.workerStack[:len(p.workerStack)-removeWorkers] + } + p.cond.L.Unlock() + } } // dispatch dispatches tasks to workers. func (p *goPool) dispatch() { - for t := range p.taskQueue { - p.cond.L.Lock() - for len(p.workerStack) == 0 { - p.cond.Wait() - } - p.cond.L.Unlock() - workerIndex := p.popWorker() - p.workers[workerIndex].taskQueue <- t - } + for t := range p.taskQueue { + p.cond.L.Lock() + for len(p.workerStack) == 0 { + p.cond.Wait() + } + p.cond.L.Unlock() + workerIndex := p.popWorker() + p.workers[workerIndex].taskQueue <- t + } } func min(a, b int) int { - if a < b { - return a - } - return b + if a < b { + return a + } + return b } func max(a, b int) int { - if a > b { - return a - } - return b + if a > b { + return a + } + return b } diff --git a/gopool_test.go b/gopool_test.go index a5e8533..65e2252 100644 --- a/gopool_test.go +++ b/gopool_test.go @@ -1,10 +1,10 @@ package gopool import ( + "errors" "sync" "testing" "time" - "errors" "github.com/daniel-hutao/spinlock" ) @@ -93,35 +93,35 @@ func BenchmarkGoroutines(b *testing.B) { } func TestGoPoolWithError(t *testing.T) { - var errTaskError = errors.New("task error") - pool := NewGoPool(100, WithErrorCallback(func(err error) { - if err != errTaskError { - t.Errorf("Expected error %v, but got %v", errTaskError, err) - } - })) + var errTaskError = errors.New("task error") + pool := NewGoPool(100, WithErrorCallback(func(err error) { + if err != errTaskError { + t.Errorf("Expected error %v, but got %v", errTaskError, err) + } + })) defer pool.Release() - for i := 0; i< 1000; i++ { - pool.AddTask(func() (interface{}, error) { - return nil, errTaskError - }) - } - pool.Wait() + for i := 0; i < 1000; i++ { + pool.AddTask(func() (interface{}, error) { + return nil, errTaskError + }) + } + pool.Wait() } func TestGoPoolWithResult(t *testing.T) { - var expectedResult = "task result" - pool := NewGoPool(100, WithResultCallback(func(result interface{}) { - if result != expectedResult { - t.Errorf("Expected result %v, but got %v", expectedResult, result) - } - })) + var expectedResult = "task result" + pool := NewGoPool(100, WithResultCallback(func(result interface{}) { + if result != expectedResult { + t.Errorf("Expected result %v, but got %v", expectedResult, result) + } + })) defer pool.Release() - for i := 0; i< 1000; i++ { - pool.AddTask(func() (interface{}, error) { - return expectedResult, nil - }) - } - pool.Wait() + for i := 0; i < 1000; i++ { + pool.AddTask(func() (interface{}, error) { + return expectedResult, nil + }) + } + pool.Wait() } diff --git a/option.go b/option.go index 3170601..b86b64d 100644 --- a/option.go +++ b/option.go @@ -1,8 +1,8 @@ package gopool import ( - "sync" - "time" + "sync" + "time" ) // Option represents an option for the pool. @@ -10,43 +10,43 @@ type Option func(*goPool) // WithLock sets the lock for the pool. func WithLock(lock sync.Locker) Option { - return func(p *goPool) { - p.lock = lock - p.cond = sync.NewCond(p.lock) - } + return func(p *goPool) { + p.lock = lock + p.cond = sync.NewCond(p.lock) + } } // WithMinWorkers sets the minimum number of workers for the pool. func WithMinWorkers(minWorkers int) Option { - return func(p *goPool) { - p.minWorkers = minWorkers - } + return func(p *goPool) { + p.minWorkers = minWorkers + } } // WithTimeout sets the timeout for the pool. func WithTimeout(timeout time.Duration) Option { - return func(p *goPool) { - p.timeout = timeout - } + return func(p *goPool) { + p.timeout = timeout + } } // WithResultCallback sets the result callback for the pool. func WithResultCallback(callback func(interface{})) Option { - return func(p *goPool) { - p.resultCallback = callback - } + return func(p *goPool) { + p.resultCallback = callback + } } // WithErrorCallback sets the error callback for the pool. func WithErrorCallback(callback func(error)) Option { - return func(p *goPool) { - p.errorCallback = callback - } + return func(p *goPool) { + p.errorCallback = callback + } } // WithRetryCount sets the retry count for the pool. func WithRetryCount(retryCount int) Option { - return func(p *goPool) { - p.retryCount = retryCount - } + return func(p *goPool) { + p.retryCount = retryCount + } } diff --git a/worker.go b/worker.go index 4e86b4a..c4b0179 100644 --- a/worker.go +++ b/worker.go @@ -1,89 +1,88 @@ package gopool import ( - "context" - "fmt" + "context" + "fmt" ) // worker represents a worker in the pool. type worker struct { - taskQueue chan task + taskQueue chan task } func newWorker() *worker { - return &worker{ - taskQueue: make(chan task, 1), - } + return &worker{ + taskQueue: make(chan task, 1), + } } // start starts the worker in a separate goroutine. // The worker will run tasks from its taskQueue until the taskQueue is closed. // For the length of the taskQueue is 1, the worker will be pushed back to the pool after executing 1 task. func (w *worker) start(pool *goPool, workerIndex int) { - go func() { - for t := range w.taskQueue { - if t != nil { - result, err := w.executeTask(t, pool) - w.handleResult(result, err, pool) - } - pool.pushWorker(workerIndex) - } - }() + go func() { + for t := range w.taskQueue { + if t != nil { + result, err := w.executeTask(t, pool) + w.handleResult(result, err, pool) + } + pool.pushWorker(workerIndex) + } + }() } // executeTask executes a task and returns the result and error. // If the task fails, it will be retried according to the retryCount of the pool. func (w *worker) executeTask(t task, pool *goPool) (result interface{}, err error) { - for i := 0; i <= pool.retryCount; i++ { - if pool.timeout > 0 { - result, err = w.executeTaskWithTimeout(t, pool) - } else { - result, err = w.executeTaskWithoutTimeout(t, pool) - } - if err == nil || i == pool.retryCount { - return result, err - } - } - return + for i := 0; i <= pool.retryCount; i++ { + if pool.timeout > 0 { + result, err = w.executeTaskWithTimeout(t, pool) + } else { + result, err = w.executeTaskWithoutTimeout(t, pool) + } + if err == nil || i == pool.retryCount { + return result, err + } + } + return } // executeTaskWithTimeout executes a task with a timeout and returns the result and error. func (w *worker) executeTaskWithTimeout(t task, pool *goPool) (result interface{}, err error) { - // Create a context with timeout - ctx, cancel := context.WithTimeout(context.Background(), pool.timeout) - defer cancel() + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), pool.timeout) + defer cancel() - // Create a channel to receive the result of the task - done := make(chan struct{}) + // Create a channel to receive the result of the task + done := make(chan struct{}) - // Run the task in a separate goroutine - go func() { - result, err = t() - close(done) - }() + // Run the task in a separate goroutine + go func() { + result, err = t() + close(done) + }() - // Wait for the task to finish or for the context to timeout - select { - case <-done: - // The task finished successfully - return result, err - case <-ctx.Done(): - // The context timed out, the task took too long - return nil, fmt.Errorf("Task timed out") - } + // Wait for the task to finish or for the context to timeout + select { + case <-done: + // The task finished successfully + return result, err + case <-ctx.Done(): + // The context timed out, the task took too long + return nil, fmt.Errorf("Task timed out") + } } func (w *worker) executeTaskWithoutTimeout(t task, pool *goPool) (result interface{}, err error) { - // If timeout is not set or is zero, just run the task - return t() + // If timeout is not set or is zero, just run the task + return t() } - // handleResult handles the result of a task. func (w *worker) handleResult(result interface{}, err error, pool *goPool) { - if err != nil && pool.errorCallback != nil { - pool.errorCallback(err) - } else if pool.resultCallback != nil { - pool.resultCallback(result) - } + if err != nil && pool.errorCallback != nil { + pool.errorCallback(err) + } else if pool.resultCallback != nil { + pool.resultCallback(result) + } }