diff --git a/virtio/src/device/balloon.rs b/virtio/src/device/balloon.rs index 7315662499fe1fc711dec3691aca35838a5d10c2..a44f2daa7f1768f72f14871c9241bac7024bae38 100644 --- a/virtio/src/device/balloon.rs +++ b/virtio/src/device/balloon.rs @@ -14,7 +14,7 @@ use std::mem::size_of; use std::os::unix::io::{AsRawFd, RawFd}; use std::rc::Rc; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, OnceLock}; use std::{ cmp::{self, Reverse}, time::Duration, @@ -70,7 +70,7 @@ const MEM_BUFFER_PERCENT_MAX: u32 = 80; const MONITOR_INTERVAL_SECOND_MIN: u32 = 5; const MONITOR_INTERVAL_SECOND_MAX: u32 = 300; -static mut BALLOON_DEV: Option>> = None; +static BALLOON_DEV: OnceLock>> = OnceLock::new(); /// IO vector, used to find memory segments. #[derive(Clone, Copy, Default)] @@ -726,9 +726,7 @@ impl BalloonIoHandler { } let req = Request::parse(&elem, OUT_IOVEC) .with_context(|| "Fail to parse available descriptor chain")?; - // SAFETY: There is no confliction when writing global variable BALLOON_DEV, in other - // words, this function will not be called simultaneously. - if let Some(dev) = unsafe { BALLOON_DEV.as_ref() } { + if let Some(dev) = BALLOON_DEV.get() { let mut balloon_dev = dev.lock().unwrap(); for iov in req.iovec.iter() { if let Some(stat) = iov_to_buf::(&self.mem_space, iov, 0) { @@ -1012,13 +1010,7 @@ impl Balloon { /// Init balloon object for global use. pub fn object_init(dev: Arc>) { - // SAFETY: there is no confliction when writing global variable BALLOON_DEV, in other - // words, this function will not be called simultaneously. - unsafe { - if BALLOON_DEV.is_none() { - BALLOON_DEV = Some(dev) - } - } + BALLOON_DEV.get_or_init(|| dev); } /// Notify configuration changes to VM. @@ -1224,9 +1216,7 @@ impl VirtioDevice for Balloon { } pub fn qmp_balloon(target: u64) -> bool { - // SAFETY: there is no confliction when writing global variable BALLOON_DEV, in other - // words, this function will not be called simultaneously. - if let Some(dev) = unsafe { BALLOON_DEV.as_ref() } { + if let Some(dev) = BALLOON_DEV.get() { match dev.lock().unwrap().set_guest_memory_size(target) { Ok(()) => { return true; @@ -1242,9 +1232,7 @@ pub fn qmp_balloon(target: u64) -> bool { } pub fn qmp_query_balloon() -> Option { - // SAFETY: There is no confliction when writing global variable BALLOON_DEV, in other - // words, this function will not be called simultaneously. - if let Some(dev) = unsafe { BALLOON_DEV.as_ref() } { + if let Some(dev) = BALLOON_DEV.get() { let unlocked_dev = dev.lock().unwrap(); return Some(unlocked_dev.get_guest_memory_size()); }