Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweaks to upgrade #50

Merged
merged 10 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 62 additions & 53 deletions src/contracts/timelock_upgrade.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
use starknet::ClassHash;
use core::num::traits::Zero;
use starknet::{ClassHash};

#[derive(Serde, Drop, Copy, starknet::Store)]
struct PendingUpgrade {
// Gets the classhash after
implementation: ClassHash,
// Gets the timestamp when the upgrade is ready to be performed, 0 if no upgrade ongoing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be proposed implementation ?
also dont get the comment of 'gets classhash after?'
nit: 0 if no upgrade ongoing ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, i drop the proposed_ since it's already in a struct called PendingUpgrade, calldata would also be proposed_

ready_at: u64,
// Gets the hash of the calldata used for the upgrade, 0 if no upgrade ongoing
calldata_hash: felt252,
}

#[starknet::interface]
pub trait ITimelockUpgrade<TContractState> {
Expand All @@ -19,14 +30,8 @@ pub trait ITimelockUpgrade<TContractState> {
/// @param calldata The calldata to be used for the upgrade
fn upgrade(ref self: TContractState, calldata: Array<felt252>);

/// @notice Gets the proposed implementation
fn get_proposed_implementation(self: @TContractState) -> ClassHash;

/// @notice Gets the timestamp when the upgrade is ready to be performed, 0 if no upgrade ongoing
fn get_upgrade_ready_at(self: @TContractState) -> u64;

/// @notice Gets the hash of the calldata used for the upgrade, 0 if no upgrade ongoing
fn get_calldata_hash(self: @TContractState) -> felt252;
/// @notice Gets the proposed upgrade
fn get_pending_upgrade(self: @TContractState) -> PendingUpgrade;
}

#[starknet::interface]
Expand All @@ -46,19 +51,18 @@ pub mod TimelockUpgradeComponent {
use starknet::{get_block_timestamp, ClassHash};
use super::{
ITimelockUpgrade, ITimelockUpgradeCallback, ITimelockUpgradeCallbackLibraryDispatcher,
ITimelockUpgradeCallbackDispatcherTrait
ITimelockUpgradeCallbackDispatcherTrait, PendingUpgrade, PendingUpgradeZero
};

/// Time before the upgrade can be performed
const MIN_SECURITY_PERIOD: u64 = consteval_int!(7 * 24 * 60 * 60); // 7 days
/// Time window during which the upgrade can be performed
const VALID_WINDOW_PERIOD: u64 = consteval_int!(7 * 24 * 60 * 60); // 7 days


#[storage]
pub struct Storage {
pending_implementation: ClassHash,
ready_at: u64,
calldata_hash: felt252,
pending_upgrade: PendingUpgrade
}

#[event]
Expand All @@ -78,12 +82,12 @@ pub mod TimelockUpgradeComponent {

#[derive(Drop, starknet::Event)]
struct UpgradeCancelled {
cancelled_implementation: ClassHash
cancelled_upgrade: PendingUpgrade
}

#[derive(Drop, starknet::Event)]
struct Upgraded {
new_implementation: ClassHash
executed_upgrade: PendingUpgrade
}

#[embeddable_as(TimelockUpgradeImpl)]
Expand All @@ -99,53 +103,51 @@ pub mod TimelockUpgradeComponent {
self.assert_only_owner();
assert(new_implementation.is_non_zero(), 'upgrade/new-implementation-null');

let pending_implementation = self.pending_implementation.read();
if pending_implementation.is_non_zero() {
self.emit(UpgradeCancelled { cancelled_implementation: pending_implementation })
let pending_upgrade = self.pending_upgrade.read();
if pending_upgrade.is_non_zero() {
self.emit(UpgradeCancelled { cancelled_upgrade: pending_upgrade })
}

self.pending_implementation.write(new_implementation);
let ready_at = get_block_timestamp() + MIN_SECURITY_PERIOD;
self.ready_at.write(ready_at);
let calldata_hash = poseidon_hash_span(calldata.span());
self.calldata_hash.write(calldata_hash);
self
.pending_upgrade
.write(
PendingUpgrade {
implementation: new_implementation, ready_at, calldata_hash: poseidon_hash_span(calldata.span())
}
);
self.emit(UpgradeProposed { new_implementation, ready_at, calldata });
}

fn cancel_upgrade(ref self: ComponentState<TContractState>) {
self.assert_only_owner();
let proposed_implementation = self.pending_implementation.read();
assert(proposed_implementation.is_non_zero(), 'upgrade/no-new-implementation');
assert(self.ready_at.read() != 0, 'upgrade/not-ready');
self.emit(UpgradeCancelled { cancelled_implementation: proposed_implementation });
self.reset_storage();
let proposed_implementation = self.pending_upgrade.read();
assert(proposed_implementation.is_non_zero(), 'upgrade/no-pending-upgrade');
self.pending_upgrade.write(Zero::zero());
self.emit(UpgradeCancelled { cancelled_upgrade: proposed_implementation });
}

fn upgrade(ref self: ComponentState<TContractState>, calldata: Array<felt252>) {
self.assert_only_owner();
let new_implementation = self.pending_implementation.read();
let ready_at = self.ready_at.read();
let block_timestamp = get_block_timestamp();
let calldata_hash = poseidon_hash_span(calldata.span());
assert(calldata_hash == self.calldata_hash.read(), 'upgrade/invalid-calldata');
assert(new_implementation.is_non_zero(), 'upgrade/no-pending-upgrade');
assert(block_timestamp >= ready_at, 'upgrade/too-early');
assert(block_timestamp < ready_at + VALID_WINDOW_PERIOD, 'upgrade/upgrade-too-late');
self.reset_storage();
ITimelockUpgradeCallbackLibraryDispatcher { class_hash: new_implementation }
.perform_upgrade(new_implementation, calldata.span());
}

fn get_proposed_implementation(self: @ComponentState<TContractState>) -> ClassHash {
self.pending_implementation.read()
let proposed_implementation = self.pending_upgrade.read();
assert(proposed_implementation.is_non_zero(), 'upgrade/no-pending-upgrade');

let current_timestamp = get_block_timestamp();
assert(
proposed_implementation.calldata_hash == poseidon_hash_span(calldata.span()), 'upgrade/invalid-calldata'
);

assert(current_timestamp >= proposed_implementation.ready_at, 'upgrade/too-early');
assert(
current_timestamp < proposed_implementation.ready_at + VALID_WINDOW_PERIOD, 'upgrade/upgrade-too-late'
);
self.pending_upgrade.write(Zero::zero());
ITimelockUpgradeCallbackLibraryDispatcher { class_hash: proposed_implementation.implementation }
.perform_upgrade(proposed_implementation.implementation, calldata.span());
}
sgc-code marked this conversation as resolved.
Show resolved Hide resolved

fn get_upgrade_ready_at(self: @ComponentState<TContractState>) -> u64 {
self.ready_at.read()
}

fn get_calldata_hash(self: @ComponentState<TContractState>) -> felt252 {
self.calldata_hash.read()
fn get_pending_upgrade(self: @ComponentState<TContractState>) -> PendingUpgrade {
self.pending_upgrade.read()
}
}
#[generate_trait]
Expand All @@ -155,11 +157,18 @@ pub mod TimelockUpgradeComponent {
fn assert_only_owner(self: @ComponentState<TContractState>) {
get_dep_component!(self, Ownable).assert_only_owner();
}
}
}

fn reset_storage(ref self: ComponentState<TContractState>) {
self.pending_implementation.write(Zero::zero());
self.ready_at.write(0);
self.calldata_hash.write(0);
}

impl PendingUpgradeZero of core::num::traits::Zero<PendingUpgrade> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn reset_storage(ref self: ComponentState<TContractState>) {
self.pending_implementation.write(Zero::zero());
self.ready_at.write(0);
self.calldata_hash.write(0);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is just an extra line return 🤣

fn zero() -> PendingUpgrade {
PendingUpgrade { implementation: Zero::zero(), ready_at: 0, calldata_hash: 0 }
}
fn is_zero(self: @PendingUpgrade) -> bool {
*self.calldata_hash == 0
}
fn is_non_zero(self: @PendingUpgrade) -> bool {
sgc-code marked this conversation as resolved.
Show resolved Hide resolved
!self.is_zero()
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually since we never use is_zero() I would optimize is_non_zero() instead.
Although now that I think about it, maybe it makes more sense to derive Default and just compare against Default::default()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like it, Default seems more appropriate 👍

1 change: 1 addition & 0 deletions src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ pub mod contracts {
mod mocks {
mod broken_erc20;
mod erc20;
mod future_factory;
mod reentrant_erc20;
}
55 changes: 55 additions & 0 deletions src/mocks/future_factory.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use starknet::{ContractAddress, ClassHash};

#[starknet::contract]
mod FutureFactory {
use argent_gifting::contracts::timelock_upgrade::{ITimelockUpgradeCallback, TimelockUpgradeComponent};
use core::panic_with_felt252;
use openzeppelin::access::ownable::OwnableComponent;
use starknet::{
ClassHash, ContractAddress, syscalls::deploy_syscall, get_caller_address, get_contract_address, account::Call,
get_block_timestamp
};

// Ownable
component!(path: OwnableComponent, storage: ownable, event: OwnableEvent);
#[abi(embed_v0)]
impl OwnableImpl = OwnableComponent::OwnableImpl<ContractState>;

// TimelockUpgradeable
component!(path: TimelockUpgradeComponent, storage: timelock_upgrade, event: TimelockUpgradeEvent);
#[abi(embed_v0)]
impl TimelockUpgradeImpl = TimelockUpgradeComponent::TimelockUpgradeImpl<ContractState>;

#[storage]
struct Storage {
#[substorage(v0)]
ownable: OwnableComponent::Storage,
#[substorage(v0)]
timelock_upgrade: TimelockUpgradeComponent::Storage,
}

#[event]
#[derive(Drop, starknet::Event)]
enum Event {
#[flat]
OwnableEvent: OwnableComponent::Event,
#[flat]
TimelockUpgradeEvent: TimelockUpgradeComponent::Event,
}

#[constructor]
fn constructor(ref self: ContractState) {}


#[external(v0)]
fn get_num(self: @ContractState) -> u128 {
1
}

#[abi(embed_v0)]
impl TimelockUpgradeCallbackImpl of ITimelockUpgradeCallback<ContractState> {
fn perform_upgrade(ref self: ContractState, new_implementation: ClassHash, data: Span<felt252>) {
starknet::syscalls::replace_class_syscall(new_implementation).unwrap();
}
}
}

This file was deleted.

This file was deleted.

58 changes: 36 additions & 22 deletions tests-integration/upgrade.test.ts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did the asserts change to == and not should.equal..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not a big fan of chai, feels like reinventing the wheel trying to replace all the stuff we use with functions.
> -> .isBelow
=== -> equals but there is also deep.equals , eq and eql
await -> eventually

i'm not really sure having to introducing all this stuff is worth it to get slightly better errors
i leave my rant here but i changed it to align with the rest of the project and get this merged

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { CallData, hash } from "starknet";
import { assert } from "chai";
import { CallData, hash, num } from "starknet";
import {
deployer,
devnetAccount,
Expand All @@ -8,7 +9,6 @@ import {
protocolCache,
setupGiftProtocol,
} from "../lib";

// Time window which must pass before the upgrade can be performed
const MIN_SECURITY_PERIOD = 7n * 24n * 60n * 60n; // 7 day

Expand All @@ -20,24 +20,26 @@ const CURRENT_TIME = 1718898082n;
describe("Test Factory Upgrade", function () {
it("Upgrade", async function () {
const { factory } = await setupGiftProtocol();
const newFactoryClassHash = await manager.declareFixtureContract("GiftFactoryUpgrade");
const newFactoryClassHash = await manager.declareLocalContract("FutureFactory");
const calldata: any[] = [];

await manager.setTime(CURRENT_TIME);
factory.connect(deployer);
await factory.propose_upgrade(newFactoryClassHash, calldata);

await factory.get_upgrade_ready_at().should.eventually.equal(CURRENT_TIME + MIN_SECURITY_PERIOD);
await factory.get_proposed_implementation().should.eventually.equal(BigInt(newFactoryClassHash));
await factory.get_calldata_hash().should.eventually.equal(BigInt(hash.computePoseidonHashOnElements(calldata)));
let pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.ready_at === CURRENT_TIME + MIN_SECURITY_PERIOD);
assert(pendingUpgrade.implementation === num.toBigInt(newFactoryClassHash));
assert(pendingUpgrade.calldata_hash === BigInt(hash.computePoseidonHashOnElements(calldata)));

await manager.setTime(CURRENT_TIME + MIN_SECURITY_PERIOD + 1n);
await factory.upgrade(calldata);

// reset storage
await factory.get_proposed_implementation().should.eventually.equal(0n);
await factory.get_upgrade_ready_at().should.eventually.equal(0n);
await factory.get_calldata_hash().should.eventually.equal(0n);
// check storage was reset
pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.ready_at === 0n);
assert(pendingUpgrade.implementation === 0n);
assert(pendingUpgrade.calldata_hash === 0n);

await manager.getClassHashAt(factory.address).should.eventually.equal(newFactoryClassHash);

Expand Down Expand Up @@ -109,7 +111,8 @@ describe("Test Factory Upgrade", function () {
factory.connect(deployer);
await factory.propose_upgrade(newFactoryClassHash, []);

const readyAt = await factory.get_upgrade_ready_at();
const pendingUpgrade = await factory.get_pending_upgrade();
const readyAt = pendingUpgrade.ready_at;
await manager.setTime(CURRENT_TIME + readyAt + VALID_WINDOW_PERIOD);
await expectRevertWithErrorMessage("upgrade/upgrade-too-late", () => factory.upgrade([]));
});
Expand Down Expand Up @@ -144,23 +147,29 @@ describe("Test Factory Upgrade", function () {
await manager.setTime(CURRENT_TIME);
factory.connect(deployer);
const { transaction_hash: tx1 } = await factory.propose_upgrade(newClassHash, calldata);
await factory.get_proposed_implementation().should.eventually.equal(newClassHash);

const readyAt = await factory.get_upgrade_ready_at();
let pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.implementation === newClassHash);

await expectEvent(tx1, {
from_address: factory.address,
eventName: "UpgradeProposed",
data: CallData.compile([newClassHash.toString(), readyAt.toString(), calldata]),
data: CallData.compile([newClassHash.toString(), pendingUpgrade.ready_at.toString(), calldata]),
});

const { transaction_hash: tx2 } = await factory.propose_upgrade(replacementClassHash, calldata);
await factory.get_proposed_implementation().should.eventually.equal(replacementClassHash);

pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.implementation === replacementClassHash);

await expectEvent(tx2, {
from_address: factory.address,
eventName: "UpgradeCancelled",
data: [newClassHash.toString()],
data: [
newClassHash.toString(),
pendingUpgrade.ready_at.toString(),
BigInt(hash.computePoseidonHashOnElements(calldata)),
],
});
});
});
Expand All @@ -177,22 +186,27 @@ describe("Test Factory Upgrade", function () {

const { transaction_hash } = await factory.cancel_upgrade();

await factory.get_proposed_implementation().should.eventually.equal(0n);
await factory.get_upgrade_ready_at().should.eventually.equal(0n);
await factory.get_calldata_hash().should.eventually.equal(0n);

await expectEvent(transaction_hash, {
from_address: factory.address,
eventName: "UpgradeCancelled",
data: [newClassHash.toString()],
data: [
newClassHash.toString(),
(CURRENT_TIME + MIN_SECURITY_PERIOD).toString(),
BigInt(hash.computePoseidonHashOnElements(calldata)).toString(),
],
});

const pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.ready_at === 0n);
assert(pendingUpgrade.implementation === 0n);
assert(pendingUpgrade.calldata_hash === 0n);
});

it("No new implementation", async function () {
const { factory } = await setupGiftProtocol();

factory.connect(deployer);
await expectRevertWithErrorMessage("upgrade/no-new-implementation", () => factory.cancel_upgrade());
await expectRevertWithErrorMessage("upgrade/no-pending-upgrade", () => factory.cancel_upgrade());
});

it("Only Owner", async function () {
Expand Down
Loading