From abf0452215954e6aef34faee52fb5f9b43f4dbea Mon Sep 17 00:00:00 2001 From: Sh1nku <42642351+Sh1nku@users.noreply.github.com> Date: Sat, 29 Jun 2024 18:04:30 +0200 Subject: [PATCH] Make structs subclassable --- wrappers/python/README.md | 85 ++++++++++++++++--- wrappers/python/src/clients.rs | 8 +- wrappers/python/src/hosts.rs | 10 +-- wrappers/python/src/models/auth.rs | 4 +- wrappers/python/src/models/context.rs | 2 +- wrappers/python/src/models/facet_set.rs | 14 ++- wrappers/python/src/models/group.rs | 4 +- wrappers/python/src/models/json_facet.rs | 6 +- wrappers/python/src/models/response.rs | 4 +- .../src/queries/components/facet_set.rs | 8 +- .../python/src/queries/components/grouping.rs | 2 +- .../src/queries/components/json_facet.rs | 10 +-- wrappers/python/src/queries/def_type.rs | 8 +- wrappers/python/src/queries/index.rs | 4 +- wrappers/python/src/queries/select.rs | 2 +- wrappers/python/tests/test_clients.py | 31 +++++++ 16 files changed, 153 insertions(+), 49 deletions(-) diff --git a/wrappers/python/README.md b/wrappers/python/README.md index 83f3157..ff924fc 100644 --- a/wrappers/python/README.md +++ b/wrappers/python/README.md @@ -1,25 +1,33 @@ # Solrstice: A Solr 8+ Client for Rust and Python + Solrstice is a solr client library written in rust. With this wrapper you can use it in python. Both asyncio and blocking clients are provided. All apis have type hints. Documentation can be found at [sh1nku.github.io/solrstice/python](https://sh1nku.github.io/solrstice/python) + ## Features + * Config API * Collection API * Alias API * Select Documents - * Grouping Component Query - * DefTypes (lucene, dismax, edismax) - * Facet Counts (Query, Field, Pivot) - * Json Facet (Query, Stat, Terms, Nested) + * Grouping Component Query + * DefTypes (lucene, dismax, edismax) + * Facet Counts (Query, Field, Pivot) + * Json Facet (Query, Stat, Terms, Nested) * Indexing Documents * Deleting Documents + ## Installation + ```bash pip install solrstice ``` + ## Basic Usage + ### Async + ```python import asyncio @@ -32,25 +40,28 @@ from solrstice.queries import DeleteQuery, SelectQuery, UpdateQuery context = SolrServerContext(SolrSingleServerHost('localhost:8983'), SolrBasicAuth('solr', 'SolrRocks')) client = AsyncSolrCloudClient(context) + async def main(): # Create config and collection await client.upload_config('example_config', 'path/to/config') await client.create_collection('example_collection', 'example_config', shards=1, replication_factor=1) - + # Index a document await client.index(UpdateQuery(), 'example_collection', [{'id': 'example_document', 'title': 'Example document'}]) - + # Search for the document response = await client.select(SelectQuery(fq=['title:Example document']), 'example_collection') docs = response.get_docs_response().get_docs() - + # Delete the document await client.delete(DeleteQuery(ids=['example_document']), 'example_collection') - + asyncio.run(main()) ``` + ### Blocking + ```python from solrstice.auth import SolrBasicAuth from solrstice.clients import BlockingSolrCloudClient @@ -77,7 +88,9 @@ client.delete(DeleteQuery(ids=['example_document']), 'example_collection') ``` ## Grouping component + ### Field grouping + ```python group_builder = GroupingComponent(fields=["age"], limit=10) select_builder = SelectQuery(fq=["age:[* TO *]"], grouping=group_builder) @@ -85,7 +98,9 @@ groups = await client.select(select_builder, "example_collection").get_groups() age_group = groups["age"] docs = age_group.get_field_result() ``` + ### Query grouping + ```python group_builder = GroupingComponent(queries=["age:[0 TO 59]", "age:[60 TO *]"], limit=10) select_builder = SelectQuery(fq=["age:[* TO *]"], grouping=group_builder) @@ -94,30 +109,40 @@ age_group = groups["age:[0 TO 59]"] group = age_group.get_query_result() docs = group.get_docs() ``` + ## Query parsers + ### Lucene + ```python query_parser = LuceneQuery(df="population") select_builder = SelectQuery(q="outdoors", def_type=query_parser) await client.select(select_builder, "example_collection") docs = response.get_docs_response().get_docs() ``` + ### Dismax + ```python query_parser = DismaxQuery(qf="interests^20", bq=["interests:cars^20"]) select_builder = SelectQuery(q="outdoors", def_type=query_parser) await client.select(select_builder, "example_collection") docs = response.get_docs_response().get_docs() ``` + ### Edismax + ```python query_parser = EdismaxQuery(qf="interests^20", bq=["interests:cars^20"]) select_builder = SelectQuery(q="outdoors", def_type=query_parser) await client.select(select_builder, "example_collection") docs = response.get_docs_response().get_docs() ``` + ## FacetSet Component + ### Pivot facet + ```python select_builder = SelectQuery(facet_set=FacetSetComponent(pivots=PivotFacetComponent(["interests,age"]))) await client.select(select_builder, "example_collection") @@ -125,7 +150,9 @@ facets = response.get_facet_set() pivots = facets.get_pivots() interests_age = pivot.get("interests,age") ``` + ### Field facet + ```python facet_set = FacetSetComponent(fields=FieldFacetComponent(fields=[FieldFacetEntry("age")])) select_builder = SelectQuery(facet_set=facet_set) @@ -134,7 +161,9 @@ facets = response.get_facet_set() fields = facets.get_fields() age = fields.get("age") ``` + ### Query facet + ```python select_builder = SelectQuery(facet_set=FacetSetComponent(queries=["age:[0 TO 59]"])) response = await client.select(select_builder, name) @@ -142,20 +171,25 @@ facets = response.get_facet_set() queries = facets.get_queries() query = queries.get("age:[0 TO 59]") ``` + ## Json Facet Component + ### Query + ```python select_builder = SelectQuery( - json_facet=JsonFacetComponent( - facets={"below_60": JsonQueryFacet("age:[0 TO 59]")} - ) + json_facet=JsonFacetComponent( + facets={"below_60": JsonQueryFacet("age:[0 TO 59]")} + ) ) response = await client.select(select_builder, "example_collection"") facets = response.get_json_facets() below_60 = facets.get_nested_facets().get("below_60") assert below_60.get_count() == 4 ``` + ### Stat + ```python select_builder = SelectQuery( json_facet=JsonFacetComponent( @@ -167,7 +201,9 @@ facets = response.get_json_facets() total_people = facets.get_flat_facets().get("total_people") assert total_people == 1000 ``` + ### Terms + ```python select_builder = SelectQuery( json_facet=JsonFacetComponent(facets={"age": JsonTermsFacet("age")}) @@ -177,7 +213,9 @@ facets = response.get_json_facets() age_buckets = facets.get_nested_facets().get("age").get_buckets() assert len(age_buckets) == 3 ``` + ### Nested + ```python select_builder = SelectQuery( json_facet=JsonFacetComponent( @@ -199,22 +237,29 @@ total_people = ( ) assert total_people == 750.0 ``` + ## Hosts + ### Single Server + ```python context = SolrServerContext(SolrSingleServerHost('localhost:8983'), SolrBasicAuth('solr', 'SolrRocks')) client = AsyncSolrCloudClient(context) ``` + ### Multiple servers + ```python # The client will randomly select a server to send requests to. It will wait 5 seconds for a response, before trying another server. context = SolrServerContext( - SolrMultipleServerHost(["localhost:8983", "localhost:8984"], 5), - SolrBasicAuth('solr', 'SolrRocks'), + SolrMultipleServerHost(["localhost:8983", "localhost:8984"], 5), + SolrBasicAuth('solr', 'SolrRocks'), ) client = AsyncSolrCloudClient(context) ``` + ### Zookeeper + ```python context = SolrServerContext( await ZookeeperEnsembleHostConnector(["localhost:2181"], 30).connect(), @@ -224,4 +269,16 @@ client = AsyncSolrCloudClient(context) ``` ## Notes -* Multiprocessing does not work, and will block forever. Normal multithreading works fine. \ No newline at end of file + +* Multiprocessing does not work, and will block forever. Normal multithreading works fine. +* Pyo3, the Rust library for creating bindings does not allow overriding the `__init__` method on objects from + Rust. `__new__` has to be overridden instead. + + For example, if you want to create a simpler way to create a client + ```python + class SolrClient(AsyncSolrCloudClient): + def __new__(cls, host: str, auth: Optional[SolrAuth] = None): + context = SolrServerContext(SolrSingleServerHost(host), auth) + return super().__new__(cls, context=context) + client = SolrClient(config.solr_host, SolrBasicAuth("username", "password")) + ``` \ No newline at end of file diff --git a/wrappers/python/src/clients.rs b/wrappers/python/src/clients.rs index 3d329a0..cba520c 100644 --- a/wrappers/python/src/clients.rs +++ b/wrappers/python/src/clients.rs @@ -25,7 +25,7 @@ pub fn clients(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } -#[pyclass(name = "AsyncSolrCloudClient", module = "solrstice.clients")] +#[pyclass(name = "AsyncSolrCloudClient", module = "solrstice.clients", subclass)] #[derive(Clone)] pub struct AsyncSolrCloudClientWrapper(SolrServerContextWrapper); @@ -148,7 +148,11 @@ impl AsyncSolrCloudClientWrapper { } } -#[pyclass(name = "BlockingSolrCloudClient", module = "solrstice.clients")] +#[pyclass( + name = "BlockingSolrCloudClient", + module = "solrstice.clients", + subclass +)] #[derive(Clone)] pub struct BlockingSolrCloudClientWrapper(SolrServerContextWrapper); diff --git a/wrappers/python/src/hosts.rs b/wrappers/python/src/hosts.rs index a6df9d8..540b20e 100644 --- a/wrappers/python/src/hosts.rs +++ b/wrappers/python/src/hosts.rs @@ -21,7 +21,7 @@ pub fn hosts(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } -#[pyclass(name = "SolrHost", subclass, module = "solrstice.hosts")] +#[pyclass(name = "SolrHost", module = "solrstice.hosts", subclass)] #[derive(Clone)] pub struct SolrHostWrapper { pub solr_host: Arc, @@ -34,7 +34,7 @@ impl SolrHost for SolrHostWrapper { } } -#[pyclass(name = "SolrSingleServerHost", extends = SolrHostWrapper, module= "solrstice.hosts")] +#[pyclass(name = "SolrSingleServerHost", extends = SolrHostWrapper, module= "solrstice.hosts", subclass)] #[derive(Clone)] pub struct SolrSingleServerHostWrapper; @@ -51,7 +51,7 @@ impl SolrSingleServerHostWrapper { } } -#[pyclass(name = "SolrMultipleServerHost", extends = SolrHostWrapper, module= "solrstice.hosts")] +#[pyclass(name = "SolrMultipleServerHost", extends = SolrHostWrapper, module= "solrstice.hosts", subclass)] #[derive(Clone)] pub struct SolrMultipleServerHostWrapper; @@ -71,11 +71,11 @@ impl SolrMultipleServerHostWrapper { } } -#[pyclass(name = "ZookeeperEnsembleHost", extends = SolrHostWrapper, module= "solrstice.hosts")] +#[pyclass(name = "ZookeeperEnsembleHost", extends = SolrHostWrapper, module= "solrstice.hosts", subclass)] #[derive(Clone)] pub struct ZookeeperEnsembleHostWrapper; -#[pyclass(name = "ZookeeperEnsembleHostConnector")] +#[pyclass(name = "ZookeeperEnsembleHostConnector", subclass)] #[derive(Clone)] pub struct ZookeeperEnsembleHostConnectorWrapper(ZookeeperEnsembleHostConnector); diff --git a/wrappers/python/src/models/auth.rs b/wrappers/python/src/models/auth.rs index 3f2bb0f..5ce33f1 100644 --- a/wrappers/python/src/models/auth.rs +++ b/wrappers/python/src/models/auth.rs @@ -9,7 +9,7 @@ pub fn auth(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } -#[pyclass(name = "SolrAuth", subclass, module = "solrstice.auth")] +#[pyclass(name = "SolrAuth", module = "solrstice.auth", subclass)] #[derive(Clone)] pub struct SolrAuthWrapper { pub solr_auth: Arc, @@ -21,7 +21,7 @@ impl SolrAuth for SolrAuthWrapper { } } -#[pyclass(name = "SolrBasicAuth", extends=SolrAuthWrapper, module = "solrstice.auth")] +#[pyclass(name = "SolrBasicAuth", extends=SolrAuthWrapper, module = "solrstice.auth", subclass)] #[derive(Clone)] pub struct SolrBasicAuthWrapper {} diff --git a/wrappers/python/src/models/context.rs b/wrappers/python/src/models/context.rs index e06b883..d764b33 100644 --- a/wrappers/python/src/models/context.rs +++ b/wrappers/python/src/models/context.rs @@ -3,7 +3,7 @@ use crate::models::auth::SolrAuthWrapper; use pyo3::prelude::*; use solrstice::models::context::{SolrServerContext, SolrServerContextBuilder}; -#[pyclass(name = "SolrServerContext", subclass, module = "solrstice.hosts")] +#[pyclass(name = "SolrServerContext", module = "solrstice.hosts", subclass)] #[derive(Clone)] pub struct SolrServerContextWrapper(SolrServerContext); diff --git a/wrappers/python/src/models/facet_set.rs b/wrappers/python/src/models/facet_set.rs index a0ff12d..6a87cdc 100644 --- a/wrappers/python/src/models/facet_set.rs +++ b/wrappers/python/src/models/facet_set.rs @@ -24,7 +24,7 @@ pub fn facet_set(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { } #[derive(Clone, Debug, PartialEq, Default)] -#[pyclass(name = "SolrFacetSetResult", module = "solrstice.facet_set")] +#[pyclass(name = "SolrFacetSetResult", module = "solrstice.facet_set", subclass)] pub struct SolrFacetSetResultWrapper(SolrFacetSetResult); #[pymethods] @@ -75,7 +75,11 @@ impl From<&SolrFacetSetResult> for SolrFacetSetResultWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "SolrPivotFacetResult", module = "solrstice.facet_set")] +#[pyclass( + name = "SolrPivotFacetResult", + module = "solrstice.facet_set", + subclass +)] pub struct SolrPivotFacetResultWrapper(SolrPivotFacetResult); #[pymethods] @@ -128,7 +132,11 @@ impl From<&SolrPivotFacetResult> for SolrPivotFacetResultWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "SolrFieldFacetResult", module = "solrstice.facet_set")] +#[pyclass( + name = "SolrFieldFacetResult", + module = "solrstice.facet_set", + subclass +)] pub struct SolrFieldFacetResultWrapper(SolrFieldFacetResult); #[pymethods] diff --git a/wrappers/python/src/models/group.rs b/wrappers/python/src/models/group.rs index d6a3d5e..b10c2ef 100644 --- a/wrappers/python/src/models/group.rs +++ b/wrappers/python/src/models/group.rs @@ -15,7 +15,7 @@ pub fn group(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { } #[derive(Clone)] -#[pyclass(name = "SolrGroupResult", module = "solrstice.group")] +#[pyclass(name = "SolrGroupResult", module = "solrstice.group", subclass)] pub struct SolrGroupResultWrapper(SolrGroupResult); #[pymethods] @@ -62,7 +62,7 @@ impl From for SolrGroupResult { } #[derive(Clone)] -#[pyclass(name = "SolrGroupFieldResult", module = "solrstice.group")] +#[pyclass(name = "SolrGroupFieldResult", module = "solrstice.group", subclass)] pub struct SolrGroupFieldResultWrapper(SolrGroupFieldResult); #[pymethods] diff --git a/wrappers/python/src/models/json_facet.rs b/wrappers/python/src/models/json_facet.rs index 234f549..46377de 100644 --- a/wrappers/python/src/models/json_facet.rs +++ b/wrappers/python/src/models/json_facet.rs @@ -20,7 +20,11 @@ pub fn json_facet(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "SolrJsonFacetResponse", module = "solrstice.json_facet")] +#[pyclass( + name = "SolrJsonFacetResponse", + module = "solrstice.json_facet", + subclass +)] pub struct SolrJsonFacetResponseWrapper(SolrJsonFacetResponse); #[pymethods] diff --git a/wrappers/python/src/models/response.rs b/wrappers/python/src/models/response.rs index 925a05f..be3480f 100644 --- a/wrappers/python/src/models/response.rs +++ b/wrappers/python/src/models/response.rs @@ -17,7 +17,7 @@ pub fn response(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { } #[derive(Clone)] -#[pyclass(name = "SolrDocsResponse", module = "solrstice.response")] +#[pyclass(name = "SolrDocsResponse", module = "solrstice.response", subclass)] pub struct SolrDocsResponseWrapper(SolrDocsResponse); impl From for SolrDocsResponseWrapper { @@ -75,7 +75,7 @@ impl SolrDocsResponseWrapper { } #[derive(Clone)] -#[pyclass(name = "SolrResponse", module = "solrstice.response")] +#[pyclass(name = "SolrResponse", module = "solrstice.response", subclass)] pub struct SolrResponseWrapper(SolrResponse); impl From for SolrResponseWrapper { diff --git a/wrappers/python/src/queries/components/facet_set.rs b/wrappers/python/src/queries/components/facet_set.rs index 9c9aff9..f07941e 100644 --- a/wrappers/python/src/queries/components/facet_set.rs +++ b/wrappers/python/src/queries/components/facet_set.rs @@ -5,7 +5,7 @@ use solrstice::queries::components::facet_set::{ }; #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "FacetSetComponent", module = "solrstice.facet_set")] +#[pyclass(name = "FacetSetComponent", module = "solrstice.facet_set", subclass)] pub struct FacetSetComponentWrapper(FacetSetComponent); #[pymethods] @@ -55,7 +55,7 @@ impl From<&FacetSetComponent> for FacetSetComponentWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "PivotFacetComponent", module = "solrstice.facet_set")] +#[pyclass(name = "PivotFacetComponent", module = "solrstice.facet_set", subclass)] pub struct PivotFacetComponentWrapper(PivotFacetComponent); #[pymethods] @@ -95,7 +95,7 @@ impl From<&PivotFacetComponent> for PivotFacetComponentWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "FieldFacetComponent", module = "solrstice.facet_set")] +#[pyclass(name = "FieldFacetComponent", module = "solrstice.facet_set", subclass)] pub struct FieldFacetComponentWrapper(FieldFacetComponent); #[pymethods] @@ -188,7 +188,7 @@ impl From for FieldFacetMethodWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "FieldFacetEntry", module = "solrstice.facet_set")] +#[pyclass(name = "FieldFacetEntry", module = "solrstice.facet_set", subclass)] pub struct FieldFacetEntryWrapper(FieldFacetEntry); #[pymethods] diff --git a/wrappers/python/src/queries/components/grouping.rs b/wrappers/python/src/queries/components/grouping.rs index a80b574..bb1c369 100644 --- a/wrappers/python/src/queries/components/grouping.rs +++ b/wrappers/python/src/queries/components/grouping.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; use serde::{Deserialize, Serialize}; use solrstice::queries::components::grouping::{GroupFormatting, GroupingComponent}; -#[pyclass(name = "GroupingComponent", module = "solrstice.group")] +#[pyclass(name = "GroupingComponent", module = "solrstice.group", subclass)] #[derive(Clone, Serialize, Deserialize)] pub struct GroupingComponentWrapper(GroupingComponent); diff --git a/wrappers/python/src/queries/components/json_facet.rs b/wrappers/python/src/queries/components/json_facet.rs index 3dc188d..44bb9f9 100644 --- a/wrappers/python/src/queries/components/json_facet.rs +++ b/wrappers/python/src/queries/components/json_facet.rs @@ -5,7 +5,7 @@ use solrstice::queries::components::json_facet::{ use std::collections::HashMap; #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "JsonFacetComponent", module = "solrstice.json_facet")] +#[pyclass(name = "JsonFacetComponent", module = "solrstice.json_facet", subclass)] pub struct JsonFacetComponentWrapper(JsonFacetComponent); #[pymethods] @@ -45,7 +45,7 @@ impl From<&JsonFacetComponent> for JsonFacetComponentWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "JsonFacetType", subclass, module = "solrstice.json_facet")] +#[pyclass(name = "JsonFacetType", module = "solrstice.json_facet", subclass)] pub struct JsonFacetTypeWrapper(JsonFacetType); impl From for JsonFacetType { @@ -73,7 +73,7 @@ impl From<&JsonFacetType> for JsonFacetTypeWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "JsonTermsFacet", extends = JsonFacetTypeWrapper, module = "solrstice.json_facet")] +#[pyclass(name = "JsonTermsFacet", extends = JsonFacetTypeWrapper, module = "solrstice.json_facet", subclass)] pub struct JsonTermsFacetWrapper {} #[pymethods] @@ -107,7 +107,7 @@ impl JsonTermsFacetWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "JsonQueryFacet", extends = JsonFacetTypeWrapper ,module = "solrstice.json_facet")] +#[pyclass(name = "JsonQueryFacet", extends = JsonFacetTypeWrapper ,module = "solrstice.json_facet", subclass)] pub struct JsonQueryFacetWrapper {} #[pymethods] @@ -146,7 +146,7 @@ impl JsonQueryFacetWrapper { } #[derive(Clone, Debug, PartialEq)] -#[pyclass(name = "JsonStatFacet", extends = JsonFacetTypeWrapper, module = "solrstice.json_facet")] +#[pyclass(name = "JsonStatFacet", extends = JsonFacetTypeWrapper, module = "solrstice.json_facet", subclass)] pub struct JsonStatFacetWrapper {} #[pymethods] diff --git a/wrappers/python/src/queries/def_type.rs b/wrappers/python/src/queries/def_type.rs index 8147d38..caa749c 100644 --- a/wrappers/python/src/queries/def_type.rs +++ b/wrappers/python/src/queries/def_type.rs @@ -39,7 +39,7 @@ impl From for QueryOperatorWrapper { } } -#[pyclass(name = "DefType", subclass, module = "solrstice.def_type")] +#[pyclass(name = "DefType", module = "solrstice.def_type", subclass)] #[derive(Clone, Serialize, Deserialize)] pub struct DefTypeWrapper(DefType); @@ -65,7 +65,7 @@ impl DefTypeWrapper { } } -#[pyclass(name = "LuceneQuery", extends=DefTypeWrapper, module = "solrstice.def_type")] +#[pyclass(name = "LuceneQuery", extends=DefTypeWrapper, module = "solrstice.def_type", subclass)] #[derive(Clone, Serialize, Deserialize)] pub struct LuceneQueryWrapper {} @@ -85,7 +85,7 @@ impl LuceneQueryWrapper { } } -#[pyclass(name = "DismaxQuery", extends=DefTypeWrapper, module = "solrstice.def_type")] +#[pyclass(name = "DismaxQuery", extends=DefTypeWrapper, module = "solrstice.def_type", subclass)] #[derive(Clone, Serialize, Deserialize)] pub struct DismaxQueryWrapper {} @@ -117,7 +117,7 @@ impl DismaxQueryWrapper { } } -#[pyclass(name = "EdismaxQuery", extends=DefTypeWrapper, module = "solrstice.def_type")] +#[pyclass(name = "EdismaxQuery", extends=DefTypeWrapper, module = "solrstice.def_type", subclass)] #[derive(Clone, Serialize, Deserialize)] pub struct EdismaxQueryWrapper {} diff --git a/wrappers/python/src/queries/index.rs b/wrappers/python/src/queries/index.rs index f238a93..b04e69c 100644 --- a/wrappers/python/src/queries/index.rs +++ b/wrappers/python/src/queries/index.rs @@ -18,7 +18,7 @@ pub enum CommitTypeWrapper { } #[derive(Clone, Default, Serialize, Deserialize)] -#[pyclass(name = "UpdateQuery", module = "solrstice.queries")] +#[pyclass(name = "UpdateQuery", module = "solrstice.queries", subclass)] pub struct UpdateQueryWrapper(UpdateQuery); #[pymethods] @@ -123,7 +123,7 @@ impl From for CommitTypeWrapper { } #[derive(Clone, Default, Serialize, Deserialize)] -#[pyclass(name = "DeleteQuery", module = "solrstice.queries")] +#[pyclass(name = "DeleteQuery", module = "solrstice.queries", subclass)] pub struct DeleteQueryWrapper(DeleteQuery); #[pymethods] diff --git a/wrappers/python/src/queries/select.rs b/wrappers/python/src/queries/select.rs index 45025a2..12c2a6e 100644 --- a/wrappers/python/src/queries/select.rs +++ b/wrappers/python/src/queries/select.rs @@ -16,7 +16,7 @@ use solrstice::queries::components::json_facet::JsonFacetComponent; use solrstice::queries::def_type::DefType; use solrstice::queries::select::SelectQuery; -#[pyclass(name = "SelectQuery", module = "solrstice.queries")] +#[pyclass(name = "SelectQuery", module = "solrstice.queries", subclass)] #[derive(Clone, Serialize, Deserialize)] pub struct SelectQueryWrapper(SelectQuery); diff --git a/wrappers/python/tests/test_clients.py b/wrappers/python/tests/test_clients.py index 7f80677..3806dc7 100644 --- a/wrappers/python/tests/test_clients.py +++ b/wrappers/python/tests/test_clients.py @@ -1,9 +1,13 @@ import asyncio import pytest +from typing_extensions import Optional + from helpers import Config, create_config +from solrstice.auth import SolrAuth, SolrBasicAuth from solrstice.clients import AsyncSolrCloudClient, BlockingSolrCloudClient +from solrstice.hosts import SolrServerContext, SolrSingleServerHost from solrstice.queries import DeleteQuery, SelectQuery, UpdateQuery @@ -93,3 +97,30 @@ async def test_multiple_clients_works(): assert name in results[1] await client.delete_config(name) + + +@pytest.mark.asyncio +async def test_subclassing_client_works(): + class SolrClient(AsyncSolrCloudClient): + def __new__(cls, host: str, auth: Optional[SolrAuth] = None): + context = SolrServerContext(SolrSingleServerHost(host), auth) + return super().__new__(cls, context=context) + + def test_method(self) -> str: + return 'test' + + name = "SubclassingClientWorks" + + config = create_config() + + client = SolrClient(config.solr_host, SolrBasicAuth(config.solr_username, config.solr_password)) + + try: + await client.delete_config(name) + except: + pass + + await client.upload_config(name, config.config_path) + assert client.test_method() == 'test' + + await client.delete_config(name)