Skip to content

Commit

Permalink
implement address reuse check and related refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Jwyman328 committed Nov 5, 2024
1 parent 4ec2da7 commit 5bb23d7
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 194 deletions.
17 changes: 17 additions & 0 deletions backend/src/models/last_fetched.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from sqlalchemy import Enum, Integer
from enum import Enum as PyEnum


from src.database import DB


class LastFetchedType(PyEnum):
OUTPUTS = "outputs"


class LastFetched(DB.Model):
__tablename__ = "last_fetched"

id = DB.Column(Integer, primary_key=True, autoincrement=True)
type = DB.Column(Enum(LastFetchedType), unique=True, nullable=False)
timestamp = DB.Column(DB.DateTime, nullable=False, default=None)
8 changes: 6 additions & 2 deletions backend/src/models/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ class Output(DB.Model):
)
vout = DB.Column(DB.Integer, nullable=False, default=0)

address = DB.Column(DB.String(), nullable=False)

# Relationship to labels
labels = DB.relationship("Label", secondary=output_labels, back_populates="outputs")
labels = DB.relationship(
"Label", secondary=output_labels, back_populates="outputs")

# Unique constraint on the combination of txid and vout
__table_args__ = (DB.UniqueConstraint("txid", "vout", name="uq_txid_vout"),)
__table_args__ = (DB.UniqueConstraint(
"txid", "vout", name="uq_txid_vout"),)
2 changes: 2 additions & 0 deletions backend/src/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from src.services.hardware_wallet.hardware_wallet import HardwareWalletService

from src.services.privacy_metrics.privacy_metrics import PrivacyMetricsService

from src.services.last_fetched.last_fetched_service import LastFetchedService
40 changes: 40 additions & 0 deletions backend/src/services/last_fetched/last_fetched_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from src.models.last_fetched import LastFetched, LastFetchedType
from datetime import datetime
from src.database import DB
from typing import Optional

import structlog

LOGGER = structlog.get_logger()


class LastFetchedService:
@classmethod
def update_last_fetched_outputs_type(
self,
) -> None:
"""Update the last fetched time for the outputs."""
timestamp = datetime.now()
current_last_fetched_output = LastFetched.query.filter_by(
type=LastFetchedType.OUTPUTS
).first()
if current_last_fetched_output:
current_last_fetched_output.timestamp = timestamp
else:
last_fetched_output = LastFetched(
type=LastFetchedType.OUTPUTS, timestamp=timestamp
)
DB.session.add(last_fetched_output)
DB.session.commit()

@classmethod
def get_last_fetched_output_datetime(
self,
) -> Optional[datetime]:
"""Get the last fetched time for the outputs."""
last_fetched_output = LastFetched.query.filter_by(
type=LastFetchedType.OUTPUTS
).first()
if last_fetched_output:
return last_fetched_output.timestamp
return None
40 changes: 38 additions & 2 deletions backend/src/services/privacy_metrics/privacy_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from src.database import DB
from src.models.privacy_metric import PrivacyMetric, PrivacyMetricName
from src.models.outputs import Output as OutputModel
from datetime import datetime, timedelta


import structlog

LOGGER = structlog.get_logger()


class PrivacyMetricsService:
Expand Down Expand Up @@ -93,11 +100,40 @@ def analyze_annominit_set(cls, txid: str, desired_annominity_set: int) -> bool:
return True

@classmethod
def analyze_no_address_reuse(cls, txid: str) -> bool:
def analyze_no_address_reuse(
cls,
txid: str,
) -> bool:
# can I inject this in?
# circular imports are currently preventing it.
from src.services.wallet.wallet import WalletService
from src.services.last_fetched.last_fetched_service import LastFetchedService

# Check that all outputs have already been fetched recently
last_fetched_output_datetime = (
LastFetchedService.get_last_fetched_output_datetime()
)
now = datetime.now()
refetch_interval = timedelta(minutes=5)
should_refetch_outputs = now - last_fetched_output_datetime > refetch_interval

if last_fetched_output_datetime is None or should_refetch_outputs:
LOGGER.info("No last fetched output datetime found, fetching all outputs")
# this will get all the outputs and add them to the database, ensuring that they exist
WalletService.get_all_outputs()
outputs = OutputModel.query.filter_by(txid=txid).all()
for output in outputs:
if WalletService.is_address_reused(output.address):
LOGGER.info(f"Address {output.address} has been reused")
return False

return True

@classmethod
def analyze_minimal_wealth_reveal(cls, txid: str) -> bool:
def analyze_minimal_wealth_reveal(
cls,
txid: str,
) -> bool:
return True

@classmethod
Expand Down
80 changes: 50 additions & 30 deletions backend/src/services/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
PopulateOutputLabelsRequestDto,
)
from src.my_types.transactions import LiveWalletOutput
from src.services.last_fetched.last_fetched_service import LastFetchedService
from src.services.wallet.raw_output_script_examples import (
p2pkh_raw_output_script,
p2pk_raw_output_script,
Expand Down Expand Up @@ -161,8 +162,7 @@ def connect_wallet(
)

wallet_change_descriptor = (
bdk.Descriptor(change_descriptor,
bdk.Network._value2member_map_[network])
bdk.Descriptor(change_descriptor, bdk.Network._value2member_map_[network])
if change_descriptor
else None
)
Expand All @@ -184,8 +184,7 @@ def connect_wallet(
database_config=db_config,
)

LOGGER.info(
f"Connecting a new wallet to electrum server {wallet_details_id}")
LOGGER.info(f"Connecting a new wallet to electrum server {wallet_details_id}")
LOGGER.info(f"xpub {wallet_descriptor.as_string()}")

wallet.sync(blockchain, None)
Expand Down Expand Up @@ -234,8 +233,7 @@ def create_spendable_descriptor(
twelve_word_secret = bdk.Mnemonic(bdk.WordCount.WORDS12)

# xpriv
descriptor_secret_key = bdk.DescriptorSecretKey(
network, twelve_word_secret, "")
descriptor_secret_key = bdk.DescriptorSecretKey(network, twelve_word_secret, "")

wallet_descriptor = None
if script_type == ScriptType.P2PKH:
Expand Down Expand Up @@ -291,12 +289,13 @@ def get_all_utxos(self) -> List[bdk.LocalUtxo]:
utxos = self.wallet.list_unspent()
return utxos

@classmethod
def get_all_transactions(
self,
cls,
) -> List[Transaction]:
"""Get all transactions for the current wallet."""
wallet_details = Wallet.get_current_wallet()
if self.wallet is None or wallet_details is None:
if cls.wallet is None or wallet_details is None:
LOGGER.error("No electrum wallet or wallet details found.")
return []

Expand All @@ -311,7 +310,7 @@ def get_all_transactions(
LOGGER.error("No electrum url or port found in the wallet details")
return []

transactions = self.wallet.list_transactions(False)
transactions = cls.wallet.list_transactions(False)

all_tx_details: List[Transaction] = []

Expand All @@ -332,24 +331,27 @@ def get_all_transactions(
all_tx_details.append(electrum_response.data)
return all_tx_details

def get_all_outputs(self) -> List[LiveWalletOutput]:
@classmethod
def get_all_outputs(cls) -> List[LiveWalletOutput]:
"""Get all spent and unspent transaction outputs for the current wallet and mutate them as needed.
Calculate the annominity set for each output.
Attach the txid to each output.
Attach all labels to each output.
Sync the database with the incoming outputs.
"""
all_transactions = self.get_all_transactions()
all_transactions = cls.get_all_transactions()
all_outputs: List[LiveWalletOutput] = []
for transaction in all_transactions:
annominity_sets = self.calculate_output_annominity_sets(
transaction.outputs)
annominity_sets = cls.calculate_output_annominity_sets(transaction.outputs)
for output in transaction.outputs:
db_output = self.sync_local_db_with_incoming_output(
txid=transaction.txid, vout=output.output_n
)
script = bdk.Script(output.script.raw)
if self.wallet and self.wallet.is_mine(script):
if cls.wallet and cls.wallet.is_mine(script):
db_output = cls.sync_local_db_with_incoming_output(
txid=transaction.txid,
vout=output.output_n,
address=output.address,
)
LastFetchedService.update_last_fetched_outputs_type()
annominity_set = annominity_sets.get(output.value, 1)

extended_output = LiveWalletOutput(
Expand All @@ -364,27 +366,33 @@ def get_all_outputs(self) -> List[LiveWalletOutput]:
return all_outputs

# TODO add a better name since this is just adding the output to the db
@classmethod
def sync_local_db_with_incoming_output(
self,
cls,
txid: str,
vout: int,
address: str,
) -> OutputModel:
"""Sync the local database with the incoming output.
If the output is not in the database, add it.
"""

db_output = OutputModel.query.filter_by(txid=txid, vout=vout).first()
db_output = OutputModel.query.filter_by(
txid=txid, vout=vout, address=address
).first()
if not db_output:
db_output = self.add_output_to_db(txid=txid, vout=vout)
db_output = cls.add_output_to_db(txid=txid, vout=vout, address=address)
return db_output

def add_output_to_db(self, vout: int, txid: str) -> OutputModel:
db_output = OutputModel(txid=txid, vout=vout, labels=[])
@classmethod
def add_output_to_db(cls, vout: int, txid: str, address: str) -> OutputModel:
db_output = OutputModel(txid=txid, vout=vout, address=address, labels=[])
DB.session.add(db_output)
DB.session.commit()
return db_output

@classmethod
def calculate_output_annominity_sets(
self, transaction_outputs: List[Output]
) -> dict[str, int]: # -> {"value": count }
Expand Down Expand Up @@ -435,7 +443,7 @@ def get_output_labels_unique(

for output in outputs_with_label:
for label in output.labels:
key = f"{output.txid}-{output.vout}"
key = f"{output.txid}-{output.vout}-{output.address}"
if result.get(key, None) is None:
result[key] = [
OutputLabelDto(
Expand All @@ -455,25 +463,27 @@ def get_output_labels_unique(

return result

@classmethod
def populate_outputs_and_labels(
self, populate_output_labels: PopulateOutputLabelsRequestDto
cls, populate_output_labels: PopulateOutputLabelsRequestDto
) -> None: # TODO maybe a success of fail reutn type?
try:
model_dump = populate_output_labels.model_dump()
for unique_output_txid_vout in model_dump.keys():
txid, vout = unique_output_txid_vout.split("-")
self.sync_local_db_with_incoming_output(txid, int(vout))
txid, vout, address = unique_output_txid_vout.split("-")
cls.sync_local_db_with_incoming_output(txid, int(vout), address)
output_labels = model_dump[unique_output_txid_vout]
for label in output_labels:
display_name = label["display_name"]
self.add_label_to_output(txid, int(vout), display_name)
cls.add_label_to_output(txid, int(vout), display_name)
except Exception as e:
LOGGER.error("Error populating outputs and labels", error=e)
DB.session.rollback()

# TODO should this even go here or in its own service?
@classmethod
def add_label_to_output(
self, txid: str, vout: int, label_display_name: str
cls, txid: str, vout: int, label_display_name: str
) -> list[Label]:
"""Add a label to an output in the db."""
db_output = OutputModel.query.filter_by(txid=txid, vout=vout).first()
Expand Down Expand Up @@ -556,8 +566,7 @@ def build_transaction(
script, amount_per_recipient_output
)

built_transaction: bdk.TxBuilderResult = tx_builder.finish(
self.wallet)
built_transaction: bdk.TxBuilderResult = tx_builder.finish(self.wallet)

built_transaction.transaction_details.transaction
return BuildTransactionResponseType(
Expand Down Expand Up @@ -632,3 +641,14 @@ def get_fee_estimate_for_utxos_from_request(
)

return fee_estimate_response

@classmethod
def is_address_reused(self, address: str) -> bool:
"""Check if the address has been used in the wallet more than once."""
outputs_with_this_address = OutputModel.query.filter_by(address=address).all()
address_used_count = len(outputs_with_this_address)

if address_used_count > 1:
return True
else:
return False
Loading

0 comments on commit 5bb23d7

Please sign in to comment.