Redo cert request guard

This commit is contained in:
red binder 2026-04-27 18:24:37 +02:00
commit c0fdc1445d
2 changed files with 68 additions and 66 deletions

View file

@ -63,7 +63,7 @@ pub struct User {
#[derive(sqlx::FromRow)] #[derive(sqlx::FromRow)]
pub struct CertificateRecord { pub struct CertificateRecord {
pub _hash: Vec<u8>, pub hash: Vec<u8>,
pub banned: bool, pub banned: bool,
} }
@ -252,62 +252,24 @@ impl<const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> Into<User>
} }
} }
pub async fn handle_certificate( pub async fn link_certificate_to_pid(
pool: &sqlx::PgPool, pool: &sqlx::PgPool,
cert: &Certificate, cert: &Certificate,
pid: i32, pid: i32,
) -> Result<(), Errors<'static>> { ) -> Result<(), Errors<'static>> {
let hash = cert.hash(); let hash = cert.hash();
let existing = sqlx::query_as::<_, CertificateRecord>( sqlx::query(
"SELECT hash, banned FROM certificates WHERE hash = $1" "INSERT INTO certificate_pids (cert_hash, pid)
VALUES ($1, $2)
ON CONFLICT DO NOTHING"
) )
.bind(&hash[..]) .bind(&hash[..])
.fetch_optional(pool) .bind(pid)
.execute(pool)
.await .await
.map_err(|_| INVALID_TOKEN_ERRORS)?; .map_err(|_| INVALID_TOKEN_ERRORS)?;
if let Some(cert_row) = existing {
if cert_row.banned {
return Err(INVALID_TOKEN_ERRORS);
}
sqlx::query(
"INSERT INTO certificate_pids (cert_hash, pid)
VALUES ($1, $2)
ON CONFLICT DO NOTHING"
)
.bind(&hash[..])
.bind(pid)
.execute(pool)
.await
.map_err(|_| INVALID_TOKEN_ERRORS)?;
} else {
let mut tx = pool.begin().await.map_err(|_| INVALID_TOKEN_ERRORS)?;
sqlx::query(
"INSERT INTO certificates (hash, banned)
VALUES ($1, false)"
)
.bind(&hash[..])
.execute(&mut *tx)
.await
.map_err(|_| INVALID_TOKEN_ERRORS)?;
sqlx::query(
"INSERT INTO certificate_pids (cert_hash, pid)
VALUES ($1, $2)"
)
.bind(&hash[..])
.bind(pid)
.execute(&mut *tx)
.await
.map_err(|_| INVALID_TOKEN_ERRORS)?;
tx.commit().await.map_err(|_| INVALID_TOKEN_ERRORS)?;
}
Ok(()) Ok(())
} }
@ -344,24 +306,6 @@ impl<'r, const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> FromRequest<'r>
return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS)); return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS));
} }
if USE_CERT {
let cert_header = request_try!(
request
.headers()
.get("X-Nintendo-Device-Cert")
.next()
.ok_or(INVALID_TOKEN_ERRORS)
);
let Some(cert) = Certificate::new(&cert_header) else {
return Outcome::Error((Status::BadGateway, INVALID_TOKEN_ERRORS));
};
if let Err(_) = handle_certificate(pool, &cert, user.pid).await {
return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS));
}
}
// let user = User{ // let user = User{
// nex_password: format!("{:a>16}", user.nex_password), // nex_password: format!("{:a>16}", user.nex_password),
// ..user // ..user
@ -371,6 +315,62 @@ impl<'r, const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> FromRequest<'r>
} }
} }
pub struct DeviceCert(pub Certificate);
#[async_trait]
impl<'r> FromRequest<'r> for DeviceCert {
type Error = Errors<'static>;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let pool: &sqlx::PgPool = request.rocket().state().unwrap();
let cert_header = match request.headers().get("X-Nintendo-Device-Cert").next() {
Some(h) => h,
None => return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS)),
};
let cert = match Certificate::new(cert_header) {
Some(c) => c,
None => return Outcome::Error((Status::BadGateway, INVALID_TOKEN_ERRORS)),
};
let hash = cert.hash();
let existing = sqlx::query_as::<_, CertificateRecord>(
"SELECT hash, banned FROM certificates WHERE hash = $1"
)
.bind(&hash[..])
.fetch_optional(pool)
.await;
// .map_err(|_| INVALID_TOKEN_ERRORS);
let existing = match existing {
Ok(v) => v,
Err(e) => {
println!("certificate query failed: {:?}", e);
return Outcome::Error((Status::InternalServerError, INVALID_TOKEN_ERRORS));
}
};
if let Some(row) = existing {
if row.banned {
return Outcome::Error((Status::Forbidden, INVALID_TOKEN_ERRORS));
}
} else {
sqlx::query(
"INSERT INTO certificates (hash, banned)
VALUES ($1, false)"
)
.bind(&hash[..])
.execute(pool)
.await
.map_err(|_| INVALID_TOKEN_ERRORS);
}
Outcome::Success(DeviceCert(cert))
}
}
#[binread] #[binread]
#[br(big)] #[br(big)]
#[derive(Debug)] #[derive(Debug)]

View file

@ -2,7 +2,7 @@
use rocket::{post, FromForm, State}; use rocket::{post, FromForm, State};
use rocket::form::Form; use rocket::form::Form;
use serde::{Serialize}; use serde::{Serialize};
use crate::account::account::User; use crate::account::account::{Auth, DeviceCert, User, link_certificate_to_pid};
use crate::error::{Error, Errors}; use crate::error::{Error, Errors};
use crate::nnid::agreements::{CFIP, EVIL_AGREEMENT_THING}; use crate::nnid::agreements::{CFIP, EVIL_AGREEMENT_THING};
use crate::nnid::oauth::generate_token::token_type::{AUTH_REFRESH_TOKEN, AUTH_TOKEN}; use crate::nnid::oauth::generate_token::token_type::{AUTH_REFRESH_TOKEN, AUTH_TOKEN};
@ -101,7 +101,7 @@ pub struct TokenRequestReturnData{
} }
#[post("/v1/api/oauth20/access_token/generate", data="<data>")] #[post("/v1/api/oauth20/access_token/generate", data="<data>")]
pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>, ip: CFIP) -> Result<Xml<TokenRequestReturnData>, Option<Errors<'static>>>{ pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>, ip: CFIP, cert: DeviceCert) -> Result<Xml<TokenRequestReturnData>, Option<Errors<'static>>>{
let pool = pool.inner(); let pool = pool.inner();
let user = User::get_by_username(data.user_id, pool).await let user = User::get_by_username(data.user_id, pool).await
@ -123,6 +123,8 @@ pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>
return Err(Some(ACCOUNT_BANNED_ERRORS)); return Err(Some(ACCOUNT_BANNED_ERRORS));
} }
link_certificate_to_pid(&pool, &cert.0, user.pid).await?;
let access_token = TokenReturnData::new(user.pid, pool).await; let access_token = TokenReturnData::new(user.pid, pool).await;
Ok(Xml(TokenRequestReturnData{ Ok(Xml(TokenRequestReturnData{