Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[db-queries] Allow join expressions in paginated-multicolumn #6530

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 228 additions & 38 deletions nexus/db-queries/src/db/pagination.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use diesel::helper_types::*;
use diesel::pg::Pg;
use diesel::query_builder::AsQuery;
use diesel::query_dsl::methods as query_methods;
use diesel::query_source::QuerySource;
use diesel::sql_types::{Bool, SqlType};
use diesel::AppearsOnTable;
use diesel::Column;
Expand Down Expand Up @@ -70,7 +71,7 @@ where
}
}

/// Uses `pagparams` to list a subset of rows in `table`, ordered by `c1, and
/// Uses `pagparams` to list a subset of rows in `query`, ordered by `c1, and
/// then by `c2.
///
/// This is a two-column variation of the [`paginated`] function.
Expand All @@ -79,40 +80,56 @@ where
// columns" implement a subset of ExpressionMethods) or making a macro to generate
// all the necessary bounds we need.
pub fn paginated_multicolumn<T, C1, C2, M1, M2>(
table: T,
query: T,
(c1, c2): (C1, C2),
pagparams: &DataPageParams<'_, (M1, M2)>,
) -> BoxedQuery<T>
) -> <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output
where
// T is a table which can create a BoxedQuery.
T: diesel::Table,
T: query_methods::BoxedDsl<
'static,
Pg,
Output = diesel::internal::table_macro::BoxedSelectStatement<
'static,
TableSqlType<T>,
diesel::internal::table_macro::FromClause<T>,
Pg,
>,
>,
// T is a table^H^H^H^H^Hquery source which can create a BoxedQuery.
T: QuerySource,
T: AsQuery,
<T as QuerySource>::DefaultSelection:
Expression<SqlType = <T as AsQuery>::SqlType>,
T::Query: query_methods::BoxedDsl<'static, Pg>,
// Required for...everything.
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output: QueryDsl,
// C1 & C2 are columns which appear in T.
C1: 'static + Column + Copy + ExpressionMethods + AppearsOnTable<T>,
C2: 'static + Column + Copy + ExpressionMethods + AppearsOnTable<T>,
C1: 'static + Column + Copy + ExpressionMethods,
C2: 'static + Column + Copy + ExpressionMethods,
// Required to compare the columns with the marker types.
C1::SqlType: SqlType,
C2::SqlType: SqlType,
M1: Clone + AsExpression<C1::SqlType>,
M2: Clone + AsExpression<C2::SqlType>,
// Necessary for `query.limit(...)`
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::LimitDsl<
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.order(c1.desc())"
BoxedQuery<T>: query_methods::OrderDsl<Desc<C1>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::OrderDsl<
Desc<C1>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.order(...).then_order_by(c2.desc())"
BoxedQuery<T>:
query_methods::ThenOrderDsl<Desc<C2>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::ThenOrderDsl<
Desc<C2>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.order(c1.asc())"
BoxedQuery<T>: query_methods::OrderDsl<Asc<C1>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::OrderDsl<
Asc<C1>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.order(...).then_order_by(c2.asc())"
BoxedQuery<T>: query_methods::ThenOrderDsl<Asc<C2>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::ThenOrderDsl<
Asc<C2>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,

// We'd like to be able to call:
//
Expand All @@ -126,10 +143,11 @@ where
// The RHS (c2.gt(v2)) must be a boolean expression:
Gt<C2, M2>: Expression<SqlType = Bool>,
// Putting it together, we should be able to filter by LHS.and(RHS):
BoxedQuery<T>: query_methods::FilterDsl<
And<Eq<C1, M1>, Gt<C2, M2>>,
Output = BoxedQuery<T>,
>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::FilterDsl<
And<Eq<C1, M1>, Gt<C2, M2>>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,

// We'd also like to be able to call:
//
Expand All @@ -138,19 +156,30 @@ where
// We've already defined the bound on the LHS, so we add the equivalent
// bounds on the RHS for the "Less than" variant.
Lt<C2, M2>: Expression<SqlType = Bool>,
BoxedQuery<T>: query_methods::FilterDsl<
And<Eq<C1, M1>, Lt<C2, M2>>,
Output = BoxedQuery<T>,
>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::FilterDsl<
And<Eq<C1, M1>, Lt<C2, M2>>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,

// Necessary for "query.or_filter(c1.gt(v1))"
BoxedQuery<T>:
query_methods::OrFilterDsl<Gt<C1, M1>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::OrFilterDsl<
Gt<C1, M1>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.or_filter(c1.lt(v1))"
BoxedQuery<T>:
query_methods::OrFilterDsl<Lt<C1, M1>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::OrFilterDsl<
Lt<C1, M1>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
{
let mut query = table.into_boxed().limit(pagparams.limit.get().into());
use query_methods::BoxedDsl;
let mut query = query
.as_query()
.internal_into_boxed()
.limit(pagparams.limit.get().into());
let marker = pagparams.marker.map(|m| m.clone());
match pagparams.direction {
dropshot::PaginationOrder::Ascending => {
Expand Down Expand Up @@ -315,6 +344,7 @@ mod test {

use crate::db;
use async_bb8_diesel::{AsyncRunQueryDsl, AsyncSimpleConnection};
use diesel::JoinOnDsl;
use diesel::SelectableHelper;
use dropshot::PaginationOrder;
use nexus_test_utils::db::test_setup_database;
Expand All @@ -333,9 +363,18 @@ mod test {
height -> Int8,
}
}

table! {
test_phone_numbers (user_id, phone_number) {
user_id -> Uuid,
phone_number -> Int8,
}
}

allow_tables_to_appear_in_same_query!(test_users, test_phone_numbers,);
}

use schema::test_users;
use schema::{test_phone_numbers, test_users};

#[derive(Clone, Debug, Queryable, Insertable, PartialEq, Selectable)]
#[diesel(table_name = test_users)]
Expand All @@ -345,13 +384,39 @@ mod test {
height: i64,
}

#[derive(Clone, Debug, Queryable, Insertable, PartialEq, Selectable)]
#[diesel(table_name = test_phone_numbers)]
struct PhoneNumber {
user_id: Uuid,
phone_number: i64,
}

#[derive(Debug)]
struct UserAndPhoneNumber {
user: User,
phone_number: PhoneNumber,
}

impl PartialEq<((i64, i64), i64)> for UserAndPhoneNumber {
fn eq(&self, &(user, phone): &((i64, i64), i64)) -> bool {
self.user == user && self.phone_number == phone
}
}

impl PartialEq<(i64, i64)> for User {
fn eq(&self, other: &(i64, i64)) -> bool {
self.age == other.0 && self.height == other.1
}
}

impl PartialEq<i64> for PhoneNumber {
fn eq(&self, &other: &i64) -> bool {
self.phone_number == other
}
}

async fn populate_users(pool: &db::Pool, values: &Vec<(i64, i64)>) {
use schema::test_phone_numbers::dsl as phone_numbers_dsl;
use schema::test_users::dsl;

let conn = pool.claim().await.unwrap();
Expand All @@ -365,8 +430,17 @@ mod test {
height INT NOT NULL
);

CREATE TABLE test_phone_numbers (
user_id UUID NOT NULL,
-- This is definitely the correct way to store a
-- phone number in the database. :)
phone_number INT NOT NULL,
PRIMARY KEY (user_id, phone_number)
);

CREATE INDEX ON test_users (age, height);
CREATE INDEX ON test_users (height, age);",
CREATE INDEX ON test_users (height, age);
CREATE INDEX ON test_phone_numbers (user_id);",
)
.await
.unwrap();
Expand All @@ -381,7 +455,22 @@ mod test {
.collect();

diesel::insert_into(dsl::test_users)
.values(users)
.values(users.clone())
.execute_async(&*conn)
.await
.unwrap();

let mut phone_numbers = Vec::new();
for (i, user) in users.iter().enumerate() {
for j in 0..3 {
phone_numbers.push(PhoneNumber {
user_id: user.id,
phone_number: (i as i64 + 1) * 10 + j,
});
}
}
diesel::insert_into(phone_numbers_dsl::test_phone_numbers)
.values(phone_numbers)
.execute_async(&*conn)
.await
.unwrap();
Expand Down Expand Up @@ -574,6 +663,107 @@ mod test {
logctx.cleanup_successful();
}

#[tokio::test]
async fn test_paginated_multicolumn_works_with_joins() {
use async_bb8_diesel::AsyncConnection;

let logctx =
dev::test_setup_log("test_paginated_multicolumn_works_with_joins");
let mut db = test_setup_database(&logctx.log).await;
let cfg = db::Config { url: db.pg_config().clone() };
let pool = db::Pool::new_single_host(&logctx.log, &cfg);

use schema::test_phone_numbers::dsl as phone_numbers_dsl;
use schema::test_users::dsl;

populate_users(&pool, &vec![(1, 1), (1, 2), (2, 1), (2, 3), (3, 1)])
.await;

async fn get_page(
pool: &db::Pool,
pagparams: &DataPageParams<'_, (i64, i64)>,
) -> Vec<UserAndPhoneNumber> {
let conn = pool.claim().await.unwrap();
conn.transaction_async(|conn| async move {
// I couldn't figure out how to make this work without requiring a full
// table scan, and I just want the test to work so that I can get on
// with my life...
conn.batch_execute_async(
crate::db::queries::ALLOW_FULL_TABLE_SCAN_SQL,
)
.await
.unwrap();

paginated_multicolumn(
dsl::test_users.inner_join(
phone_numbers_dsl::test_phone_numbers
.on(phone_numbers_dsl::user_id.eq(dsl::id)),
),
(dsl::age, phone_numbers_dsl::phone_number),
&pagparams,
)
.select((User::as_select(), PhoneNumber::as_select()))
.load_async(&conn)
.await
})
.await
.unwrap()
.into_iter()
.map(|(user, phone_number)| UserAndPhoneNumber {
user,
phone_number,
})
.collect::<Vec<_>>()
}

// Get the first paginated result.
let mut pagparams = DataPageParams::<(i64, i64)> {
marker: None,
direction: PaginationOrder::Ascending,
limit: NonZeroU32::new(1).unwrap(),
};
let observed = get_page(&pool, &pagparams).await;
assert_eq!(dbg!(&observed), &[((1, 1), 10)]);

// Get the next paginated results, check that they arrived in the order
// we expected.
let marker =
(observed[0].user.age, observed[0].phone_number.phone_number);
pagparams.marker = Some(&marker);
pagparams.limit = NonZeroU32::new(10).unwrap();
let observed = get_page(&pool, &pagparams).await;
assert_eq!(
dbg!(&observed),
&[
((1, 1), 11),
((1, 1), 12),
((1, 2), 20),
((1, 2), 21),
((1, 2), 22),
((2, 1), 30),
((2, 1), 31),
((2, 1), 32),
((2, 3), 40),
((2, 3), 41),
]
);

// Get the next paginated results, check that they arrived in the order
// we expected.
let marker =
(observed[9].user.age, observed[9].phone_number.phone_number);
pagparams.marker = Some(&marker);
pagparams.limit = NonZeroU32::new(10).unwrap();
let observed = get_page(&pool, &pagparams).await;
assert_eq!(
dbg!(&observed),
&[((2, 3), 42), ((3, 1), 50), ((3, 1), 51), ((3, 1), 52)]
);

let _ = db.cleanup().await;
logctx.cleanup_successful();
}

#[test]
fn test_paginator() {
// The doctest exercises a basic case for Paginator. Here we test some
Expand Down
Loading