diff --git a/src/core/lockup_linear.cairo b/src/core/lockup_linear.cairo index 83321a2..3944034 100644 --- a/src/core/lockup_linear.cairo +++ b/src/core/lockup_linear.cairo @@ -172,6 +172,25 @@ trait ITokeiLockupLinear { /// Returns the admin address. fn get_admin(self: @TContractState) -> ContractAddress; + /// Returns the streams of the sender. + fn get_streams_by_sender( + self: @TContractState, sender: ContractAddress + ) -> Array; + + /// Returns the streams of the recipient. + fn get_streams_by_recipient( + self: @TContractState, recipient: ContractAddress + ) -> Array; + + /// Returns the streams ids of the sender. + fn get_streams_ids_by_sender(self: @TContractState, sender: ContractAddress) -> Array; + + /// Returns the streams ids of the recipient. + fn get_streams_ids_by_recipient( + self: @TContractState, recipient: ContractAddress + ) -> Array; + + ////////////////////////////////////////////////////////////////////////// //USER-FACING NON-CONSTANT FUNCTIONS ////////////////////////////////////////////////////////////////////////// @@ -638,6 +657,7 @@ mod TokeiLockupLinear { let stream_updated = LockupLinearStream { sender: stream.sender, asset: stream.asset, + recipient: stream.recipient, start_time: stream.start_time, cliff_time: stream.cliff_time, end_time: stream.end_time, @@ -811,6 +831,100 @@ mod TokeiLockupLinear { self.admin.read() } + fn get_streams_by_sender( + self: @ContractState, sender: ContractAddress + ) -> Array { + let max_stream_id = self.next_stream_id.read(); + let mut streams: Array = ArrayTrait::new(); + let mut i = 1; //Since the stream id starts from 1 + loop { + if i >= max_stream_id { + break; + } + let stream = self.streams.read(i); + if stream.sender == sender { + streams.append(stream); + } + i += 1; + }; + streams + } + + fn get_streams_by_recipient( + self: @ContractState, recipient: ContractAddress + ) -> Array { + let max_stream_id = self.next_stream_id.read(); + let mut streams: Array = ArrayTrait::new(); + let mut i = 1; //Since the stream id starts from 1 + loop { + if i >= max_stream_id { + break; + } + let stream = self.streams.read(i); + if stream.recipient == recipient { + streams.append(stream); + } + i += 1; + }; + streams + } + + fn get_streams_ids_by_sender(self: @ContractState, sender: ContractAddress) -> Array { + let max_stream_id = self.next_stream_id.read(); + let mut stream_ids: Array = ArrayTrait::new(); + let mut i = 1; // As the stream id starts from 1 + loop { + if i >= max_stream_id { + break; + } + let stream = self.streams.read(i); + if (stream.sender == sender) { + stream_ids.append(i); + } + + i += 1; + }; + stream_ids + } + + fn get_streams_ids_by_recipient( + self: @ContractState, recipient: ContractAddress + ) -> Array { + let max_stream_id = self.next_stream_id.read(); + let mut stream_ids: Array = ArrayTrait::new(); + let mut i = 1; // As the stream id starts from 1 + loop { + if i >= max_stream_id { + break; + } + let stream = self.streams.read(i); + if (stream.recipient == recipient) { + stream_ids.append(i); + } + + i += 1; + }; + stream_ids + } + + // fn get_streams_ids_by_recipient( + // self: @ContractState, recipient: ContractAddress + // ) -> Array { + // let streams: Array = self.get_streams_by_recipient(recipient); + // let mut stream_ids: Array = ArrayTrait::new(); + // let mut i = 0; + // loop { + // if i >= streams.len() { + // break; + // } + // let stream = *streams.at(i); + // let stream_id = self.stream_id.read(stream); + // stream_ids.append(stream_id); + // i += 1; + // }; + // stream_ids + // } + /// Creates a new stream with a given range. /// # Arguments /// * `sender` - The address streaming the assets, with the ability to cancel the stream. @@ -1316,6 +1430,7 @@ mod TokeiLockupLinear { let stream_updated = LockupLinearStream { sender: stream.sender, asset: stream.asset, + recipient: stream.recipient, start_time: stream.start_time, cliff_time: stream.cliff_time, end_time: stream.end_time, @@ -1338,6 +1453,7 @@ mod TokeiLockupLinear { let _stream_updated = LockupLinearStream { sender: stream.sender, asset: stream.asset, + recipient: stream.recipient, start_time: stream.start_time, cliff_time: stream.cliff_time, end_time: stream.end_time, @@ -1370,6 +1486,7 @@ mod TokeiLockupLinear { let stream_updated = LockupLinearStream { sender: stream.sender, asset: stream.asset, + recipient: stream.recipient, start_time: stream.start_time, cliff_time: stream.cliff_time, end_time: stream.end_time, @@ -1412,6 +1529,7 @@ mod TokeiLockupLinear { let stream_updated = LockupLinearStream { sender: stream.sender, asset: stream.asset, + recipient: stream.recipient, start_time: stream.start_time, cliff_time: stream.cliff_time, end_time: stream.end_time, @@ -1432,6 +1550,7 @@ mod TokeiLockupLinear { let stream_updated = LockupLinearStream { sender: stream.sender, asset: stream.asset, + recipient: stream.recipient, start_time: stream.start_time, cliff_time: stream.cliff_time, end_time: stream.end_time, @@ -1566,6 +1685,7 @@ mod TokeiLockupLinear { let stream = LockupLinearStream { sender, asset, + recipient, start_time: range.start, cliff_time: range.cliff, end_time: range.end, diff --git a/src/tests/test_lockup_linear.cairo b/src/tests/test_lockup_linear.cairo index d71ab27..8b63abb 100644 --- a/src/tests/test_lockup_linear.cairo +++ b/src/tests/test_lockup_linear.cairo @@ -79,6 +79,33 @@ fn create_with_duration() -> (ITokeiLockupLinearDispatcher, ERC20ABIDispatcher, stop_prank(CheatTarget::One(tokei.contract_address)); (tokei, token_dispatcher, stream_id) } + +fn create_with_duration_common( + sender: ContractAddress, + recipient: ContractAddress, + token: ERC20ABIDispatcher, + tokei: ITokeiLockupLinearDispatcher +) -> u64 { + let (alice, _, total_amount, _, cancelable, transferable, range, broker) = + Defaults::create_with_durations(); + start_prank(CheatTarget::One(tokei.contract_address), sender); + start_warp(CheatTarget::One(tokei.contract_address), 1000); + + let stream_id = tokei + .create_with_duration( + sender, + recipient, + total_amount, + token.contract_address, + cancelable, + transferable, + range, + broker, + ); + stop_warp(CheatTarget::One(tokei.contract_address)); + stop_prank(CheatTarget::One(tokei.contract_address)); + stream_id +} #[test] fn test_set_protocol_fee() { let (tokei) = setup(ADMIN()); @@ -292,6 +319,7 @@ fn test_create_with_duration() { let expected_stream = LockupLinearStream { sender: ALICE(), asset: token, + recipient: RECIPIENT(), start_time: 100, cliff_time: 100 + 2500, end_time: 100 + 10000, @@ -550,6 +578,7 @@ fn test_get_stream_when_status_settled() { let expected_stream = LockupLinearStream { sender: ALICE(), asset: token.contract_address, + recipient: RECIPIENT(), start_time: 1000, cliff_time: 1000 + 2500, end_time: 1000 + 4000, @@ -575,6 +604,7 @@ fn test_get_stream_when_not_settled() { let expected_stream = LockupLinearStream { sender: ALICE(), asset: token.contract_address, + recipient: RECIPIENT(), start_time: 1000, cliff_time: 1000 + 2500, end_time: 1000 + 4000, @@ -1081,6 +1111,90 @@ fn test_transfer_admin() { assert(tokei.get_admin() == BOB(), 'Invalid admin'); } +#[test] +fn test_get_streams_by_sender() { + // create_with_duration function passed Alice as sender and recipient as RECIPIENT + let (tokei, token, stream_id_1) = create_with_duration(); + let stream_id_2 = create_with_duration_common(ALICE(), RECIPIENT(), token, tokei); + let stream_id_3 = create_with_duration_common(BOB(), CHARLIE(), token, tokei); + + let streams_alice = tokei.get_streams_by_sender(ALICE()); + assert(streams_alice.len() == 2, 'Invalid stream count'); + + let streams_bob = tokei.get_streams_by_sender(BOB()); + assert(streams_bob.len() == 1, 'Invalid stream count'); + + let stream_alice_1 = tokei.get_stream(stream_id_1); + + let stream_alice_2 = tokei.get_stream(stream_id_2); + let stream_bob_1 = tokei.get_stream(stream_id_3); + + assert(*streams_alice.at(0) == stream_alice_1, 'Invalid stream'); + assert(*streams_alice.at(1) == stream_alice_2, 'Invalid stream'); + assert(*streams_bob.at(0) == stream_bob_1, 'Invalid stream'); +} +#[test] +fn test_get_streams_by_receiver() { + // create_with_duration function passed Alice as sender and recipient as RECIPIENT + let (tokei, token, stream_id_1) = create_with_duration(); + let stream_id_2 = create_with_duration_common(ALICE(), RECIPIENT(), token, tokei); + let stream_id_3 = create_with_duration_common(BOB(), CHARLIE(), token, tokei); + + let streams_recipient = tokei.get_streams_by_recipient(RECIPIENT()); + let streams_charlie = tokei.get_streams_by_recipient(CHARLIE()); + + let stream_1 = tokei.get_stream(stream_id_1); + let stream_2 = tokei.get_stream(stream_id_2); + let stream_3 = tokei.get_stream(stream_id_3); + + assert(streams_recipient.len() == 2, 'Invalid stream count'); + assert(streams_charlie.len() == 1, 'Invalid stream count'); + assert(*streams_recipient.at(0) == stream_1, 'Invalid stream'); + assert(*streams_recipient.at(1) == stream_2, 'Invalid stream'); + assert(*streams_charlie.at(0) == stream_3, 'Invalid stream'); +} + +#[test] +fn test_get_stream_ids_by_receiver() { + // create_with_duration function passed Alice as sender and recipient as RECIPIENT + let (tokei, token, stream_id_1) = create_with_duration(); + let stream_id_2 = create_with_duration_common(ALICE(), RECIPIENT(), token, tokei); + let stream_id_3 = create_with_duration_common(BOB(), CHARLIE(), token, tokei); + + let stream_ids_recipient = tokei.get_streams_ids_by_recipient(RECIPIENT()); + let stream_ids_charlie = tokei.get_streams_ids_by_recipient(CHARLIE()); + + let expected_stream_id_1 = 1_u64; + let expected_stream_id_2 = 2_u64; + let expected_stream_id_3 = 3_u64; + + assert(stream_ids_recipient.len() == 2, 'Invalid stream count'); + assert(stream_ids_charlie.len() == 1, 'Invalid stream count'); + assert(*stream_ids_recipient.at(0) == expected_stream_id_1, 'Invalid stream'); + assert(*stream_ids_recipient.at(1) == expected_stream_id_2, 'Invalid stream'); + assert(*stream_ids_charlie.at(0) == expected_stream_id_3, 'Invalid stream'); +} + +#[test] +fn test_get_stream_ids_by_sender() { + // create_with_duration function passed Alice as sender and recipient as RECIPIENT + let (tokei, token, stream_id_1) = create_with_duration(); + let stream_id_2 = create_with_duration_common(ALICE(), RECIPIENT(), token, tokei); + let stream_id_3 = create_with_duration_common(BOB(), CHARLIE(), token, tokei); + + let stream_ids_alice = tokei.get_streams_ids_by_sender(ALICE()); + let stream_ids_bob = tokei.get_streams_ids_by_sender(BOB()); + + let expected_stream_id_1 = 1_u64; + let expected_stream_id_2 = 2_u64; + let expected_stream_id_3 = 3_u64; + + assert(stream_ids_alice.len() == 2, 'Invalid stream count'); + assert(stream_ids_bob.len() == 1, 'Invalid stream count'); + assert(*stream_ids_alice.at(0) == expected_stream_id_1, 'Invalid stream'); + assert(*stream_ids_alice.at(1) == expected_stream_id_2, 'Invalid stream'); + assert(*stream_ids_bob.at(0) == expected_stream_id_3, 'Invalid stream'); +} #[test] #[should_panic(expected: ('lockup_unauthorized',))] diff --git a/src/tests/utils/defaults.cairo b/src/tests/utils/defaults.cairo index 65f40e8..1e21318 100644 --- a/src/tests/utils/defaults.cairo +++ b/src/tests/utils/defaults.cairo @@ -63,6 +63,7 @@ mod Defaults { LockupLinearStream { sender: contract_address_const::<'sender'>(), asset: contract_address_const::<'asset'>(), + recipient: contract_address_const::<'recipient'>(), start_time: START_TIME, cliff_time: CLIFF_TIME, end_time: END_TIME, diff --git a/src/types/lockup_linear.cairo b/src/types/lockup_linear.cairo index 4d5682b..2e8a2c9 100644 --- a/src/types/lockup_linear.cairo +++ b/src/types/lockup_linear.cairo @@ -19,6 +19,8 @@ struct LockupLinearStream { sender: ContractAddress, /// The contract address of the ERC-20 asset used for streaming. asset: ContractAddress, + /// The address receiving the streamed assets. + recipient: ContractAddress, /// The Unix timestamp indicating the stream's start. start_time: u64, /// The Unix timestamp indicating the stream's cliff period's end.