Redo cert request guard
This commit is contained in:
parent
9456013dc7
commit
c0fdc1445d
2 changed files with 68 additions and 66 deletions
|
|
@ -63,7 +63,7 @@ pub struct User {
|
||||||
|
|
||||||
#[derive(sqlx::FromRow)]
|
#[derive(sqlx::FromRow)]
|
||||||
pub struct CertificateRecord {
|
pub struct CertificateRecord {
|
||||||
pub _hash: Vec<u8>,
|
pub hash: Vec<u8>,
|
||||||
pub banned: bool,
|
pub banned: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -252,62 +252,24 @@ impl<const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> Into<User>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_certificate(
|
pub async fn link_certificate_to_pid(
|
||||||
pool: &sqlx::PgPool,
|
pool: &sqlx::PgPool,
|
||||||
cert: &Certificate,
|
cert: &Certificate,
|
||||||
pid: i32,
|
pid: i32,
|
||||||
) -> Result<(), Errors<'static>> {
|
) -> Result<(), Errors<'static>> {
|
||||||
let hash = cert.hash();
|
let hash = cert.hash();
|
||||||
|
|
||||||
let existing = sqlx::query_as::<_, CertificateRecord>(
|
sqlx::query(
|
||||||
"SELECT hash, banned FROM certificates WHERE hash = $1"
|
"INSERT INTO certificate_pids (cert_hash, pid)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
ON CONFLICT DO NOTHING"
|
||||||
)
|
)
|
||||||
.bind(&hash[..])
|
.bind(&hash[..])
|
||||||
.fetch_optional(pool)
|
.bind(pid)
|
||||||
|
.execute(pool)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
||||||
|
|
||||||
if let Some(cert_row) = existing {
|
|
||||||
if cert_row.banned {
|
|
||||||
return Err(INVALID_TOKEN_ERRORS);
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
"INSERT INTO certificate_pids (cert_hash, pid)
|
|
||||||
VALUES ($1, $2)
|
|
||||||
ON CONFLICT DO NOTHING"
|
|
||||||
)
|
|
||||||
.bind(&hash[..])
|
|
||||||
.bind(pid)
|
|
||||||
.execute(pool)
|
|
||||||
.await
|
|
||||||
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
let mut tx = pool.begin().await.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
"INSERT INTO certificates (hash, banned)
|
|
||||||
VALUES ($1, false)"
|
|
||||||
)
|
|
||||||
.bind(&hash[..])
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await
|
|
||||||
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
"INSERT INTO certificate_pids (cert_hash, pid)
|
|
||||||
VALUES ($1, $2)"
|
|
||||||
)
|
|
||||||
.bind(&hash[..])
|
|
||||||
.bind(pid)
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await
|
|
||||||
.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
|
||||||
|
|
||||||
tx.commit().await.map_err(|_| INVALID_TOKEN_ERRORS)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -344,24 +306,6 @@ impl<'r, const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> FromRequest<'r>
|
||||||
return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS));
|
return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS));
|
||||||
}
|
}
|
||||||
|
|
||||||
if USE_CERT {
|
|
||||||
let cert_header = request_try!(
|
|
||||||
request
|
|
||||||
.headers()
|
|
||||||
.get("X-Nintendo-Device-Cert")
|
|
||||||
.next()
|
|
||||||
.ok_or(INVALID_TOKEN_ERRORS)
|
|
||||||
);
|
|
||||||
|
|
||||||
let Some(cert) = Certificate::new(&cert_header) else {
|
|
||||||
return Outcome::Error((Status::BadGateway, INVALID_TOKEN_ERRORS));
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(_) = handle_certificate(pool, &cert, user.pid).await {
|
|
||||||
return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// let user = User{
|
// let user = User{
|
||||||
// nex_password: format!("{:a>16}", user.nex_password),
|
// nex_password: format!("{:a>16}", user.nex_password),
|
||||||
// ..user
|
// ..user
|
||||||
|
|
@ -371,6 +315,62 @@ impl<'r, const FORCE_BEARER_AUTH: bool, const USE_CERT: bool> FromRequest<'r>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct DeviceCert(pub Certificate);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for DeviceCert {
|
||||||
|
type Error = Errors<'static>;
|
||||||
|
|
||||||
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
|
let pool: &sqlx::PgPool = request.rocket().state().unwrap();
|
||||||
|
|
||||||
|
let cert_header = match request.headers().get("X-Nintendo-Device-Cert").next() {
|
||||||
|
Some(h) => h,
|
||||||
|
None => return Outcome::Error((Status::BadRequest, INVALID_TOKEN_ERRORS)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let cert = match Certificate::new(cert_header) {
|
||||||
|
Some(c) => c,
|
||||||
|
None => return Outcome::Error((Status::BadGateway, INVALID_TOKEN_ERRORS)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let hash = cert.hash();
|
||||||
|
|
||||||
|
let existing = sqlx::query_as::<_, CertificateRecord>(
|
||||||
|
"SELECT hash, banned FROM certificates WHERE hash = $1"
|
||||||
|
)
|
||||||
|
.bind(&hash[..])
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await;
|
||||||
|
// .map_err(|_| INVALID_TOKEN_ERRORS);
|
||||||
|
|
||||||
|
let existing = match existing {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
println!("certificate query failed: {:?}", e);
|
||||||
|
return Outcome::Error((Status::InternalServerError, INVALID_TOKEN_ERRORS));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(row) = existing {
|
||||||
|
if row.banned {
|
||||||
|
return Outcome::Error((Status::Forbidden, INVALID_TOKEN_ERRORS));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO certificates (hash, banned)
|
||||||
|
VALUES ($1, false)"
|
||||||
|
)
|
||||||
|
.bind(&hash[..])
|
||||||
|
.execute(pool)
|
||||||
|
.await
|
||||||
|
.map_err(|_| INVALID_TOKEN_ERRORS);
|
||||||
|
}
|
||||||
|
|
||||||
|
Outcome::Success(DeviceCert(cert))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[binread]
|
#[binread]
|
||||||
#[br(big)]
|
#[br(big)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
use rocket::{post, FromForm, State};
|
use rocket::{post, FromForm, State};
|
||||||
use rocket::form::Form;
|
use rocket::form::Form;
|
||||||
use serde::{Serialize};
|
use serde::{Serialize};
|
||||||
use crate::account::account::User;
|
use crate::account::account::{Auth, DeviceCert, User, link_certificate_to_pid};
|
||||||
use crate::error::{Error, Errors};
|
use crate::error::{Error, Errors};
|
||||||
use crate::nnid::agreements::{CFIP, EVIL_AGREEMENT_THING};
|
use crate::nnid::agreements::{CFIP, EVIL_AGREEMENT_THING};
|
||||||
use crate::nnid::oauth::generate_token::token_type::{AUTH_REFRESH_TOKEN, AUTH_TOKEN};
|
use crate::nnid::oauth::generate_token::token_type::{AUTH_REFRESH_TOKEN, AUTH_TOKEN};
|
||||||
|
|
@ -101,7 +101,7 @@ pub struct TokenRequestReturnData{
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/api/oauth20/access_token/generate", data="<data>")]
|
#[post("/v1/api/oauth20/access_token/generate", data="<data>")]
|
||||||
pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>, ip: CFIP) -> Result<Xml<TokenRequestReturnData>, Option<Errors<'static>>>{
|
pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>, ip: CFIP, cert: DeviceCert) -> Result<Xml<TokenRequestReturnData>, Option<Errors<'static>>>{
|
||||||
let pool = pool.inner();
|
let pool = pool.inner();
|
||||||
|
|
||||||
let user = User::get_by_username(data.user_id, pool).await
|
let user = User::get_by_username(data.user_id, pool).await
|
||||||
|
|
@ -123,6 +123,8 @@ pub async fn generate_token(pool: &State<Pool>, data: Form<TokenRequestData<'_>>
|
||||||
return Err(Some(ACCOUNT_BANNED_ERRORS));
|
return Err(Some(ACCOUNT_BANNED_ERRORS));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
link_certificate_to_pid(&pool, &cert.0, user.pid).await?;
|
||||||
|
|
||||||
let access_token = TokenReturnData::new(user.pid, pool).await;
|
let access_token = TokenReturnData::new(user.pid, pool).await;
|
||||||
|
|
||||||
Ok(Xml(TokenRequestReturnData{
|
Ok(Xml(TokenRequestReturnData{
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue