capsules_aes_gcm/
aes_gcm.rs

1// Licensed under the Apache License, Version 2.0 or the MIT License.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3// Copyright Western Digital 2023.
4
5//! Implements an AES-GCM implementation using the underlying
6//! AES-CTR implementation.
7//!
8//! This capsule requires an AES-CTR implementation to support
9//! AES-GCM. The implementation relies on AES-CTR, AES-CBC, AES-ECB and
10//! AES-CCM to ensure that when this capsule is used it exposes
11//! all of supported AES operations in a single API.
12
13use core::cell::Cell;
14use ghash::universal_hash::NewUniversalHash;
15use ghash::universal_hash::UniversalHash;
16use ghash::GHash;
17use ghash::Key;
18use kernel::hil::symmetric_encryption;
19use kernel::hil::symmetric_encryption::{
20    AES128Ctr, AES128, AES128CBC, AES128CCM, AES128ECB, AES128_BLOCK_SIZE, AES128_KEY_SIZE,
21};
22use kernel::utilities::cells::{OptionalCell, TakeCell};
23use kernel::ErrorCode;
24
25#[derive(Copy, Clone, Eq, PartialEq, Debug)]
26enum GCMState {
27    Idle,
28    GenerateHashKey,
29    CtrEncrypt,
30}
31
32pub struct Aes128Gcm<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>> {
33    aes: &'a A,
34
35    mac: OptionalCell<GHash>,
36
37    crypt_buf: TakeCell<'static, [u8]>,
38
39    client: OptionalCell<&'a dyn symmetric_encryption::Client<'a>>,
40    ccm_client: OptionalCell<&'a dyn symmetric_encryption::CCMClient>,
41    gcm_client: OptionalCell<&'a dyn symmetric_encryption::GCMClient>,
42
43    state: Cell<GCMState>,
44    encrypting: Cell<bool>,
45
46    buf: TakeCell<'static, [u8]>,
47
48    pos: Cell<(usize, usize, usize)>,
49    key: Cell<[u8; AES128_KEY_SIZE]>,
50    iv: Cell<[u8; AES128_KEY_SIZE]>,
51}
52
53impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>> Aes128Gcm<'a, A> {
54    pub fn new(aes: &'a A, crypt_buf: &'static mut [u8]) -> Aes128Gcm<'a, A> {
55        Aes128Gcm {
56            aes,
57
58            mac: OptionalCell::empty(),
59
60            crypt_buf: TakeCell::new(crypt_buf),
61
62            client: OptionalCell::empty(),
63            ccm_client: OptionalCell::empty(),
64            gcm_client: OptionalCell::empty(),
65
66            state: Cell::new(GCMState::Idle),
67            encrypting: Cell::new(false),
68
69            buf: TakeCell::empty(),
70            pos: Cell::new((0, 0, 0)),
71            key: Cell::new(Default::default()),
72            iv: Cell::new(Default::default()),
73        }
74    }
75
76    fn start_ctr_encrypt(&self) -> Result<(), ErrorCode> {
77        self.aes.set_mode_aes128ctr(self.encrypting.get())?;
78
79        let res = AES128::set_key(self.aes, &self.key.get());
80        if res != Ok(()) {
81            return res;
82        }
83
84        self.aes.set_iv(&self.iv.get()).unwrap();
85
86        self.aes.start_message();
87        let crypt_buf = self.crypt_buf.take().unwrap();
88        let (_aad_offset, message_offset, message_len) = self.pos.get();
89
90        match AES128::crypt(
91            self.aes,
92            None,
93            crypt_buf,
94            message_offset,
95            message_offset + message_len + AES128_BLOCK_SIZE,
96        ) {
97            None => {
98                self.state.set(GCMState::CtrEncrypt);
99                Ok(())
100            }
101            Some((res, _, crypt_buf)) => {
102                self.crypt_buf.replace(crypt_buf);
103                res
104            }
105        }
106    }
107
108    fn crypt_r(
109        &self,
110        buf: &'static mut [u8],
111        aad_offset: usize,
112        message_offset: usize,
113        message_len: usize,
114        encrypting: bool,
115    ) -> Result<(), (ErrorCode, &'static mut [u8])> {
116        if self.state.get() != GCMState::Idle {
117            return Err((ErrorCode::BUSY, buf));
118        }
119
120        self.encrypting.set(encrypting);
121
122        self.aes.set_mode_aes128ctr(self.encrypting.get()).unwrap();
123        AES128::set_key(self.aes, &self.key.get()).unwrap();
124        self.aes.set_iv(&[0; AES128_BLOCK_SIZE]).unwrap();
125
126        self.aes.start_message();
127        let crypt_buf = self.crypt_buf.take().unwrap();
128
129        for i in 0..AES128_BLOCK_SIZE {
130            crypt_buf[i] = 0;
131        }
132
133        match AES128::crypt(self.aes, None, crypt_buf, 0, AES128_BLOCK_SIZE) {
134            None => {
135                self.state.set(GCMState::GenerateHashKey);
136            }
137            Some((_res, _, crypt_buf)) => {
138                self.crypt_buf.replace(crypt_buf);
139            }
140        }
141
142        self.buf.replace(buf);
143        self.pos.set((aad_offset, message_offset, message_len));
144        Ok(())
145    }
146}
147
148impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>>
149    symmetric_encryption::CCMClient for Aes128Gcm<'a, A>
150{
151    fn crypt_done(&self, buf: &'static mut [u8], res: Result<(), ErrorCode>, tag_is_valid: bool) {
152        self.ccm_client.map(move |client| {
153            client.crypt_done(buf, res, tag_is_valid);
154        });
155    }
156}
157
158impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>>
159    symmetric_encryption::AES128GCM<'a> for Aes128Gcm<'a, A>
160{
161    fn set_client(&self, client: &'a dyn symmetric_encryption::GCMClient) {
162        self.gcm_client.set(client);
163    }
164
165    fn set_key(&self, key: &[u8]) -> Result<(), ErrorCode> {
166        if key.len() < AES128_KEY_SIZE {
167            Err(ErrorCode::INVAL)
168        } else {
169            let mut new_key = [0u8; AES128_KEY_SIZE];
170            new_key.copy_from_slice(key);
171            self.key.set(new_key);
172            Ok(())
173        }
174    }
175
176    fn set_iv(&self, nonce: &[u8]) -> Result<(), ErrorCode> {
177        let mut new_nonce = [0u8; AES128_KEY_SIZE];
178        let len = nonce.len().min(12);
179
180        new_nonce[0..len].copy_from_slice(&nonce[0..len]);
181        new_nonce[12..16].copy_from_slice(&[0, 0, 0, 1]);
182
183        self.iv.set(new_nonce);
184        Ok(())
185    }
186
187    fn crypt(
188        &self,
189        buf: &'static mut [u8],
190        aad_offset: usize,
191        message_offset: usize,
192        message_len: usize,
193        encrypting: bool,
194    ) -> Result<(), (ErrorCode, &'static mut [u8])> {
195        if self.state.get() != GCMState::Idle {
196            return Err((ErrorCode::BUSY, buf));
197        }
198
199        let _ = self
200            .crypt_r(buf, aad_offset, message_offset, message_len, encrypting)
201            .map_err(|(ecode, _)| {
202                self.buf.take().map(|buf| {
203                    self.gcm_client.map(move |client| {
204                        client.crypt_done(buf, Err(ecode), false);
205                    });
206                });
207            });
208
209        Ok(())
210    }
211}
212
213impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>>
214    symmetric_encryption::AES128<'a> for Aes128Gcm<'a, A>
215{
216    fn enable(&self) {
217        self.aes.enable();
218    }
219
220    fn disable(&self) {
221        self.aes.disable();
222    }
223
224    fn set_client(&'a self, client: &'a dyn symmetric_encryption::Client<'a>) {
225        self.client.set(client);
226    }
227
228    fn set_key(&self, key: &[u8]) -> Result<(), ErrorCode> {
229        AES128::set_key(self.aes, key)
230    }
231
232    fn set_iv(&self, iv: &[u8]) -> Result<(), ErrorCode> {
233        self.aes.set_iv(iv)
234    }
235
236    fn start_message(&self) {
237        self.aes.start_message()
238    }
239
240    fn crypt(
241        &self,
242        source: Option<&'static mut [u8]>,
243        dest: &'static mut [u8],
244        start_index: usize,
245        stop_index: usize,
246    ) -> Option<(
247        Result<(), ErrorCode>,
248        Option<&'static mut [u8]>,
249        &'static mut [u8],
250    )> {
251        AES128::crypt(self.aes, source, dest, start_index, stop_index)
252    }
253}
254
255impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a> + AES128CCM<'a>>
256    symmetric_encryption::AES128CCM<'a> for Aes128Gcm<'a, A>
257{
258    fn set_client(&'a self, client: &'a dyn symmetric_encryption::CCMClient) {
259        self.ccm_client.set(client);
260    }
261
262    fn set_key(&self, key: &[u8]) -> Result<(), ErrorCode> {
263        AES128CCM::set_key(self.aes, key)
264    }
265
266    fn set_nonce(&self, nonce: &[u8]) -> Result<(), ErrorCode> {
267        self.aes.set_nonce(nonce)
268    }
269
270    fn crypt(
271        &self,
272        buf: &'static mut [u8],
273        a_off: usize,
274        m_off: usize,
275        m_len: usize,
276        mic_len: usize,
277        confidential: bool,
278        encrypting: bool,
279    ) -> Result<(), (ErrorCode, &'static mut [u8])> {
280        AES128CCM::crypt(
281            self.aes,
282            buf,
283            a_off,
284            m_off,
285            m_len,
286            mic_len,
287            confidential,
288            encrypting,
289        )
290    }
291}
292
293impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>> AES128Ctr
294    for Aes128Gcm<'a, A>
295{
296    fn set_mode_aes128ctr(&self, encrypting: bool) -> Result<(), ErrorCode> {
297        self.aes.set_mode_aes128ctr(encrypting)
298    }
299}
300
301impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>> AES128ECB
302    for Aes128Gcm<'a, A>
303{
304    fn set_mode_aes128ecb(&self, encrypting: bool) -> Result<(), ErrorCode> {
305        self.aes.set_mode_aes128ecb(encrypting)
306    }
307}
308
309impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>> AES128CBC
310    for Aes128Gcm<'a, A>
311{
312    fn set_mode_aes128cbc(&self, encrypting: bool) -> Result<(), ErrorCode> {
313        self.aes.set_mode_aes128cbc(encrypting)
314    }
315}
316
317impl<'a, A: AES128<'a> + AES128Ctr + AES128CBC + AES128ECB + AES128CCM<'a>>
318    symmetric_encryption::Client<'a> for Aes128Gcm<'a, A>
319{
320    fn crypt_done(&self, _: Option<&'static mut [u8]>, crypt_buf: &'static mut [u8]) {
321        match self.state.get() {
322            GCMState::Idle => unreachable!(),
323            GCMState::GenerateHashKey => {
324                let (aad_offset, message_offset, message_len) = self.pos.get();
325
326                let mut mac = GHash::new(Key::from_slice(&crypt_buf[0..AES128_BLOCK_SIZE]));
327                let buf = self.buf.take().unwrap();
328
329                if self.encrypting.get() {
330                    mac.update_padded(&buf[aad_offset..message_offset]);
331
332                    crypt_buf[AES128_BLOCK_SIZE..(AES128_BLOCK_SIZE + message_len)]
333                        .copy_from_slice(&buf[message_offset..(message_offset + message_len)]);
334                    for i in 0..AES128_BLOCK_SIZE {
335                        crypt_buf[i] = 0;
336                    }
337
338                    self.mac.replace(mac);
339                } else {
340                    let copy_offset = (message_offset / AES128_BLOCK_SIZE) * AES128_BLOCK_SIZE;
341                    mac.update_padded(&buf[aad_offset..message_offset]);
342                    mac.update_padded(&buf[message_offset..(message_offset + message_len)]);
343
344                    let associated_data_bits = ((message_offset - aad_offset) as u64) * 8;
345                    let buffer_bits = (message_len as u64) * 8;
346
347                    let mut block = ghash::Block::default();
348                    block[..8].copy_from_slice(&associated_data_bits.to_be_bytes());
349                    block[8..].copy_from_slice(&buffer_bits.to_be_bytes());
350                    mac.update(&block);
351
352                    let mut tag = mac.finalize().into_bytes();
353
354                    for i in 0..AES128_BLOCK_SIZE {
355                        tag[i] ^= crypt_buf[copy_offset + i];
356                    }
357
358                    buf[0..AES128_BLOCK_SIZE].copy_from_slice(&tag);
359                }
360                self.crypt_buf.replace(crypt_buf);
361                self.buf.replace(buf);
362
363                self.start_ctr_encrypt().unwrap();
364            }
365            GCMState::CtrEncrypt => {
366                let buf = self.buf.take().unwrap();
367                let (aad_offset, message_offset, message_len) = self.pos.get();
368                let tag_offset = (message_offset / AES128_BLOCK_SIZE) * AES128_BLOCK_SIZE;
369                let copy_offset = (message_offset / AES128_BLOCK_SIZE).max(1) * AES128_BLOCK_SIZE;
370
371                if self.encrypting.get() {
372                    // Check the mac
373                    let mut mac = self.mac.take().unwrap();
374                    mac.update_padded(
375                        &crypt_buf[(message_offset + AES128_BLOCK_SIZE)
376                            ..(message_offset + message_len + AES128_BLOCK_SIZE)],
377                    );
378
379                    buf[0..message_len]
380                        .copy_from_slice(&crypt_buf[copy_offset..(copy_offset + message_len)]);
381
382                    let associated_data_bits = ((message_offset - aad_offset) as u64) * 8;
383                    let buffer_bits = (message_len as u64) * 8;
384
385                    let mut block = ghash::Block::default();
386                    block[..8].copy_from_slice(&associated_data_bits.to_be_bytes());
387                    block[8..].copy_from_slice(&buffer_bits.to_be_bytes());
388                    mac.update(&block);
389
390                    let mut tag = mac.finalize().into_bytes();
391
392                    for i in 0..AES128_BLOCK_SIZE {
393                        tag[i] ^= crypt_buf[tag_offset + i];
394                    }
395
396                    buf[(message_offset + message_len)
397                        ..(message_offset + message_len + AES128_BLOCK_SIZE)]
398                        .copy_from_slice(&tag);
399                } else {
400                    buf[0..message_len]
401                        .copy_from_slice(&crypt_buf[copy_offset..(copy_offset + message_len)]);
402                }
403
404                self.aes.disable();
405                self.crypt_buf.replace(crypt_buf);
406                self.state.set(GCMState::Idle);
407                self.gcm_client.map(move |client| {
408                    client.crypt_done(buf, Ok(()), true);
409                });
410            }
411        }
412    }
413}