diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/db.rs | 151 | ||||
| -rw-r--r-- | src/lib.rs | 107 |
2 files changed, 258 insertions, 0 deletions
diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..6b4240b --- /dev/null +++ b/src/db.rs @@ -0,0 +1,151 @@ +#![allow(non_camel_case_types, non_upper_case_globals)] +#![allow(unused)] // remove this (?) + +use diesel::backend::Backend; +use diesel::expression::{ + IsContainedInGroupBy, ValidGrouping, is_aggregate, is_contained_in_group_by, +}; +use diesel::internal::table_macro::{FromClause, SelectStatement}; +use diesel::prelude::*; +use diesel::prelude::*; +use diesel::query_builder::{AsQuery, AstPass, Query, QueryFragment, QueryId}; +use diesel::query_source::{AppearsInFromClause, Once, Table as TableTrait}; +use diesel::sql_types::{Text, Uuid}; + +pub use self::columns::*; + +pub mod columns { + use super::*; + + macro_rules! col { + ($col:ident, $sql_type:ty) => { + pub struct $col; + + impl Expression for $col { + type SqlType = $sql_type; + } + + impl<'a, QS> AppearsOnTable<QS> for $col where + QS: AppearsInFromClause<super::Table<'a>, Count = Once> + { + } + + impl<'a> SelectableExpression<super::Table<'a>> for $col {} + + impl ValidGrouping<()> for $col { + type IsAggregate = is_aggregate::No; + } + + impl<GB> ValidGrouping<GB> for $col + where + GB: IsContainedInGroupBy<$col, Output = is_contained_in_group_by::Yes>, + { + type IsAggregate = is_aggregate::Yes; + } + + impl<'a> Column for $col { + type Table = Table<'static>; + + const NAME: &'static str = stringify!($col); + } + + impl<DB> QueryFragment<DB> for $col + where + DB: Backend, + { + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + pass.push_identifier(<$col as Column>::NAME)?; + Ok(()) + } + } + + impl QueryId for $col { + type QueryId = $col; + + const HAS_STATIC_QUERY_ID: bool = true; + } + }; + } + + col!(id, Uuid); + col!(username, Text); + col!(token, Text); +} + +//pub const all_columns: <Table<'static> as TableTrait>::AllColumns = (id, username, token); + +pub type SqlType = (Uuid, Text, Text); + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct Table<'a> { + name: &'a str, + schema: Option<&'a str>, +} + +impl<'a> Table<'a> { + pub fn new(name: &'a str, schema: Option<&'a str>) -> Self { + Self { name, schema } + } +} + +impl QuerySource for Table<'_> { + type FromClause = Self; + type DefaultSelection = <Self as TableTrait>::AllColumns; + + fn from_clause(&self) -> Self::FromClause { + self.clone() + } + + fn default_selection(&self) -> Self::DefaultSelection { + <Self as TableTrait>::all_columns() + } +} + +impl AsQuery for Table<'_> { + type SqlType = SqlType; + type Query = SelectStatement<FromClause<Self>>; + + fn as_query(self) -> Self::Query { + SelectStatement::simple(self) + } +} + +impl TableTrait for Table<'_> +where + Self: QuerySource + AsQuery, +{ + type PrimaryKey = id; + type AllColumns = (id, username, token); + + fn primary_key(&self) -> Self::PrimaryKey { + id + } + + fn all_columns() -> Self::AllColumns { + (id, username, token) + } +} + +impl<DB> QueryFragment<DB> for Table<'_> +where + DB: Backend, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + if let Some(ref schema) = self.schema { + pass.push_identifier(schema)?; + pass.push_sql("."); + } + pass.push_identifier(self.name)?; + Ok(()) + } +} + +impl QueryId for Table<'_> { + type QueryId = (); + + const HAS_STATIC_QUERY_ID: bool = false; +} + +impl AppearsInFromClause<Table<'_>> for Table<'_> { + type Count = Once; +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9edc99b --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,107 @@ +use constant_time_eq::constant_time_eq; +use diesel::expression::AsExpression; +use diesel::pg::PgConnection; +use diesel::prelude::*; +use diesel::sql_types::Bool; +use pamsm::{Pam, PamError, PamFlags, PamLibExt, PamResult, PamServiceModule, pam_module}; + +use self::db::*; + +mod db; + +struct PamTokenPg; + +fn authenticate(pam: Pam, flags: PamFlags, args: &[String]) -> PamResult<()> { + let mut dsn = None; + let mut table = None; + let mut schema = None; + + macro_rules! parse_options { + ($args:expr, $($opt:ident),+$(,)?) => { + for arg in $args { + if false { unreachable!() } + $(else if arg.starts_with(concat!(stringify!($opt), "=")) { + let value = &arg[(stringify!($opt).len() + 1)..]; + if $opt.is_none() { + $opt = Some(value); + } else { + pam.syslog(pamsm::LogLvl::ERR, concat!( + "improperly configured: option '", + stringify!($opt), + "' specified multiple times", + ))?; + return Err(PamError::SYSTEM_ERR); + } + })+ + else { + pam.syslog( + pamsm::LogLvl::ERR, + &format!("improperly configured: invalid option '{}'", arg), + )?; + return Err(PamError::SYSTEM_ERR); + } + } + } + } + parse_options!(args, dsn, table, schema); + let Some(dsn) = dsn else { + pam.syslog( + pamsm::LogLvl::ERR, + "improperly configured: missing required option 'dsn'", + )?; + return Err(PamError::SYSTEM_ERR); + }; + let Some(table) = table else { + pam.syslog( + pamsm::LogLvl::ERR, + "improperly configured: missing required option 'table'", + )?; + return Err(PamError::SYSTEM_ERR); + }; + let table = db::Table::new(table, schema); + + let pam_user = pam + .get_user(None) + .and_then(|user| user.ok_or(PamError::AUTH_ERR))? + .to_str() + .map_err(|_| PamError::SERVICE_ERR)?; + let pam_authtok = pam + .get_authtok(None) + .and_then(|authtok| authtok.ok_or(PamError::AUTH_ERR))? + .to_bytes(); + + let mut conn = PgConnection::establish(dsn).map_err(|_| PamError::AUTHINFO_UNAVAIL)?; + + let user_tokens = table + .select(token) + .filter(username.eq(pam_user)) + .filter(token.ne("").or(AsExpression::<Bool>::as_expression( + !flags.contains(PamFlags::DISALLOW_NULL_AUTHTOK), + ))) + .load::<String>(&mut conn) + .map_err(|_| PamError::AUTHINFO_UNAVAIL)?; + + for user_token in user_tokens { + let user_token = user_token.as_bytes(); + if constant_time_eq(user_token, pam_authtok) { + return Ok(()); + } + } + + Err(PamError::AUTH_ERR) +} + +impl PamServiceModule for PamTokenPg { + fn authenticate(pam: Pam, flags: PamFlags, args: Vec<String>) -> PamError { + authenticate(pam, flags, &args) + .err() + .unwrap_or(PamError::SUCCESS) + } +} + +pam_module!(PamTokenPg); + +#[cfg(test)] +mod tests { + use super::*; +} |
