Initial commit

This commit is contained in:
zawz 2024-07-24 13:22:50 +02:00
commit 493517487a
44 changed files with 11131 additions and 0 deletions

6
.gitignore vendored Normal file
View file

@ -0,0 +1,6 @@
/target
.env
token.json
token.txt
*.sql.gz
.surrealdb

45
.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,45 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug executable 'web-actix-surreal'",
"cargo": {
"args": [
"build",
"--bin=web-actix-surreal",
"--package=web-actix-surreal"
],
"filter": {
"name": "web-actix-surreal",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'web-actix-surreal'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=web-actix-surreal",
"--package=web-actix-surreal"
],
"filter": {
"name": "web-actix-surreal",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

6
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,6 @@
{
"rust-analyzer.checkOnSave.overrideCommand": [
"cargo", "check", "--message-format=json", "--target-dir", "target/lsp" ],
"rust-analyzer.cargo.buildScripts.overrideCommand": [
"cargo", "check", "--message-format=json", "--target-dir", "target/lsp" ]
}

4667
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

33
Cargo.toml Normal file
View file

@ -0,0 +1,33 @@
# workspace = { members = ["entity"] }
[package]
name = "web-actix-surreal"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
framework = { path = "framework", features = ["json", "moka", "actix"] }
framework_macros = { path = "framework/macros" }
actix-web = "4"
actix-web-httpauth = "0.8.1"
chrono = { version = "0.4.10", features = ["serde"] }
derive_more = "0.99.2"
dotenvy = "0.15.7"
futures = "0.3.1"
serde = "1.0"
serde_derive = "1.0"
serde_json = "1.0"
reqwest = { version = "^0.11", features = ["json"] }
json = "0.12.4"
thiserror = "1.0.56"
jwks-client-update = "0.2.1"
once_cell = "1.19.0"
listenfd = "1.0.1"
env_logger = "0.11.1"
log = "0.4.20"
surrealdb = "1.5.3"
const_format = "0.2.32"

48
docker-compose.yml Normal file
View file

@ -0,0 +1,48 @@
version: '3.8'
services:
keycloak:
image: quay.io/keycloak/keycloak:23.0.4
restart: unless-stopped
command:
- start-dev
- --db=postgres
- --db-url=jdbc:postgresql://postgres:5432/keycloak
- --db-username=postgres
- --db-password=postgres
ports:
- 8080:8080
volumes:
- /etc/localtime:/etc/localtime:ro
environment:
- KC_DB=postgres
- KEYCLOAK_HOSTNAME=localhost
- KEYCLOAK_ADMIN=admin
- KEYCLOAK_ADMIN_PASSWORD=admin
postgres:
image: postgres:15-alpine
restart: unless-stopped
ports:
- 5432:5432
volumes:
- database:/var/lib/postgresql/data:rw
environment:
POSTGRES_INITDB_ARGS: --encoding=UTF-8 --lc-collate=C --lc-ctype=C
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: keycloak
surreal:
image: surrealdb/surrealdb:latest
restart: unless-stopped
command: start --log trace --auth --user root --pass root file:/data/database.db
user: root
ports:
- 8888:8000
volumes:
- surrealdb:/data
volumes:
database:
surrealdb:

4268
framework/Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

27
framework/Cargo.toml Normal file
View file

@ -0,0 +1,27 @@
[package]
name = "framework"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
full = ["postcard", "json", "moka"]
postcard = ["dep:postcard"]
json = ["dep:serde_json"]
moka = ["dep:moka"]
actix = ["dep:actix-web-httpauth"]
[dependencies]
actix-web-httpauth = { version = "^0.8", optional = true }
log = "^0.4"
futures = "^0.3"
thiserror = "^1.0"
serde = "^1.0"
serde_derive = "^1.0"
moka = { version = "^0.12", features = ["future"], optional = true }
postcard = { version = "^1.0", features = ["alloc"], optional = true }
serde_json = { version = "^1.0", optional = true }
surrealdb = "^1.5"
chrono = { version = "^0.4", features = ["serde"] }
jwks-client-update = "^0.2"

View file

@ -0,0 +1,14 @@
[package]
name = "framework_macros"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
attribute-derive = "0.9.1"
quote = "1.0.36"
syn = { version = "^2.0", features = ["full"] }
[lib]
proc-macro = true

340
framework/macros/src/lib.rs Normal file
View file

@ -0,0 +1,340 @@
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, PathArguments};
use attribute_derive::FromAttr;
#[derive(Debug, FromAttr)]
#[attribute(ident = attr_name)]
// overriding the builtin error messages
#[attribute(error(missing_field = "`{field}` was not specified"))]
struct FilterAttribute {
field: Option<String>,
operator: Option<String>,
limit: bool,
offset: bool,
}
enum InputData<'a> {
Struct(StructData<'a>),
// Enum(EnumData<'a>),
}
struct StructData<'a> {
fields: Vec<Field<'a>>,
}
// struct EnumData<'a> {
// fields: Vec<Field<'a>>,
// }
#[derive(Debug)]
struct Field<'a> {
ident: &'a Ident,
is_opt: bool,
attr: FilterAttribute,
}
fn extract_type_path(ty: &syn::Type) -> Option<&syn::Path> {
match *ty {
syn::Type::Path(ref typepath) if typepath.qself.is_none() => Some(&typepath.path),
_ => None,
}
}
/// Auto generate a Filter implementation.
/// The #[filter] attribute indicates a field that will be used for filtering, with optional parameters.
/// - Auto-detects options and applies no filter if the value is None
/// - `operator` parameter defines the operator to use for comparison. Default "="
/// - `field` defines the field in database to filter with. Defaults to the name of the struct field
/// - `limit` if specified indicates the given field is the LIMIT
/// - `offset` if specified indicates the given field is the OFFSET
/// - Field must be formattable for insertion into string
///
/// Example:
/// ```
/// use framework_macros::Filter;
///
/// #[derive(Filter)]
/// pub struct MyFilter {
/// #[filter(operator = "~")]
/// pub first_name: Option<String>,
/// #[filter]
/// pub superuser: Option<bool>,
/// #[filter(field = "created_at", operator = "<")]
/// pub created_before: Option<chrono::NaiveDateTime>,
/// #[filter(limit)]
/// pub limit: Option<u64>,
/// #[filter(offset)]
/// pub offset: Option<u64>,
/// }
/// ```
#[proc_macro_derive(Filter, attributes(filter))]
pub fn derive(input: TokenStream) -> TokenStream {
// Parse the input tokens into a syntax tree.
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;
let fdata = get_object_data(&input.data);
let expanded = match fdata {
InputData::Struct(StructData { fields }) => {
let (limit_offset_fields, filter_fields): (Vec<Field>, Vec<Field>) = fields
.into_iter()
.partition(|x| x.attr.limit || x.attr.offset);
let (limit_fields, offset_fields): (Vec<Field>, Vec<Field>) =
limit_offset_fields.into_iter().partition(|x| x.attr.limit);
// TODO: improve output result
if limit_fields.len() > 1 {
panic!("Only a single field can be limit")
}
if offset_fields.len() > 1 {
panic!("Only a single field can be offset")
}
// if offset_fields.len() + offset_fields.len() == 1 {
// panic!("Either none or both limit and offset have to be defined")
// }
let components_streams: Vec<_> = filter_fields
.iter()
.map(|f| {
let ident = &f.ident;
let name = ident.to_string();
let op = &f.attr.operator.as_ref().map(|x| &x[..]).unwrap_or("=");
let field = &f.attr.field.as_ref().unwrap_or(&name);
let inner_line = quote! {
components.push((#field, #op, #name));
};
if f.is_opt {
quote! {
if self.#ident.is_some() {
#inner_line
}
}
} else {
inner_line
}
})
.collect();
let bind_streams: Vec<_> = filter_fields
.iter()
.map(|f| {
let ident = &f.ident;
let name = ident.to_string();
if f.is_opt {
quote! {
if let Some(v) = &self.#ident {
query = query.bind((#name, v));
}
}
} else {
quote! {
query = query.bind((#name, #ident));
}
}
})
.collect();
let limit_gen_stream = if limit_fields.len() == 1 {
let field = &limit_fields[0];
let ident = field.ident;
let name = ident.to_string();
let inner_line = quote! {
pagination += &format!(" LIMIT ${}", #name);
};
if field.is_opt {
quote! {
if self.#ident.is_some() {
#inner_line
}
}
} else {
inner_line
}
} else {
quote! {}
};
let limit_bind_stream = if limit_fields.len() == 1 {
let field = &limit_fields[0];
let ident = field.ident;
let name = ident.to_string();
if field.is_opt {
quote! {
if let Some(v) = &self.#ident {
query = query.bind((#name, v));
}
}
} else {
quote! {
query = query.bind((#name, #ident));
}
}
} else {
quote! {}
};
let offset_gen_stream = if offset_fields.len() == 1 {
let field = &offset_fields[0];
let ident = field.ident;
let name = ident.to_string();
let inner_line = quote! {
pagination += &format!(" START ${}", #name);
};
if field.is_opt {
quote! {
if self.#ident.is_some() {
#inner_line
}
}
} else {
inner_line
}
} else {
quote! {}
};
let offset_bind_stream = if offset_fields.len() == 1 {
let field = &offset_fields[0];
let ident = field.ident;
let name = ident.to_string();
if field.is_opt {
quote! {
if let Some(v) = &self.#ident {
query = query.bind((#name, v));
}
}
} else {
quote! {
query = query.bind((#name, #ident));
}
}
} else {
quote! {}
};
quote! {
impl Filter for #name {
fn generate_filter(&self) -> String {
let mut components: Vec<(&str, &str, &str)> = vec!();
#(
#components_streams
)*
let mut where_filter = if components.len() == 0 {
String::from("")
} else {
let c = components[0];
let mut ret = format!("WHERE {} {} ${}", c.0, c.1, c.2);
for c in &components[1..] {
ret += &format!("AND {} {} ${}", c.0, c.1, c.2);
}
ret
};
let mut pagination = String::from("");
#limit_gen_stream
#offset_gen_stream
return where_filter+&pagination;
}
fn bind<'a, C: surrealdb::Connection>(&self, mut query: surrealdb::method::Query<'a, C>) -> surrealdb::method::Query<'a, C> {
#(
#bind_streams
)*
#limit_bind_stream
#offset_bind_stream
query
}
}
}
}
};
println!("{}", expanded);
expanded.into()
}
// get relevant data from one field
fn field_data(field: &syn::Field) -> Field {
let ty = &field.ty;
let (_, is_opt) = match extract_type_from_option(ty) {
Some(v) => (v, true),
None => (ty, false),
};
Field {
ident: field.ident.as_ref().unwrap(),
// TODO: parse all attributes instead of first one
attr: FilterAttribute::from_attribute(&field.attrs[0]).unwrap(),
is_opt,
}
}
// get relevant data from input
fn get_object_data<'a>(data: &'a Data) -> InputData<'a> {
match *data {
Data::Struct(ref data) => match data.fields {
Fields::Named(ref fields) => {
let fields: Vec<Field> = fields
.named
.iter()
.filter(|f| f.attrs.len() > 0)
.map(|f| field_data(f))
.collect();
InputData::Struct(StructData { fields })
}
_ => {
unimplemented!()
}
},
Data::Enum(_) => unimplemented!(),
Data::Union(_) => unimplemented!(),
}
}
fn extract_type_from_option(ty: &syn::Type) -> Option<&syn::Type> {
use syn::{GenericArgument, Path, PathSegment};
// TODO store (with lazy static) the vec of string
// TODO maybe optimization, reverse the order of segments
fn extract_option_segment(path: &Path) -> Option<&PathSegment> {
let idents_of_path = path
.segments
.iter()
.into_iter()
.fold(String::new(), |mut acc, v| {
acc.push_str(&v.ident.to_string());
acc.push('|');
acc
});
vec!["Option|", "std|option|Option|", "core|option|Option|"]
.into_iter()
.find(|s| &idents_of_path == *s)
.and_then(|_| path.segments.last())
}
extract_type_path(ty)
.and_then(|path| extract_option_segment(path))
.and_then(|path_seg| {
let type_params = &path_seg.arguments;
// It should have only on angle-bracketed param ("<String>"):
match *type_params {
PathArguments::AngleBracketed(ref params) => params.args.first(),
_ => None,
}
})
.and_then(|generic_arg| match *generic_arg {
GenericArgument::Type(ref ty) => Some(ty),
_ => None,
})
}

95
framework/src/auth/jwt.rs Normal file
View file

@ -0,0 +1,95 @@
use super::OAuthUser;
use jwks_client_update::{jwt::Jwt, keyset::KeyStore};
pub trait Bearer {
fn token(&self) -> &str;
}
#[cfg(feature = "actix")]
pub mod actix_bearer_auth {
use actix_web_httpauth::extractors::bearer::BearerAuth;
pub struct ActixBearerAuth {
bearer: BearerAuth,
}
impl From<BearerAuth> for ActixBearerAuth {
fn from(value: BearerAuth) -> Self {
Self { bearer: value }
}
}
impl Into<BearerAuth> for ActixBearerAuth {
fn into(self) -> BearerAuth {
self.bearer
}
}
impl super::Bearer for ActixBearerAuth {
fn token(&self) -> &str {
self.bearer.token()
}
}
}
#[cfg(feature = "actix")]
pub use actix_bearer_auth::ActixBearerAuth;
pub struct JWT {
pub headers: serde_json::Value,
pub claims: serde_json::Value,
}
impl From<Jwt> for JWT {
fn from(value: Jwt) -> Self {
Self {
headers: value.header().into().unwrap(),
claims: value.payload().into().unwrap(),
}
}
}
impl JWT {
pub fn get_expiry_date(&self) -> Option<chrono::DateTime<chrono::Utc>> {
let ts = self.claims.get("exp")?.as_i64()?;
chrono::DateTime::from_timestamp(ts, 0)
}
pub fn get_claim(&self, key: &str) -> Option<&serde_json::Value> {
self.claims.get(key)
}
pub fn get_claim_str(&self, key: &str) -> Option<&str> {
self.get_claim(key)?.as_str()
}
pub fn get_user_id(&self) -> Option<&str> {
self.get_claim_str("sub")
}
pub fn get_session_id(&self) -> Option<&str> {
self.get_claim_str("sid")
}
}
pub struct JWKS {
pub store: KeyStore,
}
impl From<KeyStore> for JWKS {
fn from(value: KeyStore) -> Self {
Self { store: value }
}
}
impl JWKS {
pub async fn new_from(url: String) -> Result<Self, jwks_client_update::error::Error> {
Ok(Self {
store: KeyStore::new_from(url).await?,
})
}
pub fn auth_from_bearer<T: Bearer>(
&self,
bearer: T,
) -> Result<OAuthUser, jwks_client_update::error::Error> {
self.store
.verify(bearer.token())
.map(|v| OAuthUser::from(&JWT::from(v)))
}
}

View file

@ -0,0 +1,8 @@
pub mod jwt;
pub mod permission;
pub mod user;
pub use permission::{Permission, Permissions};
pub use user::OAuthUser;
pub use jwt::{JWKS, JWT};

View file

@ -0,0 +1,35 @@
use std::collections::BTreeSet;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub struct Permission(String);
impl From<String> for Permission {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for Permission {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
impl Permission {
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Permissions {
pub permissions: BTreeSet<Permission>,
}
impl Permissions {
pub fn has_perm(&self, perm: &Permission) -> bool {
self.permissions.get(perm).is_some()
}
}

View file

@ -0,0 +1,54 @@
use std::collections::BTreeSet;
use serde_derive::{Deserialize, Serialize};
use super::JWT;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthUser {
pub oidc_id: String,
pub username: String,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub email: Option<String>,
pub superuser: bool,
pub roles: Option<BTreeSet<String>>,
pub permissions: Option<BTreeSet<String>>,
}
impl OAuthUser {
pub fn has_perm(&self, perm: &str) -> bool {
if self.superuser {
return true;
}
if let Some(v) = &self.permissions {
return v.get(perm).is_some();
}
return false;
}
}
impl From<&JWT> for OAuthUser {
fn from(value: &JWT) -> Self {
Self {
oidc_id: value.get_user_id().unwrap().to_string(),
username: value
.get_claim_str("preferred_username")
.map(|x| x.to_string())
.unwrap(),
first_name: value.get_claim_str("given_name").map(|x| x.to_string()),
last_name: value.get_claim_str("family_name").map(|x| x.to_string()),
email: value.get_claim_str("email").map(|x| x.to_string()),
superuser: value
.get_claim_str("superuser")
.map(|x| x.parse::<bool>().unwrap_or(false))
.unwrap_or(false),
roles: value
.get_claim_str("roles")
.map(|x| x.split(",").map(|x| x.to_string()).collect()),
permissions: value
.get_claim_str("permissions")
.map(|x| x.split(",").map(|x| x.to_string()).collect()),
}
}
}

411
framework/src/cache/mod.rs vendored Normal file
View file

@ -0,0 +1,411 @@
use std::marker::PhantomData;
use futures::Future;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[cfg(feature = "moka")]
pub mod moka;
#[cfg(feature = "moka")]
pub use moka::Cache as MokaCache;
#[cfg(feature = "moka")]
pub type MokaCacheSerde<K, Se> = CacheSerde<MokaCache<K, Vec<u8>>, K, Se>;
pub mod serialize;
pub use serialize::*;
/// Cache trait implementation.
/// Must implement the functions `get`, `insert` and `validate`,
/// other functions are optional and if undefined default behavior will be auto-implemented.
/// The cache must be async. Compatible with any async runtime.
pub trait Cache<K,V>
where
V: Clone + Send+Sync + 'static,
{
/// The error type returned if get/insert/invalidate fails.
/// An error type must be provided even if the functions can never fail.
type Error: std::error::Error + 'static;
/// Try to get a value from the cache.
/// The provided value has to be owned and not a reference.
/// Must return Ok(None) if the cache is functional but the value doesn't exist.
fn get(&self, key: &K) -> impl Future<Output = Result<Option<V>, Self::Error>>;
/// Insert a value into cache.
/// Consumes both key and value provided.
fn insert(&self, key: K, value: V) -> impl Future<Output = Result<(), Self::Error>>;
/// Evict a value from the cache.
/// Must return Ok(()) if the cache works but the key doesn't exist
fn invalidate(&self, key: &K) -> impl Future<Output = Result<(), Self::Error>>;
/// Wrap a closure return value to cache with given key.
/// When a cache failure occurs the closure is called and the value is inserted into cache.
fn call<F>(&self, key: K, closure: F) -> impl Future<Output = V>
where
F: FnOnce() -> V,
{
async move {
self.call_async(key, || async { closure() } ).await
}
}
/// Same as get(key) but doesn't wrap a Result, instead returns None if there is an error.
fn get_infaillible(&self, key: &K) -> impl Future<Output = Option<V>> {
async move {
self.get(key).await.map_or_else(|e| { log::error!("Error on cache get: {}", e.to_string()); None }, |x| x )
}
}
/// Same as insert(key, value) but doesn't wrap a Result, instead returns () if there is an error.
fn insert_infaillible(&self, key: K, value: V) -> impl Future<Output = ()> {
async move {
self.insert(key, value).await.unwrap_or_else( |e| log::error!("Error on cache insert: {}", e.to_string()) )
}
}
/// Wrap an async closure return value to cache with given key.
/// When a cache failure occurs the closure is called and the value is inserted into cache.
fn call_async<F,Fut>(&self, key: K, closure: F) -> impl Future<Output = V>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = V>,
{
async move {
if let Some(v) = self.get_infaillible(&key).await {
v
} else {
// if not cached: run and insert
let ret = closure().await;
self.insert_infaillible(key, ret.clone()).await;
ret
}
}
}
/// Wrap a closure returning a result to cache with the given key.
/// When a cache failure occurs the closure is called and the value is inserted into cache.
fn call_result<F,FE>(&self, key: K, closure: F) -> impl Future<Output = Result<V, FE>>
where
F: FnOnce() -> Result<V,FE>,
{
async move {
self.call_result_async(key, || async { closure() } ).await
}
}
// Same as call_result() but with an async closure
fn call_result_async<F,Fut,FE>(&self, key: K, closure: F) -> impl Future<Output = Result<V, FE>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<V,FE>>,
{
async move {
if let Some(v) = self.get_infaillible(&key).await {
Ok(v)
} else {
let ret = closure().await;
// cache only if return is non-error
if let Ok(v) = &ret {
self.insert_infaillible(key, v.clone()).await;
}
ret
}
}
}
// Generic call with transform function arguments. Mostly used for serializing and deserializing values.
// First the value is extracted from cache, and if successful the transform_out function is executed.
// If the value was not in cache, the closure is executed, and the transform_in function is executed.
// When a tranform fails, the corresponding failure function gets executed.
//
// The call is focused on result safety, in case of failures, unless the closure itself returned fail a value will always be returned
//
// Scenarios:
// request -> cached -> transform_out -> ok -> return value
// | \-> fail -> failure_out -> closure -> return value
// |
// \-> not cached -> closure -> transform_in -> ok -> insert -> return value
// \-> fail -> return value
fn call_transform_result_async<F,TI,TO,Fut,VE,VO, EI, EO, TE>(
&self, key: K, closure: F,
transform_in: TI, transform_out: TO,
failure_in: EI, failure_out: EO,
) -> impl Future<Output = Result<VO, VE>>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<VO,VE>>,
VO: Send+Sync + 'static,
TI: FnOnce(&VO) -> Result<V, TE>,
TO: FnOnce(&V) -> Result<VO, TE>,
EI: FnOnce(TE, K),
EO: FnOnce(TE, K),
{
async move {
if let Some(v) = self.get_infaillible(&key).await {
// value was cached: try to apply the transform
match transform_out(&v) {
Ok(vo) => Ok(vo),
Err(e) => {
// transform fail: run the failure function and run closure
failure_out(e, key);
closure().await
}
}
} else {
// no value was cached: run the closure
let ret = closure().await;
if let Ok(vo) = &ret {
// closure was success: try to transform
match transform_in(vo) {
Ok(c) => {
// transform pass: insert in cache
self.insert_infaillible(key, c).await;
},
Err(e) => {
// transform fail: run the failure function
failure_in(e, key);
}
}
}
// return closure result as-is, whether ok or fail
ret
}
}
}
// Generic wrap with transform function arguments. Mostly used for serializing and deserializing values.
// First the value is extracted from cache, and if successful the transform_out function is executed.
// If the value was not in cache, the closure is executed, and the transform_in function is executed.
//
// Scenarios:
// request -> cached -> transform_out -> ok -> return value
// | \-> fail -> return err
// |
// \-> not cached -> closure -> transform_in -> ok -> insert -> return value
// \-> fail -> return err
fn wrap_transform_result_async<F,TI,TO,Fut,VE,VO,VTE>(
&self, key: K, closure: F,
transform_in: TI, transform_out: TO,
) -> impl Future<Output = Result<Result<VO, VE>, VTE>>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<VO,VE>>,
VO: Send+Sync + 'static,
TI: FnOnce(&VO) -> Result<V, VTE>,
TO: FnOnce(&V) -> Result<VO, VTE>,
{
async move {
if let Some(v) = self.get_infaillible(&key).await {
// value was cached: apply the transform and return as-is
transform_out(&v).map(|x| Ok(x))
} else {
// no value was cached: run the closure
let ret = closure().await;
if let Ok(vo) = &ret {
// closure success: transform
match transform_in(vo) {
Ok(c) => {
// transform success: insert
self.insert_infaillible(key, c).await;
Ok(ret)
},
Err(e) => {
// transform fail: fail
Err(e)
}
}
} else {
// closure fail: return function success but closure fail
Ok(ret)
}
}
}
}
}
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
InternalCacheError(Box<dyn std::error::Error>),
#[error(transparent)]
SerializeError(Box<dyn std::error::Error>),
}
#[derive(Clone)]
pub struct CacheSerde<C,K,Se>
where
C: Cache<K,Vec<u8>>,
K: std::fmt::Display,
Se: CacheSerializer,
{
cache: C,
_phantom: PhantomData<(K,Se)>,
}
impl<C,K,Se> CacheSerde<C,K,Se>
where
C: Cache<K,Vec<u8>>,
K: std::fmt::Display,
Se: CacheSerializer,
{
// Initialize from a moka Cache instance
pub fn new(cache: C) -> Self {
Self {
cache,
_phantom: Default::default(),
}
}
// Get raw data from cache
pub async fn get_raw(&self, key: &K) -> Result<Option<Vec<u8>>, Error> {
self.cache.get(key).await.map_err(|e| Error::InternalCacheError(Box::new(e)))
}
// Get raw data from cache
pub async fn get_raw_infaillible(&self, key: &K) -> Option<Vec<u8>> {
self.cache.get_infaillible(key).await
}
// Insert raw data into cache
pub async fn insert_raw(&self, key: K, value: Vec<u8>) -> Result<(), Error> {
self.cache.insert(key, value).await.map_err(|e| Error::InternalCacheError(Box::new(e)))
}
// Insert raw data into cache
pub async fn insert_raw_infaillible(&self, key: K, value: Vec<u8>) {
self.cache.insert_infaillible(key, value).await
}
// Evict the given value from cache
pub async fn invalidate(&self, key: &K) -> Result<(), Error> {
self.cache.invalidate(key).await.map_err(|e| Error::InternalCacheError(Box::new(e)))
}
fn deserialize<R>(value: &[u8]) -> Result<R, Error>
where
R: for<'a> Deserialize<'a> + Send+Sync + 'static,
{
Se::deserialize(value).map_err(|e| Error::SerializeError(Box::new(e)))
}
fn serialize<R>(value: &R) -> Result<Vec<u8>, Error>
where
R: Serialize + Send+Sync + 'static,
{
Se::serialize(value).map_err(|e| Error::SerializeError(Box::new(e)))
}
// Get from cache and deserialize
pub async fn get<R>(&self, key: &K) -> Option<Result<R, Error>>
where
R: for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.get_raw_infaillible(key).await.map(|x| Self::deserialize(&x))
}
// Serialize and insert to cache
pub async fn insert<R>(&self, key: K, value: R) -> Result<(), Error>
where
R: Serialize + Send+Sync + 'static,
{
self.insert_raw_infaillible(key, Self::serialize(&value)?).await;
Ok(())
}
// Map the closure return value to cache with given key.
// First tries to get the value from cache, but on failure will run the closure and cache the result.
// Doesn't panic, when an error occurs the closure is called instead.
pub async fn call<F,R>(&self, key: K, closure: F) -> R
where
F: FnOnce() -> R,
R: Serialize + for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.call_async(key, || async { closure() }).await
}
// Same as call() but with an async closure
pub async fn call_async<F,Fut,R>(&self, key: K, closure: F) -> R
where
F: FnOnce() -> Fut,
Fut: Future<Output = R>,
R: Serialize + for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.call_result_async(key, || async { Ok::<R,()>(closure().await) }).await.unwrap()
}
// Same as call() but the closure returns a Result, where the value isn't cached if the closure returns an error
// Can be used for closures which return errors that can't be serialized
pub async fn call_result<F,R,E>(&self, key: K, closure: F) -> Result<R,E>
where
F: FnOnce() -> Result<R,E>,
R: Serialize + for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.call_result_async(key, || async { closure() }).await
}
// same as call_result() but with an async closure
pub async fn call_result_async<F,Fut,R,E>(&self, key: K, closure: F) -> Result<R,E>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<R,E>>,
R: Serialize + for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.cache.call_transform_result_async(key,
closure,
Self::serialize,
|x| Self::deserialize(&x),
|e,key| { log::error!("Cache serialize failed for key {} with error '{}'", key, e); },
|e,key| { log::error!("Cache deserialize failed for key {} with error '{}'", key, e); },
).await
}
// Map the closure return value to cache with given key.
// First tries to get the value from cache, but on failure will run the closure and cache the result.
// Value is wrapped in a Result with potential cache failures
pub async fn wrap<F,R>(&self, key: K, closure: F) -> Result<R, Error>
where
F: FnOnce() -> R,
R: Serialize + for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.wrap_async(key, || async { closure() }).await
}
// Same as wrap() but with an async closure
pub async fn wrap_async<F,Fut,R>(&self, key: K, closure: F) -> Result<R, Error>
where
F: FnOnce() -> Fut,
Fut: Future<Output = R>,
R: Serialize + for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.wrap_result_async(key, || async { Ok(closure().await) }).await.unwrap()
}
// Same as wrap() but the closure returns a Result, where the value isn't cached if the closure returns an error
// Can be used for closures which return errors that can't be serialized
pub async fn wrap_result<F,R,E>(&self, key: K, closure: F) -> Result<Result<R,E>, Error>
where
F: FnOnce() -> Result<R,E>,
R: Serialize + for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.wrap_result_async(key, || async { closure() }).await
}
// Same as wrap_result() but with an async closure
pub async fn wrap_result_async<F,Fut,R,E>(&self, key: K, closure: F) -> Result<Result<R,E>, Error>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<R,E>>,
R: Serialize + for<'a> Deserialize<'a> + Send+Sync + 'static,
{
self.cache.wrap_transform_result_async(key,
closure,
Self::serialize,
|x| Self::deserialize(&x),
).await.map_err(Into::into)
}
}

73
framework/src/cache/moka.rs vendored Normal file
View file

@ -0,0 +1,73 @@
use thiserror::Error;
#[derive(Error,Debug)]
pub enum Error {
#[error("unknown error")]
Null,
}
pub use moka::future::CacheBuilder as CacheBuilder;
// Object used to perform cache calls and wrappings
#[derive(Clone)]
pub struct Cache<K,V>
where
K: std::cmp::Eq + std::hash::Hash + std::fmt::Display + Send+Sync + 'static,
V: Clone + Send+Sync + 'static,
{
cache: moka::future::Cache<K, V>
}
impl<K,V> Cache<K,V>
where
K: std::cmp::Eq + std::hash::Hash + std::fmt::Display + Send+Sync + 'static,
V: Clone + Send+Sync + 'static,
{
pub fn builder() -> moka::future::CacheBuilder<K, V, moka::future::Cache<K,V>>
where
K: std::cmp::Eq + std::hash::Hash + std::fmt::Display + Send+Sync + 'static,
V: Clone + Send+Sync + 'static
{
moka::future::Cache::builder()
}
// Initialize from a moka Cache instance
pub fn new(cache: moka::future::Cache<K,V>) -> Self {
Self {
cache
}
}
}
impl<K,V> From<moka::future::Cache<K,V>> for Cache<K,V>
where
K: std::cmp::Eq + std::hash::Hash + std::fmt::Display + Send+Sync + 'static,
V: Clone + Send+Sync + 'static,
{
fn from(value: moka::future::Cache<K,V>) -> Self {
Self::new(value)
}
}
impl<K,V> super::Cache<K,V> for Cache<K,V>
where
K: std::cmp::Eq + std::hash::Hash + std::fmt::Display + Send+Sync + 'static,
V: Clone + Send+Sync + 'static,
{
type Error = Error;
// Get a value from cache
async fn get(&self, key: &K) -> Result<Option<V>, Error> {
Ok(self.cache.get(key).await)
}
// Insert a value into cache
async fn insert(&self, key: K, value: V) -> Result<(), Error> {
Ok(self.cache.insert(key, value).await)
}
// Evict a value from cache
async fn invalidate(&self, key: &K) -> Result<(), Error> {
Ok(self.cache.invalidate(key).await)
}
}

25
framework/src/cache/serialize/json.rs vendored Normal file
View file

@ -0,0 +1,25 @@
use serde::{Deserialize, Serialize};
use super::CacheSerializer;
#[derive(Clone)]
pub struct JsonSerializer;
impl CacheSerializer for JsonSerializer
{
type Error = serde_json::Error;
fn serialize<V>(value: &V) -> Result<Vec<u8>, Self::Error>
where
V: Serialize + Send+Sync + 'static
{
serde_json::to_vec(value)
}
fn deserialize<V>(value: &[u8]) -> Result<V, Self::Error>
where
V: for<'a> Deserialize<'a> + Send+Sync + 'static
{
serde_json::from_slice(value)
}
}

26
framework/src/cache/serialize/mod.rs vendored Normal file
View file

@ -0,0 +1,26 @@
use serde::{Deserialize, Serialize};
pub trait CacheSerializer
{
type Error: std::error::Error + 'static;
fn serialize<V>(value: &V) -> Result<Vec<u8>, Self::Error>
where
V: Serialize + Send+Sync + 'static
;
fn deserialize<V>(value: &[u8]) -> Result<V, Self::Error>
where
V: for<'a> Deserialize<'a> + Send+Sync + 'static
;
}
#[cfg(feature = "postcard")]
pub mod postcard;
#[cfg(feature = "postcard")]
pub use postcard::PostcardSerializer;
#[cfg(feature = "json")]
pub mod json;
#[cfg(feature = "json")]
pub use json::JsonSerializer;

View file

@ -0,0 +1,25 @@
use serde::{Deserialize, Serialize};
use super::CacheSerializer;
#[derive(Clone)]
pub struct PostcardSerializer;
impl CacheSerializer for PostcardSerializer
{
type Error = postcard::Error;
fn serialize<V>(value: &V) -> Result<Vec<u8>, Self::Error>
where
V: Serialize + Send+Sync + 'static
{
postcard::to_allocvec(value)
}
fn deserialize<V>(value: &[u8]) -> Result<V, Self::Error>
where
V: for<'a> Deserialize<'a> + Send+Sync + 'static
{
postcard::from_bytes(value)
}
}

View file

@ -0,0 +1,79 @@
/// Trait used to apply filters to surrealDB queries.
/// Works in two parts: generate a string to append to the SurrealDB query, and apply binds to said query.
/// Usually created from the [derive macro](../../../framework_macros/index.html)
///
/// Example:
/// ```
/// use framework::db::filter::Filter;
///
/// pub struct MyFilter {
/// pub first_name: Option<String>,
/// pub superuser: Option<bool>,
/// pub created_before: Option<chrono::NaiveDateTime>,
/// pub limit: Option<u64>,
/// pub offset: Option<u64>,
/// }
///
/// impl Filter for MyFilter {
/// fn generate_filter(&self) -> String {
/// let mut components: Vec<(&str, &str, &str)> = vec![];
/// if self.first_name.is_some() {
/// components.push(("first_name", "~", "first_name"));
/// }
/// if self.superuser.is_some() {
/// components.push(("superuser", "=", "superuser"));
/// }
/// if self.created_before.is_some() {
/// components.push(("created_at", "<", "created_before"));
/// }
/// let mut where_filter = if components.len() == 0 {
/// String::from("")
/// } else {
/// let c = components[0];
/// let mut ret = format!("WHERE {} {} ${}", c.0, c.1, c.2);
/// for c in &components[1..] {
/// ret += &format!("AND {} {} ${}", c.0, c.1, c.2);
/// }
/// ret
/// };
/// let mut pagination = String::from("");
/// if self.limit.is_some() {
/// pagination += &format!(" LIMIT ${}", "limit");
/// }
/// if self.offset.is_some() {
/// pagination += &format!(" START ${}", "offset");
/// }
/// return where_filter + &pagination;
/// }
///
/// fn bind<'a, C: surrealdb::Connection>(
/// &self,
/// mut query: surrealdb::method::Query<'a, C>,
/// ) -> surrealdb::method::Query<'a, C> {
/// if let Some(v) = &self.first_name {
/// query = query.bind(("first_name", v));
/// }
/// if let Some(v) = &self.superuser {
/// query = query.bind(("superuser", v));
/// }
/// if let Some(v) = &self.created_before {
/// query = query.bind(("created_before", v));
/// }
/// if let Some(v) = &self.limit {
/// query = query.bind(("limit", v));
/// }
/// if let Some(v) = &self.offset {
/// query = query.bind(("offset", v));
/// }
/// query
/// }
/// }
/// ```
pub trait Filter {
fn generate_filter(&self) -> String;
fn bind<'a, C: surrealdb::Connection>(
&self,
query: surrealdb::method::Query<'a, C>,
) -> surrealdb::method::Query<'a, C>;
}

63
framework/src/db/mod.rs Normal file
View file

@ -0,0 +1,63 @@
use serde::{Deserialize, Serialize};
use surrealdb::sql::{Datetime, Id};
pub mod filter;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecordId {
pub id: Id,
}
impl Into<Id> for RecordId {
fn into(self) -> Id {
self.id
}
}
impl<T> From<T> for RecordId
where
Id: From<T>,
{
fn from(value: T) -> Self {
Self {
id: Id::from(value),
}
}
}
/// Database record definition, wrapper for serializing/deserializing to/from surreal queries
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Record<T> {
pub id: RecordId,
#[serde(flatten)]
pub value: T,
pub created_at: Option<Datetime>,
pub updated_at: Option<Datetime>,
}
/// Flattened database record definition, wrapper for serializing/deserializing externally.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlatRecord<T> {
pub id: String,
#[serde(flatten)]
pub value: T,
pub created_at: Option<Datetime>,
pub updated_at: Option<Datetime>,
}
impl<T> From<Record<T>> for FlatRecord<T> {
fn from(value: Record<T>) -> Self {
Self {
id: value.id.id.to_raw(),
value: value.value,
created_at: value.created_at,
updated_at: value.updated_at,
}
}
}
impl<T> Record<T> {
pub fn to_flat(self) -> FlatRecord<T> {
FlatRecord::from(self)
}
}

3
framework/src/lib.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod auth;
pub mod cache;
pub mod db;

View file

@ -0,0 +1 @@
{"schemas":"DEFINE TABLE role SCHEMAFULL;\n\nDEFINE FIELD name ON TABLE role TYPE string;\nDEFINE FIELD permissions ON TABLE role TYPE array<string>;\nDEFINE FIELD created_at ON TABLE role TYPE datetime VALUE time::now() READONLY;\nDEFINE FIELD updated_at ON TABLE role TYPE datetime VALUE time::now();\n\nDEFINE INDEX roleName ON TABLE user COLUMNS name UNIQUE;\n\nDEFINE TABLE script_migration SCHEMAFULL\n PERMISSIONS\n FOR select FULL\n FOR create, update, delete NONE;\n\nDEFINE FIELD script_name ON script_migration TYPE string;\nDEFINE FIELD executed_at ON script_migration TYPE datetime VALUE time::now() READONLY;\nDEFINE TABLE user SCHEMAFULL;\n\nDEFINE FIELD oidc_id ON TABLE user TYPE string;\nDEFINE FIELD username ON TABLE user TYPE string;\nDEFINE FIELD first_name ON TABLE user TYPE option<string>;\nDEFINE FIELD last_name ON TABLE user TYPE option<string>;\nDEFINE FIELD email ON TABLE user TYPE option<string>;\nDEFINE FIELD superuser ON TABLE user TYPE bool DEFAULT false;\nDEFINE FIELD roles ON TABLE user TYPE array<record<role>>;\n\nDEFINE FIELD created_at ON TABLE user TYPE datetime VALUE time::now() READONLY;\nDEFINE FIELD updated_at ON TABLE user TYPE datetime VALUE time::now();\n\nDEFINE INDEX userOidcId ON TABLE user COLUMNS oidc_id UNIQUE;\nDEFINE INDEX userUsernameIndex ON TABLE user COLUMNS username UNIQUE;\nDEFINE INDEX userEmailIndex ON TABLE user COLUMNS email UNIQUE;\n","events":""}

8
schemas/role.surql Normal file
View file

@ -0,0 +1,8 @@
DEFINE TABLE role SCHEMAFULL;
DEFINE FIELD name ON TABLE role TYPE string;
DEFINE FIELD permissions ON TABLE role TYPE array<string>;
DEFINE FIELD created_at ON TABLE role TYPE datetime VALUE time::now() READONLY;
DEFINE FIELD updated_at ON TABLE role TYPE datetime VALUE time::now();
DEFINE INDEX roleName ON TABLE user COLUMNS name UNIQUE;

View file

@ -0,0 +1,7 @@
DEFINE TABLE script_migration SCHEMAFULL
PERMISSIONS
FOR select FULL
FOR create, update, delete NONE;
DEFINE FIELD script_name ON script_migration TYPE string;
DEFINE FIELD executed_at ON script_migration TYPE datetime VALUE time::now() READONLY;

16
schemas/user.surql Normal file
View file

@ -0,0 +1,16 @@
DEFINE TABLE user SCHEMAFULL;
DEFINE FIELD oidc_id ON TABLE user TYPE string;
DEFINE FIELD username ON TABLE user TYPE string;
DEFINE FIELD first_name ON TABLE user TYPE option<string>;
DEFINE FIELD last_name ON TABLE user TYPE option<string>;
DEFINE FIELD email ON TABLE user TYPE option<string>;
DEFINE FIELD superuser ON TABLE user TYPE bool DEFAULT false;
DEFINE FIELD roles ON TABLE user TYPE array<record<role>>;
DEFINE FIELD created_at ON TABLE user TYPE datetime VALUE time::now() READONLY;
DEFINE FIELD updated_at ON TABLE user TYPE datetime VALUE time::now();
DEFINE INDEX userOidcId ON TABLE user COLUMNS oidc_id UNIQUE;
DEFINE INDEX userUsernameIndex ON TABLE user COLUMNS username UNIQUE;
DEFINE INDEX userEmailIndex ON TABLE user COLUMNS email UNIQUE;

53
src/auth/mod.rs Normal file
View file

@ -0,0 +1,53 @@
use framework::{auth::jwt::ActixBearerAuth, db::Record};
use crate::AppState;
use actix_web_httpauth::extractors::bearer::BearerAuth;
pub mod user;
pub use user::UserWithPerm;
use crate::errors::ServiceError;
use framework::auth::JWKS;
use once_cell::sync::Lazy;
use std::sync::RwLock;
use framework::auth::OAuthUser;
pub static OIDC_CERT_URL: Lazy<String> =
Lazy::new(|| std::env::var("OIDC_CERT_URL").expect("variable OIDC_CERT_URL must be set"));
static KEY_STORE: Lazy<RwLock<Option<JWKS>>> = Lazy::new(|| RwLock::new(None));
pub async fn update_key_store(url: String) -> Result<(), jwks_client_update::error::Error> {
let mut keystore_update = KEY_STORE.write().unwrap();
*keystore_update = Some(JWKS::new_from(url).await?);
Ok(())
}
pub async fn auth_from_bearer(bearer: BearerAuth) -> Result<OAuthUser, ServiceError> {
// TODO: extract globals into configs/state ?
let mut keystore = KEY_STORE.read().unwrap();
if keystore.is_none() {
drop(keystore);
update_key_store(OIDC_CERT_URL.clone()).await?;
keystore = KEY_STORE.read().unwrap();
}
let keystore = keystore.as_ref().unwrap();
keystore
.auth_from_bearer(ActixBearerAuth::from(bearer))
.map_err(Into::into)
}
pub async fn user_from_bearer(
state: &AppState,
bearer: BearerAuth,
) -> Result<Record<UserWithPerm>, ServiceError> {
let oauth_user = auth_from_bearer(bearer).await?;
if let Some(v) = UserWithPerm::new_from_oauth(oauth_user, state).await? {
Ok(v)
} else {
Err(ServiceError::Forbidden("Unregistered user".to_string()))
}
}

48
src/auth/user.rs Normal file
View file

@ -0,0 +1,48 @@
use framework::auth::OAuthUser;
use framework::db::Record;
use serde_derive::{Deserialize, Serialize};
use framework::auth::{Permission, Permissions};
use crate::db::user;
use crate::errors::Result;
use crate::model::User;
use crate::AppState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserWithPerm {
#[serde(flatten)]
pub user: User,
#[serde(flatten)]
pub permissions: Option<Permissions>,
}
pub fn cache_key(oidc_id: &str) -> String {
format!("auth:user:{}", oidc_id)
}
impl UserWithPerm {
pub async fn new_from_oauth(
authuser: OAuthUser,
state: &AppState,
) -> Result<Option<Record<Self>>> {
state
.cache
.call_result_async(cache_key(&authuser.oidc_id), || async {
user::get_from_oidc_id(&state.conn, &authuser.oidc_id).await
})
.await
}
pub fn has_perm(&self, perm: &Permission) -> bool {
if self.user.superuser {
return true;
}
if let Some(v) = &self.permissions {
return v.has_perm(perm);
} else {
return false;
}
}
}

1
src/db/mod.rs Normal file
View file

@ -0,0 +1 @@
pub mod user;

135
src/db/user.rs Normal file
View file

@ -0,0 +1,135 @@
use crate::model::User;
use crate::{auth::UserWithPerm, model::user::SimpleUser};
use const_format::formatcp;
use framework::db::{Record, RecordId};
use serde::{Deserialize, Serialize};
use surrealdb::{engine::remote::ws::Client, sql::Thing, Surreal};
use crate::errors::Result;
use framework::db::filter::Filter;
use framework_macros::Filter;
const TABLE: &str = "user";
const BASE_FIELDS: &str =
"*, array::flatten(roles.permissions) AS permissions, roles.name AS roles";
#[derive(Debug, Deserialize, Filter)]
pub struct UserFilter {
#[filter(operator = "~")]
pub first_name: Option<String>,
#[filter(operator = "~")]
pub last_name: Option<String>,
#[filter(operator = "~")]
pub email: Option<String>,
#[filter]
pub superuser: Option<bool>,
#[filter(field = "created_at", operator = "<")]
pub created_before: Option<chrono::NaiveDateTime>,
#[filter(field = "created_at", operator = ">=")]
pub created_after: Option<chrono::NaiveDateTime>,
pub sort: Option<String>,
pub rev_sort: Option<bool>,
#[filter(limit)]
pub limit: Option<u64>,
#[filter(offset)]
pub offset: Option<u64>,
}
pub async fn get_from_id<T>(conn: &Surreal<Client>, id: T) -> Result<Option<Record<UserWithPerm>>>
where
T: Serialize + std::fmt::Display,
{
conn.query(formatcp!("SELECT {BASE_FIELDS} FROM $user"))
.bind((TABLE, Thing::try_from(format!("{}:{}", TABLE, id))))
.await?
.take(0)
.map_err(Into::into)
}
/// Get a user that matches either id, oidc_id, or username
pub async fn get_from_any_id<T>(
conn: &Surreal<Client>,
id: T,
) -> Result<Option<Record<UserWithPerm>>>
where
T: Serialize + std::fmt::Display,
{
let mut query = conn
.query(formatcp!(
r#"
SELECT {BASE_FIELDS} FROM type::table($table) WHERE id = $userid;
SELECT {BASE_FIELDS} FROM type::table($table) WHERE oidc_id = $str_id;
SELECT {BASE_FIELDS} FROM type::table($table) WHERE username = $str_id;
"#
))
.bind(("table", TABLE))
// TODO: cleaner way than Thing ?
.bind((
"userid",
Thing::try_from(format!("{}:{}", TABLE, id)).unwrap(),
))
.bind(("str_id", &id))
.await?;
let mut ret = None;
for i in 0..3 {
if let Some(v) = query.take(i)? {
ret = Some(v);
break;
}
}
Ok(ret)
}
pub async fn get_from_oidc_id<T>(
conn: &Surreal<Client>,
id: T,
) -> Result<Option<Record<UserWithPerm>>>
where
T: Serialize,
{
conn.query(formatcp!(
"SELECT {BASE_FIELDS} FROM type::table($table) WHERE oidc_id = $oidc_id"
))
.bind(("table", TABLE))
.bind(("oidc_id", id))
.await?
.take(0)
.map_err(Into::into)
}
pub async fn get_all(conn: &Surreal<Client>) -> Result<Vec<Record<UserWithPerm>>> {
conn.select("user").await.map_err(Into::into)
}
pub async fn get_paged_filtered(
conn: &Surreal<Client>,
filter: UserFilter,
) -> Result<Vec<Record<UserWithPerm>>> {
let query = conn.query(format!(
"SELECT {BASE_FIELDS} FROM type::table($table) {}",
filter.generate_filter()
));
let ret = filter.bind(query).bind(("table", TABLE)).await?.take(0)?;
Ok(ret)
}
pub async fn update<I>(conn: &Surreal<Client>, id: I, user: SimpleUser) -> Result<()>
where
I: Into<surrealdb::sql::Id>,
{
let _: Option<Record<SimpleUser>> = conn.update((TABLE, id)).merge(user).await?;
Ok(())
}
pub async fn delete<I>(conn: &Surreal<Client>, id: I) -> Result<()>
where
I: Into<surrealdb::sql::Id>,
{
let _: Option<Record<SimpleUser>> = conn.delete((TABLE, id)).await?;
Ok(())
}

68
src/errors.rs Normal file
View file

@ -0,0 +1,68 @@
use actix_web::{error::ResponseError, HttpResponse};
use derive_more::Display;
use jwks_client_update::error::{Type as JwksErrorType , Error as JwksError};
#[derive(thiserror::Error, Debug, Display)]
pub enum ServiceError {
#[display(fmt = "Internal Server Error: {}", _0)]
InternalServerError(Box<dyn std::error::Error>),
#[display(fmt = "Unauthorized: {}", _0)]
Unauthorized(String),
#[display(fmt = "Forbidden: {}", _0)]
Forbidden(String),
#[display(fmt = "BadRequest: {}", _0)]
BadRequest(String),
}
#[derive(thiserror::Error,Debug)]
pub enum Error {
#[error(transparent)]
SurrealDB(#[from] surrealdb::Error),
#[error("unknown error")]
Unknown,
}
pub type Result<T> = std::result::Result<T, Error>;
// impl ResponseError trait allows to convert our errors into http responses with appropriate data
impl ResponseError for ServiceError {
fn error_response(&self) -> HttpResponse {
match self {
ServiceError::BadRequest(ref message) => HttpResponse::BadRequest().json(message),
ServiceError::Unauthorized(ref message) => HttpResponse::Unauthorized().json(message),
ServiceError::Forbidden(ref message) => HttpResponse::Forbidden().json(message),
ServiceError::InternalServerError(e) => {
log::error!("{}", e.to_string());
HttpResponse::InternalServerError().json("Internal Server Error")
},
}
}
}
impl From<JwksError> for ServiceError {
fn from(value: JwksError) -> Self {
match value.typ {
JwksErrorType::Invalid => Self::Unauthorized("Invalid token".to_string()),
JwksErrorType::Expired => Self::Unauthorized("Token expired".to_string()),
_ => Self::InternalServerError(value.into()),
}
}
}
impl From<Error> for ServiceError {
fn from(value: Error) -> Self {
Self::InternalServerError(value.into())
}
}
impl Error {
pub fn to_service_err(self) -> ServiceError {
ServiceError::from(self)
}
}

71
src/handlers/mod.rs Normal file
View file

@ -0,0 +1,71 @@
use crate::auth::auth_from_bearer;
use crate::AppState;
use actix_web::web::{Bytes, Data};
use actix_web::{Error, HttpResponse};
use actix_web_httpauth::extractors::bearer::BearerAuth;
use framework::db::Record;
use crate::model::User;
pub mod user;
// pub mod role;
// pub mod user_role;
pub fn map_option_response_404<T, F>(value: Option<T>, func: F) -> HttpResponse
where
F: FnOnce(T) -> HttpResponse,
{
match value {
Some(v) => func(v),
None => HttpResponse::NotFound().finish(),
}
}
pub async fn ping() -> HttpResponse {
HttpResponse::Ok().into()
}
// Handler for POST /register
pub async fn register(bearer: BearerAuth, state: Data<AppState>) -> Result<HttpResponse, Error> {
let authuser = auth_from_bearer(bearer).await?;
let user: User = authuser.into();
let conn = &state.conn;
let _ = state
.cache
.invalidate(&crate::auth::user::cache_key(&user.oidc_id))
.await;
let qres = conn
.query("SELECT * FROM user WHERE oidc_id = $oidc_id")
.bind(("oidc_id", user.oidc_id.clone()))
.await
.unwrap()
.take(0);
let q: Option<Record<User>> = qres.unwrap();
if let Some(mut v) = q {
// user exists, check if it's different
let user = if !v.value.metadata_eq(&user) {
// update the existing user
v.value.metadata_from(user);
let updated: Option<Record<User>> =
conn.update(("user", v.id)).merge(v.value).await.unwrap();
updated.unwrap()
} else {
v
};
Ok(HttpResponse::Ok().json(user.to_flat()))
} else {
let mut created: Vec<Record<User>> = conn.create("user").content(user).await.unwrap();
if created.len() != 1 {
panic!("Unexpected error");
}
Ok(HttpResponse::Ok().json(created.pop().unwrap().to_flat()))
}
}
pub async fn echo(bytes: Bytes) -> Result<HttpResponse, Error> {
let str = String::from_utf8(bytes.to_vec()).unwrap();
log::debug!("echo request: {}", str);
Ok(HttpResponse::Ok().body(str))
}

123
src/handlers/user.rs Normal file
View file

@ -0,0 +1,123 @@
use crate::auth;
use crate::auth::{user_from_bearer, UserWithPerm};
use crate::errors::ServiceError;
use crate::model::user::SimpleUser;
use crate::AppState;
use actix_web::web::{Data, Json, Path, Query};
use actix_web::{Error, HttpResponse};
use actix_web_httpauth::extractors::bearer::BearerAuth;
use framework::db::FlatRecord;
use crate::db::user::{self, UserFilter};
const REQUIRED_PERM: &str = "user";
// info for current user
pub async fn get_current_user(
bearer: BearerAuth,
state: Data<AppState>,
) -> Result<HttpResponse, Error> {
let resp = user_from_bearer(&state, bearer).await?;
Ok(HttpResponse::Ok().json(resp.to_flat()))
}
// Handler for GET /users/{id}
pub async fn get_by_id(
bearer: BearerAuth,
state: Data<AppState>,
id: Path<String>,
) -> Result<HttpResponse, Error> {
let authuser = user_from_bearer(&state, bearer).await?.value;
if !authuser.has_perm(&format!("{}_read", REQUIRED_PERM).into()) {
return Ok(HttpResponse::Forbidden().finish());
}
let user = user::get_from_any_id(&state.conn, &id.into_inner())
.await
.map_err(ServiceError::from)?;
let resp = if let Some(u) = user {
HttpResponse::Ok().json(u.to_flat())
} else {
HttpResponse::NotFound().finish()
};
Ok(resp)
}
// Handler for GET /users
pub async fn get(
bearer: BearerAuth,
state: Data<AppState>,
filter: Query<UserFilter>,
) -> Result<HttpResponse, Error> {
let authuser = user_from_bearer(&state, bearer).await?.value;
if !authuser.has_perm(&format!("{}_read", REQUIRED_PERM).into()) {
return Ok(HttpResponse::Forbidden().finish());
}
let filter: UserFilter = filter.into_inner();
if filter.limit.is_some() && filter.limit.unwrap() == 0 {
return Ok(HttpResponse::BadRequest().body("invalid limit"));
}
let res: Vec<FlatRecord<UserWithPerm>> = user::get_paged_filtered(&state.conn, filter)
.await
.map_err(|e| e.to_service_err())?
.into_iter()
.map(|x| x.to_flat())
.collect();
Ok(HttpResponse::Ok().json(res))
}
pub async fn update(
bearer: BearerAuth,
state: Data<AppState>,
id: Path<String>,
data: Json<SimpleUser>,
) -> Result<HttpResponse, Error> {
let authuser = user_from_bearer(&state, bearer).await?.value;
if !authuser.has_perm(&format!("{}_update", REQUIRED_PERM).into()) {
return Ok(HttpResponse::Forbidden().finish());
}
let rec = user::get_from_any_id(&state.conn, id.into_inner())
.await
.map_err(|e| e.to_service_err())?;
let rec = if let Some(v) = rec {
v
} else {
return Ok(HttpResponse::NotFound().into());
};
user::update(&state.conn, rec.id, data.into_inner())
.await
.map_err(|e| e.to_service_err())?;
Ok(HttpResponse::Ok().into())
}
pub async fn delete(
bearer: BearerAuth,
state: Data<AppState>,
id: Path<String>,
) -> Result<HttpResponse, Error> {
let authuser = user_from_bearer(&state, bearer).await?.value;
if !authuser.has_perm(&format!("{}_delete", REQUIRED_PERM).into()) {
return Ok(HttpResponse::Forbidden().finish());
}
let rec = user::get_from_any_id(&state.conn, id.into_inner())
.await
.map_err(|e| e.to_service_err())?;
let rec = if let Some(v) = rec {
v
} else {
return Ok(HttpResponse::NotFound().into());
};
user::delete(&state.conn, rec.id)
.await
.map_err(|e| e.to_service_err())?;
let _ = state
.cache
.invalidate(&auth::user::cache_key(&authuser.user.oidc_id))
.await;
Ok(HttpResponse::Ok().into())
}

116
src/main.rs Normal file
View file

@ -0,0 +1,116 @@
use std::env;
use actix_web::{
web::{self, Data},
App, HttpServer,
};
// use actix_web_httpauth::middleware::HttpAuthentication;
pub mod auth;
pub mod db;
pub mod errors;
pub mod handlers;
pub mod model;
pub mod util;
use framework::cache::MokaCacheSerde;
use surrealdb::{
engine::remote::ws::{Client, Ws},
opt::auth::Root,
Surreal,
};
#[derive(Clone)]
pub struct AppState {
conn: Surreal<Client>,
cache: MokaCacheSerde<String, framework::cache::JsonSerializer>,
}
#[actix_web::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// read configs
dotenvy::dotenv().ok();
env_logger::init();
let db_url = env::var("DATABASE_URL").expect("Environment variable DATABASE_URL is not set");
let host = env::var("LISTEN_HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
let port = env::var("LISTEN_PORT").unwrap_or_else(|_| "8000".to_string());
let server_url = format!("{host}:{port}");
let db_conn = Surreal::new::<Ws>(db_url).await?;
// Signin as a namespace, database, or root user
db_conn
.signin(Root {
username: "root",
password: "root",
})
.await?;
// Select a specific namespace / database
db_conn
.use_ns(env::var("DATABASE_NAMESPACE").unwrap_or_else(|_| "demo".into()))
.use_db(env::var("DATABASE_DB").unwrap_or_else(|_| "demo".into()))
.await?;
// initialize key store
auth::update_key_store(auth::OIDC_CERT_URL.clone()).await?;
let state = AppState {
conn: db_conn,
cache: MokaCacheSerde::new(
framework::cache::MokaCache::builder()
.weigher(|_key, value: &Vec<u8>| -> u32 {
value.len().try_into().unwrap_or(u32::MAX)
})
.max_capacity(32 * 1024 * 1024)
.build()
.into(),
),
};
// create server and try to serve over socket if possible
let mut listenfd = listenfd::ListenFd::from_env();
let mut server = HttpServer::new(move || {
//let auth = HttpAuthentication::bearer(auth::validator);
App::new()
.app_data(Data::new(state.clone()))
.service(
web::scope("api/public")
.route("ping", web::get().to(handlers::ping))
.route("echo", web::post().to(handlers::echo)),
)
.service(
web::scope("api")
.route("register", web::post().to(handlers::register))
.route("user", web::get().to(handlers::user::get_current_user))
.service(
web::scope("users")
.route("", web::get().to(handlers::user::get))
// .route("", web::post() .to(handlers::user::add))
.route("{id}", web::get().to(handlers::user::get_by_id))
.route("{id}", web::put().to(handlers::user::update))
.route("{id}", web::delete().to(handlers::user::delete)),
// .route("{id}/roles", web::put() .to(handlers::user_role::set))
// .route("{id}/roles", web::post() .to(handlers::user_role::add))
// .route("{id}/roles/{role_id}", web::delete().to(handlers::user_role::delete))
), // .service(
// web::scope("roles")
// .route("", web::get() .to(handlers::role::get))
// // .route("", web::post() .to(handlers::role::add))
// .route("{id}", web::get() .to(handlers::role::get_by_id))
// .route("{id}", web::delete() .to(handlers::role::delete))
// )
)
});
server = match listenfd.take_tcp_listener(0)? {
Some(listener) => server.listen_auto_h2c(listener)?,
None => server.bind_auto_h2c(&server_url)?,
};
println!("Starting server at {server_url}");
server.run().await?;
Ok(())
}

5
src/model/mod.rs Normal file
View file

@ -0,0 +1,5 @@
pub mod user;
pub mod role;
pub use user::User;
pub use role::Role;

4
src/model/role.rs Normal file
View file

@ -0,0 +1,4 @@
pub struct Role {
pub name: String,
pub permissions: Vec<String>,
}

57
src/model/user.rs Normal file
View file

@ -0,0 +1,57 @@
// use framework::db::RecordId;
use serde::{Deserialize, Serialize};
use framework::auth::OAuthUser;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimpleUser {
pub oidc_id: String,
pub username: String,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub email: Option<String>,
pub superuser: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub oidc_id: String,
pub username: String,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub email: Option<String>,
pub superuser: bool,
pub roles: Vec<String>,
}
impl From<OAuthUser> for User {
fn from(value: OAuthUser) -> Self {
crate::model::User {
oidc_id: value.oidc_id,
username: value.username,
first_name: value.first_name,
last_name: value.last_name,
email: value.email,
superuser: value.superuser,
roles: Vec::new(),
}
}
}
impl User {
pub fn metadata_eq(&self, value: &User) -> bool {
self.oidc_id == value.oidc_id
&& self.username == value.username
&& self.first_name == value.first_name
&& self.last_name == value.last_name
&& self.email == value.email
}
pub fn metadata_from(&mut self, value: User) {
self.oidc_id = value.oidc_id;
self.username = value.username;
self.first_name = value.first_name;
self.last_name = value.last_name;
self.email = value.email;
}
}

26
src/util.rs Normal file
View file

@ -0,0 +1,26 @@
use std::time::Duration;
pub fn str_bool_eval(val: &str) -> Option<bool> {
match val {
"true"|"True"|"TRUE"|"yes"|"1" => Some(true),
"false"|"False"|"FALSE"|"no"|"0" => Some(false),
_ => None,
}
}
pub fn dur_to_num(dur: Duration) -> (u128, &'static str) {
if dur.as_nanos() < 10000 {
(dur.as_nanos(), "ns")
} else if dur.as_micros() < 10000 {
(dur.as_micros(), "µs")
} else if dur.as_millis() < 10000 {
(dur.as_millis(), "ms")
} else {
(dur.as_secs().into(), "s")
}
}
pub fn dur_to_str(dur: Duration) -> String {
let (n,s) = dur_to_num(dur);
n.to_string() +" "+ s
}

4
tools/api.sh Executable file
View file

@ -0,0 +1,4 @@
#!/bin/sh
curl --fail-with-body -H "Authorization: Bearer $(cat token.txt)" "$@"
echo

8
tools/db-init.sh Executable file
View file

@ -0,0 +1,8 @@
#!/bin/sh
file=${1-structure.surql}
ns=${2-test}
db=${3-test}
surreal import --conn http://localhost:8888 --user root --pass root --ns $ns --db $db $file

1
tools/dump.sh Executable file
View file

@ -0,0 +1 @@
docker compose exec postgres pg_dumpall -U postgres "$@" | gzip > dump.sql.gz

10
tools/get_token.sh Executable file
View file

@ -0,0 +1,10 @@
curl -sf -L -X POST 'http://localhost:8080/realms/master/protocol/openid-connect/token' \
-H 'Content-Type: application/x-www-form-urlencoded' \
--data-urlencode "client_id=actix" \
--data-urlencode "grant_type=password" \
--data-urlencode "client_secret=JQI6gnEeHlaOX3yzZsiVwAAUIouTlR4j" \
--data-urlencode "scope=openid" \
--data-urlencode "username=${1-admin}" \
--data-urlencode "password=${2-admin}" > token.json
jq -r .access_token < token.json > token.txt

9
tools/refresh_token.sh Executable file
View file

@ -0,0 +1,9 @@
curl --no-progress-meter -f -L -X POST 'http://localhost:8080/realms/master/protocol/openid-connect/token' \
-H 'Content-Type: application/x-www-form-urlencoded' \
--data-urlencode "client_id=actix" \
--data-urlencode "grant_type=refresh_token" \
--data-urlencode "client_secret=JQI6gnEeHlaOX3yzZsiVwAAUIouTlR4j" \
--data-urlencode "refresh_token=$(jq -r .refresh_token token.json)" \
> token.json
jq -r .access_token < token.json > token.txt

9
tools/seaorm_gen.sh Executable file
View file

@ -0,0 +1,9 @@
#!/bin/sh
sea-orm-cli migrate
sea-orm-cli generate entity -l \
--with-serde both \
--serde-skip-deserializing-primary-key \
-o entity/src/
echo 'pub mod links;' >> entity/src/lib.rs