diff --git a/siwe_auth/backend.py b/siwe_auth/backend.py index 3147f31..1a76968 100644 --- a/siwe_auth/backend.py +++ b/siwe_auth/backend.py @@ -115,7 +115,7 @@ def authenticate(self, request, signature: str = None, siwe_message: SiweMessage logging.info(f"Created group '{custom_group[0]}'.") group_manager: GroupManager = custom_group[1] if group_manager.is_member( - ethereum_address=wallet.ethereum_address, + wallet=wallet, provider=HTTPProvider(settings.PROVIDER), ): logging.info( diff --git a/siwe_auth/custom_groups/erc1155.py b/siwe_auth/custom_groups/erc1155.py index 7388a6e..01945e8 100644 --- a/siwe_auth/custom_groups/erc1155.py +++ b/siwe_auth/custom_groups/erc1155.py @@ -54,20 +54,22 @@ def _is_member( return expression(balance) @abstractmethod - def is_member(self, ethereum_address: str, provider: HTTPProvider) -> bool: + def is_member(self, wallet: object, provider: HTTPProvider) -> bool: pass class ERC1155OwnerManager(ERC1155Manager): - def is_member(self, ethereum_address: str, provider: HTTPProvider) -> bool: + def is_member(self, wallet: object, provider: HTTPProvider) -> bool: + if not self._valid_wallet(wallet=wallet): + return False try: return self._is_member( - ethereum_address=ethereum_address, + ethereum_address=wallet.ethereum_address, provider=provider, expression=lambda x: x > 0, ) except ValueError: logging.error( - f"Unable to verify membership of invalid address: {ethereum_address}" + f"Unable to verify membership of invalid address: {wallet.ethereum_address}" ) return False diff --git a/siwe_auth/custom_groups/erc20.py b/siwe_auth/custom_groups/erc20.py index e00b1ad..a3cb0a1 100644 --- a/siwe_auth/custom_groups/erc20.py +++ b/siwe_auth/custom_groups/erc20.py @@ -3,6 +3,7 @@ from typing import Callable from siwe_auth.custom_groups.group_manager import GroupManager +from siwe_auth.utils.data_classes import EthereumBaseClass from web3 import Web3, HTTPProvider @@ -41,20 +42,22 @@ def _is_member( return expression(balance) @abstractmethod - def is_member(self, ethereum_address: str, provider: HTTPProvider) -> bool: + def is_member(self, wallet: EthereumBaseClass, provider: HTTPProvider) -> bool: pass class ERC20OwnerManager(ERC20Manager): - def is_member(self, ethereum_address: str, provider: HTTPProvider) -> bool: + def is_member(self, wallet: EthereumBaseClass, provider: HTTPProvider) -> bool: + if not self._valid_wallet(wallet=wallet): + return False try: return self._is_member( - ethereum_address=ethereum_address, + ethereum_address=wallet.ethereum_address, provider=provider, expression=lambda x: x > 0, ) except ValueError: logging.error( - f"Unable to verify membership of invalid address: {ethereum_address}" + f"Unable to verify membership of invalid address: {wallet.ethereum_address}" ) return False diff --git a/siwe_auth/custom_groups/erc721.py b/siwe_auth/custom_groups/erc721.py index 4790395..8f1e974 100644 --- a/siwe_auth/custom_groups/erc721.py +++ b/siwe_auth/custom_groups/erc721.py @@ -41,20 +41,22 @@ def _is_member( return expression(balance) @abstractmethod - def is_member(self, ethereum_address: str, provider: HTTPProvider) -> bool: + def is_member(self, wallet: object, provider: HTTPProvider) -> bool: pass class ERC721OwnerManager(ERC721Manager): - def is_member(self, ethereum_address: str, provider: HTTPProvider) -> bool: + def is_member(self, wallet: object, provider: HTTPProvider) -> bool: + if not self._valid_wallet(wallet=wallet): + return False try: return self._is_member( - ethereum_address=ethereum_address, + ethereum_address=wallet.ethereum_address, provider=provider, expression=lambda x: x > 0, ) except ValueError: logging.error( - f"Unable to verify membership of invalid address: {ethereum_address}" + f"Unable to verify membership of invalid address: {wallet.ethereum_address}" ) return False diff --git a/siwe_auth/custom_groups/group_manager.py b/siwe_auth/custom_groups/group_manager.py index 4166529..dc12b00 100644 --- a/siwe_auth/custom_groups/group_manager.py +++ b/siwe_auth/custom_groups/group_manager.py @@ -12,11 +12,14 @@ def __init__(self, config: dict): pass @abstractmethod - def is_member(self, ethereum_address: str, provider: HTTPProvider) -> bool: + def is_member(self, wallet: object, provider: HTTPProvider) -> bool: """ Membership function to identify if a given ethereum address is part of this class' group. :param provider: Web3 provider to use for membership check. - :param ethereum_address: Address to check membership of. + :param wallet: Object with ethereum_address attribute to check membership of. :return: True if address is a member else False """ pass + + def _valid_wallet(self, wallet: object): + return wallet.__getattribute__('ethereum_address') is not None diff --git a/siwe_auth/models.py b/siwe_auth/models.py index b05721f..199b782 100644 --- a/siwe_auth/models.py +++ b/siwe_auth/models.py @@ -1,7 +1,5 @@ from datetime import datetime - from django.db import models - from django.contrib.auth.models import ( BaseUserManager, AbstractBaseUser,