Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweaks to upgrade #50

Merged
merged 10 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 55 additions & 53 deletions src/contracts/timelock_upgrade.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
use starknet::ClassHash;
use core::num::traits::Zero;
use starknet::{ClassHash};

#[derive(Serde, Drop, Copy, Default, PartialEq, starknet::Store)]
struct PendingUpgrade {
// Gets the classhash after
implementation: ClassHash,
// Gets the timestamp when the upgrade is ready to be performed, 0 if no upgrade ongoing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be proposed implementation ?
also dont get the comment of 'gets classhash after?'
nit: 0 if no upgrade ongoing ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, i drop the proposed_ since it's already in a struct called PendingUpgrade, calldata would also be proposed_

ready_at: u64,
// Gets the hash of the calldata used for the upgrade, 0 if no upgrade ongoing
calldata_hash: felt252,
}

#[starknet::interface]
pub trait ITimelockUpgrade<TContractState> {
Expand All @@ -19,14 +30,8 @@ pub trait ITimelockUpgrade<TContractState> {
/// @param calldata The calldata to be used for the upgrade
fn upgrade(ref self: TContractState, calldata: Array<felt252>);

/// @notice Gets the proposed implementation
fn get_proposed_implementation(self: @TContractState) -> ClassHash;

/// @notice Gets the timestamp when the upgrade is ready to be performed, 0 if no upgrade ongoing
fn get_upgrade_ready_at(self: @TContractState) -> u64;

/// @notice Gets the hash of the calldata used for the upgrade, 0 if no upgrade ongoing
fn get_calldata_hash(self: @TContractState) -> felt252;
/// @notice Gets the proposed upgrade
fn get_pending_upgrade(self: @TContractState) -> PendingUpgrade;
}

#[starknet::interface]
Expand All @@ -46,27 +51,26 @@ pub mod TimelockUpgradeComponent {
use starknet::{get_block_timestamp, ClassHash};
use super::{
ITimelockUpgrade, ITimelockUpgradeCallback, ITimelockUpgradeCallbackLibraryDispatcher,
ITimelockUpgradeCallbackDispatcherTrait
ITimelockUpgradeCallbackDispatcherTrait, PendingUpgrade
};

/// Time before the upgrade can be performed
const MIN_SECURITY_PERIOD: u64 = consteval_int!(7 * 24 * 60 * 60); // 7 days
/// Time window during which the upgrade can be performed
const VALID_WINDOW_PERIOD: u64 = consteval_int!(7 * 24 * 60 * 60); // 7 days


#[storage]
pub struct Storage {
pending_implementation: ClassHash,
ready_at: u64,
calldata_hash: felt252,
pending_upgrade: PendingUpgrade
}

#[event]
#[derive(Drop, starknet::Event)]
pub enum Event {
UpgradeProposed: UpgradeProposed,
UpgradeCancelled: UpgradeCancelled,
Upgraded: Upgraded,
UpgradedExecuted: UpgradedExecuted,
}

#[derive(Drop, starknet::Event)]
Expand All @@ -78,12 +82,13 @@ pub mod TimelockUpgradeComponent {

#[derive(Drop, starknet::Event)]
struct UpgradeCancelled {
cancelled_implementation: ClassHash
cancelled_upgrade: PendingUpgrade
}

#[derive(Drop, starknet::Event)]
struct Upgraded {
new_implementation: ClassHash
struct UpgradedExecuted {
new_implementation: ClassHash,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UpgradeExecuted*

Copy link
Contributor

@gaetbout gaetbout Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This event is actually never emitted

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guessing it should be emitted by the callback impl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's supposed to be emitted by the next implementatio

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just meant the spelling is wrong is should be UpgradeExecuted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

calldata: Array<felt252>
}

#[embeddable_as(TimelockUpgradeImpl)]
Expand All @@ -99,53 +104,49 @@ pub mod TimelockUpgradeComponent {
self.assert_only_owner();
assert(new_implementation.is_non_zero(), 'upgrade/new-implementation-null');

let pending_implementation = self.pending_implementation.read();
if pending_implementation.is_non_zero() {
self.emit(UpgradeCancelled { cancelled_implementation: pending_implementation })
let pending_upgrade = self.pending_upgrade.read();
if pending_upgrade != Default::default() {
self.emit(UpgradeCancelled { cancelled_upgrade: pending_upgrade })
}

self.pending_implementation.write(new_implementation);
let ready_at = get_block_timestamp() + MIN_SECURITY_PERIOD;
self.ready_at.write(ready_at);
let calldata_hash = poseidon_hash_span(calldata.span());
self.calldata_hash.write(calldata_hash);
self
.pending_upgrade
.write(
PendingUpgrade {
implementation: new_implementation, ready_at, calldata_hash: poseidon_hash_span(calldata.span())
}
);
self.emit(UpgradeProposed { new_implementation, ready_at, calldata });
}

fn cancel_upgrade(ref self: ComponentState<TContractState>) {
self.assert_only_owner();
let proposed_implementation = self.pending_implementation.read();
assert(proposed_implementation.is_non_zero(), 'upgrade/no-new-implementation');
assert(self.ready_at.read() != 0, 'upgrade/not-ready');
self.emit(UpgradeCancelled { cancelled_implementation: proposed_implementation });
self.reset_storage();
let pending_upgrade = self.pending_upgrade.read();
assert(pending_upgrade != Default::default(), 'upgrade/no-pending-upgrade');
self.pending_upgrade.write(Default::default());
self.emit(UpgradeCancelled { cancelled_upgrade: pending_upgrade });
}

fn upgrade(ref self: ComponentState<TContractState>, calldata: Array<felt252>) {
self.assert_only_owner();
let new_implementation = self.pending_implementation.read();
let ready_at = self.ready_at.read();
let block_timestamp = get_block_timestamp();
let calldata_hash = poseidon_hash_span(calldata.span());
assert(calldata_hash == self.calldata_hash.read(), 'upgrade/invalid-calldata');
assert(new_implementation.is_non_zero(), 'upgrade/no-pending-upgrade');
assert(block_timestamp >= ready_at, 'upgrade/too-early');
assert(block_timestamp < ready_at + VALID_WINDOW_PERIOD, 'upgrade/upgrade-too-late');
self.reset_storage();
ITimelockUpgradeCallbackLibraryDispatcher { class_hash: new_implementation }
.perform_upgrade(new_implementation, calldata.span());
}
let pending_upgrade: PendingUpgrade = self.pending_upgrade.read();
assert(pending_upgrade != Default::default(), 'upgrade/no-pending-upgrade');
let PendingUpgrade { implementation, ready_at, calldata_hash } = pending_upgrade;

fn get_proposed_implementation(self: @ComponentState<TContractState>) -> ClassHash {
self.pending_implementation.read()
}
assert(calldata_hash == poseidon_hash_span(calldata.span()), 'upgrade/invalid-calldata');

fn get_upgrade_ready_at(self: @ComponentState<TContractState>) -> u64 {
self.ready_at.read()
let current_timestamp = get_block_timestamp();
assert(current_timestamp >= ready_at, 'upgrade/too-early');
assert(current_timestamp < ready_at + VALID_WINDOW_PERIOD, 'upgrade/upgrade-too-late');

self.pending_upgrade.write(Default::default());
ITimelockUpgradeCallbackLibraryDispatcher { class_hash: implementation }
.perform_upgrade(implementation, calldata.span());
}

fn get_calldata_hash(self: @ComponentState<TContractState>) -> felt252 {
self.calldata_hash.read()
fn get_pending_upgrade(self: @ComponentState<TContractState>) -> PendingUpgrade {
self.pending_upgrade.read()
}
}
#[generate_trait]
Expand All @@ -155,11 +156,12 @@ pub mod TimelockUpgradeComponent {
fn assert_only_owner(self: @ComponentState<TContractState>) {
get_dep_component!(self, Ownable).assert_only_owner();
}
}
}

fn reset_storage(ref self: ComponentState<TContractState>) {
self.pending_implementation.write(Zero::zero());
self.ready_at.write(0);
self.calldata_hash.write(0);
}

impl DefaultClassHash of Default<ClassHash> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn reset_storage(ref self: ComponentState<TContractState>) {
self.pending_implementation.write(Zero::zero());
self.ready_at.write(0);
self.calldata_hash.write(0);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is just an extra line return 🤣

fn default() -> ClassHash {
Zero::zero()
}
}
1 change: 1 addition & 0 deletions src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ pub mod contracts {
mod mocks {
mod broken_erc20;
mod erc20;
mod future_factory;
mod reentrant_erc20;
}
55 changes: 55 additions & 0 deletions src/mocks/future_factory.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use starknet::{ContractAddress, ClassHash};

#[starknet::contract]
mod FutureFactory {
use argent_gifting::contracts::timelock_upgrade::{ITimelockUpgradeCallback, TimelockUpgradeComponent};
use core::panic_with_felt252;
use openzeppelin::access::ownable::OwnableComponent;
use starknet::{
ClassHash, ContractAddress, syscalls::deploy_syscall, get_caller_address, get_contract_address, account::Call,
get_block_timestamp
};

// Ownable
component!(path: OwnableComponent, storage: ownable, event: OwnableEvent);
#[abi(embed_v0)]
impl OwnableImpl = OwnableComponent::OwnableImpl<ContractState>;

// TimelockUpgradeable
component!(path: TimelockUpgradeComponent, storage: timelock_upgrade, event: TimelockUpgradeEvent);
#[abi(embed_v0)]
impl TimelockUpgradeImpl = TimelockUpgradeComponent::TimelockUpgradeImpl<ContractState>;

#[storage]
struct Storage {
#[substorage(v0)]
ownable: OwnableComponent::Storage,
#[substorage(v0)]
timelock_upgrade: TimelockUpgradeComponent::Storage,
}

#[event]
#[derive(Drop, starknet::Event)]
enum Event {
#[flat]
OwnableEvent: OwnableComponent::Event,
#[flat]
TimelockUpgradeEvent: TimelockUpgradeComponent::Event,
}

#[constructor]
fn constructor(ref self: ContractState) {}


#[external(v0)]
fn get_num(self: @ContractState) -> u128 {
1
}

#[abi(embed_v0)]
impl TimelockUpgradeCallbackImpl of ITimelockUpgradeCallback<ContractState> {
fn perform_upgrade(ref self: ContractState, new_implementation: ClassHash, data: Span<felt252>) {
starknet::syscalls::replace_class_syscall(new_implementation).unwrap();
}
}
}

This file was deleted.

This file was deleted.

59 changes: 37 additions & 22 deletions tests-integration/upgrade.test.ts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did the asserts change to == and not should.equal..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not a big fan of chai, feels like reinventing the wheel trying to replace all the stuff we use with functions.
> -> .isBelow
=== -> equals but there is also deep.equals , eq and eql
await -> eventually

i'm not really sure having to introducing all this stuff is worth it to get slightly better errors
i leave my rant here but i changed it to align with the rest of the project and get this merged

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { CallData, hash } from "starknet";
import { assert } from "chai";
import { CallData, hash, num } from "starknet";
import {
deployer,
devnetAccount,
Expand All @@ -8,7 +9,6 @@ import {
protocolCache,
setupGiftProtocol,
} from "../lib";

// Time window which must pass before the upgrade can be performed
const MIN_SECURITY_PERIOD = 7n * 24n * 60n * 60n; // 7 day

Expand All @@ -20,27 +20,30 @@ const CURRENT_TIME = 1718898082n;
describe("Test Factory Upgrade", function () {
it("Upgrade", async function () {
const { factory } = await setupGiftProtocol();
const newFactoryClassHash = await manager.declareFixtureContract("GiftFactoryUpgrade");
const newFactoryClassHash = await manager.declareLocalContract("FutureFactory");
const calldata: any[] = [];

await manager.setTime(CURRENT_TIME);
factory.connect(deployer);
await factory.propose_upgrade(newFactoryClassHash, calldata);

await factory.get_upgrade_ready_at().should.eventually.equal(CURRENT_TIME + MIN_SECURITY_PERIOD);
await factory.get_proposed_implementation().should.eventually.equal(BigInt(newFactoryClassHash));
await factory.get_calldata_hash().should.eventually.equal(BigInt(hash.computePoseidonHashOnElements(calldata)));
let pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.ready_at === CURRENT_TIME + MIN_SECURITY_PERIOD);
assert(pendingUpgrade.implementation === num.toBigInt(newFactoryClassHash));
assert(pendingUpgrade.calldata_hash === BigInt(hash.computePoseidonHashOnElements(calldata)));

await manager.setTime(CURRENT_TIME + MIN_SECURITY_PERIOD + 1n);
await factory.upgrade(calldata);

// reset storage
await factory.get_proposed_implementation().should.eventually.equal(0n);
await factory.get_upgrade_ready_at().should.eventually.equal(0n);
await factory.get_calldata_hash().should.eventually.equal(0n);
// check storage was reset
pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.ready_at === 0n);
assert(pendingUpgrade.implementation === 0n);
assert(pendingUpgrade.calldata_hash === 0n);

await manager.getClassHashAt(factory.address).should.eventually.equal(newFactoryClassHash);

// test new factory has new method
const newFactory = await manager.loadContract(factory.address, newFactoryClassHash);
newFactory.connect(deployer);
await newFactory.get_num().should.eventually.equal(1n);
Expand Down Expand Up @@ -109,7 +112,8 @@ describe("Test Factory Upgrade", function () {
factory.connect(deployer);
await factory.propose_upgrade(newFactoryClassHash, []);

const readyAt = await factory.get_upgrade_ready_at();
const pendingUpgrade = await factory.get_pending_upgrade();
const readyAt = pendingUpgrade.ready_at;
await manager.setTime(CURRENT_TIME + readyAt + VALID_WINDOW_PERIOD);
await expectRevertWithErrorMessage("upgrade/upgrade-too-late", () => factory.upgrade([]));
});
Expand Down Expand Up @@ -144,23 +148,29 @@ describe("Test Factory Upgrade", function () {
await manager.setTime(CURRENT_TIME);
factory.connect(deployer);
const { transaction_hash: tx1 } = await factory.propose_upgrade(newClassHash, calldata);
await factory.get_proposed_implementation().should.eventually.equal(newClassHash);

const readyAt = await factory.get_upgrade_ready_at();
let pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.implementation === newClassHash);

await expectEvent(tx1, {
from_address: factory.address,
eventName: "UpgradeProposed",
data: CallData.compile([newClassHash.toString(), readyAt.toString(), calldata]),
data: CallData.compile([newClassHash.toString(), pendingUpgrade.ready_at.toString(), calldata]),
});

const { transaction_hash: tx2 } = await factory.propose_upgrade(replacementClassHash, calldata);
await factory.get_proposed_implementation().should.eventually.equal(replacementClassHash);

pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.implementation === replacementClassHash);

await expectEvent(tx2, {
from_address: factory.address,
eventName: "UpgradeCancelled",
data: [newClassHash.toString()],
data: [
newClassHash.toString(),
pendingUpgrade.ready_at.toString(),
BigInt(hash.computePoseidonHashOnElements(calldata)),
],
});
});
});
Expand All @@ -177,22 +187,27 @@ describe("Test Factory Upgrade", function () {

const { transaction_hash } = await factory.cancel_upgrade();

await factory.get_proposed_implementation().should.eventually.equal(0n);
await factory.get_upgrade_ready_at().should.eventually.equal(0n);
await factory.get_calldata_hash().should.eventually.equal(0n);

await expectEvent(transaction_hash, {
from_address: factory.address,
eventName: "UpgradeCancelled",
data: [newClassHash.toString()],
data: [
newClassHash.toString(),
(CURRENT_TIME + MIN_SECURITY_PERIOD).toString(),
BigInt(hash.computePoseidonHashOnElements(calldata)).toString(),
],
});

const pendingUpgrade = await factory.get_pending_upgrade();
assert(pendingUpgrade.ready_at === 0n);
assert(pendingUpgrade.implementation === 0n);
assert(pendingUpgrade.calldata_hash === 0n);
});

it("No new implementation", async function () {
const { factory } = await setupGiftProtocol();

factory.connect(deployer);
await expectRevertWithErrorMessage("upgrade/no-new-implementation", () => factory.cancel_upgrade());
await expectRevertWithErrorMessage("upgrade/no-pending-upgrade", () => factory.cancel_upgrade());
});

it("Only Owner", async function () {
Expand Down
Loading