diff --git a/src/graphql/mod.rs b/src/graphql/mod.rs index 86655d5..4e4a9a0 100644 --- a/src/graphql/mod.rs +++ b/src/graphql/mod.rs @@ -2,11 +2,33 @@ use chrono::NaiveDateTime; use juniper::{graphql_object, EmptyMutation, EmptySubscription, GraphQLObject, RootNode}; use rocket::response::content::RawHtml; use rocket::State; +use rocket::request::{FromRequest, Outcome, Request}; +use std::env; +use once_cell::sync::Lazy; // use crate::account::account::{read_basic_auth_token, read_bearer_auth_token}; use crate::nnid::oauth::TokenData; use crate::Pool; +pub static API_KEY: Lazy = Lazy::new(|| { + env::var("GRAPHQL_API_KEY").expect("GRAPHQL_API_KEY not set") +}); +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Context { + type Error = (); + + async fn from_request(req: &'r Request<'_>) -> Outcome { + let pool = req.rocket().state::().cloned().unwrap(); // assume Pool is managed as state + + // Grab API key from header + let api_key = req.headers().get_one("X-API-Key").map(|s| s.to_string()); + + Outcome::Success(Context { + pool, + api_key, + }) + } +} pub type Schema = RootNode< 'static, @@ -16,8 +38,12 @@ pub type Schema = RootNode< >; -pub struct Context(pub Pool); -impl juniper::Context for Context{} +pub struct Context { + pub pool: Pool, + pub api_key: Option, +} +impl juniper::Context for Context {} + #[derive(GraphQLObject)] #[graphql(description = "Data inside of a token")] @@ -27,6 +53,25 @@ struct TokenInfo { title_id: Option } +#[derive(GraphQLObject)] +#[graphql(description = "User information from a token")] +struct UserInfo { + username: String, + account_level: i32, + nex_password: String, + mii_data: String, +} + +#[derive(GraphQLObject)] +#[graphql(description = "User information from a username")] +pub struct UserInfoWithPId { + pub username: String, + pub account_level: i32, + pub nex_password: String, + pub mii_data: String, + pub pid: i32, +} + pub struct Query; #[graphql_object] @@ -47,7 +92,7 @@ impl Query { "select * from tokens where pid = $1 and token_id = $2 and random = $3", data.pid, data.token_id, data.random ). - fetch_one(&context.0).await.ok()?; + fetch_one(&context.pool).await.ok()?; Some(TokenInfo{ pid: data.pid, @@ -56,7 +101,84 @@ impl Query { }) } + async fn user_from_token( + token_data: String, + context: &Context, + ) -> Option { + let data = match TokenData::decode(&token_data) { + Some(data) => data, + None => { + eprintln!("Failed to decode token"); + return None; + } + }; + let user = match sqlx::query!( + "SELECT username, account_level, nex_password, mii_data FROM users WHERE pid = $1", + data.pid + ) + .fetch_one(&context.pool) + .await + .ok() { + Some(user) => user, + None => { + eprintln!("No user found for PID {}", data.pid); + return None; + } + }; + + Some(UserInfo { + username: user.username, + account_level: user.account_level, + nex_password: user.nex_password, + mii_data: user.mii_data.replace('\n', "").replace('\r', ""), + }) + } + + async fn user_by_pid(pid: i32, context: &Context) -> Option { + if context.api_key.as_deref() != Some(&*API_KEY) { + eprintln!("Rejected request: invalid API key"); + return None; + } + + let user = sqlx::query!( + "SELECT username, account_level, nex_password, mii_data FROM users WHERE pid = $1", + pid + ) + .fetch_one(&context.pool) + .await + .ok()?; + + Some(UserInfo { + username: user.username, + account_level: user.account_level, + nex_password: user.nex_password, + mii_data: user.mii_data, + }) + } + + async fn user_by_username(username: String, context: &Context) -> Option { + if context.api_key.as_deref() != Some(&*API_KEY) { + eprintln!("Rejected request: invalid API key"); + return None; + } + + let user = sqlx::query!( + "SELECT pid, username, account_level, nex_password, mii_data FROM users WHERE username = $1", + username, + ) + .fetch_one(&context.pool) + .await + .ok()?; + + Some(UserInfoWithPId { + username: user.username, + account_level: user.account_level, + nex_password: user.nex_password, + mii_data: user.mii_data, + pid: user.pid, + }) + } } @@ -76,31 +198,31 @@ impl Mutation { } */ -#[rocket::get("/graphiql")] -pub fn graphiql() -> RawHtml { - juniper_rocket::graphiql_source("/graphql", None) -} - - -#[rocket::get("/playground")] -pub fn playground() -> RawHtml { - juniper_rocket::playground_source("/graphql", None) -} +// #[rocket::get("/graphiql")] +// pub fn graphiql() -> RawHtml { +// juniper_rocket::graphiql_source("/graphql", None) +// } +// +// +// #[rocket::get("/playground")] +// pub fn playground() -> RawHtml { +// juniper_rocket::playground_source("/graphql", None) +// } #[rocket::get("/graphql?")] pub async fn get_graphql( - db: &State, request: juniper_rocket::GraphQLRequest, schema: &State, + context: Context ) -> juniper_rocket::GraphQLResponse { - request.execute(schema, db).await + request.execute(schema, &context).await } #[rocket::post("/graphql", data = "")] pub async fn post_graphql( - db: &State, request: juniper_rocket::GraphQLRequest, schema: &State, + context: Context ) -> juniper_rocket::GraphQLResponse { - request.execute(schema, db).await -} \ No newline at end of file + request.execute(schema, &context).await +} diff --git a/src/main.rs b/src/main.rs index b8737b0..25882b9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -119,6 +119,10 @@ async fn launch() -> _ { env::var("S3_PASSWD").expect("S3_PASSWD not specified").into_boxed_str() ); + pub static CDN_URL: Lazy> = Lazy::new(|| + env::var("CDN_URL").expect("CDN_URL not specified").into_boxed_str() + ); + let s3_client = ClientBuilder::new(S3_URL.clone()) .provider(Some(Box::new(StaticProvider::new(&S3_USER, &S3_PASSWD, None)))) .build() @@ -137,7 +141,6 @@ async fn launch() -> _ { .manage(S3ClientState { client: Arc::new(s3_client), }) - .manage(graphql::Context(graph_pool)) .manage(Schema::new( Query, EmptyMutation::new(), @@ -178,9 +181,8 @@ async fn launch() -> _ { nnid::mapped_ids::mapped_ids, papi::login::login, papi::user::get_user, - //graphql - graphql::graphiql, - graphql::playground, + // graphql::graphiql, + // graphql::playground, graphql::get_graphql, graphql::post_graphql, ]) diff --git a/src/nnid/people.rs b/src/nnid/people.rs index 0eaef7e..07ecf62 100644 --- a/src/nnid/people.rs +++ b/src/nnid/people.rs @@ -1,8 +1,5 @@ use std::env; -use std::fs; -use std::fs::File; use std::io::Write; -use std::path::Path; use chrono::{NaiveDate, NaiveDateTime}; use gxhash::{gxhash32, gxhash64}; use minio::s3::builders::{ObjectContent}; @@ -23,7 +20,6 @@ use crate::email::send_verification_email; use rand::Rng; use mii::{get_image_png, get_image_tga}; use minio::s3::client::Client; -use minio::s3::args::PutObjectArgs; use std::sync::Arc; const DATABASE_ERROR: Errors = Errors{ @@ -60,7 +56,7 @@ fn get_mii_img_url_path(pid: i32, format: &str) -> String{ } fn get_mii_img_url(pid: i32, format: &str) -> String{ - format!("{}/pn-boss/{}", &*S3_URL_STRING, get_mii_img_url_path(pid, format)) + format!("{}/{}/{}", &*S3_URL_STRING, &*S3_BUCKET, get_mii_img_url_path(pid, format)) } pub async fn generate_s3_images(pid: i32, mii_data: &str) { diff --git a/src/nnid/support.rs b/src/nnid/support.rs index d3e1a5d..28c749c 100644 --- a/src/nnid/support.rs +++ b/src/nnid/support.rs @@ -1,7 +1,6 @@ use rocket::{State, post, FromForm, put}; use crate::Pool; use rocket::form::Form; -use crate::email::send_verification_email; use crate::error::{Error, Errors}; use chrono::Utc; diff --git a/src/papi/login.rs b/src/papi/login.rs index 7e228e7..45a1df1 100644 --- a/src/papi/login.rs +++ b/src/papi/login.rs @@ -1,5 +1,4 @@ use rocket::{post, State}; -use rocket::form::Form; use serde::Deserialize; use serde::Serialize; use crate::Pool; diff --git a/src/papi/user.rs b/src/papi/user.rs index 770f786..ccfda69 100644 --- a/src/papi/user.rs +++ b/src/papi/user.rs @@ -1,7 +1,13 @@ +use std::env; +use once_cell::sync::Lazy; use rocket::{get}; use crate::account::account::{Auth}; use rocket::serde::json::Json; +pub static CDN_URL: Lazy> = Lazy::new(|| + env::var("CDN_URL").expect("CDN_URL not specified").into_boxed_str() +); + #[derive(serde::Serialize)] struct EmailInfo { address: String, @@ -89,7 +95,7 @@ pub async fn get_user(auth: Auth) -> Json { .map(|v| v.name) .unwrap_or_else(|| "INVALID".to_string()) }, - image_url: format!("https://cdn.spfn.cc/mii/{}/normal_face.png", user.pid), + image_url: format!("https://{}/mii/{}/normal_face.png", &CDN_URL.to_string(), user.pid), }, flags: FlagsInfo { marketing: user.marketing_allowed,