From c0fdc1445d1ecaa93eaa7ac39c99a1636ad14888 Mon Sep 17 00:00:00 2001 From: red binder Date: Mon, 27 Apr 2026 18:24:37 +0200 Subject: [PATCH] Redo cert request guard --- src/account/account.rs | 128 +++++++++++++++---------------- src/nnid/oauth/generate_token.rs | 6 +- 2 files changed, 68 insertions(+), 66 deletions(-) diff --git a/src/account/account.rs b/src/account/account.rs index 12c759a..8e81069 100644 --- a/src/account/account.rs +++ b/src/account/account.rs @@ -63,7 +63,7 @@ pub struct User { #[derive(sqlx::FromRow)] pub struct CertificateRecord { - pub _hash: Vec, + pub hash: Vec, pub banned: bool, } @@ -252,62 +252,24 @@ impl Into } } -pub async fn handle_certificate( +pub async fn link_certificate_to_pid( pool: &sqlx::PgPool, cert: &Certificate, pid: i32, ) -> Result<(), Errors<'static>> { let hash = cert.hash(); - let existing = sqlx::query_as::<_, CertificateRecord>( - "SELECT hash, banned FROM certificates WHERE hash = $1" + sqlx::query( + "INSERT INTO certificate_pids (cert_hash, pid) + VALUES ($1, $2) + ON CONFLICT DO NOTHING" ) .bind(&hash[..]) - .fetch_optional(pool) + .bind(pid) + .execute(pool) .await .map_err(|_| INVALID_TOKEN_ERRORS)?; - if let Some(cert_row) = existing { - if cert_row.banned { - return Err(INVALID_TOKEN_ERRORS); - } - - sqlx::query( - "INSERT INTO certificate_pids (cert_hash, pid) - VALUES ($1, $2) - ON CONFLICT DO NOTHING" - ) - .bind(&hash[..]) - .bind(pid) - .execute(pool) - .await - .map_err(|_| INVALID_TOKEN_ERRORS)?; - - } else { - let mut tx = pool.begin().await.map_err(|_| INVALID_TOKEN_ERRORS)?; - - sqlx::query( - "INSERT INTO certificates (hash, banned) - VALUES ($1, false)" - ) - .bind(&hash[..]) - .execute(&mut *tx) - .await - .map_err(|_| INVALID_TOKEN_ERRORS)?; - - sqlx::query( - "INSERT INTO certificate_pids (cert_hash, pid) - VALUES ($1, $2)" - ) - .bind(&hash[..]) - .bind(pid) - .execute(&mut *tx) - .await - .map_err(|_| INVALID_TOKEN_ERRORS)?; - - tx.commit().await.map_err(|_| INVALID_TOKEN_ERRORS)?; - } - Ok(()) } @@ -344,24 +306,6 @@ impl<'r, const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> FromRequest<'r> return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS)); } - if USE_CERT { - let cert_header = request_try!( - request - .headers() - .get("X-Nintendo-Device-Cert") - .next() - .ok_or(INVALID_TOKEN_ERRORS) - ); - - let Some(cert) = Certificate::new(&cert_header) else { - return Outcome::Error((Status::BadGateway, INVALID_TOKEN_ERRORS)); - }; - - if let Err(_) = handle_certificate(pool, &cert, user.pid).await { - return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS)); - } - } - // let user = User{ // nex_password: format!("{:a>16}", user.nex_password), // ..user @@ -371,6 +315,62 @@ impl<'r, const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> FromRequest<'r> } } +pub struct DeviceCert(pub Certificate); + +#[async_trait] +impl<'r> FromRequest<'r> for DeviceCert { + type Error = Errors<'static>; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let pool: &sqlx::PgPool = request.rocket().state().unwrap(); + + let cert_header = match request.headers().get("X-Nintendo-Device-Cert").next() { + Some(h) => h, + None => return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS)), + }; + + let cert = match Certificate::new(cert_header) { + Some(c) => c, + None => return Outcome::Error((Status::BadGateway, INVALID_TOKEN_ERRORS)), + }; + + let hash = cert.hash(); + + let existing = sqlx::query_as::<_, CertificateRecord>( + "SELECT hash, banned FROM certificates WHERE hash = $1" + ) + .bind(&hash[..]) + .fetch_optional(pool) + .await; + // .map_err(|_| INVALID_TOKEN_ERRORS); + + let existing = match existing { + Ok(v) => v, + Err(e) => { + println!("certificate query failed: {:?}", e); + return Outcome::Error((Status::InternalServerError, INVALID_TOKEN_ERRORS)); + } + }; + + if let Some(row) = existing { + if row.banned { + return Outcome::Error((Status::Forbidden, INVALID_TOKEN_ERRORS)); + } + } else { + sqlx::query( + "INSERT INTO certificates (hash, banned) + VALUES ($1, false)" + ) + .bind(&hash[..]) + .execute(pool) + .await + .map_err(|_| INVALID_TOKEN_ERRORS); + } + + Outcome::Success(DeviceCert(cert)) + } +} + #[binread] #[br(big)] #[derive(Debug)] diff --git a/src/nnid/oauth/generate_token.rs b/src/nnid/oauth/generate_token.rs index f5de9af..26c8775 100644 --- a/src/nnid/oauth/generate_token.rs +++ b/src/nnid/oauth/generate_token.rs @@ -2,7 +2,7 @@ use rocket::{post, FromForm, State}; use rocket::form::Form; use serde::{Serialize}; -use crate::account::account::User; +use crate::account::account::{Auth, DeviceCert, User, link_certificate_to_pid}; use crate::error::{Error, Errors}; use crate::nnid::agreements::{CFIP, EVIL_AGREEMENT_THING}; use crate::nnid::oauth::generate_token::token_type::{AUTH_REFRESH_TOKEN, AUTH_TOKEN}; @@ -101,7 +101,7 @@ pub struct TokenRequestReturnData{ } #[post("/v1/api/oauth20/access_token/generate", data="")] -pub async fn generate_token(pool: &State, data: Form>, ip: CFIP) -> Result, Option>>{ +pub async fn generate_token(pool: &State, data: Form>, ip: CFIP, cert: DeviceCert) -> Result, Option>>{ let pool = pool.inner(); let user = User::get_by_username(data.user_id, pool).await @@ -123,6 +123,8 @@ pub async fn generate_token(pool: &State, data: Form> return Err(Some(ACCOUNT_BANNED_ERRORS)); } + link_certificate_to_pid(&pool, &cert.0, user.pid).await?; + let access_token = TokenReturnData::new(user.pid, pool).await; Ok(Xml(TokenRequestReturnData{