diff --git a/rx.lua b/rx.lua index bef15ac..ffed788 100644 --- a/rx.lua +++ b/rx.lua @@ -1538,6 +1538,23 @@ function Observable:take(n) end local i = 1 + local subscription + + local function unsub() + if subscription then + subscription:unsubscribe() + end + end + + local function onCompleted() + observer:onCompleted() + unsub() + end + + local function onError(e) + observer:onError(e) + unsub() + end local function onNext(...) observer:onNext(...) @@ -1545,19 +1562,12 @@ function Observable:take(n) i = i + 1 if i > n then - observer:onCompleted() + onCompleted() end end - local function onError(e) - return observer:onError(e) - end - - local function onCompleted() - return observer:onCompleted() - end - - return self:subscribe(onNext, onError, onCompleted) + subscription = self:subscribe(onNext, onError, onCompleted) + return Subscription.create(unsub) end) end @@ -1572,6 +1582,13 @@ function Observable:takeLast(count) return Observable.create(function(observer) local buffer = {} + local subscription + + local function unsub() + if subscription then + subscription:unsubscribe() + end + end local function onNext(...) table.insert(buffer, util.pack(...)) @@ -1581,17 +1598,20 @@ function Observable:takeLast(count) end local function onError(message) - return observer:onError(message) + observer:onError(message) + unsub() end local function onCompleted() for i = 1, #buffer do observer:onNext(util.unpack(buffer[i])) end - return observer:onCompleted() + observer:onCompleted() + unsub() end - return self:subscribe(onNext, onError, onCompleted) + subscription = self:subscribe(onNext, onError, onCompleted) + return Subscription.create(unsub) end) end @@ -1600,21 +1620,32 @@ end -- @returns {Observable} function Observable:takeUntil(other) return Observable.create(function(observer) + local subscription + + local function unsub() + if subscription then + subscription:unsubscribe() + end + end + local function onNext(...) return observer:onNext(...) end local function onError(e) - return observer:onError(e) + observer:onError(e) + unsub() end local function onCompleted() - return observer:onCompleted() + observer:onCompleted() + unsub() end other:subscribe(onCompleted, onCompleted, onCompleted) - return self:subscribe(onNext, onError, onCompleted) + subscription = self:subscribe(onNext, onError, onCompleted) + return Subscription.create(unsub) end) end @@ -1626,6 +1657,23 @@ function Observable:takeWhile(predicate) return Observable.create(function(observer) local taking = true + local subscription + + local function unsub() + if subscription then + subscription:unsubscribe() + end + end + + local function onError(message) + observer:onError(message) + unsub() + end + + local function onCompleted() + observer:onCompleted() + unsub() + end local function onNext(...) if taking then @@ -1636,20 +1684,13 @@ function Observable:takeWhile(predicate) if taking then return observer:onNext(...) else - return observer:onCompleted() + return onCompleted() end end end - local function onError(message) - return observer:onError(message) - end - - local function onCompleted() - return observer:onCompleted() - end - - return self:subscribe(onNext, onError, onCompleted) + subscription = self:subscribe(onNext, onError, onCompleted) + return Subscription.create(unsub) end) end @@ -2310,4 +2351,4 @@ return { AsyncSubject = AsyncSubject, BehaviorSubject = BehaviorSubject, ReplaySubject = ReplaySubject -} \ No newline at end of file +} diff --git a/tests/take.lua b/tests/take.lua index a131cec..6b1bd8a 100644 --- a/tests/take.lua +++ b/tests/take.lua @@ -32,4 +32,23 @@ describe('take', function() expect(#onError).to.equal(0) expect(#onCompleted).to.equal(1) end) + + it('unsubscribes when it completes', function () + local keepGoing = true + local unsub = spy() + local observer + + local source = Rx.Observable.create(function (_observer) + observer = _observer + return Rx.Subscription.create(unsub) + end) + + source + :take(1) + :subscribe(Rx.Observer.create()) + expect(#unsub).to.equal(0) + + observer:onNext() + expect(#unsub).to.equal(1) + end) end) diff --git a/tests/takeLast.lua b/tests/takeLast.lua index 12bbe0b..92fdc1b 100644 --- a/tests/takeLast.lua +++ b/tests/takeLast.lua @@ -25,4 +25,22 @@ describe('takeLast', function() it('produces no elements if the source Observable produces no elements', function() expect(Rx.Observable.empty():takeLast(1)).to.produce.nothing() end) + + it('unsubscribes when it completes', function () + local unsub = spy() + local observer + + local source = Rx.Observable.create(function (_observer) + observer = _observer + return Rx.Subscription.create(unsub) + end) + + source + :takeLast(1) + :subscribe(Rx.Observer.create()) + expect(#unsub).to.equal(0) + + observer:onCompleted() + expect(#unsub).to.equal(1) + end) end) diff --git a/tests/takeUntil.lua b/tests/takeUntil.lua index 623e21d..d002dfa 100644 --- a/tests/takeUntil.lua +++ b/tests/takeUntil.lua @@ -43,4 +43,21 @@ describe('takeUntil', function() subject:onCompleted() expect(onNext).to.equal({}) end) + + it('unsubscribes when it completes', function () + local trigger = Rx.Subject.create() + local unsub = spy() + + local source = Rx.Observable.create(function (observer) + return Rx.Subscription.create(unsub) + end) + + source + :takeUntil(trigger) + :subscribe(Rx.Observer.create()) + expect(#unsub).to.equal(0) + + trigger() + expect(#unsub).to.equal(1) + end) end) diff --git a/tests/takeWhile.lua b/tests/takeWhile.lua index d48a75e..f73af5b 100644 --- a/tests/takeWhile.lua +++ b/tests/takeWhile.lua @@ -25,4 +25,27 @@ describe('takeWhile', function() it('calls onError if the predicate errors', function() expect(Rx.Observable.fromRange(3):takeWhile(error)).to.produce.error() end) + + it('unsubscribes when it completes', function () + local keepGoing = true + local unsub = spy() + local observer + + local source = Rx.Observable.create(function (_observer) + observer = _observer + return Rx.Subscription.create(unsub) + end) + + source + :takeWhile(function () return keepGoing end) + :subscribe(Rx.Observer.create()) + expect(#unsub).to.equal(0) + + observer:onNext() + expect(#unsub).to.equal(0) + + keepGoing = false + observer:onNext() + expect(#unsub).to.equal(1) + end) end)