virtio/devices/
virtio_rng.rs

1// Licensed under the Apache License, Version 2.0 or the MIT License.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3// Copyright Tock Contributors 2022.
4
5use core::cell::Cell;
6
7use kernel::deferred_call::{DeferredCall, DeferredCallClient};
8use kernel::hil::rng::{Client as RngClient, Continue as RngCont, Rng};
9use kernel::utilities::cells::OptionalCell;
10use kernel::ErrorCode;
11
12use super::super::devices::{VirtIODeviceDriver, VirtIODeviceType};
13use super::super::queues::split_queue::{SplitVirtqueue, SplitVirtqueueClient, VirtqueueBuffer};
14
15pub struct VirtIORng<'a, 'b> {
16    virtqueue: &'a SplitVirtqueue<'a, 'b, 1>,
17    buffer_capacity: Cell<usize>,
18    callback_pending: Cell<bool>,
19    deferred_call: DeferredCall,
20    client: OptionalCell<&'a dyn RngClient>,
21}
22
23impl<'a, 'b> VirtIORng<'a, 'b> {
24    pub fn new(virtqueue: &'a SplitVirtqueue<'a, 'b, 1>) -> VirtIORng<'a, 'b> {
25        VirtIORng {
26            virtqueue,
27            buffer_capacity: Cell::new(0),
28            callback_pending: Cell::new(false),
29            deferred_call: DeferredCall::new(),
30            client: OptionalCell::empty(),
31        }
32    }
33
34    pub fn provide_buffer(&self, buf: &'b mut [u8]) -> Result<usize, (&'b mut [u8], ErrorCode)> {
35        let len = buf.len();
36        if len < 4 {
37            // We don't yet support merging of randomness of multiple buffers
38            //
39            // Allowing a buffer with less than 4 elements will cause
40            // the callback to never be called, while the buffer is
41            // reinserted into the queue
42            return Err((buf, ErrorCode::INVAL));
43        }
44
45        let mut buffer_chain = [Some(VirtqueueBuffer {
46            buf,
47            len,
48            device_writeable: true,
49        })];
50
51        let res = self.virtqueue.provide_buffer_chain(&mut buffer_chain);
52
53        match res {
54            Err(ErrorCode::NOMEM) => {
55                // Hand back the buffer, the queue MUST NOT write partial
56                // buffer chains
57                let buf = buffer_chain[0].take().unwrap().buf;
58                Err((buf, ErrorCode::NOMEM))
59            }
60            Err(e) => panic!("Unexpected error {:?}", e),
61            Ok(()) => {
62                let mut cap = self.buffer_capacity.get();
63                cap += len;
64                self.buffer_capacity.set(cap);
65                Ok(cap)
66            }
67        }
68    }
69
70    fn buffer_chain_callback(
71        &self,
72        buffer_chain: &mut [Option<VirtqueueBuffer<'b>>],
73        bytes_used: usize,
74    ) {
75        // Disable further callbacks, until we're sure we need them
76        //
77        // The used buffers should stay in the queue until a client is
78        // ready to consume them
79        self.virtqueue.disable_used_callbacks();
80
81        // We only have buffer chains of a single buffer
82        let buf = buffer_chain[0].take().unwrap().buf;
83
84        // We have taken out a buffer, hence decrease the available capacity
85        assert!(self.buffer_capacity.get() >= buf.len());
86
87        // It could've happened that we don't require the callback any
88        // more, hence check beforehand
89        let cont = if self.callback_pending.get() {
90            // The callback is no longer pending
91            self.callback_pending.set(false);
92
93            let mut u32randiter = buf[0..bytes_used].chunks(4).filter_map(|slice| {
94                if slice.len() < 4 {
95                    None
96                } else {
97                    Some(u32::from_le_bytes([slice[0], slice[1], slice[2], slice[3]]))
98                }
99            });
100
101            // For now we don't use left-over randomness and assume the
102            // client has consumed the entire iterator
103            self.client
104                .map(|client| client.randomness_available(&mut u32randiter, Ok(())))
105                .unwrap_or(RngCont::Done)
106        } else {
107            RngCont::Done
108        };
109
110        if let RngCont::More = cont {
111            // Returning more is the equivalent of calling .get() on
112            // the Rng trait.
113
114            // TODO: what if this call fails?
115            let _ = self.get();
116        }
117
118        // In any case, reinsert the buffer for further processing
119        self.provide_buffer(buf).expect("Buffer reinsertion failed");
120    }
121}
122
123impl<'a> Rng<'a> for VirtIORng<'a, '_> {
124    fn get(&self) -> Result<(), ErrorCode> {
125        // Minimum buffer capacity must be 4 bytes for a single 32-bit
126        // word
127        if self.buffer_capacity.get() < 4 {
128            Err(ErrorCode::FAIL)
129        } else if self.client.is_none() {
130            Err(ErrorCode::FAIL)
131        } else if self.callback_pending.get() {
132            Err(ErrorCode::OFF)
133        } else if self.virtqueue.used_descriptor_chains_count() < 1 {
134            // There is no buffer ready in the queue, so let's rely
135            // purely on queue callbacks to notify us of the next
136            // incoming one
137            self.callback_pending.set(true);
138            self.virtqueue.enable_used_callbacks();
139            Ok(())
140        } else {
141            // There is a buffer in the virtqueue, get it and return
142            // it to a client in a deferred call
143            self.callback_pending.set(true);
144            self.deferred_call.set();
145            Ok(())
146        }
147    }
148
149    fn cancel(&self) -> Result<(), ErrorCode> {
150        // Cancel by setting the callback_pending flag to false which
151        // MUST be checked prior to every callback
152        self.callback_pending.set(false);
153
154        // For efficiency reasons, also unsubscribe from the virtqueue
155        // callbacks, which will let the buffers remain in the queue
156        // for future use
157        self.virtqueue.disable_used_callbacks();
158
159        Ok(())
160    }
161
162    fn set_client(&self, client: &'a dyn RngClient) {
163        self.client.set(client);
164    }
165}
166
167impl<'b> SplitVirtqueueClient<'b> for VirtIORng<'_, 'b> {
168    fn buffer_chain_ready(
169        &self,
170        _queue_number: u32,
171        buffer_chain: &mut [Option<VirtqueueBuffer<'b>>],
172        bytes_used: usize,
173    ) {
174        self.buffer_chain_callback(buffer_chain, bytes_used)
175    }
176}
177
178impl DeferredCallClient for VirtIORng<'_, '_> {
179    fn register(&'static self) {
180        self.deferred_call.register(self);
181    }
182
183    fn handle_deferred_call(&self) {
184        // Try to extract a descriptor chain
185        if let Some((mut chain, bytes_used)) = self.virtqueue.pop_used_buffer_chain() {
186            self.buffer_chain_callback(&mut chain, bytes_used)
187        } else {
188            // If we don't get a buffer, this must be a race condition
189            // which should not occur
190            //
191            // Prior to setting a deferred call, all virtqueue
192            // interrupts must be disabled so that no used buffer is
193            // removed before the deferred call callback
194            panic!("VirtIO RNG: deferred call callback with empty queue");
195        }
196    }
197}
198
199impl VirtIODeviceDriver for VirtIORng<'_, '_> {
200    fn negotiate_features(&self, _offered_features: u64) -> Option<u64> {
201        // We don't support any special features and do not care about
202        // what the device offers.
203        Some(0)
204    }
205
206    fn device_type(&self) -> VirtIODeviceType {
207        VirtIODeviceType::EntropySource
208    }
209}