From c72d48cf094bfad50ba81a9af156da853993758e Mon Sep 17 00:00:00 2001 From: efer-ms <112443284+efer-ms@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:33:18 -0800 Subject: [PATCH] Add EC-DSA, custom certificate common names for DTLS --- src/crypto/dtls.rs | 36 ++++- src/crypto/mod.rs | 4 +- src/crypto/ossl/cert.rs | 27 +++- src/crypto/wincrypto/cert.rs | 16 +- src/crypto/wincrypto/mod.rs | 3 +- src/dtls.rs | 3 +- src/lib.rs | 93 +++++++----- wincrypto/Cargo.lock | 4 +- wincrypto/Cargo.toml | 2 +- wincrypto/src/cert.rs | 285 +++++++++++++++++++++++++++++------ wincrypto/src/dtls.rs | 48 +++--- wincrypto/src/sha1.rs | 20 +-- wincrypto/src/srtp.rs | 47 +++--- 13 files changed, 416 insertions(+), 172 deletions(-) diff --git a/src/crypto/dtls.rs b/src/crypto/dtls.rs index a1e8920a..8b05003b 100644 --- a/src/crypto/dtls.rs +++ b/src/crypto/dtls.rs @@ -15,7 +15,7 @@ use super::{CryptoError, Fingerprint, KeyingMaterial, SrtpProfile}; // // Pion also sets this to "WebRTC", maybe for compatibility reasons. // https://github.com/pion/webrtc/blob/eed2bb2d3b9f204f9de1cd7e1046ca5d652778d2/constants.go#L31 -pub const DTLS_CERT_IDENTITY: &str = "WebRTC"; +const DTLS_CERT_IDENTITY: &str = "WebRTC"; /// Events arising from a [`Dtls`] instance. pub enum DtlsEvent { @@ -34,6 +34,34 @@ pub enum DtlsEvent { Data(Vec), } +/// Defines the type of key pair to generate for the DTLS certificate. +#[derive(Clone, Debug, Default)] +pub enum DtlsPKeyType { + /// Generate an RSA key pair + Rsa2048, + /// Generate an EC-DSA key pair using the NIST P-256 curve + #[default] + EcDsaP256, +} + +/// Controls certificate generation options. +#[derive(Clone, Debug)] +pub struct DtlsCertOptions { + /// The common name for the certificate. + pub common_name: String, + /// The type of key to generate. + pub pkey_type: DtlsPKeyType, +} + +impl Default for DtlsCertOptions { + fn default() -> Self { + Self { + common_name: DTLS_CERT_IDENTITY.into(), + pkey_type: Default::default(), + } + } +} + /// Certificate used for DTLS. #[derive(Clone)] pub struct DtlsCert(DtlsCertInner); @@ -59,12 +87,12 @@ impl DtlsCert { /// /// * **openssl** (defaults to on) for crypto backed by OpenSSL. /// * **wincrypto** for crypto backed by windows crypto. - pub fn new(p: CryptoProvider) -> Self { + pub fn new(p: CryptoProvider, opts: DtlsCertOptions) -> Self { let inner = match p { CryptoProvider::OpenSsl => { #[cfg(feature = "openssl")] { - let cert = super::ossl::OsslDtlsCert::new(); + let cert = super::ossl::OsslDtlsCert::new(opts); DtlsCertInner::OpenSsl(cert) } #[cfg(not(feature = "openssl"))] @@ -75,7 +103,7 @@ impl DtlsCert { CryptoProvider::WinCrypto => { #[cfg(all(feature = "wincrypto", target_os = "windows"))] { - let cert = super::wincrypto::WinCryptoDtlsCert::new(); + let cert = super::wincrypto::WinCryptoDtlsCert::new(opts); DtlsCertInner::WinCrypto(cert) } #[cfg(not(all(feature = "wincrypto", target_os = "windows")))] diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index cbc4ea8e..c3ba1bce 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -76,8 +76,8 @@ mod ossl; mod wincrypto; mod dtls; -pub use dtls::DtlsCert; -pub(crate) use dtls::{DtlsEvent, DtlsImpl}; +pub(crate) use dtls::DtlsImpl; +pub use dtls::{DtlsCert, DtlsCertOptions, DtlsEvent, DtlsPKeyType}; mod finger; pub use finger::Fingerprint; diff --git a/src/crypto/ossl/cert.rs b/src/crypto/ossl/cert.rs index 6c5bafc0..7a5a536c 100644 --- a/src/crypto/ossl/cert.rs +++ b/src/crypto/ossl/cert.rs @@ -2,13 +2,14 @@ use std::time::SystemTime; use openssl::asn1::{Asn1Integer, Asn1Time, Asn1Type}; use openssl::bn::BigNum; +use openssl::ec::{EcGroup, EcKey}; use openssl::hash::MessageDigest; use openssl::nid::Nid; use openssl::pkey::{PKey, Private}; use openssl::rsa::Rsa; use openssl::x509::{X509Name, X509}; -use crate::crypto::dtls::DTLS_CERT_IDENTITY; +use crate::crypto::dtls::{DtlsCertOptions, DtlsPKeyType}; use crate::crypto::Fingerprint; use super::CryptoError; @@ -25,16 +26,26 @@ pub struct OsslDtlsCert { impl OsslDtlsCert { /// Creates a new (self signed) DTLS certificate. - pub fn new() -> Self { - Self::self_signed().expect("create dtls cert") + pub fn new(options: DtlsCertOptions) -> Self { + Self::self_signed(options).expect("create dtls cert") } // The libWebRTC code we try to match is at: // https://webrtc.googlesource.com/src/+/1568f1b1330f94494197696fe235094e6293b258/rtc_base/openssl_certificate.cc#58 - fn self_signed() -> Result { + fn self_signed(options: DtlsCertOptions) -> Result { let f4 = BigNum::from_u32(RSA_F4).unwrap(); - let key = Rsa::generate_with_e(2048, &f4)?; - let pkey = PKey::from_rsa(key)?; + let pkey = match options.pkey_type { + DtlsPKeyType::Rsa2048 => { + let key = Rsa::generate_with_e(2048, &f4)?; + PKey::from_rsa(key)? + } + DtlsPKeyType::EcDsaP256 => { + let nid = Nid::X9_62_PRIME256V1; // NIST P-256 curve + let group = EcGroup::from_curve_name(nid)?; + let key = EcKey::generate(&group)?; + PKey::from_ec_key(key)? + } + }; let mut x509b = X509::builder()?; x509b.set_version(2)?; // X509.V3 (zero indexed) @@ -64,7 +75,7 @@ impl OsslDtlsCert { let mut nameb = X509Name::builder()?; nameb.append_entry_by_nid_with_type( Nid::COMMONNAME, - DTLS_CERT_IDENTITY, + options.common_name.as_str(), Asn1Type::UTF8STRING, )?; @@ -73,7 +84,7 @@ impl OsslDtlsCert { x509b.set_subject_name(&name)?; x509b.set_issuer_name(&name)?; - x509b.sign(&pkey, MessageDigest::sha1())?; + x509b.sign(&pkey, MessageDigest::sha256())?; let x509 = x509b.build(); Ok(OsslDtlsCert { pkey, x509 }) diff --git a/src/crypto/wincrypto/cert.rs b/src/crypto/wincrypto/cert.rs index 10b8e0f2..3d596917 100644 --- a/src/crypto/wincrypto/cert.rs +++ b/src/crypto/wincrypto/cert.rs @@ -1,6 +1,6 @@ use super::CryptoError; use super::WinCryptoDtls; -use crate::crypto::dtls::DTLS_CERT_IDENTITY; +use crate::crypto::dtls::{DtlsCertOptions, DtlsPKeyType}; use crate::crypto::Fingerprint; use std::sync::Arc; use str0m_wincrypto::WinCryptoError; @@ -11,10 +11,18 @@ pub struct WinCryptoDtlsCert { } impl WinCryptoDtlsCert { - pub fn new() -> Self { + pub fn new(options: DtlsCertOptions) -> Self { + let use_ec_dsa_keys = match options.pkey_type { + DtlsPKeyType::Rsa2048 => false, + DtlsPKeyType::EcDsaP256 => true, + }; + let certificate = Arc::new( - str0m_wincrypto::Certificate::new_self_signed(&format!("CN={}", DTLS_CERT_IDENTITY)) - .expect("Failed to create self-signed certificate"), + str0m_wincrypto::Certificate::new_self_signed( + use_ec_dsa_keys, + &format!("CN={}", options.common_name), + ) + .expect("Failed to create self-signed certificate"), ); Self { certificate } } diff --git a/src/crypto/wincrypto/mod.rs b/src/crypto/wincrypto/mod.rs index 40dc2cff..3c54a1b6 100644 --- a/src/crypto/wincrypto/mod.rs +++ b/src/crypto/wincrypto/mod.rs @@ -11,8 +11,9 @@ pub use dtls::WinCryptoDtls; mod srtp; pub use srtp::WinCryptoSrtpCryptoImpl; +#[cfg(not(feature = "sha1"))] mod sha1; -#[allow(unused_imports)] // If 'sha1' feature is enabled this is not used. +#[cfg(not(feature = "sha1"))] pub use sha1::sha1_hmac; pub use str0m_wincrypto::WinCryptoError; diff --git a/src/dtls.rs b/src/dtls.rs index 1c3874d2..0954abb4 100644 --- a/src/dtls.rs +++ b/src/dtls.rs @@ -5,8 +5,7 @@ use thiserror::Error; use crate::crypto::{CryptoError, DtlsImpl, Fingerprint}; -pub use crate::crypto::DtlsCert; -pub(crate) use crate::crypto::DtlsEvent; +pub use crate::crypto::{DtlsCert, DtlsCertOptions, DtlsEvent}; use crate::net::DatagramSend; /// Errors that can arise in DTLS. diff --git a/src/lib.rs b/src/lib.rs index 483baf61..3c9f33c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -626,8 +626,7 @@ use crypto::CryptoProvider; use crypto::Fingerprint; mod dtls; -use dtls::DtlsCert; -use dtls::{Dtls, DtlsEvent}; +use dtls::{Dtls, DtlsCert, DtlsCertOptions, DtlsEvent}; #[path = "ice/mod.rs"] mod ice_; @@ -637,7 +636,7 @@ pub use ice_::{Candidate, CandidateKind, IceConnectionState, IceCreds}; /// Additional configuration. pub mod config { - pub use super::crypto::{CryptoProvider, DtlsCert, Fingerprint}; + pub use super::crypto::{CryptoProvider, DtlsCert, DtlsCertOptions, DtlsPKeyType, Fingerprint}; } /// Low level ICE access. @@ -1141,10 +1140,9 @@ impl Rtc { ice.set_ice_lite(config.ice_lite); } - let dtls_cert = if let Some(c) = config.dtls_cert { - c - } else { - DtlsCert::new(config.crypto_provider) + let dtls_cert = match config.dtls_cert_config { + DtlsCertConfig::Options(options) => DtlsCert::new(config.crypto_provider, options), + DtlsCertConfig::PregeneratedCert(cert) => cert, }; let crypto_provider = dtls_cert.crypto_provider(); @@ -1854,6 +1852,25 @@ impl Rtc { } } +/// Configuation for the DTLS certificate used for the Rtc instance. This can be set to +/// allow a pregenerated certificate, or options to pass when generating a certificate +/// on-the-fly. +/// +/// The default value is DtlsCertConfig::Options(DtlsCertOptions::default()) +#[derive(Clone, Debug)] +pub enum DtlsCertConfig { + /// The options to use for the DTLS certificate generated for this Rtc instance. + Options(DtlsCertOptions), + /// A pregenerated certificate to use for this Rtc instance. + PregeneratedCert(DtlsCert), +} + +impl Default for DtlsCertConfig { + fn default() -> Self { + DtlsCertConfig::Options(DtlsCertOptions::default()) + } +} + /// Customized config for creating an [`Rtc`] instance. /// /// ``` @@ -1871,7 +1888,7 @@ impl Rtc { pub struct RtcConfig { local_ice_credentials: Option, crypto_provider: CryptoProvider, - dtls_cert: Option, + dtls_cert_config: DtlsCertConfig, fingerprint_verification: bool, ice_lite: bool, codec_config: CodecConfig, @@ -1921,7 +1938,7 @@ impl RtcConfig { /// /// This overrides what is set in [`CryptoProvider::install_process_default()`]. pub fn set_crypto_provider(mut self, p: CryptoProvider) -> Self { - if let Some(c) = &self.dtls_cert { + if let DtlsCertConfig::PregeneratedCert(c) = &self.dtls_cert_config { if p != c.crypto_provider() { panic!("set_dtls_cert() locked crypto provider to: {}", p); } @@ -1939,46 +1956,48 @@ impl RtcConfig { self.crypto_provider } - /// Get the configured DTLS certificate, if set. - /// - /// Returns [`None`] if no DTLS certificate is set. In such cases, - /// the certificate will be created on build and you can use the - /// direct API on an [`Rtc`] instance to obtain the local - /// DTLS fingerprint. + /// Returns the configured DTLS certificate configuration. /// + /// Defaults to a configuration similar to libwebrtc: /// ``` - /// # #[cfg(feature = "openssl")] { - /// # use str0m::RtcConfig; - /// let fingerprint = RtcConfig::default() - /// .build() - /// .direct_api() - /// .local_dtls_fingerprint(); - /// # } + /// # use str0m::DtlsCertConfig; + /// # use str0m::config::{DtlsCertOptions, DtlsPKeyType}; + /// + /// DtlsCertConfig::Options(DtlsCertOptions { + /// common_name: "WebRTC".into(), + /// pkey_type: DtlsPKeyType::EcDsaP256, + /// }); /// ``` - pub fn dtls_cert(&self) -> Option<&DtlsCert> { - self.dtls_cert.as_ref() + pub fn dtls_cert_config(&self) -> &DtlsCertConfig { + &self.dtls_cert_config } - /// Set the DTLS certificate for secure communication. + /// Set the DTLS certificate configuration for certificate generation. /// - /// Generating a certificate can be a time-consuming process. - /// Use this API to reuse a previously created [`DtlsCert`] if available. + /// Setting this permits you to assign a Pregenerated certificate, or + /// options for certificate generation, such as signing key type, and + /// subject name. /// - /// Setting this locks the `crypto_provider()` setting to the [`CryptoProvider`], - /// for the DTLS certificate. + /// If a Pregenerated certificate is set, this locks the `crypto_provider()` + /// setting to the [`CryptoProvider`], for the DTLS certificate. /// /// ``` - /// # use str0m::RtcConfig; - /// # use str0m::config::{DtlsCert, CryptoProvider}; + /// # use str0m::{DtlsCertConfig, RtcConfig}; + /// # use str0m::config::{DtlsCertOptions, DtlsPKeyType}; /// - /// let dtls_cert = DtlsCert::new(CryptoProvider::OpenSsl); + /// let dtls_cert_config = DtlsCertConfig::Options(DtlsCertOptions { + /// common_name: "Clark Kent".into(), + /// pkey_type: DtlsPKeyType::EcDsaP256, + /// }); /// /// let rtc_config = RtcConfig::default() - /// .set_dtls_cert(dtls_cert); + /// .set_dtls_cert_config(dtls_cert_config); /// ``` - pub fn set_dtls_cert(mut self, dtls_cert: DtlsCert) -> Self { - self.crypto_provider = dtls_cert.crypto_provider(); - self.dtls_cert = Some(dtls_cert); + pub fn set_dtls_cert_config(mut self, dtls_cert_config: DtlsCertConfig) -> Self { + if let DtlsCertConfig::PregeneratedCert(ref cert) = dtls_cert_config { + self.crypto_provider = cert.crypto_provider(); + } + self.dtls_cert_config = dtls_cert_config; self } @@ -2388,7 +2407,7 @@ impl Default for RtcConfig { Self { local_ice_credentials: None, crypto_provider: CryptoProvider::process_default().unwrap_or(CryptoProvider::OpenSsl), - dtls_cert: None, + dtls_cert_config: Default::default(), fingerprint_verification: true, ice_lite: false, codec_config: CodecConfig::new_with_defaults(), diff --git a/wincrypto/Cargo.lock b/wincrypto/Cargo.lock index 69fca75e..ac1edd55 100644 --- a/wincrypto/Cargo.lock +++ b/wincrypto/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "once_cell" @@ -33,7 +33,7 @@ dependencies = [ ] [[package]] -name = "str0m_wincrypto" +name = "str0m-wincrypto" version = "0.1.0" dependencies = [ "thiserror", diff --git a/wincrypto/Cargo.toml b/wincrypto/Cargo.toml index c5f24250..76aa6bb8 100644 --- a/wincrypto/Cargo.toml +++ b/wincrypto/Cargo.toml @@ -9,4 +9,4 @@ license = "MIT OR Apache-2.0" [dependencies] thiserror = { version = "1.0.38" } tracing = "0.1.37" -windows = { version = "0.58", features=["Win32_Security_Cryptography", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials",]} +windows = { version = "0.58", features=["Win32_Security_Cryptography", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_System_Rpc",]} diff --git a/wincrypto/src/cert.rs b/wincrypto/src/cert.rs index e7eac47f..cfca0809 100644 --- a/wincrypto/src/cert.rs +++ b/wincrypto/src/cert.rs @@ -1,12 +1,20 @@ use super::WinCryptoError; use windows::{ - core::{HSTRING, PSTR}, - Win32::Security::Cryptography::{ - szOID_RSA_SHA256RSA, BCryptCreateHash, BCryptDestroyHash, BCryptFinishHash, BCryptHashData, - CertCreateSelfSignCertificate, CertFreeCertificateContext, CertStrToNameW, - BCRYPT_HASH_HANDLE, BCRYPT_SHA256_ALG_HANDLE, CERT_CONTEXT, CERT_CREATE_SELFSIGN_FLAGS, - CERT_OID_NAME_STR, CRYPT_ALGORITHM_IDENTIFIER, CRYPT_INTEGER_BLOB, - HCRYPTPROV_OR_NCRYPT_KEY_HANDLE, X509_ASN_ENCODING, + core::{Owned, GUID, HSTRING, PSTR, PWSTR}, + Win32::{ + Foundation::GetLastError, + Security::Cryptography::{ + szOID_ECDSA_SHA256, szOID_RSA_SHA256RSA, BCryptCreateHash, BCryptFinishHash, + BCryptHashData, CertCreateSelfSignCertificate, CertFreeCertificateContext, + CertStrToNameW, NCryptCreatePersistedKey, NCryptDeleteKey, NCryptFinalizeKey, + NCryptOpenStorageProvider, BCRYPT_HASH_HANDLE, BCRYPT_SHA256_ALG_HANDLE, CERT_CONTEXT, + CERT_CREATE_SELFSIGN_FLAGS, CERT_KEY_SPEC, CERT_OID_NAME_STR, + CRYPT_ALGORITHM_IDENTIFIER, CRYPT_INTEGER_BLOB, CRYPT_KEY_PROV_INFO, + HCRYPTPROV_OR_NCRYPT_KEY_HANDLE, MS_KEY_STORAGE_PROVIDER, NCRYPT_ECDSA_P256_ALGORITHM, + NCRYPT_FLAGS, NCRYPT_KEY_HANDLE, NCRYPT_PROV_HANDLE, NCRYPT_SILENT_FLAG, + X509_ASN_ENCODING, + }, + System::Rpc::{UuidCreate, UuidToStringW, RPC_S_OK}, }, }; @@ -16,14 +24,17 @@ use windows::{ /// Certificate too early. It is also why access to the certificate pointer /// should remain hidden. #[derive(Debug)] -pub struct Certificate(pub(crate) *const CERT_CONTEXT); +pub struct Certificate { + cert_context: *const CERT_CONTEXT, + key_handle: NCRYPT_KEY_HANDLE, +} // SAFETY: CERT_CONTEXT pointers are safe to send between threads. unsafe impl Send for Certificate {} // SAFETY: CERT_CONTEXT pointers are safe to send between threads. unsafe impl Sync for Certificate {} impl Certificate { - pub fn new_self_signed(subject: &str) -> Result { + pub fn new_self_signed(use_ec_dsa_keys: bool, subject: &str) -> Result { let subject = HSTRING::from(subject); let mut subject_blob_buffer = vec![0u8; 256]; let mut subject_blob = CRYPT_INTEGER_BLOB { @@ -31,12 +42,6 @@ impl Certificate { pbData: subject_blob_buffer.as_mut_ptr(), }; - // Use RSA-SHA256 for the signature, since SHA1 is deprecated. - let signature_algorithm = CRYPT_ALGORITHM_IDENTIFIER { - pszObjId: PSTR::from_raw(szOID_RSA_SHA256RSA.as_ptr() as *mut u8), - Parameters: CRYPT_INTEGER_BLOB::default(), - }; - // SAFETY: The Windows APIs accept references, so normal borrow checker // behaviors work for those uses. The name_blob has a pointer to the buffer // which must exist for the duration of the unsafe block. @@ -51,39 +56,111 @@ impl Certificate { None, )?; + let mut key_handle = NCRYPT_KEY_HANDLE::default(); + // Generate the self-signed cert. - let cert_context = CertCreateSelfSignCertificate( - HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(0), - &subject_blob, - CERT_CREATE_SELFSIGN_FLAGS(0), - None, - Some(&signature_algorithm), - None, - None, - None, - ); + let cert_context = if use_ec_dsa_keys { + let mut guid = GUID::default(); + let result = UuidCreate(&mut guid); + if result != RPC_S_OK { + return Err(WinCryptoError(format!( + "Failed to generate GUID for EC-DSA key: {:?}", + result + ))); + } + // A formated UUID is 20 characters long, plus null termination. + let mut guid_buffer = [0u16; 42]; + let mut guid_pwstr = PWSTR::from_raw(guid_buffer.as_mut_ptr()); + let result = UuidToStringW(&guid, &mut guid_pwstr); + if result != RPC_S_OK { + return Err(WinCryptoError(format!( + "Failed to format GUID for EC-DSA key: {:?}", + result + ))); + } + + // We need to first create a EC-DSA key. We need to use NCrypt APIs + // for this, although we don't really want to persist this key. + let mut h_provider = Owned::new(NCRYPT_PROV_HANDLE::default()); + NCryptOpenStorageProvider(&mut *h_provider, MS_KEY_STORAGE_PROVIDER, 0)?; + + NCryptCreatePersistedKey( + *h_provider, + &mut key_handle, + // Use EC-256 which corresponds to NID_X9_62_prime256v1 + NCRYPT_ECDSA_P256_ALGORITHM, + // Passing None makes this key ephemeral and not persisted. + guid_pwstr, + CERT_KEY_SPEC(0), + NCRYPT_SILENT_FLAG, + )?; + NCryptFinalizeKey(key_handle, NCRYPT_FLAGS(0))?; + + let key_prov_info = CRYPT_KEY_PROV_INFO { + pwszContainerName: guid_pwstr, + pwszProvName: PWSTR(MS_KEY_STORAGE_PROVIDER.as_ptr() as *mut u16), + ..Default::default() + }; + + let signature_algorithm = CRYPT_ALGORITHM_IDENTIFIER { + pszObjId: PSTR::from_raw(szOID_ECDSA_SHA256.as_ptr() as *mut u8), + ..Default::default() + }; + + CertCreateSelfSignCertificate( + HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(key_handle.0), + &subject_blob, + CERT_CREATE_SELFSIGN_FLAGS(0), + Some(&key_prov_info as *const _ as *const _), + Some(&signature_algorithm), + None, + None, + None, + ) + } else { + // Use RSA-SHA256 for the signature, since SHA1 is deprecated. + let signature_algorithm = CRYPT_ALGORITHM_IDENTIFIER { + pszObjId: PSTR::from_raw(szOID_RSA_SHA256RSA.as_ptr() as *mut u8), + ..Default::default() + }; + + CertCreateSelfSignCertificate( + HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(0), + &subject_blob, + CERT_CREATE_SELFSIGN_FLAGS(0), + None, + Some(&signature_algorithm), + None, + None, + None, + ) + }; if cert_context.is_null() { - Err(WinCryptoError( - "Failed to generate self-signed certificate".to_string(), - )) + let win_err = GetLastError(); + Err(WinCryptoError(format!( + "Failed to generate self-signed certificate: {:?}", + win_err + ))) } else { - Ok(Self(cert_context)) + Ok(Self { + cert_context, + key_handle, + }) } } } pub fn sha256_fingerprint(&self) -> Result<[u8; 32], WinCryptoError> { let mut hash = [0u8; 32]; - let mut hash_handle = BCRYPT_HASH_HANDLE::default(); - // SAFETY: The Windows APIs accept references, so normal borrow checker // behaviors work for those uses. unsafe { + let mut hash_handle = Owned::new(BCRYPT_HASH_HANDLE::default()); // Create the hash instance. if let Err(e) = WinCryptoError::from_ntstatus(BCryptCreateHash( BCRYPT_SHA256_ALG_HANDLE, - &mut hash_handle, + &mut *hash_handle, None, None, 0, @@ -92,9 +169,9 @@ impl Certificate { } // Hash the certificate contents. - let cert_info = *self.0; + let cert_info = *self.cert_context; if let Err(e) = WinCryptoError::from_ntstatus(BCryptHashData( - hash_handle, + *hash_handle, std::slice::from_raw_parts( cert_info.pbCertEncoded, cert_info.cbCertEncoded as usize, @@ -105,18 +182,22 @@ impl Certificate { } // Grab the result of the hash. - WinCryptoError::from_ntstatus(BCryptFinishHash(hash_handle, &mut hash, 0))?; - - // Destroy the allocated hash. - WinCryptoError::from_ntstatus(BCryptDestroyHash(hash_handle))?; + WinCryptoError::from_ntstatus(BCryptFinishHash(*hash_handle, &mut hash, 0))?; } Ok(hash) } + + pub fn context(&self) -> *const CERT_CONTEXT { + self.cert_context + } } impl From<*const CERT_CONTEXT> for Certificate { - fn from(value: *const CERT_CONTEXT) -> Self { - Self(value) + fn from(cert_context: *const CERT_CONTEXT) -> Self { + Self { + cert_context, + key_handle: NCRYPT_KEY_HANDLE::default(), + } } } @@ -125,30 +206,134 @@ impl Drop for Certificate { // SAFETY: The Certificate is no longer usable, so it's safe to pass the pointer // to Windows for release. unsafe { - _ = CertFreeCertificateContext(Some(self.0)); + _ = CertFreeCertificateContext(Some(self.cert_context)); + _ = NCryptDeleteKey(self.key_handle, NCRYPT_SILENT_FLAG.0); } } } #[cfg(test)] mod tests { + use std::ffi::CStr; + use windows::Win32::Security::Cryptography::{ + szOID_ECC_PUBLIC_KEY, szOID_RSA_RSA, CertNameToStrA, CERT_X500_NAME_STR, X509_ASN_ENCODING, + }; + #[test] - fn verify_self_signed() { - let cert = super::Certificate::new_self_signed("cn=WebRTC").unwrap(); + fn verify_self_signed_rsa() { + let cert = super::Certificate::new_self_signed(false, "cn=WebRTC-RSA").unwrap(); + let cert_context = cert.context(); // Verify it is self-signed. unsafe { - let subject = (*(*cert.0).pCertInfo).Subject; - let subject = std::slice::from_raw_parts(subject.pbData, subject.cbData as usize); - let issuer = (*(*cert.0).pCertInfo).Issuer; - let issuer = std::slice::from_raw_parts(issuer.pbData, issuer.cbData as usize); - assert_eq!(issuer, subject); + assert_eq!( + CStr::from_ptr( + (*(*cert_context).pCertInfo) + .SubjectPublicKeyInfo + .Algorithm + .pszObjId + .0 as *const i8 + ), + CStr::from_ptr(szOID_RSA_RSA.as_ptr() as *const i8) + ); + + let subject = (*(*cert_context).pCertInfo).Subject; + let issuer = (*(*cert_context).pCertInfo).Issuer; + // Verify raw contents are equivalent. + assert_eq!( + std::slice::from_raw_parts(issuer.pbData, issuer.cbData as usize), + std::slice::from_raw_parts(subject.pbData, subject.cbData as usize) + ); + + let mut buffer = [0u8; 128]; + CertNameToStrA( + X509_ASN_ENCODING, + &subject, + CERT_X500_NAME_STR, + Some(&mut buffer), + ); + let subject = CStr::from_bytes_until_nul(&buffer) + .unwrap() + .to_str() + .unwrap(); + assert_eq!("CN=WebRTC-RSA", subject); + + CertNameToStrA( + X509_ASN_ENCODING, + &issuer, + CERT_X500_NAME_STR, + Some(&mut buffer), + ); + let issuer = CStr::from_bytes_until_nul(&buffer) + .unwrap() + .to_str() + .unwrap(); + assert_eq!("CN=WebRTC-RSA", issuer); } } #[test] - fn verify_fingerprint() { - let cert = super::Certificate::new_self_signed("cn=WebRTC").unwrap(); + fn verify_self_signed_ec_dsa() { + let cert = super::Certificate::new_self_signed(true, "cn=ecDsa").unwrap(); + let cert_context = cert.context(); + + // Verify it is self-signed. + unsafe { + assert_eq!( + CStr::from_ptr( + (*(*cert_context).pCertInfo) + .SubjectPublicKeyInfo + .Algorithm + .pszObjId + .0 as *const i8 + ), + CStr::from_ptr(szOID_ECC_PUBLIC_KEY.as_ptr() as *const i8) + ); + let subject = (*(*cert_context).pCertInfo).Subject; + let issuer = (*(*cert_context).pCertInfo).Issuer; + // Verify raw contents are equivalent. + assert_eq!( + std::slice::from_raw_parts(issuer.pbData, issuer.cbData as usize), + std::slice::from_raw_parts(subject.pbData, subject.cbData as usize) + ); + + let mut buffer = [0u8; 128]; + CertNameToStrA( + X509_ASN_ENCODING, + &subject, + CERT_X500_NAME_STR, + Some(&mut buffer), + ); + let subject = CStr::from_bytes_until_nul(&buffer) + .unwrap() + .to_str() + .unwrap(); + assert_eq!("CN=ecDsa", subject); + + CertNameToStrA( + X509_ASN_ENCODING, + &issuer, + CERT_X500_NAME_STR, + Some(&mut buffer), + ); + let issuer = CStr::from_bytes_until_nul(&buffer) + .unwrap() + .to_str() + .unwrap(); + assert_eq!("CN=ecDsa", issuer); + } + } + + #[test] + fn verify_fingerprint_rsa() { + let cert = super::Certificate::new_self_signed(false, "cn=WebRTC").unwrap(); + let fingerprint = cert.sha256_fingerprint().unwrap(); + assert_eq!(fingerprint.len(), 32); + } + + #[test] + fn verify_fingerprint_ec_dsa() { + let cert = super::Certificate::new_self_signed(true, "cn=WebRTC").unwrap(); let fingerprint = cert.sha256_fingerprint().unwrap(); assert_eq!(fingerprint.len(), 32); } diff --git a/wincrypto/src/dtls.rs b/wincrypto/src/dtls.rs index 9206f1b4..06d53cef 100644 --- a/wincrypto/src/dtls.rs +++ b/wincrypto/src/dtls.rs @@ -124,8 +124,7 @@ impl Dtls { pub fn set_as_client(&mut self, active: bool) -> Result<(), WinCryptoError> { self.is_client = Some(active); - let mut cert_contexts = [self.cert.0]; - + let mut cert_contexts = [self.cert.context()]; let schannel_cred = SCHANNEL_CRED { dwVersion: SCHANNEL_CRED_VERSION, hRootStore: windows::Win32::Security::Cryptography::HCERTSTORE(std::ptr::null_mut()), @@ -234,21 +233,21 @@ impl Dtls { let mut output = vec![0u8; header_size + trailer_size + message_size]; output[header_size..header_size + message_size].copy_from_slice(data); - let sec_buffers = [ + let mut sec_buffers = [ SecBuffer { BufferType: SECBUFFER_STREAM_HEADER, cbBuffer: header_size as u32, - pvBuffer: &output[0] as *const _ as *mut _, + pvBuffer: output[0..].as_mut_ptr() as *mut _, }, SecBuffer { BufferType: SECBUFFER_DATA, cbBuffer: message_size as u32, - pvBuffer: &output[header_size] as *const _ as *mut _, + pvBuffer: output[header_size..].as_mut_ptr() as *mut _, }, SecBuffer { BufferType: SECBUFFER_STREAM_TRAILER, cbBuffer: trailer_size as u32, - pvBuffer: &output[header_size + message_size] as *const _ as *mut _, + pvBuffer: output[header_size + message_size..].as_mut_ptr() as *mut _, }, SecBuffer { cbBuffer: 0, @@ -259,7 +258,7 @@ impl Dtls { let sec_buffer_desc = SecBufferDesc { ulVersion: SECBUFFER_VERSION, cBuffers: 4, - pBuffers: &sec_buffers[0] as *const _ as *mut _, + pBuffers: sec_buffers.as_mut_ptr() as *mut _, }; // SAFETY: The references passed are all borrow checked. However, @@ -310,38 +309,38 @@ impl Dtls { let in_buffer_desc = match datagram { Some(datagram) => { buffers[2].cbBuffer = datagram.len() as u32; - buffers[2].pvBuffer = datagram.as_ptr() as *mut _; + buffers[2].pvBuffer = datagram as *const _ as *mut _; SecBufferDesc { ulVersion: SECBUFFER_VERSION, cBuffers: buffers.len() as u32, - pBuffers: &buffers[0] as *const _ as *mut _, + pBuffers: buffers.as_mut_ptr() as *mut _, } } None => SecBufferDesc { ulVersion: SECBUFFER_VERSION, cBuffers: 2, - pBuffers: &buffers[0] as *const _ as *mut _, + pBuffers: buffers.as_mut_ptr() as *mut _, }, }; - let token_buffer = [0u8; DATAGRAM_MTU]; - let alert_buffer = [0u8; DATAGRAM_MTU]; - let out_buffers = [ + let mut token_buffer = [0u8; DATAGRAM_MTU]; + let mut alert_buffer = [0u8; DATAGRAM_MTU]; + let mut out_buffers = [ SecBuffer { cbBuffer: token_buffer.len() as u32, BufferType: SECBUFFER_TOKEN, - pvBuffer: &token_buffer as *const _ as *mut _, + pvBuffer: token_buffer.as_mut_ptr() as *mut _, }, SecBuffer { cbBuffer: alert_buffer.len() as u32, BufferType: SECBUFFER_ALERT, - pvBuffer: &alert_buffer as *const _ as *mut _, + pvBuffer: alert_buffer.as_mut_ptr() as *mut _, }, ]; let mut out_buffer_desc = SecBufferDesc { cBuffers: out_buffers.len() as u32, - pBuffers: &out_buffers[0] as *const _ as *mut _, + pBuffers: out_buffers.as_mut_ptr() as *mut _, ulVersion: SECBUFFER_VERSION, }; @@ -376,7 +375,10 @@ impl Dtls { ) } else { // Server - debug!("AcceptSecurityContext {:?}", in_buffer_desc); + debug!( + "AcceptSecurityContext {:?} {:?}", + in_buffer_desc, out_buffer_desc + ); AcceptSecurityContext( self.cred_handle.as_ref().map(|r| r as *const _), self.security_ctx.as_ref().map(|r| r as *const _), @@ -538,14 +540,14 @@ impl Dtls { let header_size = self.encrypt_message_input_sizes.cbHeader as usize; let trailer_size = self.encrypt_message_input_sizes.cbTrailer as usize; - let output = datagram.to_vec(); - let alert = [0u8; 512]; + let mut output = datagram.to_vec(); + let mut alert = [0u8; 512]; - let sec_buffers = [ + let mut sec_buffers = [ SecBuffer { BufferType: SECBUFFER_DATA, cbBuffer: output.len() as u32, - pvBuffer: &output[0] as *const _ as *mut _, + pvBuffer: output.as_mut_ptr() as *mut _, }, SecBuffer { cbBuffer: 0, @@ -565,13 +567,13 @@ impl Dtls { SecBuffer { BufferType: SECBUFFER_ALERT, cbBuffer: alert.len() as u32, - pvBuffer: &alert[0] as *const _ as *mut _, + pvBuffer: alert.as_mut_ptr() as *mut _, }, ]; let sec_buffer_desc = SecBufferDesc { ulVersion: SECBUFFER_VERSION, cBuffers: 4, - pBuffers: &sec_buffers[0] as *const _ as *mut _, + pBuffers: sec_buffers.as_mut_ptr() as *mut _, }; // SAFETY: All the passed in values are borrow checked. The `sec_buffer_desc` diff --git a/wincrypto/src/sha1.rs b/wincrypto/src/sha1.rs index 2191445c..592186da 100644 --- a/wincrypto/src/sha1.rs +++ b/wincrypto/src/sha1.rs @@ -1,7 +1,10 @@ use super::WinCryptoError; -use windows::Win32::Security::Cryptography::{ - BCryptCreateHash, BCryptDestroyHash, BCryptFinishHash, BCryptHashData, BCRYPT_HASH_HANDLE, - BCRYPT_HMAC_SHA1_ALG_HANDLE, +use windows::{ + core::Owned, + Win32::Security::Cryptography::{ + BCryptCreateHash, BCryptFinishHash, BCryptHashData, BCRYPT_HASH_HANDLE, + BCRYPT_HMAC_SHA1_ALG_HANDLE, + }, }; pub fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> Result<[u8; 20], WinCryptoError> { @@ -9,10 +12,10 @@ pub fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> Result<[u8; 20], WinCryptoEr // behaviors work for these uses. unsafe { // Create hash. - let mut hash_handle = BCRYPT_HASH_HANDLE::default(); + let mut hash_handle = Owned::new(BCRYPT_HASH_HANDLE::default()); WinCryptoError::from_ntstatus(BCryptCreateHash( BCRYPT_HMAC_SHA1_ALG_HANDLE, - &mut hash_handle, + &mut *hash_handle, None, Some(key), 0, @@ -20,15 +23,12 @@ pub fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> Result<[u8; 20], WinCryptoEr // Update hash with data. for payload in payloads { - WinCryptoError::from_ntstatus(BCryptHashData(hash_handle, payload, 0))?; + WinCryptoError::from_ntstatus(BCryptHashData(*hash_handle, payload, 0))?; } // Get the hash result. let mut hash = [0u8; 20]; - WinCryptoError::from_ntstatus(BCryptFinishHash(hash_handle, &mut hash, 0))?; - - // Free the hash. - WinCryptoError::from_ntstatus(BCryptDestroyHash(hash_handle))?; + WinCryptoError::from_ntstatus(BCryptFinishHash(*hash_handle, &mut hash, 0))?; Ok(hash) } diff --git a/wincrypto/src/srtp.rs b/wincrypto/src/srtp.rs index e59e4c87..df94be83 100644 --- a/wincrypto/src/srtp.rs +++ b/wincrypto/src/srtp.rs @@ -1,10 +1,13 @@ use super::WinCryptoError; use std::ptr::addr_of; -use windows::Win32::Security::Cryptography::{ - BCryptDecrypt, BCryptDestroyKey, BCryptEncrypt, BCryptGenerateSymmetricKey, - BCRYPT_AES_ECB_ALG_HANDLE, BCRYPT_AES_GCM_ALG_HANDLE, BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO, - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO_VERSION, BCRYPT_BLOCK_PADDING, BCRYPT_FLAGS, - BCRYPT_KEY_HANDLE, +use windows::{ + core::Owned, + Win32::Security::Cryptography::{ + BCryptDecrypt, BCryptEncrypt, BCryptGenerateSymmetricKey, BCRYPT_AES_ECB_ALG_HANDLE, + BCRYPT_AES_GCM_ALG_HANDLE, BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO, + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO_VERSION, BCRYPT_BLOCK_PADDING, BCRYPT_FLAGS, + BCRYPT_KEY_HANDLE, + }, }; const MAX_BUFFER_SIZE: usize = 2048; @@ -15,7 +18,7 @@ const AEAD_AES_GCM_TAG_LEN: usize = 16; /// does NOT implement Clone/Copy, otherwise we could destroy the key /// too early. It is also why access to the key handle should remain /// hidden. -pub struct SrtpKey(BCRYPT_KEY_HANDLE); +pub struct SrtpKey(Owned); // SAFETY: BCRYPT_KEY_HANDLEs are safe to send between threads. unsafe impl Send for SrtpKey {} // SAFETY: BCRYPT_KEY_HANDLEs are safe to send between threads. @@ -30,45 +33,33 @@ impl SrtpKey { /// Creates a key from the given data for operating AES in ECB mode. pub fn create_aes_ecb_key(key: &[u8]) -> Result { - let mut key_handle = BCRYPT_KEY_HANDLE::default(); // SAFETY: The key and key_handle will exist before and after this call. unsafe { + let mut key_handle = Owned::new(BCRYPT_KEY_HANDLE::default()); WinCryptoError::from_ntstatus(BCryptGenerateSymmetricKey( BCRYPT_AES_ECB_ALG_HANDLE, - &mut key_handle, + &mut *key_handle, None, &key, 0, ))?; + Ok(Self(key_handle)) } - Ok(Self(key_handle)) } /// Creates a key from the given data for operating AES in GCM mode. pub fn create_aes_gcm_key(key: &[u8]) -> Result { - let mut key_handle = BCRYPT_KEY_HANDLE::default(); // SAFETY: The key and key_handle will exist before and after this call. unsafe { + let mut key_handle = Owned::new(BCRYPT_KEY_HANDLE::default()); WinCryptoError::from_ntstatus(BCryptGenerateSymmetricKey( BCRYPT_AES_GCM_ALG_HANDLE, - &mut key_handle, + &mut *key_handle, None, &key, 0, ))?; - } - Ok(Self(key_handle)) - } -} - -impl Drop for SrtpKey { - fn drop(&mut self) { - // SAFETY: The SrtpKey is being dropped it is safe to copy the handle - // because the handle will no longer be accessible after this. - unsafe { - if let Err(e) = WinCryptoError::from_ntstatus(BCryptDestroyKey(self.0)) { - error!("Failed to destory crypto key: {}", e); - } + Ok(Self(key_handle)) } } } @@ -84,7 +75,7 @@ pub fn srtp_aes_128_ecb_round( // behaviors work. unsafe { WinCryptoError::from_ntstatus(BCryptEncrypt( - key.0, + *key.0, Some(input), None, None, @@ -137,7 +128,7 @@ pub fn srtp_aes_128_cm( let encrypted_countered_iv = std::slice::from_raw_parts_mut(countered_iv.as_mut_ptr(), countered_iv.len()); WinCryptoError::from_ntstatus(BCryptEncrypt( - key.0, + *key.0, Some(&countered_iv[..offset]), None, None, @@ -194,7 +185,7 @@ pub fn srtp_aead_aes_128_gcm_encrypt( // `cipher_text` and `iv`) exists for the duration of the unsafe block. unsafe { WinCryptoError::from_ntstatus(BCryptEncrypt( - key.0, + *key.0, Some(plain_text), Some(addr_of!(auth_cipher_mode_info) as *const std::ffi::c_void), None, @@ -252,7 +243,7 @@ pub fn srtp_aead_aes_128_gcm_decrypt( // `cipher_text` and `iv`) exists for the duration of the unsafe block. unsafe { WinCryptoError::from_ntstatus(BCryptDecrypt( - key.0, + *key.0, Some(cipher_text), Some(addr_of!(auth_cipher_mode_info) as *const std::ffi::c_void), None,