Skip to content

Commit

Permalink
prevent possibility of concurrent writes to shared db channel by thro…
Browse files Browse the repository at this point in the history
…wing away object during loads
  • Loading branch information
nicpottier committed Sep 6, 2017
1 parent 0903c8d commit e6bb2ac
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions backends/rapidpro/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@ import (
// It will return an error if the channel does not exist or is not active.
func getChannel(b *backend, channelType courier.ChannelType, channelUUID courier.ChannelUUID) (courier.Channel, error) {
// look for the channel locally
channel, localErr := getLocalChannel(channelType, channelUUID)
cachedChannel, localErr := getCachedChannel(channelType, channelUUID)

// found it? return it
if localErr == nil {
return channel, nil
return cachedChannel, nil
}

// look in our database instead
dbErr := loadChannelFromDB(b, channel, channelType, channelUUID)
channel, dbErr := loadChannelFromDB(b, channelType, channelUUID)

// if it wasn't found in the DB, clear our cache and return that it wasn't found
if dbErr == courier.ErrChannelNotFound {
clearLocalChannel(channelUUID)
return channel, dbErr
return cachedChannel, dbErr
}

// if we had some other db error, return it if our cached channel was only just expired
if dbErr != nil && localErr == courier.ErrChannelExpired {
return channel, nil
return cachedChannel, nil
}

// no cached channel, oh well, we fail
Expand All @@ -41,7 +41,7 @@ func getChannel(b *backend, channelType courier.ChannelType, channelUUID courier
}

// we found it in the db, cache it locally
cacheLocalChannel(channel)
cacheChannel(channel)
return channel, nil
}

Expand All @@ -51,31 +51,33 @@ FROM channels_channel
WHERE uuid = $1 AND is_active = true AND org_id IS NOT NULL`

// ChannelForUUID attempts to look up the channel with the passed in UUID, returning it
func loadChannelFromDB(b *backend, channel *DBChannel, channelType courier.ChannelType, uuid courier.ChannelUUID) error {
func loadChannelFromDB(b *backend, channelType courier.ChannelType, uuid courier.ChannelUUID) (*DBChannel, error) {
channel := &DBChannel{UUID_: uuid}

// select just the fields we need
err := b.db.Get(channel, lookupChannelFromUUIDSQL, uuid)

// we didn't find a match
if err == sql.ErrNoRows {
return courier.ErrChannelNotFound
return nil, courier.ErrChannelNotFound
}

// other error
if err != nil {
return err
return nil, err
}

// is it the right type?
if channelType != courier.AnyChannelType && channelType != channel.ChannelType() {
return courier.ErrChannelWrongType
return nil, courier.ErrChannelWrongType
}

// found it, return it
return nil
return channel, nil
}

// getLocalChannel returns a Channel object for the passed in type and UUID.
func getLocalChannel(channelType courier.ChannelType, uuid courier.ChannelUUID) (*DBChannel, error) {
// getCachedChannel returns a Channel object for the passed in type and UUID.
func getCachedChannel(channelType courier.ChannelType, uuid courier.ChannelUUID) (*DBChannel, error) {
// first see if the channel exists in our local cache
cacheMutex.RLock()
channel, found := channelCache[uuid]
Expand All @@ -84,25 +86,24 @@ func getLocalChannel(channelType courier.ChannelType, uuid courier.ChannelUUID)
if found {
// if it was found but the type is wrong, that's an error
if channelType != courier.AnyChannelType && channel.ChannelType() != channelType {
return &DBChannel{ChannelType_: channelType, UUID_: uuid}, courier.ErrChannelWrongType
return nil, courier.ErrChannelWrongType
}

// if we've expired, clear our cache and return it
// if we've expired, we return it with an error
if channel.expiration.Before(time.Now()) {
return channel, courier.ErrChannelExpired
}

return channel, nil
}

return &DBChannel{ChannelType_: channelType, UUID_: uuid}, courier.ErrChannelNotFound
return nil, courier.ErrChannelNotFound
}

func cacheLocalChannel(channel *DBChannel) {
func cacheChannel(channel *DBChannel) {
// set our expiration
channel.expiration = time.Now().Add(localTTL * time.Second)

// first write to our local cache
cacheMutex.Lock()
channelCache[channel.UUID()] = channel
cacheMutex.Unlock()
Expand Down

0 comments on commit e6bb2ac

Please sign in to comment.