diff --git a/aes.go b/aes.go index 231b75e2..95daeacf 100644 --- a/aes.go +++ b/aes.go @@ -20,7 +20,7 @@ type extraModes interface { NewGCMTLS() (cipher.AEAD, error) } -var _ extraModes = (*aesCipher)(nil) +var _ extraModes = (*aesWithCTR)(nil) func NewAESCipher(key []byte) (cipher.Block, error) { var kind cipherKind @@ -38,19 +38,32 @@ func NewAESCipher(key []byte) (cipher.Block, error) { if err != nil { return nil, err } - return &aesCipher{c}, nil + ac := aesCipher{c} + // The SymCrypt provider doesn't support AES-CTR. + // Prove that the provider supports AES-CTR before + // returning an aesWithCTR. + if loadCipher(kind, cipherModeCTR) != nil { + return &aesWithCTR{ac}, nil + } + return &ac, nil } // NewGCMTLS returns a GCM cipher specific to TLS // and should not be used for non-TLS purposes. func NewGCMTLS(c cipher.Block) (cipher.AEAD, error) { - return c.(*aesCipher).NewGCMTLS() + if c, ok := c.(*aesCipher); ok { + return c.NewGCMTLS() + } + return c.(*aesWithCTR).NewGCMTLS() } // NewGCMTLS13 returns a GCM cipher specific to TLS 1.3 and should not be used // for non-TLS purposes. func NewGCMTLS13(c cipher.Block) (cipher.AEAD, error) { - return c.(*aesCipher).NewGCMTLS13() + if c, ok := c.(*aesCipher); ok { + return c.NewGCMTLS13() + } + return c.(*aesWithCTR).NewGCMTLS13() } type aesCipher struct { @@ -83,10 +96,6 @@ func (c *aesCipher) NewCBCDecrypter(iv []byte) cipher.BlockMode { return c.newCBC(iv, cipherOpDecrypt) } -func (c *aesCipher) NewCTR(iv []byte) cipher.Stream { - return c.newCTR(iv) -} - func (c *aesCipher) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) { return c.newGCMChecked(nonceSize, tagSize) } @@ -98,3 +107,11 @@ func (c *aesCipher) NewGCMTLS() (cipher.AEAD, error) { func (c *aesCipher) NewGCMTLS13() (cipher.AEAD, error) { return c.newGCM(cipherGCMTLS13) } + +type aesWithCTR struct { + aesCipher +} + +func (c *aesWithCTR) NewCTR(iv []byte) cipher.Stream { + return c.newCTR(iv) +}