Hi everyone. I’m trying to write a generic “cursor based pagination” helper for diesel. I’m kind of new to Rust/Diesel so I’m not even sure if what I’m trying to achieve is possible.
Basically, I want to extract “cursors” for each rows, which would consist of the columns that were ORDER BY
. For example, in SELECT * FROM users ORDER BY city, id
, the cursor of a row would be its city, id
columns.
I pasted the code I have so far (a lot of which was taken from crates.io pagination helper). I’m kind of stuck at the TODO
and apparently , SelectStatement
is not in the public API and I shouldn’t be trying to find out what’s in the ORDER BY
clause. So I guess that approach won’t work.
Any tips moving forward? I guess I could make the .paginate()
function take a tuple with the same columns that were passed to order_by()
?
use diesel::pg::Pg;
use diesel::prelude::*;
use diesel::query_builder::*;
use diesel::query_dsl::LoadQuery;
use diesel::sql_types::{BigInt, Text};
use indexmap::IndexMap;
pub type AppResult<T> = Result<T, Box<dyn std::error::Error>>;
#[derive(QueryableByName, Queryable, Debug)]
pub struct WithCount<T> {
#[diesel(embed)]
pub record: T,
#[sql_type = "::diesel::sql_types::BigInt"]
pub total: i64,
}
pub trait WithCountExtension<T> {
fn records_and_total(self) -> (Vec<T>, i64);
}
impl<T> WithCountExtension<T> for Vec<WithCount<T>> {
fn records_and_total(self) -> (Vec<T>, i64) {
let cnt = self.get(0).map(|row| row.total).unwrap_or(0);
let vec = self.into_iter().map(|row| row.record).collect();
(vec, cnt)
}
}
#[derive(Debug, Clone, Copy)]
pub struct PaginationOptions<'a> {
after: Option<&'a String>,
pub per_page: u32,
}
impl<'a> PaginationOptions<'a> {
pub fn new(params: &'a IndexMap<String, String>) -> AppResult<Self> {
const DEFAULT_PER_PAGE: u32 = 10;
const MAX_PER_PAGE: u32 = 100;
let per_page = params
.get("per_page")
.map(|s| s.parse().map_err(|e| "per page error"))
.unwrap_or(Ok(DEFAULT_PER_PAGE))?;
if per_page > MAX_PER_PAGE {
return Err(format!("cannot request more than {} items", MAX_PER_PAGE,).into());
}
Ok(Self {
after: params.get("after"),
per_page,
})
}
pub fn after(&self) -> Option<&String> {
self.after
}
}
pub trait Paginate: Sized {
fn paginate(self, params: &IndexMap<String, String>) -> AppResult<PaginatedQuery<Self>> {
Ok(PaginatedQuery {
query: self,
options: PaginationOptions::new(params)?,
})
}
}
impl<T> Paginate for T {}
#[derive(Debug)]
pub struct Paginated<'a, T> {
records_and_total: Vec<WithCount<T>>,
options: PaginationOptions<'a>,
}
impl<'a, T> Paginated<'a, T> {
pub fn total(&self) -> Option<i64> {
Some(
self.records_and_total
.get(0)
.map(|row| row.total)
.unwrap_or_default(),
)
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.records_and_total.iter().map(|row| &row.record)
}
}
impl<'a, T: 'static> IntoIterator for Paginated<'a, T> {
type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
type Item = T;
fn into_iter(self) -> Self::IntoIter {
Box::new(self.records_and_total.into_iter().map(|row| row.record))
}
}
#[derive(Debug)]
pub struct PaginatedQuery<'a, T> {
query: T,
options: PaginationOptions<'a>,
}
impl<'a, T> PaginatedQuery<'a, T> {
pub fn load<U>(self, conn: &'a PgConnection) -> QueryResult<Paginated<U>>
where
Self: LoadQuery<PgConnection, WithCount<U>>,
T: std::fmt::Debug,
{
let options = self.options;
// dbg!(&self.query.order); // no field `order` on type T
// TODO: extract the ORDER BY clause
let records_and_total = self.internal_load(conn)?;
Ok(Paginated {
records_and_total,
options,
})
}
}
impl<'a, T> QueryId for PaginatedQuery<'a, T> {
const HAS_STATIC_QUERY_ID: bool = false;
type QueryId = ();
}
impl<'a, T: Query> Query for PaginatedQuery<'a, T> {
type SqlType = (T::SqlType, BigInt);
}
impl<'a, T, DB> RunQueryDsl<DB> for PaginatedQuery<'a, T> {}
impl<'a, T> QueryFragment<Pg> for PaginatedQuery<'a, T>
where
T: QueryFragment<Pg>,
{
fn walk_ast(&self, mut out: AstPass<'_, Pg>) -> QueryResult<()> {
out.push_sql("SELECT * FROM (");
out.push_sql("SELECT *, COUNT(*) OVER () FROM (");
self.query.walk_ast(out.reborrow())?;
out.push_sql(") t ");
out.push_sql(") s ");
if let Some(after) = self.options.after() {
out.push_sql(" WHERE username < ");
out.push_bind_param::<Text, _>(&after)?;
}
out.push_sql(" LIMIT ");
// TODO: per_page + 1.. look ahead one row.. if that row not returned means has_next_page =
// false
out.push_bind_param::<BigInt, _>(&i64::from(self.options.per_page))?;
/*
if let Some(after) = self.options.after() {
out.push_sql(" WHERE username < ");
out.push_bind_param::<Text, _>(&after)?;
}
*/
Ok(())
}
}
#[cfg(test)]
mod tests {
use indexmap::IndexMap;
use super::*;
use diesel::prelude::Connection;
use std::env;
pub fn establish_connection() -> PgConnection {
let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
PgConnection::establish(&database_url)
.expect(&format!("Error connecting to {}", database_url))
}
#[derive(Queryable, Debug, Clone, PartialEq)]
pub struct User {
pub username: String,
pub city: String,
pub is_suspended: bool,
}
#[test]
#[ignore]
// Important:
// For cursor based pagination to work:
// 1) the query must be ordered by at least one unique column the last column in the ORDER BY
// clause).
// 2) The sorted columns must be returned by the SELECT so that it's possible to form cursors
#[test]
fn cursor_paginate_no_after() -> AppResult<()> {
use crate::schema::users::dsl::*;
let conn = establish_connection();
let query = users
.select((username, city, is_suspended))
.order_by((city.desc(), username.asc()))
.filter(is_suspended.eq(false));
let mut params = IndexMap::new();
params.insert(String::from("per_page"), String::from("3"));
let res = query.paginate(¶ms)?.load::<User>(&conn);
dbg!(&res);
Ok(())
}
}