Cursor based pagination with Diesel

Hi everyone. I’m trying to write a generic “cursor based pagination” helper for diesel. I’m kind of new to Rust/Diesel so I’m not even sure if what I’m trying to achieve is possible.

Basically, I want to extract “cursors” for each rows, which would consist of the columns that were ORDER BY. For example, in SELECT * FROM users ORDER BY city, id, the cursor of a row would be its city, id columns.

I pasted the code I have so far (a lot of which was taken from crates.io pagination helper). I’m kind of stuck at the TODO and apparently , SelectStatement is not in the public API and I shouldn’t be trying to find out what’s in the ORDER BY clause. So I guess that approach won’t work.

Any tips moving forward? I guess I could make the .paginate() function take a tuple with the same columns that were passed to order_by()?

use diesel::pg::Pg;
use diesel::prelude::*;
use diesel::query_builder::*;
use diesel::query_dsl::LoadQuery;
use diesel::sql_types::{BigInt, Text};
use indexmap::IndexMap;
pub type AppResult<T> = Result<T, Box<dyn std::error::Error>>;

#[derive(QueryableByName, Queryable, Debug)]
pub struct WithCount<T> {
    #[diesel(embed)]
    pub record: T,
    #[sql_type = "::diesel::sql_types::BigInt"]
    pub total: i64,
}

pub trait WithCountExtension<T> {
    fn records_and_total(self) -> (Vec<T>, i64);
}

impl<T> WithCountExtension<T> for Vec<WithCount<T>> {
    fn records_and_total(self) -> (Vec<T>, i64) {
        let cnt = self.get(0).map(|row| row.total).unwrap_or(0);
        let vec = self.into_iter().map(|row| row.record).collect();
        (vec, cnt)
    }
}

#[derive(Debug, Clone, Copy)]
pub struct PaginationOptions<'a> {
    after: Option<&'a String>,
    pub per_page: u32,
}

impl<'a> PaginationOptions<'a> {
    pub fn new(params: &'a IndexMap<String, String>) -> AppResult<Self> {
        const DEFAULT_PER_PAGE: u32 = 10;
        const MAX_PER_PAGE: u32 = 100;

        let per_page = params
            .get("per_page")
            .map(|s| s.parse().map_err(|e| "per page error"))
            .unwrap_or(Ok(DEFAULT_PER_PAGE))?;

        if per_page > MAX_PER_PAGE {
            return Err(format!("cannot request more than {} items", MAX_PER_PAGE,).into());
        }

        Ok(Self {
            after: params.get("after"),
            per_page,
        })
    }

    pub fn after(&self) -> Option<&String> {
        self.after
    }
}

pub trait Paginate: Sized {
    fn paginate(self, params: &IndexMap<String, String>) -> AppResult<PaginatedQuery<Self>> {
        Ok(PaginatedQuery {
            query: self,
            options: PaginationOptions::new(params)?,
        })
    }
}

impl<T> Paginate for T {}

#[derive(Debug)]
pub struct Paginated<'a, T> {
    records_and_total: Vec<WithCount<T>>,
    options: PaginationOptions<'a>,
}

impl<'a, T> Paginated<'a, T> {
    pub fn total(&self) -> Option<i64> {
        Some(
            self.records_and_total
                .get(0)
                .map(|row| row.total)
                .unwrap_or_default(),
        )
    }
    pub fn iter(&self) -> impl Iterator<Item = &T> {
        self.records_and_total.iter().map(|row| &row.record)
    }
}

impl<'a, T: 'static> IntoIterator for Paginated<'a, T> {
    type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
    type Item = T;

    fn into_iter(self) -> Self::IntoIter {
        Box::new(self.records_and_total.into_iter().map(|row| row.record))
    }
}

#[derive(Debug)]
pub struct PaginatedQuery<'a, T> {
    query: T,
    options: PaginationOptions<'a>,
}

impl<'a, T> PaginatedQuery<'a, T> {
    pub fn load<U>(self, conn: &'a PgConnection) -> QueryResult<Paginated<U>>
    where
        Self: LoadQuery<PgConnection, WithCount<U>>,
        T: std::fmt::Debug,
    {
        let options = self.options;
        // dbg!(&self.query.order); // no field `order` on type T
        // TODO: extract the ORDER BY clause
        let records_and_total = self.internal_load(conn)?;
        Ok(Paginated {
            records_and_total,
            options,
        })
    }
}

impl<'a, T> QueryId for PaginatedQuery<'a, T> {
    const HAS_STATIC_QUERY_ID: bool = false;
    type QueryId = ();
}

impl<'a, T: Query> Query for PaginatedQuery<'a, T> {
    type SqlType = (T::SqlType, BigInt);
}

impl<'a, T, DB> RunQueryDsl<DB> for PaginatedQuery<'a, T> {}

impl<'a, T> QueryFragment<Pg> for PaginatedQuery<'a, T>
where
    T: QueryFragment<Pg>,
{
    fn walk_ast(&self, mut out: AstPass<'_, Pg>) -> QueryResult<()> {
        out.push_sql("SELECT * FROM (");
        out.push_sql("SELECT *, COUNT(*) OVER () FROM (");
        self.query.walk_ast(out.reborrow())?;
        out.push_sql(") t ");
        out.push_sql(") s ");
        if let Some(after) = self.options.after() {
            out.push_sql(" WHERE username < ");
            out.push_bind_param::<Text, _>(&after)?;
        }
        out.push_sql(" LIMIT ");
        // TODO: per_page + 1.. look ahead one row.. if that row not returned means has_next_page =
        // false
        out.push_bind_param::<BigInt, _>(&i64::from(self.options.per_page))?;
        /*
        if let Some(after) = self.options.after() {
            out.push_sql(" WHERE username < ");
            out.push_bind_param::<Text, _>(&after)?;
        }
        */
        Ok(())
    }
}

#[cfg(test)]
mod tests {

    use indexmap::IndexMap;

    use super::*;
    use diesel::prelude::Connection;
    use std::env;

    pub fn establish_connection() -> PgConnection {
        let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
        PgConnection::establish(&database_url)
            .expect(&format!("Error connecting to {}", database_url))
    }

    #[derive(Queryable, Debug, Clone, PartialEq)]
    pub struct User {
        pub username: String,
        pub city: String,
        pub is_suspended: bool,
    }

    #[test]
    #[ignore]
    // Important:
    // For cursor based pagination to work:
    // 1) the query must be ordered by at least one unique column the last column in the ORDER BY
    //    clause).
    // 2) The sorted columns must be returned by the SELECT so that it's possible to form cursors
    #[test]
    fn cursor_paginate_no_after() -> AppResult<()> {
        use crate::schema::users::dsl::*;
        let conn = establish_connection();
        let query = users
            .select((username, city, is_suspended))
            .order_by((city.desc(), username.asc()))
            .filter(is_suspended.eq(false));
        let mut params = IndexMap::new();
        params.insert(String::from("per_page"), String::from("3"));

        let res = query.paginate(&params)?.load::<User>(&conn);
        dbg!(&res);

        Ok(())
    }
}

Just to have some answer here for now:

You cannot access the field of a concrete type (SelectStatement) in a generic context. All rustc knows at the location of your TODO is that there is some type T that implements Debug, that does not allow you to access any details of a concrete type there. This is a general rust restriction and nothing diesel related.

(I know that does not address the underlying issue, but at least gives you an pointer on what’s wrong with the current implementation. I will try to write a longer answer in the next few days/weeks)

Ok, I had some time to think about that. In my opinion the best solution would be something like the following: (Only showing the relevant part of the code)

pub struct PaginatedQuery<'a, T, O> {
    query: T,
    order_clause: O,
    options: PaginationOptions<'a>,
}

pub trait Paginate: Sized {
    fn paginate<O>(self, cursor: O, params: &IndexMap<String, String>) -> AppResult<PaginatedQuery<dsl::Order<Self, O>, O>> 
        where Self: OrderDsl<O>,
                   O: Clone
    {
        Ok(PaginatedQuery {
            query: self.order(cursor.clone()),
            order_clause: cursor,
            options: PaginationOptions::new(params)?,
        })
    }
}

so instead of calling order explicitly on your query before constructing the PaginatedQuery helper you do that implicitly while constructing the helper.