diff --git a/crates/ryo3-sqlformat/src/lib.rs b/crates/ryo3-sqlformat/src/lib.rs index b3bb438..44c9484 100644 --- a/crates/ryo3-sqlformat/src/lib.rs +++ b/crates/ryo3-sqlformat/src/lib.rs @@ -14,11 +14,11 @@ #![allow(clippy::module_name_repetitions)] #![allow(clippy::unused_self)] -use std::collections::HashMap; - use pyo3::prelude::PyModule; use pyo3::prelude::*; use sqlformat::{self, QueryParams}; +use std::collections::{BTreeMap, HashMap}; +use std::hash::{Hash, Hasher}; #[pyclass(name = "SqlfmtQueryParams", module = "ryo3")] #[derive(Debug, Clone)] @@ -33,16 +33,16 @@ impl PySqlfmtQueryParams { sqlfmt_params(Some(params)) } - fn __str__(&self) -> String { + fn __repr__(&self) -> String { match &self.params { QueryParams::Named(p) => { // collect into string for display let s = p .iter() - .map(|(k, v)| format!("(\"{k}, \"{v}\")")) + .map(|(k, v)| format!("\"{k}\": \"{v}\"")) .collect::>() .join(", "); - format!("SqlfmtQueryParams({s})") + format!("SqlfmtQueryParams({{{s}}})") } QueryParams::Indexed(p) => { let s = p @@ -55,6 +55,59 @@ impl PySqlfmtQueryParams { QueryParams::None => String::from("SqlfmtQueryParams(None)"), } } + + fn __str__(&self) -> String { + self.__repr__() + } + + fn __len__(&self) -> usize { + match &self.params { + QueryParams::Named(p) => p.len(), + QueryParams::Indexed(p) => p.len(), + QueryParams::None => 0, + } + } + + fn __eq__(&self, other: &PySqlfmtQueryParams) -> bool { + match (&self.params, &other.params) { + (QueryParams::Named(p1), QueryParams::Named(p2)) => { + // make 2 treeeeees... + let p1: HashMap<&str, &str> = + p1.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); + let p2: HashMap<&str, &str> = + p2.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); + p1 == p2 + } + (QueryParams::Indexed(p1), QueryParams::Indexed(p2)) => p1 == p2, + (QueryParams::None, QueryParams::None) => true, + _ => false, + } + } + + fn __ne__(&self, other: &PySqlfmtQueryParams) -> bool { + !self.__eq__(other) + } + + fn __hash__(&self) -> PyResult { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + match &self.params { + QueryParams::Named(p) => { + let p: BTreeMap<&str, &str> = + p.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); + for (k, v) in p.iter() { + k.hash(&mut hasher); + v.hash(&mut hasher); + } + } + QueryParams::Indexed(p) => { + for v in p.iter() { + v.hash(&mut hasher); + } + } + QueryParams::None => {} + } + Ok(hasher.finish()) + } } impl From> for PySqlfmtQueryParams { diff --git a/tests/sqlformat/__init__.py b/tests/sqlformat/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/tests/sqlformat/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/tests/test_sqlfmt.py b/tests/sqlformat/test_sqlfmt.py similarity index 100% rename from tests/test_sqlfmt.py rename to tests/sqlformat/test_sqlfmt.py diff --git a/tests/sqlformat/test_sqlparams.py b/tests/sqlformat/test_sqlparams.py new file mode 100644 index 0000000..1b98767 --- /dev/null +++ b/tests/sqlformat/test_sqlparams.py @@ -0,0 +1,32 @@ +import ry + +params_list: list[tuple[str, int | str | float]] = [ + ("zoom_level", "0"), + ("tile_column", 0), + ("tile_row", "0"), +] + +params_arr = [ + params_list, + [("zoom_level", "0"), ("tile_column", "0"), ("tile_row", "0")], + {"zoom_level": "0", "tile_column": "0", "tile_row": "0"}, + {"zoom_level": 0, "tile_column": 0, "tile_row": 0}, +] + +import pytest + + +@pytest.mark.parametrize("params", params_arr) +def test_sqlparams(params): + sqlfmt_params_obj = ry.sqlfmt_params( + params, + ) + + # test the repr + repr_str = "ry." + repr(sqlfmt_params_obj) + print(repr_str) + # exec + round_tripped = eval(repr_str) + print("ry." + repr(round_tripped)) + assert sqlfmt_params_obj == round_tripped + assert hash(sqlfmt_params_obj) == hash(round_tripped)