diff --git a/backends/rapidpro/channel.go b/backends/rapidpro/channel.go index 7a6e57862..882759042 100644 --- a/backends/rapidpro/channel.go +++ b/backends/rapidpro/channel.go @@ -14,25 +14,25 @@ import ( // It will return an error if the channel does not exist or is not active. func getChannel(b *backend, channelType courier.ChannelType, channelUUID courier.ChannelUUID) (courier.Channel, error) { // look for the channel locally - channel, localErr := getLocalChannel(channelType, channelUUID) + cachedChannel, localErr := getCachedChannel(channelType, channelUUID) // found it? return it if localErr == nil { - return channel, nil + return cachedChannel, nil } // look in our database instead - dbErr := loadChannelFromDB(b, channel, channelType, channelUUID) + channel, dbErr := loadChannelFromDB(b, channelType, channelUUID) // if it wasn't found in the DB, clear our cache and return that it wasn't found if dbErr == courier.ErrChannelNotFound { clearLocalChannel(channelUUID) - return channel, dbErr + return cachedChannel, dbErr } // if we had some other db error, return it if our cached channel was only just expired if dbErr != nil && localErr == courier.ErrChannelExpired { - return channel, nil + return cachedChannel, nil } // no cached channel, oh well, we fail @@ -41,7 +41,7 @@ func getChannel(b *backend, channelType courier.ChannelType, channelUUID courier } // we found it in the db, cache it locally - cacheLocalChannel(channel) + cacheChannel(channel) return channel, nil } @@ -51,31 +51,33 @@ FROM channels_channel WHERE uuid = $1 AND is_active = true AND org_id IS NOT NULL` // ChannelForUUID attempts to look up the channel with the passed in UUID, returning it -func loadChannelFromDB(b *backend, channel *DBChannel, channelType courier.ChannelType, uuid courier.ChannelUUID) error { +func loadChannelFromDB(b *backend, channelType courier.ChannelType, uuid courier.ChannelUUID) (*DBChannel, error) { + channel := &DBChannel{UUID_: uuid} + // select just the fields we need err := b.db.Get(channel, lookupChannelFromUUIDSQL, uuid) // we didn't find a match if err == sql.ErrNoRows { - return courier.ErrChannelNotFound + return nil, courier.ErrChannelNotFound } // other error if err != nil { - return err + return nil, err } // is it the right type? if channelType != courier.AnyChannelType && channelType != channel.ChannelType() { - return courier.ErrChannelWrongType + return nil, courier.ErrChannelWrongType } // found it, return it - return nil + return channel, nil } -// getLocalChannel returns a Channel object for the passed in type and UUID. -func getLocalChannel(channelType courier.ChannelType, uuid courier.ChannelUUID) (*DBChannel, error) { +// getCachedChannel returns a Channel object for the passed in type and UUID. +func getCachedChannel(channelType courier.ChannelType, uuid courier.ChannelUUID) (*DBChannel, error) { // first see if the channel exists in our local cache cacheMutex.RLock() channel, found := channelCache[uuid] @@ -84,10 +86,10 @@ func getLocalChannel(channelType courier.ChannelType, uuid courier.ChannelUUID) if found { // if it was found but the type is wrong, that's an error if channelType != courier.AnyChannelType && channel.ChannelType() != channelType { - return &DBChannel{ChannelType_: channelType, UUID_: uuid}, courier.ErrChannelWrongType + return nil, courier.ErrChannelWrongType } - // if we've expired, clear our cache and return it + // if we've expired, we return it with an error if channel.expiration.Before(time.Now()) { return channel, courier.ErrChannelExpired } @@ -95,14 +97,13 @@ func getLocalChannel(channelType courier.ChannelType, uuid courier.ChannelUUID) return channel, nil } - return &DBChannel{ChannelType_: channelType, UUID_: uuid}, courier.ErrChannelNotFound + return nil, courier.ErrChannelNotFound } -func cacheLocalChannel(channel *DBChannel) { +func cacheChannel(channel *DBChannel) { // set our expiration channel.expiration = time.Now().Add(localTTL * time.Second) - // first write to our local cache cacheMutex.Lock() channelCache[channel.UUID()] = channel cacheMutex.Unlock()