diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..814b600
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+.idea/*
+/vendor
+log
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..fab3551
--- /dev/null
+++ b/README.md
@@ -0,0 +1,121 @@
+# **redis-semaphore**
+
+
Implements a semaphore using redis commands. The semaphore is blocking, not polling, and has a fair queue serving processes on a first-come, first-serve basis.
+
Implementation based on Redis BLPOP ability to block execution until queue is not empty or timeout reached.
+
+### **Redis Client**
+
+redis-semaphore requires redis client provided by the user.
+
It is not dependant on specific redis version, and can accept any implementation that satisfies its Redis interface.
+
Implementations of `go_redis` && `redis.v5` clients are already given in this repository for your convenience.
+
Providing nil object will result in validation error.
+
+### **Redis Keys**
+
+redis-semaphore uses 4 keys to maintain semaphore lifecycle:
+
+1. **name** - derived by the binding key given by user. Semaphores are separated in Redis by their names
+
+2. **version** - In case of future possible updates & fixes, version key enables to differentiating between old and updated clients
+
+3. **available resources queue name** - represents the queue name in redis holding list of free locks to use
+
+4. **locked resources set name** - represents key in redis which under it all used locks and their expiration time will be stored
+
+### **Num Connections**
+
+due to the blocking nature of `blpop` command, note that it's very important to set size of redis connections pool that is higher
+than number of expected concurrent locks at worst case.
+
Exhausting all redis connections will result in a deadlock.
+
+### **Options**
+
+#### **Logging**
+
+redis-semaphore provides logging mechanism to enable monitoring in case needed.
+
It is not dependant on specific log tool, and can accept any implementation that satisfies its Logger interface.
+
Note that Logger interface should support 3 types of log levels:
+
+ - a. `Error` (0) - show only non blocking errors (errors that will not terminate semaphore process)
+ - b. `Info` (1) - log only critical information (lock/unlock succeeded/failed, etc)
+ - c. `Debug` (2) - verbose, include internal steps
+
+
Implementation of `logrus` client is already given in this repository for your convenience.
+
logger is optional. In case user have no need for log, do not pass it in options
+
+#### **Settings**
+
+The semaphore uses 4 settings to determine it's behavior, each of them can be overridden:
+
+1. **`Expiration`** - redis-semaphore must have an expiration time to ensure that after a while all evidence of the semaphore will disappear and your redis server will not be cluttered with unused keys.
+ Also, it represents the maximum amount of time mutual exclusion is guaranteed. Value is set to 1 minute by default.
+
+1. **`TryLockTimeout`** - each lock operation must be bounded by max running time and cannot block execution indefinitely. value is set to 30 seconds by default. This setting can be overridden to any duration between 1 second and semaphore expiration time.
+
+2. **`MaxParallelResources`** - redis-semaphore allows to define a set number of processes inside the semaphore-protected block (1 by default). All those processes can run in the critical section simultaneously.
+
+3. **`LockAttempts`** - user can choose to retry acquiring lock if timeout reached. All attempts will have the same timeout. Number of attempts is 1 be default (no retries).
+
+### **Usage**
+
+#### **Creating New Semaphore**
+
+```
+bindingKey = "my_lock_key"
+redisClient := semaphoreredis.NewRedisV5Client(redis.NewClient(&redis.Options{Addr: "localhost:6379"}))
+logger := semaphorelogger.NewLogrusLogger(logrus.New(), semaphorelogger.LogLevelInfo, bindingKey)
+overrideSettings := semaphore.Settings{
+ TryLockTimeout: 20 * time.Second,
+ LockAttempts: 2,
+ MaxParallelResources: 1,
+}
+
+s, err := semaphore.New(bindingKey, redisClient, logger, overrideSettings)
+```
+
+
Creates a new Semaphore. Mandatory params are binding key and Redis client. Optional params are logger and overrides to the default settings. Validation error will be returned on invalid params.
+
After semaphore is created, its settings cannot be modified. If you wish to alter semaphore setting, it would require creating and new object.
+ Note that creating multiple semaphores with the same binding key but different `MaxParallelResources` setting will have no effect. The setting of the first semaphore that will acquire lock will be applied until this semaphore will be expired.
+
+#### **Lock & Unlock**
+
+```
+token, err := s.Lock()
+isLockUsed, err := s.IsResourceLocked(token) //isLockUsed = true
+numFreeLocks, err := s.GetNumAvailableResources() //numFreeLocks = MaxParallelResources - 1
+err := s.Unlock(token) //don't forget this!
+isLockUsed, err := s.IsResourceLocked(token) //isLockUsed = false
+numFreeLocks, err := s.GetNumAvailableResources() //numFreeLocks = MaxParallelResources
+```
+
+
redis-semaphore enables separate lock & unlock operations.
+
Performing lock operation on the Semaphore creates all it's keys in redis if used for the first time or expired, and checks for expired locks otherwise (see expired resources section).
+
Lock function returns unique uuid representing the acquired lock. This string should be given as parameter to unlock function when we want to release the lock.
+
Resource will be locked until will be freed by unlocking it, or until semaphore will expire.
+
Performing lock or unlock oprations resets the semaphore's expiration time.
+
+#### **Execute With Mutex**
+```
+WithMutex(lockByKey string, redisClient Redis, logger Logger, safeCode func(), settings ...Settings) error
+```
+Wrapper for encapsulating semaphore internal implementation. Mandatory params are binding key, Redis client and block of code to run. Optional params are logger and settings overrides.
+
Function will create new semaphore, acquire lock, run function in critical section, and then release lock.
+ If error occurred while running code block, unlock procedure will run all the same.
+
+#### **Custom Timeout**
+```
+token, err := s.LockWithCustomTimeout(5 * time.Second)
+```
+User can choose to acquire lock using the same Semaphore but with alternating timeout for each lock operation. The custom timeout is subjects to the same limitations as `TryLockTimeout` parameter.
+Providing invalid timeout will result in validation error.
+
+### **Expired Resources**
+
+There are possible cases where non expired Semaphore will contain locks that passed their expiration time.
+The main reason for that is the extension of the Semaphore's expiration upon lock & unlock operations.
+Before every lock operation, expired resources (if exists) will be cleaned up and returned to available locks queue.
+
+### **Trying Lock On Expired Semaphore**
+
+Note that as opposed to locking algorithms that uses polling, in case semaphore expires while process awaits in the queue, it will be not possible to acquire lock!
+Client will have to wait until timeout will be reached and then he will be able to lock successfully at the next attempt.
\ No newline at end of file
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..dca65b2
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,15 @@
+module github.com/gtforge/redis-semaphore-go
+
+go 1.12
+
+require (
+ github.com/go-redis/redis v6.15.2+incompatible
+ github.com/golang/mock v1.3.1
+ github.com/onsi/ginkgo v1.9.0
+ github.com/onsi/gomega v1.6.0
+ github.com/pkg/errors v0.8.1
+ github.com/satori/go.uuid v1.2.0
+ github.com/sirupsen/logrus v1.4.2
+ github.com/tylerb/gls v0.0.0-20150407001822-e606233f194d
+ gopkg.in/redis.v5 v5.2.9
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..2fd25c1
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,49 @@
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
+github.com/go-redis/redis v6.15.2+incompatible h1:9SpNVG76gr6InJGxoZ6IuuxaCOQwDAhzyXg+Bs+0Sb4=
+github.com/go-redis/redis v6.15.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
+github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=
+github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
+github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
+github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
+github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
+github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
+github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
+github.com/onsi/ginkgo v1.9.0 h1:SZjF721BByVj8QH636/8S2DnX4n0Re3SteMmw3N+tzc=
+github.com/onsi/ginkgo v1.9.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
+github.com/onsi/gomega v1.6.0 h1:8XTW0fcJZEq9q+Upcyws4JSGua2MFysCL5xkaSgHc+M=
+github.com/onsi/gomega v1.6.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
+github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
+github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
+github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
+github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
+github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
+github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
+github.com/tylerb/gls v0.0.0-20150407001822-e606233f194d h1:yYYPFFlbqxF5mrj5sEfETtM/Ssz2LTy0/VKlDdXYctc=
+github.com/tylerb/gls v0.0.0-20150407001822-e606233f194d/go.mod h1:0MwyId/pXK5wkYYEXe7NnVknX+aNBuF73fLV3U0reU8=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=
+golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
+golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
+gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
+gopkg.in/redis.v5 v5.2.9 h1:MNZYOLPomQzZMfpN3ZtD1uyJ2IDonTTlxYiV/pEApiw=
+gopkg.in/redis.v5 v5.2.9/go.mod h1:6gtv0/+A4iM08kdRfocWYB3bLX2tebpNtfKlFT6H4mY=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
+gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
+gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
diff --git a/semaphore/constructor.go b/semaphore/constructor.go
new file mode 100644
index 0000000..42da31c
--- /dev/null
+++ b/semaphore/constructor.go
@@ -0,0 +1,167 @@
+package semaphore
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-logger"
+
+ "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-redis"
+
+ "github.com/pkg/errors"
+)
+
+const (
+ errLvl = semaphorelogger.LogLevelError
+ infoLvl = semaphorelogger.LogLevelInfo
+ debugLvl = semaphorelogger.LogLevelDebug
+)
+
+type Redis = semaphoreredis.Redis
+
+type Logger = semaphorelogger.Logger
+
+type semaphore struct {
+ lockByKey string
+ options Options
+ redis redisImpl
+}
+
+type redisImpl struct {
+ client Redis
+ keys []string
+}
+
+type Options struct {
+ TryLockTimeout time.Duration
+ LockAttempts int64
+ MaxParallelResources int64
+ Logger Logger
+ Expiration time.Duration
+}
+
+var defaultOptions = Options{
+ Expiration: 1 * time.Minute,
+ TryLockTimeout: 30 * time.Second,
+ LockAttempts: 1,
+ MaxParallelResources: 1,
+ Logger: semaphorelogger.NewEmptyLogger(),
+}
+
+func New(lockByKey string, redisClient Redis, options ...Options) (Semaphore, error) {
+ return create(lockByKey, redisClient, options...)
+}
+
+func create(lockByKey string, redisClient Redis, options ...Options) (*semaphore, error) {
+
+ s := &semaphore{
+ lockByKey: lockByKey,
+ options: setOptions(options...),
+ redis: redisImpl{client: redisClient},
+ }
+
+ err := s.validate()
+ if err != nil {
+ return nil, errors.Wrapf(err, "error in validating new semaphore")
+ }
+
+ s.setRedisKeys()
+
+ s.options.Logger.WithFields(map[string]interface{}{
+ "options": fmt.Sprintf("%+v", s.options),
+ }).Log(debugLvl, "new semaphore object created successfully")
+
+ return s, nil
+}
+
+func setOptions(overrides ...Options) Options {
+ options := defaultOptions
+
+ if len(overrides) == 0 {
+ return options
+ }
+
+ override := overrides[0]
+
+ if override.Expiration != 0 {
+ options.Expiration = override.Expiration
+ }
+
+ if override.TryLockTimeout != 0 {
+ options.TryLockTimeout = override.TryLockTimeout
+ }
+
+ if override.MaxParallelResources != 0 {
+ options.MaxParallelResources = override.MaxParallelResources
+ }
+
+ if override.LockAttempts != 0 {
+ options.LockAttempts = override.LockAttempts
+ }
+
+ if override.Logger != nil {
+ options.Logger = override.Logger
+ }
+
+ return options
+}
+
+func (s *semaphore) validate() error {
+ if s.lockByKey == "" {
+ return fmt.Errorf("lock by key field must be non empty")
+ }
+
+ if s.redis.client == nil {
+ return fmt.Errorf("redis client must be non nil")
+ }
+
+ if s.options.Expiration < time.Second {
+ return fmt.Errorf("expiration time must be at least 1 second, received %v", s.options.TryLockTimeout)
+ }
+
+ if s.options.TryLockTimeout < time.Second || s.options.TryLockTimeout > s.options.Expiration {
+ return fmt.Errorf("try lock timeout must be at least 1 second and smaller or equal to semaphore Expiration time, received %v", s.options.TryLockTimeout)
+ }
+
+ if s.options.MaxParallelResources <= 0 {
+ return fmt.Errorf("max parallel resources setting must be positive number, received %v", s.options.MaxParallelResources)
+ }
+
+ if s.options.LockAttempts <= 0 {
+ return fmt.Errorf("lock attempts setting must be positive number, received %v", s.options.LockAttempts)
+ }
+
+ return nil
+}
+
+func (s *semaphore) setRedisKeys() {
+ s.redis.keys = append(s.redis.keys, s.name(), s.availableQueueName(), s.lockedResourcesName(), s.version())
+}
+
+const (
+ namePrefix = "semaphore"
+ availableQueueNamePostfix = "available"
+ lockedQueueNamePostfix = "locked"
+ versionPostfix = "version"
+ releaseExpiredPostfix = "release_expired"
+)
+
+func (s *semaphore) name() string {
+ return fmt.Sprintf("%v:%v", namePrefix, s.lockByKey)
+}
+
+func (s *semaphore) availableQueueName() string {
+ return fmt.Sprintf("%v:%v", s.name(), availableQueueNamePostfix)
+}
+
+func (s *semaphore) lockedResourcesName() string {
+ return fmt.Sprintf("%v:%v", s.name(), lockedQueueNamePostfix)
+}
+
+func (s *semaphore) version() string {
+ return fmt.Sprintf("%v:%v", s.name(), versionPostfix)
+}
+
+func (s *semaphore) releaseExpiredLockName() string {
+ return fmt.Sprintf("%v:%v", s.name(), releaseExpiredPostfix)
+}
diff --git a/semaphore/init.go b/semaphore/init.go
new file mode 100644
index 0000000..1859f63
--- /dev/null
+++ b/semaphore/init.go
@@ -0,0 +1,68 @@
+package semaphore
+
+import (
+ "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-redis"
+ "github.com/pkg/errors"
+ "github.com/satori/go.uuid"
+)
+
+const (
+ mutexExistsValue = "v"
+ semaphoreVersion = "1.0"
+)
+
+func (s *semaphore) init() (err error) {
+ isNewSemaphore, err := s.redis.client.SetNX(s.name(), mutexExistsValue, s.options.Expiration)
+ if err != nil {
+ return errors.Wrapf(err, "failed to setNX semaphore name")
+ }
+
+ if isNewSemaphore {
+ s.options.Logger.Log(debugLvl, "semaphore is used for the first time or expired, creating keys in redis")
+ return s.create()
+ } else { //semaphore with this key already exists
+ s.options.Logger.Log(debugLvl, "semaphore already exists in redis - checking if should release expired resources")
+ return s.releaseExpiredResources()
+ }
+}
+
+func (s *semaphore) create() error {
+ pipeErr := s.redis.client.TxPipelined(func(pipe semaphoreredis.Pipeline) error { //execute redis transaction
+
+ err := pipe.Del(s.lockedResourcesName())
+ if err != nil {
+ return errors.Wrapf(err, "failed to delete queue of locked resources")
+ }
+
+ err = pipe.Del(s.availableQueueName()) //in case last client crushed and did not expire key
+ if err != nil {
+ return errors.Wrapf(err, "failed to delete queue of available resources")
+ }
+
+ var resources []interface{}
+
+ for i := 0; i < int(s.options.MaxParallelResources); i++ {
+ resources = append(resources, uuid.NewV4().String())
+ }
+
+ err = pipe.RPush(s.availableQueueName(), resources...)
+ if err != nil {
+ return errors.Wrapf(err, "failed to add %v resources to available resources queue", s.options.MaxParallelResources)
+ }
+
+ err = s.redis.client.Set(s.version(), semaphoreVersion, s.options.Expiration)
+ if err != nil {
+ return errors.Wrapf(err, "failed to set semaphore version")
+ }
+
+ return nil
+ })
+
+ if pipeErr != nil {
+ return pipeErr
+ }
+
+ s.options.Logger.Log(debugLvl, "semaphore created in redis successfully")
+
+ return nil
+}
diff --git a/semaphore/is_locked.go b/semaphore/is_locked.go
new file mode 100644
index 0000000..a8a2dc0
--- /dev/null
+++ b/semaphore/is_locked.go
@@ -0,0 +1,27 @@
+package semaphore
+
+import (
+ "github.com/pkg/errors"
+)
+
+func (s *semaphore) IsResourceLocked(resource string) (bool, error) {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resource": resource,
+ }).Log(debugLvl, "received is resource locked request")
+
+ return s.isResourceLocked(resource)
+}
+
+func (s *semaphore) isResourceLocked(resource string) (bool, error) {
+ isResourceLocked, err := s.redis.client.HExists(s.lockedResourcesName(), resource)
+ if err != nil {
+ return false, errors.Wrapf(err, "failed to check if resource %v exists in locked resources queue", resource)
+ }
+
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resource": resource,
+ "is_locked": isResourceLocked,
+ }).Log(debugLvl, "retrieved from redis resource lock status")
+
+ return isResourceLocked, nil
+}
diff --git a/semaphore/lock.go b/semaphore/lock.go
new file mode 100644
index 0000000..1b93d05
--- /dev/null
+++ b/semaphore/lock.go
@@ -0,0 +1,112 @@
+package semaphore
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-redis"
+
+ "github.com/pkg/errors"
+)
+
+var TimeoutError = errors.New("reached timeout while trying to acquire lock")
+
+func (s *semaphore) Lock() (string, error) {
+ return s.LockWithCustomTimeout(s.options.TryLockTimeout)
+}
+
+func (s *semaphore) LockWithCustomTimeout(timeout time.Duration) (string, error) {
+ if timeout < time.Second || timeout > s.options.Expiration {
+ return "", errors.New("try lock timeout must be at least 1 second and smaller or equal to semaphore Expiration time")
+ }
+
+ var (
+ isTimedOut bool
+ resource string
+ attempts int64
+ )
+
+ s.options.Logger.WithFields(map[string]interface{}{
+ "attempts": s.options.LockAttempts,
+ "timeout": timeout,
+ }).Log(debugLvl, "received lock request")
+
+ for ok := true; ok; ok = isTimedOut {
+
+ attempts++
+
+ if attempts > s.options.LockAttempts {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "attempts": s.options.LockAttempts,
+ "timeout": timeout,
+ }).Log(infoLvl, "all attempts to acquire lock reached timeout")
+ return "", TimeoutError //reached timeout when trying to pop from queue
+ }
+
+ err := s.init()
+ if err != nil {
+ return "", err
+ }
+
+ s.options.Logger.WithFields(map[string]interface{}{
+ "attempt": attempts,
+ "timeout": timeout,
+ }).Log(debugLvl, "trying to acquire lock")
+
+ resource, isTimedOut, err = s.redis.client.BLPop(timeout, s.availableQueueName())
+ if err != nil {
+ return "", errors.Wrapf(err, "failed to pop available resource from queue")
+ }
+
+ if isTimedOut {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "attempt": attempts,
+ "timeout": timeout,
+ "attempts_left": s.options.LockAttempts - attempts,
+ }).Log(infoLvl, "reached timeout while trying to acquire lock")
+ }
+ }
+
+ pipeErr := s.redis.client.TxPipelined(func(pipe semaphoreredis.Pipeline) error { //execute redis transaction
+
+ err := s.redis.client.HSet(s.lockedResourcesName(), resource, fmt.Sprint(time.Now().UnixNano())) //value = time of insertion so we would know when it is expired
+ if err != nil {
+ return errors.Wrapf(err, "failed to add resource %v to locked queue", resource)
+ }
+
+ return s.updateExpirationTime(pipe)
+ })
+
+ if pipeErr != nil {
+ return resource, pipeErr
+ }
+
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resource": resource,
+ }).Log(infoLvl, "resource locked successfully")
+
+ return resource, nil
+}
+
+func (s *semaphore) updateExpirationTime(pipe semaphoreredis.Pipeline) error {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "expiration_time": time.Now().Add(s.options.Expiration).Format("15:04:05.000"),
+ }).Log(debugLvl, "update semaphore redis keys Expiration time")
+
+ var err error
+
+ for _, k := range s.redis.keys {
+
+ if k == s.lockedResourcesName() {
+ err = pipe.PExpire(k, s.options.Expiration*2) //avoid race condition where semaphore will expire between init and lock
+ } else {
+ err = pipe.PExpire(k, s.options.Expiration)
+ }
+
+ if err != nil {
+ return errors.Wrapf(err, "failed to update Expiration time of key %v", k)
+ }
+ }
+
+ return nil
+}
diff --git a/semaphore/mock/semaphore_mock.go b/semaphore/mock/semaphore_mock.go
new file mode 100644
index 0000000..09ca245
--- /dev/null
+++ b/semaphore/mock/semaphore_mock.go
@@ -0,0 +1,109 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: ./semaphore/semaphore.go
+
+// Package mock_semaphore is a generated GoMock package.
+package mock_semaphore
+
+import (
+ reflect "reflect"
+ time "time"
+
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockSemaphore is a mock of Semaphore interface
+type MockSemaphore struct {
+ ctrl *gomock.Controller
+ recorder *MockSemaphoreMockRecorder
+}
+
+// MockSemaphoreMockRecorder is the mock recorder for MockSemaphore
+type MockSemaphoreMockRecorder struct {
+ mock *MockSemaphore
+}
+
+// NewMockSemaphore creates a new mock instance
+func NewMockSemaphore(ctrl *gomock.Controller) *MockSemaphore {
+ mock := &MockSemaphore{ctrl: ctrl}
+ mock.recorder = &MockSemaphoreMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockSemaphore) EXPECT() *MockSemaphoreMockRecorder {
+ return m.recorder
+}
+
+// Lock mocks base method
+func (m *MockSemaphore) Lock() (string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Lock")
+ ret0, _ := ret[0].(string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Lock indicates an expected call of Lock
+func (mr *MockSemaphoreMockRecorder) Lock() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockSemaphore)(nil).Lock))
+}
+
+// LockWithCustomTimeout mocks base method
+func (m *MockSemaphore) LockWithCustomTimeout(timeout time.Duration) (string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LockWithCustomTimeout", timeout)
+ ret0, _ := ret[0].(string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LockWithCustomTimeout indicates an expected call of LockWithCustomTimeout
+func (mr *MockSemaphoreMockRecorder) LockWithCustomTimeout(timeout interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockWithCustomTimeout", reflect.TypeOf((*MockSemaphore)(nil).LockWithCustomTimeout), timeout)
+}
+
+// Unlock mocks base method
+func (m *MockSemaphore) Unlock(resource string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Unlock", resource)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Unlock indicates an expected call of Unlock
+func (mr *MockSemaphoreMockRecorder) Unlock(resource interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockSemaphore)(nil).Unlock), resource)
+}
+
+// IsResourceLocked mocks base method
+func (m *MockSemaphore) IsResourceLocked(resource string) (bool, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IsResourceLocked", resource)
+ ret0, _ := ret[0].(bool)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// IsResourceLocked indicates an expected call of IsResourceLocked
+func (mr *MockSemaphoreMockRecorder) IsResourceLocked(resource interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsResourceLocked", reflect.TypeOf((*MockSemaphore)(nil).IsResourceLocked), resource)
+}
+
+// GetNumAvailableResources mocks base method
+func (m *MockSemaphore) GetNumAvailableResources() (int64, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetNumAvailableResources")
+ ret0, _ := ret[0].(int64)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetNumAvailableResources indicates an expected call of GetNumAvailableResources
+func (mr *MockSemaphoreMockRecorder) GetNumAvailableResources() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNumAvailableResources", reflect.TypeOf((*MockSemaphore)(nil).GetNumAvailableResources))
+}
diff --git a/semaphore/num_available.go b/semaphore/num_available.go
new file mode 100644
index 0000000..632e41c
--- /dev/null
+++ b/semaphore/num_available.go
@@ -0,0 +1,30 @@
+package semaphore
+
+import (
+ "github.com/pkg/errors"
+)
+
+func (s *semaphore) GetNumAvailableResources() (int64, error) {
+ s.options.Logger.Log(debugLvl, "received get num available resources request")
+
+ isSemaphoreExists, err := s.redis.client.Exists(s.name())
+ if err != nil {
+ return 0, errors.Wrapf(err, "failed to check if semaphore exists while getting num available resources")
+ }
+
+ if !isSemaphoreExists { //semaphore does not exists - return initial value
+ s.options.Logger.Log(debugLvl, "semaphore does not exists in redis - all resources are free")
+ return s.options.MaxParallelResources, nil
+ }
+
+ lenQueue, err := s.redis.client.LLen(s.availableQueueName())
+ if err != nil {
+ return 0, errors.Wrapf(err, "failed to get length of available resources queue")
+ }
+
+ s.options.Logger.WithFields(map[string]interface{}{
+ "num_available_resources": lenQueue,
+ }).Log(debugLvl, "retrieved num available resources for semaphore")
+
+ return lenQueue, nil
+}
diff --git a/semaphore/release_expired.go b/semaphore/release_expired.go
new file mode 100644
index 0000000..ce387db
--- /dev/null
+++ b/semaphore/release_expired.go
@@ -0,0 +1,80 @@
+package semaphore
+
+import (
+ "strconv"
+ "time"
+
+ "github.com/pkg/errors"
+)
+
+const (
+ secondaryLockValue = "v"
+ secondaryLockExpiration = 5 * time.Second
+)
+
+func (s *semaphore) releaseExpiredResources() error {
+
+ isSecondaryLockAcquired, err := s.redis.client.SetNX(s.releaseExpiredLockName(), secondaryLockValue, secondaryLockExpiration) //releasing expired resources should be done under lock
+ if err != nil {
+ return errors.Wrapf(err, "failed to setNX secondary lock for releasing expired resources")
+ }
+
+ if !isSecondaryLockAcquired {
+ s.options.Logger.Log(infoLvl, "semaphore failed to acquire secondary lock (already taken by another process) - skip releasing expired resources")
+ return nil //other process is already releasing expired resources
+ }
+
+ defer func() {
+ err := s.redis.client.Del(s.releaseExpiredLockName())
+ if err != nil {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "error": err,
+ }).Log(errLvl, "failed to delete secondary lock while releasing expired resources")
+ }
+
+ }() //release secondary lock once finished
+
+ lockedResourcesMap, err := s.redis.client.HGetAll(s.lockedResourcesName())
+ if err != nil {
+ return errors.Wrapf(err, "failed to get all locked resources")
+ }
+
+ if len(lockedResourcesMap) == 0 {
+ s.options.Logger.Log(debugLvl, "all semaphore locks are free")
+ }
+
+ var expiredResources []string
+
+ for resource, lockedAtStr := range lockedResourcesMap {
+
+ lockedAt, err := strconv.ParseInt(lockedAtStr, 10, 64)
+ if err != nil {
+ return errors.Wrapf(err, "failed to parse locked resource Expiration time %v as integer", lockedAtStr)
+ }
+
+ expireAt := lockedAt + s.options.Expiration.Nanoseconds()
+
+ if expireAt <= (time.Now().UnixNano()) { //resource is using lock more than Expiration time - release index
+
+ expiredResources = append(expiredResources, resource)
+
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resource": resource,
+ "locked_at": time.Unix(0, lockedAt).Format("15:04:05.000"),
+ "expired_at": time.Unix(0, expireAt).Format("15:04:05.000"),
+ }).Log(infoLvl, "found resource with expired lock - performing unlock")
+ } else {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resource": resource,
+ "locked_at": time.Unix(0, lockedAt).Format("15:04:05.000"),
+ "will_expire_at": time.Unix(0, expireAt).Format("15:04:05.000"),
+ }).Log(infoLvl, "resource is locked and not expired yet")
+ }
+ }
+
+ if len(expiredResources) > 0 {
+ return s.unlock(expiredResources...)
+ }
+
+ return nil
+}
diff --git a/semaphore/semaphore-logger/empty_logger_impl.go b/semaphore/semaphore-logger/empty_logger_impl.go
new file mode 100644
index 0000000..148b3d3
--- /dev/null
+++ b/semaphore/semaphore-logger/empty_logger_impl.go
@@ -0,0 +1,15 @@
+package semaphorelogger
+
+type emptyLoggerImpl struct {
+}
+
+func NewEmptyLogger() Logger {
+ return &emptyLoggerImpl{}
+}
+
+func (l *emptyLoggerImpl) WithFields(fields map[string]interface{}) Logger {
+ return l
+}
+
+func (l *emptyLoggerImpl) Log(level LogLevel, format string, args ...interface{}) {
+}
diff --git a/semaphore/semaphore-logger/logger_interface.go b/semaphore/semaphore-logger/logger_interface.go
new file mode 100644
index 0000000..54e40ae
--- /dev/null
+++ b/semaphore/semaphore-logger/logger_interface.go
@@ -0,0 +1,15 @@
+package semaphorelogger
+
+type LogLevel int
+
+const (
+ LogLevelError LogLevel = iota
+ LogLevelInfo
+ LogLevelDebug
+)
+
+//go:generate mockgen -source=./semaphore/semaphore-logger/logger_interface.go -destination=./semaphore/semaphore-logger/mock/logger_interface_mock.go Logger
+type Logger interface {
+ WithFields(fields map[string]interface{}) Logger
+ Log(level LogLevel, format string, args ...interface{})
+}
diff --git a/semaphore/semaphore-logger/logrus_impl.go b/semaphore/semaphore-logger/logrus_impl.go
new file mode 100644
index 0000000..c5c3917
--- /dev/null
+++ b/semaphore/semaphore-logger/logrus_impl.go
@@ -0,0 +1,44 @@
+package semaphorelogger
+
+import (
+ "fmt"
+
+ "github.com/sirupsen/logrus"
+)
+
+type logrusImpl struct {
+ logger *logrus.Entry
+ bindingKey string
+}
+
+func NewLogrusLogger(logger *logrus.Logger, level LogLevel, bindingKey string) Logger {
+ if logger == nil {
+ logger = logrus.StandardLogger()
+ }
+
+ logger.SetLevel(parseLevel(level))
+
+ return &logrusImpl{logger: logger.WithField("binding_key", bindingKey), bindingKey: bindingKey}
+}
+
+func (l *logrusImpl) WithFields(fields map[string]interface{}) Logger {
+ return &logrusImpl{logger: l.logger.WithFields(logrus.Fields(fields)).WithField("binding_key", l.bindingKey), bindingKey: l.bindingKey}
+}
+
+func (l *logrusImpl) Log(level LogLevel, format string, args ...interface{}) {
+ l.logger.Log(parseLevel(level), fmt.Sprintf(format, args...))
+}
+
+func parseLevel(level LogLevel) logrus.Level {
+ switch level {
+
+ case LogLevelError:
+ return logrus.ErrorLevel
+ case LogLevelInfo:
+ return logrus.InfoLevel
+ case LogLevelDebug:
+ return logrus.DebugLevel
+ default:
+ return logrus.TraceLevel
+ }
+}
diff --git a/semaphore/semaphore-logger/mock/logger_interface_mock.go b/semaphore/semaphore-logger/mock/logger_interface_mock.go
new file mode 100644
index 0000000..d0df685
--- /dev/null
+++ b/semaphore/semaphore-logger/mock/logger_interface_mock.go
@@ -0,0 +1,80 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: ./semaphore/semaphore-logger/logger_interface.go
+
+// Package mock_semaphorelogger is a generated GoMock package.
+package mock_semaphorelogger
+
+import (
+ reflect "reflect"
+
+ gomock "github.com/golang/mock/gomock"
+ semaphore_logger "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-logger"
+)
+
+// MockLogger is a mock of Logger interface
+type MockLogger struct {
+ ctrl *gomock.Controller
+ recorder *MockLoggerMockRecorder
+}
+
+// MockLoggerMockRecorder is the mock recorder for MockLogger
+type MockLoggerMockRecorder struct {
+ mock *MockLogger
+}
+
+// NewMockLogger creates a new mock instance
+func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
+ mock := &MockLogger{ctrl: ctrl}
+ mock.recorder = &MockLoggerMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
+ return m.recorder
+}
+
+// GetLevel mocks base method
+func (m *MockLogger) GetLevel() semaphore_logger.LogLevel {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetLevel")
+ ret0, _ := ret[0].(semaphore_logger.LogLevel)
+ return ret0
+}
+
+// GetLevel indicates an expected call of GetLevel
+func (mr *MockLoggerMockRecorder) GetLevel() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLevel", reflect.TypeOf((*MockLogger)(nil).GetLevel))
+}
+
+// WithFields mocks base method
+func (m *MockLogger) WithFields(fields map[string]interface{}) semaphore_logger.Logger {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "WithFields", fields)
+ ret0, _ := ret[0].(semaphore_logger.Logger)
+ return ret0
+}
+
+// WithFields indicates an expected call of WithFields
+func (mr *MockLoggerMockRecorder) WithFields(fields interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithFields", reflect.TypeOf((*MockLogger)(nil).WithFields), fields)
+}
+
+// Log mocks base method
+func (m *MockLogger) Log(level semaphore_logger.LogLevel, format string, args ...interface{}) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{level, format}
+ for _, a := range args {
+ varargs = append(varargs, a)
+ }
+ m.ctrl.Call(m, "Log", varargs...)
+}
+
+// Log indicates an expected call of Log
+func (mr *MockLoggerMockRecorder) Log(level, format interface{}, args ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{level, format}, args...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Log", reflect.TypeOf((*MockLogger)(nil).Log), varargs...)
+}
diff --git a/semaphore/semaphore-redis/go_redis_impl.go b/semaphore/semaphore-redis/go_redis_impl.go
new file mode 100644
index 0000000..8cc0c76
--- /dev/null
+++ b/semaphore/semaphore-redis/go_redis_impl.go
@@ -0,0 +1,102 @@
+package semaphoreredis
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/go-redis/redis"
+ "github.com/pkg/errors"
+)
+
+type GoRedisImpl struct {
+ Client *redis.Client
+}
+
+func NewGoRedisImpl(client *redis.Client) Redis {
+ return &GoRedisImpl{Client: client}
+}
+
+func (c *GoRedisImpl) Set(key string, value interface{}, expiration time.Duration) error {
+ return c.Client.Set(key, value, expiration).Err()
+}
+
+func (c *GoRedisImpl) SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
+ return c.Client.SetNX(key, value, expiration).Result()
+}
+
+func (c *GoRedisImpl) Exists(key string) (bool, error) {
+ numExistingKeys, err := c.Client.Exists(key).Result()
+ if err != nil {
+ return false, err
+ }
+
+ return numExistingKeys == 1, nil
+}
+
+func (c *GoRedisImpl) TxPipelined(f func(pipe Pipeline) error) error {
+ _, err := c.Client.TxPipelined(func(pipeline redis.Pipeliner) error {
+ pipe := &goRedisPipelineImpl{
+ pipeline: pipeline,
+ }
+
+ return f(pipe)
+ })
+
+ return err
+}
+
+func (c *GoRedisImpl) BLPop(timeout time.Duration, keys ...string) (string, bool, error) {
+ keyVal, err := c.Client.BLPop(timeout, keys...).Result()
+ if err != nil {
+ if err == redis.Nil {
+ return "", true, nil
+ }
+ return "", false, errors.Wrapf(err, "failed to pop available resource from queue")
+ }
+
+ if len(keyVal) != 2 {
+ return "", false, fmt.Errorf("received unexpected value from redis in response to redis blpop command: %v", keyVal)
+ }
+
+ return keyVal[1], false, nil
+}
+
+func (c *GoRedisImpl) LLen(key string) (int64, error) {
+ return c.Client.LLen(key).Result()
+}
+
+func (c *GoRedisImpl) HSet(key, field, value string) error {
+ return c.Client.HSet(key, field, value).Err()
+}
+
+func (c *GoRedisImpl) HGetAll(key string) (map[string]string, error) {
+ return c.Client.HGetAll(key).Result()
+}
+
+func (c *GoRedisImpl) HExists(key, field string) (bool, error) {
+ return c.Client.HExists(key, field).Result()
+}
+
+func (c *GoRedisImpl) Del(keys ...string) error {
+ return c.Client.Del(keys...).Err()
+}
+
+type goRedisPipelineImpl struct {
+ pipeline redis.Pipeliner
+}
+
+func (c *goRedisPipelineImpl) Del(keys ...string) error {
+ return c.pipeline.Del(keys...).Err()
+}
+
+func (c *goRedisPipelineImpl) RPush(key string, values ...interface{}) error {
+ return c.pipeline.RPush(key, values...).Err()
+}
+
+func (c *goRedisPipelineImpl) HDel(key string, fields ...string) error {
+ return c.pipeline.HDel(key, fields...).Err()
+}
+
+func (c *goRedisPipelineImpl) PExpire(key string, expiration time.Duration) error {
+ return c.pipeline.PExpire(key, expiration).Err()
+}
diff --git a/semaphore/semaphore-redis/mock/redis_interface_mock.go b/semaphore/semaphore-redis/mock/redis_interface_mock.go
new file mode 100644
index 0000000..2726685
--- /dev/null
+++ b/semaphore/semaphore-redis/mock/redis_interface_mock.go
@@ -0,0 +1,285 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: ./semaphore/semaphore-redis/redis_interface.go
+
+// Package mock_semaphoreredis is a generated GoMock package.
+package mock_semaphoreredis
+
+import (
+ reflect "reflect"
+ time "time"
+
+ gomock "github.com/golang/mock/gomock"
+ semaphore_redis "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-redis"
+)
+
+// MockRedis is a mock of Redis interface
+type MockRedis struct {
+ ctrl *gomock.Controller
+ recorder *MockRedisMockRecorder
+}
+
+// MockRedisMockRecorder is the mock recorder for MockRedis
+type MockRedisMockRecorder struct {
+ mock *MockRedis
+}
+
+// NewMockRedis creates a new mock instance
+func NewMockRedis(ctrl *gomock.Controller) *MockRedis {
+ mock := &MockRedis{ctrl: ctrl}
+ mock.recorder = &MockRedisMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockRedis) EXPECT() *MockRedisMockRecorder {
+ return m.recorder
+}
+
+// Set mocks base method
+func (m *MockRedis) Set(key string, value interface{}, expiration time.Duration) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Set", key, value, expiration)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Set indicates an expected call of Set
+func (mr *MockRedisMockRecorder) Set(key, value, expiration interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockRedis)(nil).Set), key, value, expiration)
+}
+
+// SetNX mocks base method
+func (m *MockRedis) SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SetNX", key, value, expiration)
+ ret0, _ := ret[0].(bool)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// SetNX indicates an expected call of SetNX
+func (mr *MockRedisMockRecorder) SetNX(key, value, expiration interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNX", reflect.TypeOf((*MockRedis)(nil).SetNX), key, value, expiration)
+}
+
+// Exists mocks base method
+func (m *MockRedis) Exists(key string) (bool, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Exists", key)
+ ret0, _ := ret[0].(bool)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Exists indicates an expected call of Exists
+func (mr *MockRedisMockRecorder) Exists(key interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockRedis)(nil).Exists), key)
+}
+
+// PExpire mocks base method
+func (m *MockRedis) PExpire(key string, expiration time.Duration) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "PExpire", key, expiration)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// PExpire indicates an expected call of PExpire
+func (mr *MockRedisMockRecorder) PExpire(key, expiration interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PExpire", reflect.TypeOf((*MockRedis)(nil).PExpire), key, expiration)
+}
+
+// TxPipelined mocks base method
+func (m *MockRedis) TxPipelined(f func(semaphore_redis.Pipeline) error) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "TxPipelined", f)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// TxPipelined indicates an expected call of TxPipelined
+func (mr *MockRedisMockRecorder) TxPipelined(f interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxPipelined", reflect.TypeOf((*MockRedis)(nil).TxPipelined), f)
+}
+
+// BLPop mocks base method
+func (m *MockRedis) BLPop(timeout time.Duration, keys ...string) (string, bool, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{timeout}
+ for _, a := range keys {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "BLPop", varargs...)
+ ret0, _ := ret[0].(string)
+ ret1, _ := ret[1].(bool)
+ ret2, _ := ret[2].(error)
+ return ret0, ret1, ret2
+}
+
+// BLPop indicates an expected call of BLPop
+func (mr *MockRedisMockRecorder) BLPop(timeout interface{}, keys ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{timeout}, keys...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BLPop", reflect.TypeOf((*MockRedis)(nil).BLPop), varargs...)
+}
+
+// LLen mocks base method
+func (m *MockRedis) LLen(key string) (int64, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LLen", key)
+ ret0, _ := ret[0].(int64)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LLen indicates an expected call of LLen
+func (mr *MockRedisMockRecorder) LLen(key interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LLen", reflect.TypeOf((*MockRedis)(nil).LLen), key)
+}
+
+// HSet mocks base method
+func (m *MockRedis) HSet(key, field, value string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "HSet", key, field, value)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// HSet indicates an expected call of HSet
+func (mr *MockRedisMockRecorder) HSet(key, field, value interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HSet", reflect.TypeOf((*MockRedis)(nil).HSet), key, field, value)
+}
+
+// HGetAll mocks base method
+func (m *MockRedis) HGetAll(key string) (map[string]string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "HGetAll", key)
+ ret0, _ := ret[0].(map[string]string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// HGetAll indicates an expected call of HGetAll
+func (mr *MockRedisMockRecorder) HGetAll(key interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HGetAll", reflect.TypeOf((*MockRedis)(nil).HGetAll), key)
+}
+
+// HExists mocks base method
+func (m *MockRedis) HExists(key, field string) (bool, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "HExists", key, field)
+ ret0, _ := ret[0].(bool)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// HExists indicates an expected call of HExists
+func (mr *MockRedisMockRecorder) HExists(key, field interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HExists", reflect.TypeOf((*MockRedis)(nil).HExists), key, field)
+}
+
+// Del mocks base method
+func (m *MockRedis) Del(keys ...string) error {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{}
+ for _, a := range keys {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "Del", varargs...)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Del indicates an expected call of Del
+func (mr *MockRedisMockRecorder) Del(keys ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Del", reflect.TypeOf((*MockRedis)(nil).Del), keys...)
+}
+
+// MockPipeline is a mock of Pipeline interface
+type MockPipeline struct {
+ ctrl *gomock.Controller
+ recorder *MockPipelineMockRecorder
+}
+
+// MockPipelineMockRecorder is the mock recorder for MockPipeline
+type MockPipelineMockRecorder struct {
+ mock *MockPipeline
+}
+
+// NewMockPipeline creates a new mock instance
+func NewMockPipeline(ctrl *gomock.Controller) *MockPipeline {
+ mock := &MockPipeline{ctrl: ctrl}
+ mock.recorder = &MockPipelineMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockPipeline) EXPECT() *MockPipelineMockRecorder {
+ return m.recorder
+}
+
+// Del mocks base method
+func (m *MockPipeline) Del(keys ...string) error {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{}
+ for _, a := range keys {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "Del", varargs...)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Del indicates an expected call of Del
+func (mr *MockPipelineMockRecorder) Del(keys ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Del", reflect.TypeOf((*MockPipeline)(nil).Del), keys...)
+}
+
+// HDel mocks base method
+func (m *MockPipeline) HDel(key string, fields ...string) error {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{key}
+ for _, a := range fields {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "HDel", varargs...)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// HDel indicates an expected call of HDel
+func (mr *MockPipelineMockRecorder) HDel(key interface{}, fields ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{key}, fields...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HDel", reflect.TypeOf((*MockPipeline)(nil).HDel), varargs...)
+}
+
+// RPush mocks base method
+func (m *MockPipeline) RPush(key string, values ...interface{}) error {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{key}
+ for _, a := range values {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "RPush", varargs...)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// RPush indicates an expected call of RPush
+func (mr *MockPipelineMockRecorder) RPush(key interface{}, values ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{key}, values...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPush", reflect.TypeOf((*MockPipeline)(nil).RPush), varargs...)
+}
diff --git a/semaphore/semaphore-redis/redis_interface.go b/semaphore/semaphore-redis/redis_interface.go
new file mode 100644
index 0000000..5c3970c
--- /dev/null
+++ b/semaphore/semaphore-redis/redis_interface.go
@@ -0,0 +1,26 @@
+package semaphoreredis
+
+import (
+ "time"
+)
+
+//go:generate mockgen -source=./semaphore/semaphore-redis/redis_interface.go -destination=./semaphore/semaphore-redis/mock/redis_interface_mock.go Redis
+type Redis interface {
+ Set(key string, value interface{}, expiration time.Duration) error
+ SetNX(key string, value interface{}, expiration time.Duration) (isSet bool, err error)
+ Exists(key string) (bool, error)
+ TxPipelined(f func(pipe Pipeline) error) error
+ BLPop(timeout time.Duration, keys ...string) (val string, isTimedOut bool, err error)
+ LLen(key string) (len int64, err error)
+ HSet(key, field, value string) error
+ HGetAll(key string) (keyVal map[string]string, err error)
+ HExists(key, field string) (isExists bool, err error)
+ Del(keys ...string) error
+}
+
+type Pipeline interface {
+ Del(keys ...string) error
+ HDel(key string, fields ...string) error
+ RPush(key string, values ...interface{}) error
+ PExpire(key string, expiration time.Duration) error
+}
diff --git a/semaphore/semaphore-redis/redis_v5_impl.go b/semaphore/semaphore-redis/redis_v5_impl.go
new file mode 100644
index 0000000..9e55160
--- /dev/null
+++ b/semaphore/semaphore-redis/redis_v5_impl.go
@@ -0,0 +1,97 @@
+package semaphoreredis
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/pkg/errors"
+ "gopkg.in/redis.v5"
+)
+
+type RedisV5Impl struct {
+ Client *redis.Client
+}
+
+func NewRedisV5Client(client *redis.Client) Redis {
+ return &RedisV5Impl{Client: client}
+}
+
+func (c *RedisV5Impl) Set(key string, value interface{}, expiration time.Duration) error {
+ return c.Client.Set(key, value, expiration).Err()
+}
+
+func (c *RedisV5Impl) SetNX(key string, value interface{}, expiration time.Duration) (bool, error) {
+ return c.Client.SetNX(key, value, expiration).Result()
+}
+
+func (c *RedisV5Impl) Exists(key string) (bool, error) {
+ return c.Client.Exists(key).Result()
+}
+
+func (c *RedisV5Impl) TxPipelined(f func(pipe Pipeline) error) error {
+ _, err := c.Client.TxPipelined(func(pipeline *redis.Pipeline) error {
+ pipe := &redisV5PipelineImpl{
+ pipeline: pipeline,
+ }
+
+ return f(pipe)
+ })
+
+ return err
+}
+
+func (c *RedisV5Impl) BLPop(timeout time.Duration, keys ...string) (string, bool, error) {
+ keyVal, err := c.Client.BLPop(timeout, keys...).Result()
+ if err != nil {
+ if err == redis.Nil {
+ return "", true, nil
+ }
+ return "", false, errors.Wrapf(err, "failed to pop available resource from queue")
+ }
+
+ if len(keyVal) != 2 {
+ return "", false, fmt.Errorf("received unexpected value from redis in response to redis blpop command: %v", keyVal)
+ }
+
+ return keyVal[1], false, nil
+}
+
+func (c *RedisV5Impl) LLen(key string) (int64, error) {
+ return c.Client.LLen(key).Result()
+}
+
+func (c *RedisV5Impl) HSet(key, field, value string) error {
+ return c.Client.HSet(key, field, value).Err()
+}
+
+func (c *RedisV5Impl) HGetAll(key string) (map[string]string, error) {
+ return c.Client.HGetAll(key).Result()
+}
+
+func (c *RedisV5Impl) HExists(key, field string) (bool, error) {
+ return c.Client.HExists(key, field).Result()
+}
+
+func (c *RedisV5Impl) Del(keys ...string) error {
+ return c.Client.Del(keys...).Err()
+}
+
+type redisV5PipelineImpl struct {
+ pipeline *redis.Pipeline
+}
+
+func (c *redisV5PipelineImpl) Del(keys ...string) error {
+ return c.pipeline.Del(keys...).Err()
+}
+
+func (c *redisV5PipelineImpl) RPush(key string, values ...interface{}) error {
+ return c.pipeline.RPush(key, values...).Err()
+}
+
+func (c *redisV5PipelineImpl) HDel(key string, fields ...string) error {
+ return c.pipeline.HDel(key, fields...).Err()
+}
+
+func (c *redisV5PipelineImpl) PExpire(key string, expiration time.Duration) error {
+ return c.pipeline.PExpire(key, expiration).Err()
+}
diff --git a/semaphore/semaphore.go b/semaphore/semaphore.go
new file mode 100644
index 0000000..e06d66d
--- /dev/null
+++ b/semaphore/semaphore.go
@@ -0,0 +1,12 @@
+package semaphore
+
+import "time"
+
+//go:generate mockgen -source=./semaphore/semaphore.go -destination=./semaphore/mock/semaphore_mock.go Semaphore
+type Semaphore interface {
+ Lock() (string, error)
+ LockWithCustomTimeout(timeout time.Duration) (string, error)
+ Unlock(resource string) error
+ IsResourceLocked(resource string) (bool, error)
+ GetNumAvailableResources() (int64, error)
+}
diff --git a/semaphore/semaphore_test.go b/semaphore/semaphore_test.go
new file mode 100644
index 0000000..ebbb645
--- /dev/null
+++ b/semaphore/semaphore_test.go
@@ -0,0 +1,775 @@
+package semaphore
+
+import (
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-logger"
+
+ "github.com/pkg/errors"
+
+ "github.com/tylerb/gls"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+func TestSemaphore(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "semaphore suit")
+}
+
+var _ = Describe("semaphore Tests", func() {
+
+ var (
+ s *semaphore
+ logger Logger
+ redis testRedis
+ lockByKey string
+ err error
+ )
+
+ BeforeEach(func() {
+ redis = redisClient
+ })
+
+ AfterEach(func() {
+ Expect(redis.FlushAll()).To(Succeed())
+ })
+
+ var _ = Describe(".New", func() {
+
+ var (
+ res *semaphore
+ overrides []Options
+ )
+
+ JustBeforeEach(func() {
+ res, err = create(lockByKey, redis, overrides...)
+ })
+
+ Context("when lock by key is empty", func() {
+ BeforeEach(func() {
+ lockByKey = ""
+ })
+
+ It("should return error", func() {
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("lock by key field must be non empty"))
+ Expect(res).To(BeZero())
+ })
+ })
+
+ Context("when lock by key is not empty", func() {
+ BeforeEach(func() {
+ lockByKey = "semaphore_constructor_test"
+ })
+
+ Context("when override Options is empty struct", func() {
+ BeforeEach(func() {
+ overrides = []Options{}
+ })
+
+ It("should create semaphore with default Options", func() {
+ Expect(err).To(Succeed())
+ Expect(res.options).To(Equal(defaultOptions))
+ })
+
+ It("should not create any key in redis", func() {
+ for _, key := range res.redis.keys {
+ exists, err := redis.Exists(key)
+ Expect(err).To(Succeed())
+ Expect(exists).To(BeFalse())
+ }
+ })
+ })
+
+ Context("when given more than one override Options", func() {
+ BeforeEach(func() {
+ overrides = []Options{{}, {MaxParallelResources: 3}}
+ })
+
+ It("should ignore all overrides except the first", func() {
+ Expect(err).To(Succeed())
+ Expect(res.options).To(Equal(defaultOptions))
+ })
+ })
+
+ Context("when given exactly one non empty override Options", func() {
+
+ var overrideSetting = Options{
+ Expiration: 1 * time.Minute,
+ TryLockTimeout: 30 * time.Second,
+ MaxParallelResources: 3,
+ LockAttempts: 2,
+ Logger: semaphorelogger.NewLogrusLogger(nil, debugLvl, lockByKey),
+ }
+
+ BeforeEach(func() {
+ overrides = []Options{overrideSetting}
+ })
+
+ Context("when overriding expiration with invalid value", func() {
+ BeforeEach(func() {
+ overrides[0].Expiration = time.Millisecond
+ })
+
+ It("should return error", func() {
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("expiration time must be at least 1 second"))
+ Expect(res).To(BeZero())
+ })
+ })
+
+ Context("when overriding expiration with valid value", func() {
+ BeforeEach(func() {
+ overrides[0].Expiration = overrideSetting.Expiration
+ })
+
+ Context("when overriding try lock timeout with invalid value", func() {
+ BeforeEach(func() {
+ overrides[0].TryLockTimeout = overrides[0].Expiration + 1
+ })
+
+ It("should return error", func() {
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("try lock timeout must be at least 1 second and smaller or equal to semaphore Expiration time"))
+ Expect(res).To(BeZero())
+ })
+ })
+
+ Context("when overriding try lock timeout with valid value", func() {
+
+ Context("when overriding max parallel resources with invalid value", func() {
+ BeforeEach(func() {
+ overrides[0].MaxParallelResources = -2
+ })
+
+ It("should return error", func() {
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("max parallel resources setting must be positive number"))
+ Expect(res).To(BeZero())
+ })
+ })
+
+ Context("when overriding max parallel resources with valid value", func() {
+ BeforeEach(func() {
+ overrides[0].MaxParallelResources = overrideSetting.MaxParallelResources
+ })
+
+ Context("when overriding lock attempts with invalid value", func() {
+ BeforeEach(func() {
+ overrides[0].LockAttempts = -1
+ })
+
+ It("should return error", func() {
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("lock attempts setting must be positive number"))
+ Expect(res).To(BeZero())
+ })
+ })
+
+ Context("when overriding lock attempts with valid value", func() {
+ BeforeEach(func() {
+ overrides[0].LockAttempts = overrideSetting.LockAttempts
+ })
+
+ Context("when not providing logger", func() {
+ BeforeEach(func() {
+ overrides[0].Logger = nil
+ })
+
+ It("should create semaphore with override options and empty logger", func() {
+ Expect(err).To(Succeed())
+ Expect(res.options.TryLockTimeout).To(Equal(overrideSetting.TryLockTimeout))
+ Expect(res.options.LockAttempts).To(Equal(overrideSetting.LockAttempts))
+ Expect(res.options.MaxParallelResources).To(Equal(overrideSetting.MaxParallelResources))
+ Expect(res.options.Logger).To(Equal(semaphorelogger.NewEmptyLogger()))
+
+ })
+ })
+
+ Context("when providing logger", func() {
+ BeforeEach(func() {
+ overrides[0].Logger = overrideSetting.Logger
+ })
+
+ It("should create semaphore with override options & logger", func() {
+ Expect(err).To(Succeed())
+ Expect(res.options).To(Equal(overrideSetting))
+ })
+ })
+ })
+ })
+ })
+ })
+ })
+ })
+ })
+
+ var _ = Describe(".WithMutex", func() {
+
+ var (
+ override Options
+ funcDuration time.Duration
+ functionCalledCounter int64
+ numLocksToPerform int64
+ wg sync.WaitGroup
+ done chan bool
+ resource string
+ )
+
+ BeforeEach(func() {
+ lockByKey = "with_mutex_test_key"
+ logger = semaphorelogger.NewLogrusLogger(nil, debugLvl, lockByKey)
+ override = Options{TryLockTimeout: time.Second, Expiration: 3 * time.Second, MaxParallelResources: 5, Logger: logger}
+ functionCalledCounter = 0
+ done = make(chan bool)
+ wg = sync.WaitGroup{}
+ })
+
+ BeforeEach(func() {
+ s, err = create(lockByKey, redis, override)
+ Expect(err).To(Succeed())
+ })
+
+ JustBeforeEach(func() {
+ gls.Go(func() {
+ time.Sleep(50 * time.Millisecond)
+ err = WithMutex(lockByKey, redis, func() { atomic.AddInt64(&functionCalledCounter, 1) }, []Options{override}...)
+ for i := 0; i < int(atomic.LoadInt64(&numLocksToPerform)); i++ {
+ wg.Done()
+ }
+ })
+ })
+
+ AfterEach(func() {
+ Expect(redis.FlushAll()).To(Succeed())
+ })
+
+ Context("when performing only a single lock", func() {
+ BeforeEach(func() {
+ addLocks(&wg, &numLocksToPerform, 1)
+ })
+
+ It("should lock, run func, unlock and return no error", func() {
+ wg.Wait()
+ Expect(err).To(Succeed())
+ Expect(atomic.LoadInt64(&functionCalledCounter)).To(Equal(atomic.LoadInt64(&numLocksToPerform)))
+ validateNumFreeLocks(s, s.options.MaxParallelResources)
+ })
+ })
+
+ Context("when all locks are taken", func() {
+ BeforeEach(func() {
+ for i := 0; i < int(s.options.MaxParallelResources); i++ {
+ resource, err = s.Lock()
+ Expect(err).To(Succeed())
+ atomic.AddInt64(&functionCalledCounter, 1)
+ }
+ addLocks(&wg, &numLocksToPerform, s.options.MaxParallelResources)
+ })
+
+ Context("when no lock released while waiting to acquire lock", func() {
+ BeforeEach(func() {
+ // do nothing
+ })
+
+ It("should not lock the last function and return timeout error", func() {
+ wg.Wait()
+ Expect(err).To(HaveOccurred())
+ Expect(errors.Cause(err)).To(Equal(TimeoutError))
+ })
+ })
+
+ Context("when one lock released while waiting to acquire lock", func() {
+ JustBeforeEach(func() {
+ Expect(s.Unlock(resource)).To(Succeed())
+ atomic.AddInt64(&functionCalledCounter, -1)
+ })
+
+ It("should lock, run func, unlock and return no error", func() {
+ wg.Wait()
+ Expect(err).To(Succeed())
+ Expect(atomic.LoadInt64(&functionCalledCounter)).To(Equal(atomic.LoadInt64(&numLocksToPerform)))
+ validateNumFreeLocks(s, 1)
+ })
+ })
+ })
+
+ Context("when function duration < try lock timeout time", func() {
+ BeforeEach(func() {
+ funcDuration = s.options.TryLockTimeout / 2
+ })
+
+ Context("when num simultaneous locks <= max parallel resources", func() {
+ BeforeEach(func() {
+ runLocksInParallel(s, &wg, funcDuration, &numLocksToPerform, &functionCalledCounter, s.options.MaxParallelResources, done)
+ })
+
+ It("should lock, run func, unlock and return no error", func() {
+ for i := 0; i < int(atomic.LoadInt64(&numLocksToPerform))-1; i++ {
+ <-done
+ }
+ Expect(err).To(Succeed())
+ Expect(atomic.LoadInt64(&functionCalledCounter)).To(Equal(atomic.LoadInt64(&numLocksToPerform)))
+ validateNumFreeLocks(s, s.options.MaxParallelResources)
+ })
+ })
+
+ Context("when num simultaneous locks > max parallel resources", func() {
+ BeforeEach(func() {
+ runLocksInParallel(s, &wg, funcDuration, &numLocksToPerform, &functionCalledCounter, s.options.MaxParallelResources+1, done)
+ })
+
+ It("should not lock the last function and return timeout error", func() {
+ for i := 0; i < int(atomic.LoadInt64(&numLocksToPerform))-1; i++ {
+ <-done
+ }
+ Expect(err).To(HaveOccurred())
+ Expect(errors.Cause(err)).To(Equal(TimeoutError))
+ Expect(atomic.LoadInt64(&functionCalledCounter)).To(Equal(atomic.LoadInt64(&numLocksToPerform) - 1))
+ validateNumFreeLocks(s, s.options.MaxParallelResources)
+ })
+ })
+ })
+
+ Context("when function duration > try lock timeout time", func() {
+ BeforeEach(func() {
+ funcDuration = s.options.TryLockTimeout * 2
+ })
+
+ Context("when num simultaneous locks <= max parallel resources", func() {
+ BeforeEach(func() {
+ runLocksInParallel(s, &wg, funcDuration, &numLocksToPerform, &functionCalledCounter, s.options.MaxParallelResources-1, done)
+ })
+
+ It("should not lock the last function and return timeout error", func() {
+ for i := 0; i < int(atomic.LoadInt64(&numLocksToPerform))-1; i++ {
+ <-done
+ }
+ Expect(err).To(Succeed())
+ Expect(atomic.LoadInt64(&functionCalledCounter)).To(Equal(atomic.LoadInt64(&numLocksToPerform)))
+ validateNumFreeLocks(s, s.options.MaxParallelResources)
+ })
+ })
+
+ Context("when num simultaneous locks > max parallel resources", func() {
+ BeforeEach(func() {
+ runLocksInParallel(s, &wg, funcDuration, &numLocksToPerform, &functionCalledCounter, s.options.MaxParallelResources+1, done)
+ })
+
+ It("should not lock the last function and return timeout error", func() {
+ for i := 0; i < int(atomic.LoadInt64(&numLocksToPerform))-1; i++ {
+ <-done
+ }
+ Expect(err).To(HaveOccurred())
+ Expect(errors.Cause(err)).To(Equal(TimeoutError))
+ Expect(atomic.LoadInt64(&functionCalledCounter)).To(Equal(atomic.LoadInt64(&numLocksToPerform) - 1))
+ validateNumFreeLocks(s, s.options.MaxParallelResources)
+ })
+ })
+ })
+ })
+
+ Describe(".Lock", func() {
+
+ var (
+ numLockedResources int64
+ res string
+ )
+
+ BeforeEach(func() {
+ lockByKey = "lock_test_key"
+ logger = semaphorelogger.NewLogrusLogger(nil, debugLvl, lockByKey)
+ s, err = create(lockByKey, redis, []Options{{TryLockTimeout: time.Second, Expiration: 2 * time.Second, MaxParallelResources: 3, LockAttempts: 2, Logger: logger}}...)
+ Expect(err).To(Succeed())
+ })
+
+ BeforeEach(func() {
+ numLockedResources = 0
+ })
+
+ JustBeforeEach(func() {
+ res, err = s.Lock()
+ numLockedResources++
+ })
+
+ Context("when we use this semaphore for the first time", func() {
+ BeforeEach(func() {
+ numLockedResources = 0
+ })
+
+ It("should lock resource and return no error", func() {
+ Expect(err).To(Succeed())
+ validateSuccessfulLock(s, redis, res, numLockedResources)
+ })
+ })
+
+ Context("when this semaphore is already in use", func() {
+
+ Context("when there is at least one free resource", func() {
+ BeforeEach(func() {
+ for i := 0; i < int(s.options.MaxParallelResources)-1; i++ {
+ _, err = s.Lock()
+ Expect(err).To(Succeed())
+ numLockedResources++
+ }
+ })
+
+ It("should lock resource and return no error", func() {
+ Expect(err).To(Succeed())
+ validateSuccessfulLock(s, redis, res, numLockedResources)
+ })
+ })
+
+ Context("when all resources are taken and not expired", func() {
+ BeforeEach(func() {
+ for i := 0; i < int(s.options.MaxParallelResources); i++ {
+ _, err = s.Lock()
+ Expect(err).To(Succeed())
+ numLockedResources++
+ }
+ })
+
+ It("should reach timeout and return timeout error", func() {
+ Expect(err).To(HaveOccurred())
+ Expect(err).To(Equal(TimeoutError))
+ Expect(res).To(BeZero())
+ })
+ })
+
+ Context("when reach timeout on first attempt but some resources expired and succeed on second attempt", func() {
+ BeforeEach(func() {
+ for i := 0; i < int(s.options.MaxParallelResources); i++ {
+ _, err = s.Lock()
+ Expect(err).To(Succeed())
+ numLockedResources++
+ time.Sleep(s.options.Expiration / time.Duration(s.options.MaxParallelResources+1))
+ }
+ })
+
+ It("should release expired resources, lock resource and return no error", func() {
+ Expect(err).To(Succeed())
+ validateSuccessfulLock(s, redis, res, numLockedResources-2) //two resources will be expired after first attempt
+ })
+ })
+ })
+ })
+
+ Describe(".LockWithCustomTimeout", func() {
+
+ var (
+ timeout time.Duration
+ res string
+ )
+
+ BeforeEach(func() {
+ lockByKey = "lock_with_custom_timeout_test_key"
+ logger = semaphorelogger.NewLogrusLogger(nil, debugLvl, lockByKey)
+ s, err = create(lockByKey, redis, Options{Logger: logger})
+ Expect(err).To(Succeed())
+ })
+
+ JustBeforeEach(func() {
+ res, err = s.LockWithCustomTimeout(timeout)
+ })
+
+ Context("when timeout is smaller than 1 second", func() {
+ BeforeEach(func() {
+ timeout = 0
+ })
+
+ It("should return error", func() {
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("try lock timeout must be at least 1 second and smaller or equal to semaphore Expiration time"))
+ })
+ })
+
+ Context("when timeout is greater than semaphore Expiration time", func() {
+ BeforeEach(func() {
+ timeout = s.options.Expiration + 1
+ })
+
+ It("should return error", func() {
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("try lock timeout must be at least 1 second and smaller or equal to semaphore Expiration time"))
+ })
+ })
+
+ Context("when timeout is a positive duration smaller than semaphore Expiration time", func() {
+ BeforeEach(func() {
+ timeout = s.options.Expiration - 1
+ })
+
+ It("should perform lock operation", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(Not(BeEmpty()))
+ })
+ })
+ })
+
+ Describe(".Unlock", func() {
+
+ var resource string
+
+ BeforeEach(func() {
+ lockByKey = "unlock_test_key"
+ logger = semaphorelogger.NewLogrusLogger(nil, debugLvl, lockByKey)
+ s, err = create(lockByKey, redis, []Options{{MaxParallelResources: 3, Logger: logger}}...)
+ Expect(err).To(Succeed())
+ })
+
+ JustBeforeEach(func() {
+ err = s.Unlock(resource)
+ })
+
+ Context("when resource is not locked", func() {
+ BeforeEach(func() {
+ // do nothing
+ })
+
+ It("should do nothing and return no error", func() {
+ Expect(err).To(Succeed())
+ })
+ })
+
+ Context("when resource is locked", func() {
+ BeforeEach(func() {
+ resource, err = s.Lock()
+ Expect(err).To(Succeed())
+ })
+
+ It("should unlock resource and return no error", func() {
+ Expect(err).To(Succeed())
+ validateSuccessfulUnLock(s, redis, resource, 0)
+ })
+ })
+ })
+
+ Describe(".IsResourceLocked", func() {
+
+ var (
+ checkedResource, lockedResource string
+ res bool
+ )
+
+ BeforeEach(func() {
+ lockByKey = "is_resource_locked_test"
+ logger = semaphorelogger.NewLogrusLogger(nil, debugLvl, lockByKey)
+ s, err = create(lockByKey, redis, []Options{{TryLockTimeout: time.Second, Expiration: time.Second, Logger: logger}}...)
+ Expect(err).To(Succeed())
+ })
+
+ JustBeforeEach(func() {
+ res, err = s.IsResourceLocked(checkedResource)
+ })
+
+ Context("when semaphore not being locked yet", func() {
+ BeforeEach(func() {
+ // do nothing
+ })
+
+ It("should return false", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(BeFalse())
+ })
+ })
+
+ Context("when there are locked resources", func() {
+ BeforeEach(func() {
+ lockedResource, err = s.Lock()
+ Expect(err).To(Succeed())
+ })
+
+ Context("when locked resource != checked resource", func() {
+ BeforeEach(func() {
+ checkedResource = lockedResource + "1"
+ })
+
+ It("should return false", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(BeFalse())
+ })
+ })
+
+ Context("when locked resource = checked resource", func() {
+ BeforeEach(func() {
+ checkedResource = lockedResource
+ })
+
+ Context("when semaphore name key has expired", func() {
+
+ Context("when semaphore locked queue key has not expired yet", func() {
+ BeforeEach(func() {
+ time.Sleep(s.options.Expiration)
+ })
+
+ It("should return true", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(BeTrue())
+ })
+ })
+
+ Context("when semaphore locked queue key has also expired", func() {
+ BeforeEach(func() {
+ time.Sleep(s.options.Expiration * 2)
+ })
+
+ It("should return false", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(BeFalse())
+ })
+ })
+ })
+
+ Context("when semaphore name key has not expired", func() {
+ BeforeEach(func() {
+ // do nothing
+ })
+
+ It("should return true", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(BeTrue())
+ })
+ })
+ })
+ })
+ })
+
+ Describe(".GetNumAvailableResources", func() {
+
+ var (
+ semaphoreSize, numLockedResources int64
+ res int64
+ )
+
+ BeforeEach(func() {
+ semaphoreSize = 4
+ numLockedResources = 2
+ })
+
+ BeforeEach(func() {
+ lockByKey = "num_available_resources_test_key"
+ logger = semaphorelogger.NewLogrusLogger(nil, debugLvl, lockByKey)
+ s, err = create(lockByKey, redis, []Options{{TryLockTimeout: time.Second, Expiration: time.Second, MaxParallelResources: semaphoreSize, Logger: logger}}...)
+ Expect(err).To(Succeed())
+ })
+
+ JustBeforeEach(func() {
+ res, err = s.GetNumAvailableResources()
+ })
+
+ Context("when semaphore not locked yet", func() {
+ BeforeEach(func() {
+ // do nothing
+ })
+
+ It("should return semaphore size", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(Equal(semaphoreSize))
+ })
+ })
+
+ Context("when semaphore locked before", func() {
+ BeforeEach(func() {
+ for i := 0; i < int(numLockedResources); i++ {
+ _, err = s.Lock()
+ Expect(err).To(Succeed())
+ }
+ })
+
+ Context("when semaphore has not expired", func() {
+ BeforeEach(func() {
+ // do nothing
+ })
+
+ It("should return number of free resources", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(Equal(semaphoreSize - numLockedResources))
+ })
+ })
+
+ Context("when semaphore has expired", func() {
+ BeforeEach(func() {
+ time.Sleep(s.options.Expiration)
+ })
+
+ It("should return semaphore size", func() {
+ Expect(err).To(Succeed())
+ Expect(res).To(Equal(semaphoreSize))
+ })
+ })
+ })
+ })
+})
+
+func runLocksInParallel(s *semaphore, wg *sync.WaitGroup, funcDuration time.Duration, numLocksToPerformAddr, functionCalledCounterAddr *int64, numLocksToPerform int64, done chan bool) {
+ addLocks(wg, numLocksToPerformAddr, numLocksToPerform)
+
+ for i := 0; i < int(numLocksToPerform)-1; i++ {
+ gls.Go(func() {
+ defer GinkgoRecover()
+ Expect(WithMutex(s.lockByKey, s.redis.client, func() { time.Sleep(funcDuration); atomic.AddInt64(functionCalledCounterAddr, 1); wg.Wait() }, s.options)).To(Succeed())
+ done <- true
+ })
+ }
+}
+
+func addLocks(wg *sync.WaitGroup, numLocksToPerformAddr *int64, numLocks int64) {
+ atomic.StoreInt64(numLocksToPerformAddr, numLocks)
+ wg.Add(int(numLocks))
+}
+
+func validateNumFreeLocks(s *semaphore, expectedNumFreeLocks int64) {
+ numFreeLocks, err := s.GetNumAvailableResources()
+ Expect(err).To(Succeed())
+ Expect(numFreeLocks).To(Equal(expectedNumFreeLocks))
+}
+
+func validateSuccessfulLock(s *semaphore, redis testRedis, resource string, numLockedResources int64) {
+ //check resource added to locked resources queue
+ isResourceLocked, err := s.IsResourceLocked(resource)
+ Expect(err).To(Succeed())
+ Expect(isResourceLocked).To(BeTrue())
+
+ //check resource deleted from available resources queue
+ numAvailableResources, err := s.GetNumAvailableResources()
+ Expect(err).To(Succeed())
+ Expect(numAvailableResources).To(Equal(s.options.MaxParallelResources - numLockedResources))
+
+ //check all redis keys Expiration time updated
+ validateExpirationTime(s, redis, numLockedResources)
+}
+
+func validateSuccessfulUnLock(s *semaphore, redis testRedis, resource string, numLockedResources int64) {
+ //check resource deleted from locked resources queue
+ isResourceLocked, err := s.IsResourceLocked(resource)
+ Expect(err).To(Succeed())
+ Expect(isResourceLocked).To(BeFalse())
+
+ //check resource added to available resources queue
+ numAvailableResources, err := s.GetNumAvailableResources()
+ Expect(err).To(Succeed())
+ Expect(numAvailableResources).To(Equal(s.options.MaxParallelResources - numLockedResources))
+
+ //check all redis keys Expiration time updated
+ validateExpirationTime(s, redis, numLockedResources)
+}
+
+func validateExpirationTime(s *semaphore, redis testRedis, numLockedResources int64) {
+ for _, key := range s.redis.keys {
+ ttl, err := redis.TTL(key)
+ Expect(err).To(Succeed())
+
+ if (key == s.availableQueueName() && numLockedResources == s.options.MaxParallelResources) || (key == s.lockedResourcesName() && numLockedResources == 0) { //no items in queue so key deleted
+ Expect(ttl.Seconds()).To(Equal(float64(-2))) //not exists
+ } else if key == s.lockedResourcesName() {
+ Expect(ttl.Seconds()).To(Equal(s.options.Expiration.Seconds() * 2))
+ } else {
+ Expect(ttl.Seconds()).To(Equal(s.options.Expiration.Seconds()))
+ }
+ }
+}
diff --git a/semaphore/test_redis.go b/semaphore/test_redis.go
new file mode 100644
index 0000000..7f1e11f
--- /dev/null
+++ b/semaphore/test_redis.go
@@ -0,0 +1,29 @@
+package semaphore
+
+import (
+ "time"
+
+ "gopkg.in/redis.v5"
+
+ "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-redis"
+)
+
+type testRedis interface {
+ semaphoreredis.Redis
+ TTL(key string) (time.Duration, error)
+ FlushAll() error
+}
+
+type testRedisWrapper struct {
+ semaphoreredis.RedisV5Impl
+}
+
+var redisClient testRedis = &testRedisWrapper{RedisV5Impl: semaphoreredis.RedisV5Impl{Client: redis.NewClient(&redis.Options{Addr: "localhost:6379"})}}
+
+func (w *testRedisWrapper) TTL(key string) (time.Duration, error) {
+ return w.Client.TTL(key).Result()
+}
+
+func (w *testRedisWrapper) FlushAll() error {
+ return w.Client.FlushAll().Err()
+}
diff --git a/semaphore/unlock.go b/semaphore/unlock.go
new file mode 100644
index 0000000..9d2821b
--- /dev/null
+++ b/semaphore/unlock.go
@@ -0,0 +1,61 @@
+package semaphore
+
+import (
+ "fmt"
+
+ "github.com/gtforge/redis-semaphore-go/semaphore/semaphore-redis"
+ "github.com/pkg/errors"
+)
+
+func (s *semaphore) Unlock(resource string) error {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resource": resource,
+ }).Log(debugLvl, "received unlock request")
+
+ isResourceLocked, err := s.isResourceLocked(resource)
+ if err != nil {
+ return err
+ }
+
+ if !isResourceLocked {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resource": resource,
+ }).Log(infoLvl, "resource was not locked, no need to unlock")
+ return nil // index was not locked - no need to do anything
+ }
+
+ return s.unlock(resource)
+}
+
+func (s *semaphore) unlock(resources ...string) (err error) {
+ pipeErr := s.redis.client.TxPipelined(func(pipe semaphoreredis.Pipeline) error { //execute redis transaction
+
+ err = pipe.HDel(s.lockedResourcesName(), resources...)
+ if err != nil {
+ return errors.Wrapf(err, "failed to remove resources %+v from locked queue", resources)
+ }
+
+ var resourcesAsInterface []interface{}
+
+ for resource := range resources {
+ resourcesAsInterface = append(resourcesAsInterface, resource)
+ }
+
+ err = pipe.RPush(s.availableQueueName(), resourcesAsInterface...)
+ if err != nil {
+ return errors.Wrapf(err, "failed to add resources %+v to available queue", resources)
+ }
+
+ return s.updateExpirationTime(pipe)
+ })
+
+ if pipeErr != nil {
+ return pipeErr
+ }
+
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resources": fmt.Sprintf("%+v", resources),
+ }).Log(infoLvl, "resources unlocked successfully")
+
+ return nil
+}
diff --git a/semaphore/wrapper.go b/semaphore/wrapper.go
new file mode 100644
index 0000000..684a9ce
--- /dev/null
+++ b/semaphore/wrapper.go
@@ -0,0 +1,31 @@
+package semaphore
+
+import (
+ "github.com/pkg/errors"
+)
+
+func WithMutex(lockByKey string, redisClient Redis, safeCode func(), options ...Options) error {
+ s, err := create(lockByKey, redisClient, options...)
+ if err != nil {
+ return errors.Wrapf(err, "failed to create semaphore %v", s.lockByKey)
+ }
+
+ resource, err := s.Lock()
+ if err != nil {
+ return errors.Wrapf(err, "failed to lock semaphore %v", s.lockByKey)
+ }
+
+ defer func() {
+ err := s.Unlock(resource)
+ if err != nil {
+ s.options.Logger.WithFields(map[string]interface{}{
+ "resource": resource,
+ "error": err,
+ }).Log(errLvl, "failed to unlock resource after critical section execution")
+ }
+ }()
+
+ safeCode() //execute locked function
+
+ return nil
+}