Skip to content

Commit

Permalink
fix hashing and repr strings for sql format params
Browse files Browse the repository at this point in the history
  • Loading branch information
jessekrubin committed Dec 18, 2024
1 parent bc7f1e3 commit 654c97c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 5 deletions.
63 changes: 58 additions & 5 deletions crates/ryo3-sqlformat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::unused_self)]

use std::collections::HashMap;

use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use sqlformat::{self, QueryParams};
use std::collections::{BTreeMap, HashMap};
use std::hash::{Hash, Hasher};

#[pyclass(name = "SqlfmtQueryParams", module = "ryo3")]
#[derive(Debug, Clone)]
Expand All @@ -33,16 +33,16 @@ impl PySqlfmtQueryParams {
sqlfmt_params(Some(params))
}

fn __str__(&self) -> String {
fn __repr__(&self) -> String {
match &self.params {
QueryParams::Named(p) => {
// collect into string for display
let s = p
.iter()
.map(|(k, v)| format!("(\"{k}, \"{v}\")"))
.map(|(k, v)| format!("\"{k}\": \"{v}\""))
.collect::<Vec<String>>()
.join(", ");
format!("SqlfmtQueryParams({s})")
format!("SqlfmtQueryParams({{{s}}})")
}
QueryParams::Indexed(p) => {
let s = p
Expand All @@ -55,6 +55,59 @@ impl PySqlfmtQueryParams {
QueryParams::None => String::from("SqlfmtQueryParams(None)"),
}
}

fn __str__(&self) -> String {
self.__repr__()
}

fn __len__(&self) -> usize {
match &self.params {
QueryParams::Named(p) => p.len(),
QueryParams::Indexed(p) => p.len(),
QueryParams::None => 0,
}
}

fn __eq__(&self, other: &PySqlfmtQueryParams) -> bool {
match (&self.params, &other.params) {
(QueryParams::Named(p1), QueryParams::Named(p2)) => {
// make 2 treeeeees...
let p1: HashMap<&str, &str> =
p1.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
let p2: HashMap<&str, &str> =
p2.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
p1 == p2
}
(QueryParams::Indexed(p1), QueryParams::Indexed(p2)) => p1 == p2,
(QueryParams::None, QueryParams::None) => true,
_ => false,
}
}

fn __ne__(&self, other: &PySqlfmtQueryParams) -> bool {
!self.__eq__(other)
}

fn __hash__(&self) -> PyResult<u64> {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
match &self.params {
QueryParams::Named(p) => {
let p: BTreeMap<&str, &str> =
p.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
for (k, v) in p.iter() {
k.hash(&mut hasher);
v.hash(&mut hasher);
}
}
QueryParams::Indexed(p) => {
for v in p.iter() {
v.hash(&mut hasher);
}
}
QueryParams::None => {}
}
Ok(hasher.finish())
}
}

impl From<Vec<(String, String)>> for PySqlfmtQueryParams {
Expand Down
1 change: 1 addition & 0 deletions tests/sqlformat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import annotations
File renamed without changes.
32 changes: 32 additions & 0 deletions tests/sqlformat/test_sqlparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import ry

params_list: list[tuple[str, int | str | float]] = [
("zoom_level", "0"),
("tile_column", 0),
("tile_row", "0"),
]

params_arr = [
params_list,
[("zoom_level", "0"), ("tile_column", "0"), ("tile_row", "0")],
{"zoom_level": "0", "tile_column": "0", "tile_row": "0"},
{"zoom_level": 0, "tile_column": 0, "tile_row": 0},
]

import pytest


@pytest.mark.parametrize("params", params_arr)
def test_sqlparams(params):
sqlfmt_params_obj = ry.sqlfmt_params(
params,
)

# test the repr
repr_str = "ry." + repr(sqlfmt_params_obj)
print(repr_str)
# exec
round_tripped = eval(repr_str)
print("ry." + repr(round_tripped))
assert sqlfmt_params_obj == round_tripped
assert hash(sqlfmt_params_obj) == hash(round_tripped)

0 comments on commit 654c97c

Please sign in to comment.