Redo cert request guard
This commit is contained in:
parent
9456013dc7
commit
c0fdc1445d
2 changed files with 68 additions and 66 deletions
|
|
@ -63,7 +63,7 @@ pub struct User {
|
|||
|
||||
#[derive(sqlx::FromRow)]
|
||||
pub struct CertificateRecord {
|
||||
pub _hash: Vec<u8>,
|
||||
pub hash: Vec<u8>,
|
||||
pub banned: bool,
|
||||
}
|
||||
|
||||
|
|
@ -252,26 +252,13 @@ impl<const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> Into<User>
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn handle_certificate(
|
||||
pub async fn link_certificate_to_pid(
|
||||
pool: &sqlx::PgPool,
|
||||
cert: &Certificate,
|
||||
pid: i32,
|
||||
) -> Result<(), Errors<'static>> {
|
||||
let hash = cert.hash();
|
||||
|
||||
let existing = sqlx::query_as::<_, CertificateRecord>(
|
||||
"SELECT hash, banned FROM certificates WHERE hash = $1"
|
||||
)
|
||||
.bind(&hash[..])
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
||||
|
||||
if let Some(cert_row) = existing {
|
||||
if cert_row.banned {
|
||||
return Err(INVALID_TOKEN_ERRORS);
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO certificate_pids (cert_hash, pid)
|
||||
VALUES ($1, $2)
|
||||
|
|
@ -283,31 +270,6 @@ pub async fn handle_certificate(
|
|||
.await
|
||||
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
||||
|
||||
} else {
|
||||
let mut tx = pool.begin().await.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO certificates (hash, banned)
|
||||
VALUES ($1, false)"
|
||||
)
|
||||
.bind(&hash[..])
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO certificate_pids (cert_hash, pid)
|
||||
VALUES ($1, $2)"
|
||||
)
|
||||
.bind(&hash[..])
|
||||
.bind(pid)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
||||
|
||||
tx.commit().await.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -344,24 +306,6 @@ impl<'r, const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> FromRequest<'r>
|
|||
return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS));
|
||||
}
|
||||
|
||||
if USE_CERT {
|
||||
let cert_header = request_try!(
|
||||
request
|
||||
.headers()
|
||||
.get("X-Nintendo-Device-Cert")
|
||||
.next()
|
||||
.ok_or(INVALID_TOKEN_ERRORS)
|
||||
);
|
||||
|
||||
let Some(cert) = Certificate::new(&cert_header) else {
|
||||
return Outcome::Error((Status::BadGateway, INVALID_TOKEN_ERRORS));
|
||||
};
|
||||
|
||||
if let Err(_) = handle_certificate(pool, &cert, user.pid).await {
|
||||
return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS));
|
||||
}
|
||||
}
|
||||
|
||||
// let user = User{
|
||||
// nex_password: format!("{:a>16}", user.nex_password),
|
||||
// ..user
|
||||
|
|
@ -371,6 +315,62 @@ impl<'r, const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> FromRequest<'r>
|
|||
}
|
||||
}
|
||||
|
||||
pub struct DeviceCert(pub Certificate);
|
||||
|
||||
#[async_trait]
|
||||
impl<'r> FromRequest<'r> for DeviceCert {
|
||||
type Error = Errors<'static>;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
let pool: &sqlx::PgPool = request.rocket().state().unwrap();
|
||||
|
||||
let cert_header = match request.headers().get("X-Nintendo-Device-Cert").next() {
|
||||
Some(h) => h,
|
||||
None => return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS)),
|
||||
};
|
||||
|
||||
let cert = match Certificate::new(cert_header) {
|
||||
Some(c) => c,
|
||||
None => return Outcome::Error((Status::BadGateway, INVALID_TOKEN_ERRORS)),
|
||||
};
|
||||
|
||||
let hash = cert.hash();
|
||||
|
||||
let existing = sqlx::query_as::<_, CertificateRecord>(
|
||||
"SELECT hash, banned FROM certificates WHERE hash = $1"
|
||||
)
|
||||
.bind(&hash[..])
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
// .map_err(|_| INVALID_TOKEN_ERRORS);
|
||||
|
||||
let existing = match existing {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
println!("certificate query failed: {:?}", e);
|
||||
return Outcome::Error((Status::InternalServerError, INVALID_TOKEN_ERRORS));
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(row) = existing {
|
||||
if row.banned {
|
||||
return Outcome::Error((Status::Forbidden, INVALID_TOKEN_ERRORS));
|
||||
}
|
||||
} else {
|
||||
sqlx::query(
|
||||
"INSERT INTO certificates (hash, banned)
|
||||
VALUES ($1, false)"
|
||||
)
|
||||
.bind(&hash[..])
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|_| INVALID_TOKEN_ERRORS);
|
||||
}
|
||||
|
||||
Outcome::Success(DeviceCert(cert))
|
||||
}
|
||||
}
|
||||
|
||||
#[binread]
|
||||
#[br(big)]
|
||||
#[derive(Debug)]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
use rocket::{post, FromForm, State};
|
||||
use rocket::form::Form;
|
||||
use serde::{Serialize};
|
||||
use crate::account::account::User;
|
||||
use crate::account::account::{Auth, DeviceCert, User, link_certificate_to_pid};
|
||||
use crate::error::{Error, Errors};
|
||||
use crate::nnid::agreements::{CFIP, EVIL_AGREEMENT_THING};
|
||||
use crate::nnid::oauth::generate_token::token_type::{AUTH_REFRESH_TOKEN, AUTH_TOKEN};
|
||||
|
|
@ -101,7 +101,7 @@ pub struct TokenRequestReturnData{
|
|||
}
|
||||
|
||||
#[post("/v1/api/oauth20/access_token/generate", data="<data>")]
|
||||
pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>, ip: CFIP) -> Result<Xml<TokenRequestReturnData>, Option<Errors<'static>>>{
|
||||
pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>, ip: CFIP, cert: DeviceCert) -> Result<Xml<TokenRequestReturnData>, Option<Errors<'static>>>{
|
||||
let pool = pool.inner();
|
||||
|
||||
let user = User::get_by_username(data.user_id, pool).await
|
||||
|
|
@ -123,6 +123,8 @@ pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>
|
|||
return Err(Some(ACCOUNT_BANNED_ERRORS));
|
||||
}
|
||||
|
||||
link_certificate_to_pid(&pool, &cert.0, user.pid).await?;
|
||||
|
||||
let access_token = TokenReturnData::new(user.pid, pool).await;
|
||||
|
||||
Ok(Xml(TokenRequestReturnData{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue