diff --git a/Cargo.lock b/Cargo.lock index f4c523c63c3d505cb527f107e3631a27fdc687fd..36c8f841879cdb9a30fb1911130c24957782e5ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,6 +28,7 @@ dependencies = [ "nix 0.26.2", "once_cell", "thiserror", + "trace", "util", "vmm-sys-util", ] @@ -48,7 +49,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8512c9117059663fb5606788fbca3619e2a91dac0e3fe516242eab1fa6be5e44" dependencies = [ "alsa-sys", - "bitflags", + "bitflags 1.3.2", "libc", "nix 0.24.3", ] @@ -82,7 +83,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ba16453d10c712284061a05f6510f75abeb92b56ba88dfeb48c74775020cc22" dependencies = [ "atk-sys", - "bitflags", + "bitflags 1.3.2", "glib", "libc", ] @@ -117,7 +118,7 @@ version = "0.65.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfdf7b466f9a4903edc73f95d6d2bcd5baf8ae620638762244d3f60143643cc5" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cexpr", "clang-sys", "lazy_static", @@ -140,6 +141,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" + [[package]] name = "bitintr" version = "0.3.0" @@ -199,7 +206,7 @@ version = "0.17.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab3603c4028a5e368d09b51c8b624b9a46edcd7c3778284077a6125af73c9f0a" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cairo-sys-rs", "glib", "libc", @@ -279,7 +286,7 @@ version = "4.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f13b9c79b5d1dd500d20ef541215a6423c75829ef43117e1b4d17fd8af0b5d76" dependencies = [ - "bitflags", + "bitflags 1.3.2", "clap_derive", "clap_lex", "once_cell", @@ -307,6 +314,18 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "code_generator" +version = "2.4.0" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "serde", + "syn 2.0.18", + "toml", +] + [[package]] name = "const_format" version = "0.2.31" @@ -368,6 +387,7 @@ dependencies = [ "clap", "cpu", "drm-fourcc", + "hisysevent", "libc", "libpulse-binding", "libpulse-simple-binding", @@ -381,6 +401,8 @@ dependencies = [ "rusb", "serde", "serde_json", + "strum", + "strum_macros", "thiserror", "trace", "ui", @@ -492,7 +514,7 @@ version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be1df5ea52cccd7e3a0897338b5564968274b52f5fd12601e0afa44f454c74d3" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cairo-rs", "gdk-pixbuf", "gdk-sys", @@ -508,7 +530,7 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b023fbe0c6b407bd3d9805d107d9800da3829dc5a676653210f1d5f16d7f59bf" dependencies = [ - "bitflags", + "bitflags 1.3.2", "gdk-pixbuf-sys", "gio", "glib", @@ -583,7 +605,7 @@ version = "0.17.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d14522e56c6bcb6f7a3aebc25cbcfb06776af4c0c25232b601b4383252d7cb92" dependencies = [ - "bitflags", + "bitflags 1.3.2", "futures-channel", "futures-core", "futures-io", @@ -616,7 +638,7 @@ version = "0.17.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7f1de7cbde31ea4f0a919453a2dcece5d54d5b70e08f8ad254dc4840f5f09b6" dependencies = [ - "bitflags", + "bitflags 1.3.2", "futures-channel", "futures-core", "futures-executor", @@ -682,7 +704,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c4222ab92b08d4d0bab90ddb6185b4e575ceeea8b8cdf00b938d7b6661d966" dependencies = [ "atk", - "bitflags", + "bitflags 1.3.2", "cairo-rs", "field-offset", "futures-channel", @@ -748,6 +770,17 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hisysevent" +version = "2.4.0" +dependencies = [ + "anyhow", + "code_generator", + "lazy_static", + "libloading", + "log", +] + [[package]] name = "hypervisor" version = "2.4.0" @@ -785,7 +818,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7b36074613a723279637061b40db993208908a94f10ccb14436ce735bc0f57" dependencies = [ - "bitflags", + "bitflags 1.3.2", "libc", ] @@ -806,19 +839,20 @@ dependencies = [ [[package]] name = "kvm-bindings" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efe70e65a5b092161d17f5005b66e5eefe7a94a70c332e755036fc4af78c4e79" +checksum = "081fbd8164229a990fbf24a1f35d287740db110c2b5d42addf460165f1b0e032" dependencies = [ "vmm-sys-util", ] [[package]] name = "kvm-ioctls" -version = "0.15.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bdde2b46ee7b6587ef79f751019c4726c4f2d3e4628df5d69f3f9c5cb6c6bd4" +checksum = "9002dff009755414f22b962ec6ae6980b07d6d8b06e5297b1062019d72bd6a8c" dependencies = [ + "bitflags 2.5.0", "kvm-bindings", "libc", "vmm-sys-util", @@ -858,7 +892,7 @@ version = "2.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1745b20bfc194ac12ef828f144f0ec2d4a7fe993281fa3567a0bd4969aee6890" dependencies = [ - "bitflags", + "bitflags 1.3.2", "libc", "libpulse-sys", "num-derive", @@ -902,9 +936,9 @@ dependencies = [ [[package]] name = "libusb1-sys" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d0e2afce4245f2c9a418511e5af8718bcaf2fa408aefb259504d1a9cb25f27" +checksum = "17f6bace2f39082e9787c851afce469e7b2fe0f1cc64bbc68ca96653b63d8f17" dependencies = [ "cc", "libc", @@ -1072,7 +1106,7 @@ version = "0.24.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "libc", ] @@ -1083,7 +1117,7 @@ version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "libc", "memoffset 0.7.1", @@ -1189,7 +1223,7 @@ version = "0.17.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52c280b82a881e4208afb3359a8e7fde27a1b272280981f1f34610bed5770d37" dependencies = [ - "bitflags", + "bitflags 1.3.2", "gio", "glib", "libc", @@ -1574,6 +1608,7 @@ name = "stratovirt" version = "2.4.0" dependencies = [ "anyhow", + "hisysevent", "log", "machine", "machine_manager", @@ -1726,6 +1761,7 @@ version = "2.4.0" dependencies = [ "anyhow", "lazy_static", + "libloading", "log", "regex", "trace_generator", @@ -1853,6 +1889,7 @@ dependencies = [ "address_space", "anyhow", "byteorder", + "clap", "devices", "hypervisor", "kvm-bindings", @@ -1901,11 +1938,11 @@ dependencies = [ [[package]] name = "vmm-sys-util" -version = "0.11.1" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd64fe09d8e880e600c324e7d664760a17f56e9672b7495a86381b49e4f72f46" +checksum = "1d1435039746e20da4f8d507a72ee1b916f7b4b05af7a91c093d2c6561934ede" dependencies = [ - "bitflags", + "bitflags 1.3.2", "libc", ] diff --git a/Cargo.toml b/Cargo.toml index a82f9b7914906205b86ae1fd4a4f0b3f96bdaaa1..4c78415f7a98ca013c4d9aa334841e3bb31e8225 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ machine = { path = "machine" } machine_manager = { path = "machine_manager" } util = { path = "util" } trace = { path = "trace" } +hisysevent = { path = "hisysevent" } [workspace] members = [ @@ -39,9 +40,18 @@ vnc_auth = ["machine/vnc_auth"] ohui_srv = ["machine/ohui_srv"] ramfb = ["machine/ramfb"] virtio_gpu = ["machine/virtio_gpu"] -trace_to_logger = ["trace/trace_to_logger"] -trace_to_ftrace = ["trace/trace_to_ftrace"] -trace_to_hitrace = ["trace/trace_to_hitrace"] +trace_to_logger = ["trace/trace_to_logger", "machine/trace_to_logger"] +trace_to_ftrace = ["trace/trace_to_ftrace", "machine/trace_to_ftrace"] +trace_to_hitrace = ["trace/trace_to_hitrace", "machine/trace_to_hitrace"] +hisysevent = ["hisysevent/hisysevent"] +vfio = ["machine/vfio_device"] +usb_uas = ["machine/usb_uas"] +virtio_rng = ["machine/virtio_rng"] +virtio_scsi = ["machine/virtio_scsi"] +vhost_vsock = ["machine/vhost_vsock"] +vhostuser_block = ["machine/vhostuser_block"] +vhostuser_net = ["machine/vhostuser_net"] +vhost_net = ["machine/vhost_net"] [package.metadata.rpm.cargo] buildflags = ["--release"] @@ -55,3 +65,4 @@ panic = "abort" [profile.release] panic = "abort" lto = true +debug = true diff --git a/_typos.toml b/_typos.toml index 1d161a7f2e07c03f29db87bb48ff7f5c025f9f3f..9bbf0e694252c678de7726faf1fc0a25392c923f 100644 --- a/_typos.toml +++ b/_typos.toml @@ -18,6 +18,8 @@ RTC_MIS = "RTC_MIS" SECCOMP_FILETER_FLAG_TSYNC = "SECCOMP_FILETER_FLAG_TSYNC" test_ths = "test_ths" UART_LSR_THRE = "UART_LSR_THRE" +closID = "closID" +CLOS = "CLOS" [default.extend-words] ba = "ba" diff --git a/acpi/src/acpi_device.rs b/acpi/src/acpi_device.rs index aff03b84d8a65211f3dc9788a8395e7ff80667a9..d315fc40bd5a383b4270fa51e2ec3f1795cb6c6c 100644 --- a/acpi/src/acpi_device.rs +++ b/acpi/src/acpi_device.rs @@ -52,7 +52,7 @@ impl AcpiPMTimer { } let now = Instant::now(); let time_nanos = now.duration_since(self.start).as_nanos(); - let counter: u128 = (time_nanos * PM_TIMER_FREQUENCY) / (NANOSECONDS_PER_SECOND as u128); + let counter: u128 = (time_nanos * PM_TIMER_FREQUENCY) / u128::from(NANOSECONDS_PER_SECOND); data.copy_from_slice(&((counter & 0xFFFF_FFFF) as u32).to_le_bytes()); true diff --git a/acpi/src/aml_compiler.rs b/acpi/src/aml_compiler.rs index 2ad8d51ae662482c4f8a534eb3754191fdc80652..f97216ae2eae04e4963fd37d419ae0650fd2ea52 100644 --- a/acpi/src/aml_compiler.rs +++ b/acpi/src/aml_compiler.rs @@ -1616,7 +1616,7 @@ impl AmlIrqNoFlags { impl AmlBuilder for AmlIrqNoFlags { fn aml_bytes(&self) -> Vec { - let irq_mask = 1 << (self.irq as u16); + let irq_mask = 1 << u16::from(self.irq); vec![0x22, (irq_mask & 0xFF) as u8, (irq_mask >> 8) as u8] } } @@ -1819,7 +1819,7 @@ mod test { let elem4 = AmlFieldUnit::new(Some("FLD3"), 4); let elem5 = AmlFieldUnit::new(Some("FLD4"), 12); - for e in vec![elem1, elem2, elem3, elem4, elem5] { + for e in [elem1, elem2, elem3, elem4, elem5] { field.append_child(e); } @@ -1991,7 +1991,7 @@ mod test { if_scope1.append_child(AmlReturn::new()); method1.append_child(if_scope1); - let store1 = AmlStore::new(AmlArg(0), AmlLocal(0).clone()); + let store1 = AmlStore::new(AmlArg(0), AmlLocal(0)); method1.append_child(store1); let mut while_scope = AmlWhile::new(AmlLLess::new(AmlLocal(0), AmlArg(1))); @@ -2031,7 +2031,7 @@ mod test { method2.append_child(store2); let mut pkg1 = AmlPackage::new(3); - vec![0x01, 0x03F8, 0x03FF].iter().for_each(|&x| { + [0x01, 0x03F8, 0x03FF].iter().for_each(|&x| { pkg1.append_child(AmlInteger(x as u64)); }); let named_pkg1 = AmlNameDecl::new("PKG1", pkg1); diff --git a/acpi/src/table_loader.rs b/acpi/src/table_loader.rs index 0a58c6de6e53ccef3919daffe619388058b52f2d..e968e4d07a865a192171b4d5d641264f5d01d270 100644 --- a/acpi/src/table_loader.rs +++ b/acpi/src/table_loader.rs @@ -356,7 +356,7 @@ impl TableLoader { dst_file_entry.file_blob.lock().unwrap() [offset as usize..(offset as usize + size as usize)] - .copy_from_slice(&(src_offset as u64).as_bytes()[0..size as usize]); + .copy_from_slice(&u64::from(src_offset).as_bytes()[0..size as usize]); self.cmds.push(TableLoaderEntry::new_add_pointer_entry( dst_file, src_file, offset, size, @@ -382,7 +382,7 @@ mod test { let file_bytes = file_name.as_bytes(); // SATETY: The "alloc" field of union consists of u8 members, so the access is safe. - let alloc = unsafe { &table_loader.cmds.get(0).unwrap().entry.alloc }; + let alloc = unsafe { &table_loader.cmds.first().unwrap().entry.alloc }; assert_eq!( alloc.file[0..file_bytes.len()].to_vec(), file_bytes.to_vec() @@ -450,7 +450,7 @@ mod test { .add_cksum_entry(&file, 0_u32, 0_u32, file_len + 1) .is_err()); assert!(table_loader - .add_cksum_entry(&file, (file_len - 1) as u32, 80, 20) + .add_cksum_entry(&file, file_len - 1, 80, 20) .is_ok()); assert!(table_loader .add_cksum_entry(&file, file_len - 1, 0, 50) diff --git a/address_space/Cargo.toml b/address_space/Cargo.toml index 7ff40882a51df2b5fa33d5d705017bb6aff11656..c77820d3f8cd796b2ff0b588a86321e571dddd8c 100644 --- a/address_space/Cargo.toml +++ b/address_space/Cargo.toml @@ -10,7 +10,7 @@ description = "provide memory management for VM" libc = "0.2" log = "0.4" nix = { version = "0.26.2", default-features = false, features = ["fs", "feature"] } -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" arc-swap = "1.6.0" thiserror = "1.0" anyhow = "1.0" @@ -19,3 +19,4 @@ machine_manager = { path = "../machine_manager" } migration = { path = "../migration" } migration_derive = { path = "../migration/migration_derive" } util = { path = "../util" } +trace = { path = "../trace" } diff --git a/address_space/src/address.rs b/address_space/src/address.rs index a5ace6f03a372f0dc7cfbc05b06047ac668d73fd..5d1d02b670711b2effedf950002a328f1d7d1a97 100644 --- a/address_space/src/address.rs +++ b/address_space/src/address.rs @@ -15,6 +15,15 @@ use std::ops::{BitAnd, BitOr}; use util::num_ops::{round_down, round_up}; +#[derive(PartialEq, Eq)] +pub enum AddressAttr { + Ram, + MMIO, + RamDevice, + RomDevice, + RomDeviceForce, +} + /// Represent the address in given address space. #[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct GuestAddress(pub u64); @@ -166,16 +175,17 @@ impl AddressRange { /// /// * `other` - Other AddressRange. pub fn find_intersection(&self, other: AddressRange) -> Option { - let begin = self.base.raw_value() as u128; - let end = self.size as u128 + begin; - let other_begin = other.base.raw_value() as u128; - let other_end = other.size as u128 + other_begin; + let begin = u128::from(self.base.raw_value()); + let end = u128::from(self.size) + begin; + let other_begin = u128::from(other.base.raw_value()); + let other_end = u128::from(other.size) + other_begin; if end <= other_begin || other_end <= begin { return None; } let start = std::cmp::max(self.base, other.base); - let size_inter = (std::cmp::min(end, other_end) - start.0 as u128) as u64; + // SAFETY: The range of a region will not exceed 64 bits. + let size_inter = (std::cmp::min(end, other_end) - u128::from(start.0)) as u64; Some(AddressRange { base: start, diff --git a/address_space/src/address_space.rs b/address_space/src/address_space.rs index ae18dd86c1367f0e25d4c438f5447c1ac1b8f549..95cad2ac921a774de08029dd1ff1fae781a30a38 100644 --- a/address_space/src/address_space.rs +++ b/address_space/src/address_space.rs @@ -12,19 +12,17 @@ use std::fmt; use std::fmt::Debug; -use std::io::Write; use std::sync::{Arc, Mutex}; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use arc_swap::ArcSwap; use log::error; use once_cell::sync::OnceCell; use crate::{ - AddressRange, AddressSpaceError, FlatRange, GuestAddress, Listener, ListenerReqType, Region, - RegionIoEventFd, RegionType, + AddressAttr, AddressRange, AddressSpaceError, FlatRange, GuestAddress, Listener, + ListenerReqType, Region, RegionIoEventFd, RegionType, }; -use migration::{migration::Migratable, MigrationManager}; use util::aio::Iovec; use util::byte_code::ByteCode; @@ -41,10 +39,23 @@ impl FlatView { } } - fn read(&self, dst: &mut dyn std::io::Write, addr: GuestAddress, count: u64) -> Result<()> { + fn read( + &self, + dst: &mut dyn std::io::Write, + addr: GuestAddress, + count: u64, + attr: AddressAttr, + ) -> Result<()> { let mut len = count; let mut l = count; let mut start = addr; + let region_type = match attr { + AddressAttr::Ram => RegionType::Ram, + AddressAttr::MMIO => RegionType::IO, + AddressAttr::RamDevice => RegionType::RamDevice, + AddressAttr::RomDevice => RegionType::RomDevice, + AddressAttr::RomDeviceForce => RegionType::RomDevice, + }; loop { if let Some(fr) = self.find_flatrange(start) { @@ -53,6 +64,19 @@ impl FlatView { let region_base = fr.addr_range.base.unchecked_sub(fr.offset_in_region); let fr_remain = fr.addr_range.size - fr_offset; + if !util::test_helper::is_test_enabled() && fr.owner.region_type() != region_type { + // Read op RomDevice in I/O access mode as MMIO + if region_type == RegionType::IO + && fr.owner.region_type() == RegionType::RomDevice + { + if fr.owner.get_rom_device_romd().unwrap() { + bail!("mismatch region type") + } + } else { + bail!("mismatch region type") + } + } + if fr.owner.region_type() == RegionType::Ram || fr.owner.region_type() == RegionType::RamDevice { @@ -74,17 +98,42 @@ impl FlatView { } } - fn write(&self, src: &mut dyn std::io::Read, addr: GuestAddress, count: u64) -> Result<()> { + fn write( + &self, + src: &mut dyn std::io::Read, + addr: GuestAddress, + count: u64, + attr: AddressAttr, + ) -> Result<()> { let mut l = count; let mut len = count; let mut start = addr; + let region_type = match attr { + AddressAttr::Ram => RegionType::Ram, + AddressAttr::MMIO => RegionType::IO, + AddressAttr::RamDevice => RegionType::RamDevice, + AddressAttr::RomDeviceForce => RegionType::RomDevice, + _ => { + bail!("Error write attr") + } + }; loop { if let Some(fr) = self.find_flatrange(start) { let fr_offset = start.offset_from(fr.addr_range.base); let region_offset = fr.offset_in_region + fr_offset; let region_base = fr.addr_range.base.unchecked_sub(fr.offset_in_region); let fr_remain = fr.addr_range.size - fr_offset; + + // Read/Write ops to RomDevice is MMIO. + if !util::test_helper::is_test_enabled() + && fr.owner.region_type() != region_type + && !(region_type == RegionType::IO + && fr.owner.region_type() == RegionType::RomDevice) + { + bail!("mismatch region type") + } + if fr.owner.region_type() == RegionType::Ram || fr.owner.region_type() == RegionType::RamDevice { @@ -230,7 +279,7 @@ impl AddressSpace { } locked_listener.enable(); - let mut idx = 0; + let mut idx = 0_usize; let mut mls = self.listeners.lock().unwrap(); for ml in mls.iter() { if ml.lock().unwrap().priority() >= locked_listener.priority() { @@ -381,8 +430,8 @@ impl AddressSpace { /// * `new_evtfds` - New `RegionIoEventFd` array. fn update_ioeventfds_pass(&self, new_evtfds: &[RegionIoEventFd]) -> Result<()> { let old_evtfds = self.ioeventfds.lock().unwrap(); - let mut old_idx = 0; - let mut new_idx = 0; + let mut old_idx = 0_usize; + let mut new_idx = 0_usize; while old_idx < old_evtfds.len() || new_idx < new_evtfds.len() { let old_fd = old_evtfds.get(old_idx); @@ -450,19 +499,27 @@ impl AddressSpace { Ok(()) } - /// Return the host address according to the given `GuestAddress`. + /// Return the host address according to the given `GuestAddress`. It is dangerous to + /// read and write directly to hva. We strongly recommend that you use the read and + /// write interface provided by AddressSpace unless you know exactly what you need and + /// are sure it is safe. /// /// # Arguments /// /// * `addr` - Guest address. - pub fn get_host_address(&self, addr: GuestAddress) -> Option { + /// + /// # Safety + /// + /// Using this function, the caller needs to make it clear that hva is always in the ram + /// range of the virtual machine. And if you want to operate [hva,hva+size], the range + /// from hva to hva+size needs to be in the ram range. + pub unsafe fn get_host_address(&self, addr: GuestAddress, attr: AddressAttr) -> Option { let view = self.flat_view.load(); - view.find_flatrange(addr).and_then(|range| { let offset = addr.offset_from(range.addr_range.base); range .owner - .get_host_address() + .get_host_address(attr) .map(|host| host + range.offset_in_region + offset) }) } @@ -474,7 +531,7 @@ impl AddressSpace { /// * `addr` - Guest address. /// Return Error if the `addr` is not mapped. /// or return the HVA address and available mem length - pub fn addr_cache_init(&self, addr: GuestAddress) -> Option<(u64, u64)> { + pub fn addr_cache_init(&self, addr: GuestAddress, attr: AddressAttr) -> Option<(u64, u64)> { let view = self.flat_view.load(); if let Some(flat_range) = view.find_flatrange(addr) { @@ -484,12 +541,15 @@ impl AddressSpace { let region_remain = flat_range.owner.size() - region_offset; let fr_remain = flat_range.addr_range.size - fr_offset; - return flat_range.owner.get_host_address().map(|host| { - ( - host + region_offset, - std::cmp::min(fr_remain, region_remain), - ) - }); + // SAFETY: addr and size is in ram region. + return unsafe { + flat_range.owner.get_host_address(attr).map(|host| { + ( + host + region_offset, + std::cmp::min(fr_remain, region_remain), + ) + }) + }; } None @@ -503,6 +563,7 @@ impl AddressSpace { /// * `count` - Memory needed length pub fn get_address_map( &self, + cache: &Option, addr: GuestAddress, count: u64, res: &mut Vec, @@ -512,7 +573,7 @@ impl AddressSpace { loop { let io_vec = self - .addr_cache_init(start) + .get_host_address_from_cache(start, cache) .map(|(hva, fr_len)| Iovec { iov_base: hva, iov_len: std::cmp::min(len, fr_len), @@ -542,7 +603,7 @@ impl AddressSpace { cache: &Option, ) -> Option<(u64, u64)> { if cache.is_none() { - return self.addr_cache_init(addr); + return self.addr_cache_init(addr, AddressAttr::Ram); } let region_cache = cache.unwrap(); if addr.0 >= region_cache.start && addr.0 < region_cache.end { @@ -551,7 +612,7 @@ impl AddressSpace { region_cache.end - addr.0, )) } else { - self.addr_cache_init(addr) + self.addr_cache_init(addr, AddressAttr::Ram) } } @@ -569,13 +630,14 @@ impl AddressSpace { }) } - pub fn get_region_cache(&self, addr: GuestAddress) -> Option { + pub fn get_region_cache(&self, addr: GuestAddress, attr: AddressAttr) -> Option { let view = &self.flat_view.load(); if let Some(range) = view.find_flatrange(addr) { let reg_type = range.owner.region_type(); let start = range.addr_range.base.0; let end = range.addr_range.end_addr().0; - let host_base = self.get_host_address(GuestAddress(start)).unwrap_or(0); + // SAFETY: the size is in region range, and the type will be checked in get_host_address. + let host_base = unsafe { self.get_host_address(GuestAddress(start), attr) }?; let cache = RegionCache { reg_type, host_base, @@ -609,10 +671,17 @@ impl AddressSpace { /// # Errors /// /// Return Error if the `addr` is not mapped. - pub fn read(&self, dst: &mut dyn std::io::Write, addr: GuestAddress, count: u64) -> Result<()> { + pub fn read( + &self, + dst: &mut dyn std::io::Write, + addr: GuestAddress, + count: u64, + attr: AddressAttr, + ) -> Result<()> { + trace::address_space_read(&addr, count); let view = self.flat_view.load(); - view.read(dst, addr, count)?; + view.read(dst, addr, count, attr)?; Ok(()) } @@ -627,16 +696,25 @@ impl AddressSpace { /// # Errors /// /// Return Error if the `addr` is not mapped. - pub fn write(&self, src: &mut dyn std::io::Read, addr: GuestAddress, count: u64) -> Result<()> { + pub fn write( + &self, + src: &mut dyn std::io::Read, + addr: GuestAddress, + count: u64, + attr: AddressAttr, + ) -> Result<()> { + trace::address_space_write(&addr, count); let view = self.flat_view.load(); + let mut buf = Vec::new(); + src.read_to_end(&mut buf).unwrap(); + if !*self.hyp_ioevtfd_enabled.get_or_init(|| false) { let ioeventfds = self.ioeventfds.lock().unwrap(); - if let Ok(index) = ioeventfds - .as_slice() - .binary_search_by(|ioevtfd| ioevtfd.addr_range.base.cmp(&addr)) - { - let evtfd = &ioeventfds[index]; + for evtfd in ioeventfds.as_slice() { + if evtfd.addr_range.base != addr { + continue; + } if count == evtfd.addr_range.size || evtfd.addr_range.size == 0 { if !evtfd.data_match { if let Err(e) = evtfd.fd.write(1) { @@ -645,25 +723,27 @@ impl AddressSpace { return Ok(()); } - let mut buf = Vec::new(); - src.read_to_end(&mut buf).unwrap(); + let mut buf_temp = buf.clone(); - if buf.len() <= 8 { - let data = u64::from_bytes(buf.as_slice()).unwrap(); + if buf_temp.len() <= 8 { + buf_temp.resize(8, 0); + let data = u64::from_bytes(buf_temp.as_slice()).unwrap(); if *data == evtfd.data { if let Err(e) = evtfd.fd.write(1) { error!("Failed to write ioeventfd {:?}: {}", evtfd, e); } return Ok(()); + } else { + continue; } } - view.write(&mut buf.as_slice(), addr, count)?; + view.write(&mut buf_temp.as_slice(), addr, count, attr)?; return Ok(()); } } } - view.write(src, addr, count)?; + view.write(&mut buf.as_slice(), addr, count, attr)?; Ok(()) } @@ -676,30 +756,19 @@ impl AddressSpace { /// /// # Note /// To use this method, it is necessary to implement `ByteCode` trait for your object. - pub fn write_object(&self, data: &T, addr: GuestAddress) -> Result<()> { - self.write(&mut data.as_bytes(), addr, std::mem::size_of::() as u64) - .with_context(|| "Failed to write object") - } - - /// Write an object to memory via host address. - /// - /// # Arguments - /// - /// * `data` - The object that will be written to the memory. - /// * `host_addr` - The start host address where the object will be written to. - /// - /// # Note - /// To use this method, it is necessary to implement `ByteCode` trait for your object. - pub fn write_object_direct(&self, data: &T, host_addr: u64) -> Result<()> { - // Mark vmm dirty page manually if live migration is active. - MigrationManager::mark_dirty_log(host_addr, data.as_bytes().len() as u64); - - // SAFETY: The host addr is managed by memory space, it has been verified. - let mut dst = unsafe { - std::slice::from_raw_parts_mut(host_addr as *mut u8, std::mem::size_of::()) - }; - dst.write_all(data.as_bytes()) - .with_context(|| "Failed to write object via host address") + pub fn write_object( + &self, + data: &T, + addr: GuestAddress, + attr: AddressAttr, + ) -> Result<()> { + self.write( + &mut data.as_bytes(), + addr, + std::mem::size_of::() as u64, + attr, + ) + .with_context(|| "Failed to write object") } /// Read some data from memory to form an object. @@ -710,40 +779,21 @@ impl AddressSpace { /// /// # Note /// To use this method, it is necessary to implement `ByteCode` trait for your object. - pub fn read_object(&self, addr: GuestAddress) -> Result { + pub fn read_object(&self, addr: GuestAddress, attr: AddressAttr) -> Result { let mut obj = T::default(); self.read( &mut obj.as_mut_bytes(), addr, std::mem::size_of::() as u64, + attr, ) .with_context(|| "Failed to read object")?; Ok(obj) } - /// Read some data from memory to form an object via host address. - /// - /// # Arguments - /// - /// * `hoat_addr` - The start host address where the data will be read from. - /// - /// # Note - /// To use this method, it is necessary to implement `ByteCode` trait for your object. - pub fn read_object_direct(&self, host_addr: u64) -> Result { - let mut obj = T::default(); - let mut dst = obj.as_mut_bytes(); - // SAFETY: host_addr is managed by address_space, it has been verified for legality. - let src = unsafe { - std::slice::from_raw_parts_mut(host_addr as *mut u8, std::mem::size_of::()) - }; - dst.write_all(src) - .with_context(|| "Failed to read object via host address")?; - - Ok(obj) - } - /// Update the topology of memory. pub fn update_topology(&self) -> Result<()> { + trace::trace_scope_start!(address_update_topology); let old_fv = self.flat_view.load(); let addr_range = AddressRange::new(GuestAddress(0), self.root.size()); @@ -775,7 +825,7 @@ mod test { use vmm_sys_util::eventfd::EventFd; use super::*; - use crate::{HostMemMapping, RegionOps}; + use crate::{AddressAttr, HostMemMapping, RegionOps}; #[derive(Default, Clone)] struct TestListener { @@ -920,10 +970,10 @@ mod test { let listener3 = Arc::new(Mutex::new(ListenerPrior3::default())); let listener4 = Arc::new(Mutex::new(ListenerPrior4::default())); let listener5 = Arc::new(Mutex::new(ListenerNeg::default())); - space.register_listener(listener1.clone()).unwrap(); + space.register_listener(listener1).unwrap(); space.register_listener(listener3.clone()).unwrap(); - space.register_listener(listener5.clone()).unwrap(); - space.register_listener(listener2.clone()).unwrap(); + space.register_listener(listener5).unwrap(); + space.register_listener(listener2).unwrap(); space.register_listener(listener4.clone()).unwrap(); let mut pre_prior = std::i32::MIN; @@ -968,13 +1018,13 @@ mod test { let space = AddressSpace::new(root, "space", None).unwrap(); let listener1 = Arc::new(Mutex::new(ListenerPrior0::default())); let listener2 = Arc::new(Mutex::new(ListenerPrior0::default())); - space.register_listener(listener1.clone()).unwrap(); + space.register_listener(listener1).unwrap(); space.register_listener(listener2.clone()).unwrap(); space.unregister_listener(listener2).unwrap(); assert_eq!(space.listeners.lock().unwrap().len(), 1); for listener in space.listeners.lock().unwrap().iter() { - assert_eq!(listener.lock().unwrap().enabled(), true); + assert!(listener.lock().unwrap().enabled()); } } @@ -1014,7 +1064,7 @@ mod test { .reqs .lock() .unwrap() - .get(0) + .first() .unwrap() .1, AddressRange::new(region_c.offset(), region_c.size()) @@ -1037,7 +1087,7 @@ mod test { assert_eq!(locked_listener.reqs.lock().unwrap().len(), 4); // delete flat-range 0~6000 first, belonging to region_c assert_eq!( - locked_listener.reqs.lock().unwrap().get(0).unwrap().1, + locked_listener.reqs.lock().unwrap().first().unwrap().1, AddressRange::new(region_c.offset(), region_c.size()) ); // add range 0~2000, belonging to region_c @@ -1187,16 +1237,16 @@ mod test { ram2.start_address().unchecked_add(ram2.size()) ); assert!(space.address_in_memory(GuestAddress(0), 0)); - assert_eq!(space.address_in_memory(GuestAddress(1000), 0), false); - assert_eq!(space.address_in_memory(GuestAddress(1500), 0), false); + assert!(!space.address_in_memory(GuestAddress(1000), 0)); + assert!(!space.address_in_memory(GuestAddress(1500), 0)); assert!(space.address_in_memory(GuestAddress(2900), 0)); assert_eq!( - space.get_host_address(GuestAddress(500)), + unsafe { space.get_host_address(GuestAddress(500), AddressAttr::Ram) }, Some(ram1.host_address() + 500) ); assert_eq!( - space.get_host_address(GuestAddress(2500)), + unsafe { space.get_host_address(GuestAddress(2500), AddressAttr::Ram) }, Some(ram2.host_address() + 500) ); @@ -1217,18 +1267,22 @@ mod test { ram2.start_address().unchecked_add(ram2.size()) ); assert!(space.address_in_memory(GuestAddress(0), 0)); - assert_eq!(space.address_in_memory(GuestAddress(1000), 0), false); - assert_eq!(space.address_in_memory(GuestAddress(1500), 0), false); - assert_eq!(space.address_in_memory(GuestAddress(2400), 0), false); + assert!(!space.address_in_memory(GuestAddress(1000), 0)); + assert!(!space.address_in_memory(GuestAddress(1500), 0)); + assert!(!space.address_in_memory(GuestAddress(2400), 0)); assert!(space.address_in_memory(GuestAddress(2900), 0)); assert_eq!( - space.get_host_address(GuestAddress(500)), + unsafe { space.get_host_address(GuestAddress(500), AddressAttr::Ram) }, Some(ram1.host_address() + 500) ); - assert!(space.get_host_address(GuestAddress(2400)).is_none()); + assert!(unsafe { + space + .get_host_address(GuestAddress(2400), AddressAttr::Ram) + .is_none() + }); assert_eq!( - space.get_host_address(GuestAddress(2500)), + unsafe { space.get_host_address(GuestAddress(2500), AddressAttr::Ram) }, Some(ram2.host_address() + 500) ); } @@ -1245,9 +1299,15 @@ mod test { .unwrap(); let data: u64 = 10000; - assert!(space.write_object(&data, GuestAddress(992)).is_ok()); - let data1: u64 = space.read_object(GuestAddress(992)).unwrap(); + assert!(space + .write_object(&data, GuestAddress(992), AddressAttr::Ram) + .is_ok()); + let data1: u64 = space + .read_object(GuestAddress(992), AddressAttr::Ram) + .unwrap(); assert_eq!(data1, 10000); - assert!(space.write_object(&data, GuestAddress(993)).is_err()); + assert!(space + .write_object(&data, GuestAddress(993), AddressAttr::Ram) + .is_err()); } } diff --git a/address_space/src/host_mmap.rs b/address_space/src/host_mmap.rs index ed2ce2e9854600add884f48307da53b22300c9ba..d31d4fd34f4a37b940e0d53bc07a7b398938998d 100644 --- a/address_space/src/host_mmap.rs +++ b/address_space/src/host_mmap.rs @@ -25,12 +25,9 @@ use nix::unistd::{mkstemp, sysconf, unlink, SysconfVar}; use crate::{AddressRange, GuestAddress, Region}; use machine_manager::config::{HostMemPolicy, MachineMemConfig, MemZoneConfig}; -use util::{ - syscall::mbind, - unix::{do_mmap, host_page_size}, -}; +use util::unix::{do_mmap, host_page_size, mbind}; -const MAX_PREALLOC_THREAD: u8 = 16; +const MAX_PREALLOC_THREAD: i64 = 16; /// Verify existing pages in the mapping. const MPOL_MF_STRICT: u32 = 1; /// Move pages owned by this process to conform to mapping. @@ -59,9 +56,9 @@ impl FileBackend { /// # Arguments /// /// * `fd` - Opened backend file. - pub fn new_common(fd: File) -> Self { + pub fn new_common(fd: Arc) -> Self { Self { - file: Arc::new(fd), + file: fd, offset: 0, page_size: 0, } @@ -171,7 +168,8 @@ fn max_nr_threads(nr_vcpus: u8) -> u8 { return 1; } - min(min(nr_host_cpu as u8, MAX_PREALLOC_THREAD), nr_vcpus) + // MAX_PREALLOC_THREAD's value(16) is less than 255. + min(min(nr_host_cpu, MAX_PREALLOC_THREAD) as u8, nr_vcpus) } /// Touch pages to pre-alloc memory for VM. @@ -205,11 +203,12 @@ fn touch_pages(start: u64, page_size: u64, nr_pages: u64) { /// * `size` - Size of memory. /// * `nr_vcpus` - Number of vcpus. fn mem_prealloc(host_addr: u64, size: u64, nr_vcpus: u8) { + trace::trace_scope_start!(pre_alloc, args = (size)); let page_size = host_page_size(); let threads = max_nr_threads(nr_vcpus); let nr_pages = (size + page_size - 1) / page_size; - let pages_per_thread = nr_pages / (threads as u64); - let left = nr_pages % (threads as u64); + let pages_per_thread = nr_pages / u64::from(threads); + let left = nr_pages % u64::from(threads); let mut addr = host_addr; let mut threads_join = Vec::new(); for i in 0..threads { @@ -294,7 +293,7 @@ pub fn create_default_mem(mem_config: &MachineMemConfig, thread_num: u8) -> Resu pub fn create_backend_mem(mem_config: &MemZoneConfig, thread_num: u8) -> Result { let mut f_back: Option = None; - if mem_config.memfd { + if mem_config.memfd() { let anon_fd = memfd_create( &CString::new("stratovirt_anon_mem")?, MemFdCreateFlag::empty(), @@ -368,15 +367,20 @@ fn set_host_memory_policy(mem_mappings: &Arc, zone: &MemZoneConf nmask = vec![0_u64; max_node]; } - mbind( - host_addr_start, - zone.size, - policy as u32, - nmask, - max_node as u64, - MPOL_MF_STRICT | MPOL_MF_MOVE, - ) - .with_context(|| "Failed to call mbind")?; + // SAFETY: + // 1. addr is managed by memory mapping, it can be guaranteed legal. + // 2. node_mask was created in this function. + // 3. Upper limit of max_node is MAX_NODES. + unsafe { + mbind( + host_addr_start, + zone.size, + policy as u32, + nmask, + max_node as u64, + MPOL_MF_STRICT | MPOL_MF_MOVE, + )?; + } Ok(()) } diff --git a/address_space/src/lib.rs b/address_space/src/lib.rs index f9feddf38ed49e6b7818c0d00160d4bbe776d162..ae8760a4ce2ae9957a5d9edc4124df283b8bd026 100644 --- a/address_space/src/lib.rs +++ b/address_space/src/lib.rs @@ -19,7 +19,7 @@ //! use std::sync::{Arc, Mutex}; //! extern crate address_space; //! use address_space::{ -//! AddressSpace, FileBackend, GuestAddress, HostMemMapping, Region, RegionOps, +//! AddressAttr, AddressSpace, FileBackend, GuestAddress, HostMemMapping, Region, RegionOps, //! }; //! //! struct DummyDevice; @@ -76,7 +76,7 @@ //! space.root().add_subregion(io_region, 0x2000); //! //! // 5. access address_space -//! space.write_object(&0x11u64, GuestAddress(0)); +//! space.write_object(&0x11u64, GuestAddress(0), AddressAttr::Ram); //! } //! ``` @@ -90,7 +90,7 @@ mod region; mod state; pub use crate::address_space::{AddressSpace, RegionCache}; -pub use address::{AddressRange, GuestAddress}; +pub use address::{AddressAttr, AddressRange, GuestAddress}; pub use error::AddressSpaceError; pub use host_mmap::{create_backend_mem, create_default_mem, FileBackend, HostMemMapping}; pub use listener::{Listener, ListenerReqType, MemSlot}; diff --git a/address_space/src/region.rs b/address_space/src/region.rs index f5e7c77ab67055e047db43d368358d0e07b1a32a..48c70302a7454bd0a0971b850ee4c061cb0ea65d 100644 --- a/address_space/src/region.rs +++ b/address_space/src/region.rs @@ -21,8 +21,8 @@ use log::{debug, warn}; use crate::address_space::FlatView; use crate::{ - AddressRange, AddressSpace, AddressSpaceError, FileBackend, GuestAddress, HostMemMapping, - RegionOps, + AddressAttr, AddressRange, AddressSpace, AddressSpaceError, FileBackend, GuestAddress, + HostMemMapping, RegionOps, }; use migration::{migration::Migratable, MigrationManager}; @@ -201,7 +201,7 @@ macro_rules! rw_multi_ops { let offset = $args.offset; let cnt = $args.count; let access_size = $args.access_size; - let mut pos = 0; + let mut pos = 0_u64; for _ in 0..(cnt / access_size) { if !$ops( &mut $slice[pos as usize..(pos + access_size) as usize], @@ -467,11 +467,21 @@ impl Region { /// Get the host address if this region is backed by host-memory, /// Return `None` if it is not a Ram-type region. - pub fn get_host_address(&self) -> Option { - if self.region_type != RegionType::Ram - && self.region_type != RegionType::RamDevice - && self.region_type != RegionType::RomDevice - { + /// + /// # Safety + /// + /// Need to make it clear that hva is always in the ram range of the virtual machine. + /// And if you want to operate [hva,hva+size], the range from hva to hva+size needs + /// to be in the ram range. + pub unsafe fn get_host_address(&self, attr: AddressAttr) -> Option { + let region_type = match attr { + AddressAttr::Ram => RegionType::Ram, + AddressAttr::MMIO => return None, + AddressAttr::RamDevice => RegionType::RamDevice, + AddressAttr::RomDevice | AddressAttr::RomDeviceForce => RegionType::RomDevice, + }; + + if self.region_type != region_type { return None; } self.mem_mapping.as_ref().map(|r| r.host_address()) @@ -1155,7 +1165,7 @@ mod test { assert_eq!(&data, &mut res_data); assert_eq!( - ram_region.get_host_address().unwrap(), + unsafe { ram_region.get_host_address(AddressAttr::Ram).unwrap() }, mem_mapping.host_address() ); @@ -1208,7 +1218,7 @@ mod test { let mut device_locked = test_dev_clone.lock().unwrap(); device_locked.read(data, addr, offset) }; - let test_dev_clone = test_dev.clone(); + let test_dev_clone = test_dev; let write_ops = move |data: &[u8], addr: GuestAddress, offset: u64| -> bool { let mut device_locked = test_dev_clone.lock().unwrap(); device_locked.write(data, addr, offset) @@ -1219,7 +1229,7 @@ mod test { write: Arc::new(write_ops), }; - let io_region = Region::init_io_region(16, test_dev_ops.clone(), "io_region"); + let io_region = Region::init_io_region(16, test_dev_ops, "io_region"); let data = [0x01u8; 8]; let mut data_res = [0x0u8; 8]; let count = data.len() as u64; @@ -1235,7 +1245,7 @@ mod test { .is_ok()); assert_eq!(data.to_vec(), data_res.to_vec()); - assert!(io_region.get_host_address().is_none()); + assert!(unsafe { io_region.get_host_address(AddressAttr::Ram).is_none() }); } #[test] @@ -1299,7 +1309,7 @@ mod test { }; let io_region = Region::init_io_region(1 << 4, default_ops.clone(), "io1"); - let io_region2 = Region::init_io_region(1 << 4, default_ops.clone(), "io2"); + let io_region2 = Region::init_io_region(1 << 4, default_ops, "io2"); io_region2.set_priority(10); // add duplicate io-region or ram-region will fail @@ -1323,7 +1333,7 @@ mod test { .subregions .read() .unwrap() - .get(0) + .first() .unwrap() .priority(), 10 @@ -1364,9 +1374,9 @@ mod test { region_b.set_priority(2); region_c.set_priority(1); region_a.add_subregion(region_b.clone(), 2000).unwrap(); - region_a.add_subregion(region_c.clone(), 0).unwrap(); - region_b.add_subregion(region_d.clone(), 0).unwrap(); - region_b.add_subregion(region_e.clone(), 2000).unwrap(); + region_a.add_subregion(region_c, 0).unwrap(); + region_b.add_subregion(region_d, 0).unwrap(); + region_b.add_subregion(region_e, 2000).unwrap(); let addr_range = AddressRange::from((0u64, region_a.size())); let view = region_a @@ -1405,14 +1415,14 @@ mod test { let region_b = Region::init_container_region(5000, "region_b"); let region_c = Region::init_io_region(1000, default_ops.clone(), "regionc"); let region_d = Region::init_io_region(3000, default_ops.clone(), "region_d"); - let region_e = Region::init_io_region(2000, default_ops.clone(), "region_e"); + let region_e = Region::init_io_region(2000, default_ops, "region_e"); region_a.add_subregion(region_b.clone(), 2000).unwrap(); - region_a.add_subregion(region_c.clone(), 0).unwrap(); + region_a.add_subregion(region_c, 0).unwrap(); region_d.set_priority(2); region_e.set_priority(3); - region_b.add_subregion(region_d.clone(), 0).unwrap(); - region_b.add_subregion(region_e.clone(), 2000).unwrap(); + region_b.add_subregion(region_d, 0).unwrap(); + region_b.add_subregion(region_e, 2000).unwrap(); let addr_range = AddressRange::from((0u64, region_a.size())); let view = region_a diff --git a/address_space/src/state.rs b/address_space/src/state.rs index 6a6d81ff7984f3f55b9d70bfe186512e775ef869..2f28232589e8c46b22df5383008e75a6db56e17c 100644 --- a/address_space/src/state.rs +++ b/address_space/src/state.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use anyhow::{bail, Context, Result}; -use crate::{AddressSpace, FileBackend, GuestAddress, HostMemMapping, Region}; +use crate::{AddressAttr, AddressSpace, FileBackend, GuestAddress, HostMemMapping, Region}; use migration::{ error::MigrationError, DeviceStateDesc, FieldDesc, MemBlock, MigrationHook, StateTransfer, }; @@ -168,14 +168,14 @@ impl MigrationHook for AddressSpace { } fn send_memory(&self, fd: &mut dyn Write, range: MemBlock) -> Result<()> { - self.read(fd, GuestAddress(range.gpa), range.len) + self.read(fd, GuestAddress(range.gpa), range.len, AddressAttr::Ram) .map_err(|e| MigrationError::SendVmMemoryErr(e.to_string()))?; Ok(()) } fn recv_memory(&self, fd: &mut dyn Read, range: MemBlock) -> Result<()> { - self.write(fd, GuestAddress(range.gpa), range.len) + self.write(fd, GuestAddress(range.gpa), range.len, AddressAttr::Ram) .map_err(|e| MigrationError::RecvVmMemoryErr(e.to_string()))?; Ok(()) diff --git a/block_backend/Cargo.toml b/block_backend/Cargo.toml index 6f7c45b3d904151d20dddaf9e2832676b9ff4944..d052bd0d55d48d3c2fb4e0eef0e09b98cc251585 100644 --- a/block_backend/Cargo.toml +++ b/block_backend/Cargo.toml @@ -7,7 +7,7 @@ license = "Mulan PSL v2" [dependencies] thiserror = "1.0" -vmm-sys-util = "0.11.0" +vmm-sys-util = "0.12.1" anyhow = "1.0" log = "0.4" byteorder = "1.4.3" diff --git a/block_backend/src/file.rs b/block_backend/src/file.rs index 56fac1e08f245a6e27ab22782f756a0f50a0a3d7..3acef03816abdc7d334d02efed5c0374ee27cfeb 100644 --- a/block_backend/src/file.rs +++ b/block_backend/src/file.rs @@ -14,7 +14,10 @@ use std::{ cell::RefCell, fs::File, io::{Seek, SeekFrom}, - os::unix::prelude::{AsRawFd, RawFd}, + os::{ + linux::fs::MetadataExt, + unix::prelude::{AsRawFd, RawFd}, + }, rc::Rc, sync::{ atomic::{AtomicBool, AtomicI64, AtomicU32, AtomicU64, Ordering}, @@ -26,7 +29,7 @@ use anyhow::{Context, Result}; use log::error; use vmm_sys_util::epoll::EventSet; -use crate::{BlockIoErrorCallback, BlockProperty}; +use crate::{qcow2::DEFAULT_SECTOR_SIZE, BlockIoErrorCallback, BlockProperty}; use machine_manager::event_loop::{register_event_helper, unregister_event_helper}; use util::{ aio::{Aio, AioCb, AioEngine, Iovec, OpCode}, @@ -52,7 +55,7 @@ impl CombineRequest { } pub struct FileDriver { - pub file: File, + pub file: Arc, aio: Rc>>, pub incomplete: Arc, delete_evts: Vec, @@ -60,7 +63,7 @@ pub struct FileDriver { } impl FileDriver { - pub fn new(file: File, aio: Aio, block_prop: BlockProperty) -> Self { + pub fn new(file: Arc, aio: Aio, block_prop: BlockProperty) -> Self { Self { file, incomplete: aio.incomplete_cnt.clone(), @@ -102,7 +105,7 @@ impl FileDriver { completecb: T, ) -> Result<()> { if req_list.is_empty() { - return self.complete_request(opcode, &Vec::new(), 0, 0, completecb); + return self.complete_request(opcode, 0, completecb); } let single_req = req_list.len() == 1; let cnt = Arc::new(AtomicU32::new(req_list.len() as u32)); @@ -127,16 +130,10 @@ impl FileDriver { self.process_request(OpCode::Preadv, req_list, completecb) } - fn complete_request( - &mut self, - opcode: OpCode, - iovec: &[Iovec], - offset: usize, - nbytes: u64, - completecb: T, - ) -> Result<()> { - let aiocb = self.package_aiocb(opcode, iovec.to_vec(), offset, nbytes, completecb); - (self.aio.borrow_mut().complete_func)(&aiocb, nbytes as i64) + pub fn complete_request(&mut self, opcode: OpCode, res: i64, completecb: T) -> Result<()> { + let iovec: Vec = Vec::new(); + let aiocb = self.package_aiocb(opcode, iovec.to_vec(), 0, 0, completecb); + (self.aio.borrow_mut().complete_func)(&aiocb, res) } pub fn write_vectored(&mut self, req_list: Vec, completecb: T) -> Result<()> { @@ -194,16 +191,22 @@ impl FileDriver { unregister_event_helper(self.block_prop.iothread.as_ref(), &mut self.delete_evts) } + pub fn actual_size(&mut self) -> Result { + let meta_data = self.file.metadata()?; + Ok(meta_data.st_blocks() * DEFAULT_SECTOR_SIZE) + } + pub fn disk_size(&mut self) -> Result { let disk_size = self .file + .as_ref() .seek(SeekFrom::End(0)) .with_context(|| "Failed to seek the end for file")?; Ok(disk_size) } pub fn extend_to_len(&mut self, len: u64) -> Result<()> { - let file_end = self.file.seek(SeekFrom::End(0))?; + let file_end = self.file.as_ref().seek(SeekFrom::End(0))?; if len > file_end { self.file.set_len(len)?; } diff --git a/block_backend/src/lib.rs b/block_backend/src/lib.rs index cbe37dfaa8b6c4afd8e8474e1939a131a26e75ac..5e711a8906413b08d10f1004edc5bf3c91fde13a 100644 --- a/block_backend/src/lib.rs +++ b/block_backend/src/lib.rs @@ -15,6 +15,7 @@ pub mod qcow2; pub mod raw; use std::{ + fmt, fs::File, sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, @@ -151,6 +152,66 @@ impl CreateOptions { } } +// Transform size into string with storage units. +fn size_to_string(size: f64) -> Result { + let units = ["", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"]; + + // Switch to higher power if the integer part is >= 1000, + // For example: 1000 * 2^30 bytes + // It's better to output 0.978 TiB, rather than 1000 GiB. + let n = (size / 1000.0 * 1024.0).log2() as u64; + let idx = n / 10; + if idx >= units.len() as u64 { + bail!("Input value {} is too large", size); + } + let div = 1_u64 << (idx * 10); + + // Keep three significant digits and do not output any extra zeros, + // For example: 512 * 2^20 bytes + // It's better to output 512 MiB, rather than 512.000 MiB. + let num_str = format!("{:.3}", size / div as f64); + let num_str = num_str.trim_end_matches('0').trim_end_matches('.'); + + let res = format!("{} {}", num_str, units[idx as usize]); + Ok(res) +} + +#[derive(Default)] +pub struct ImageInfo { + pub path: String, + pub format: String, + pub actual_size: u64, + pub virtual_size: u64, + pub cluster_size: Option, + pub snap_lists: Option, +} + +impl fmt::Display for ImageInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "image: {}\n\ + file format: {}\n\ + virtual size: {} ({} bytes)\n\ + disk size: {}", + self.path, + self.format, + size_to_string(self.virtual_size as f64).unwrap_or_else(|e| format!("{:?}", e)), + self.virtual_size, + size_to_string(self.actual_size as f64).unwrap_or_else(|e| format!("{:?}", e)) + )?; + + if let Some(cluster_size) = self.cluster_size { + writeln!(f, "cluster_size: {}", cluster_size)?; + } + + if let Some(snap_lists) = &self.snap_lists { + write!(f, "Snapshot list:\n{}", snap_lists)?; + } + Ok(()) + } +} + #[derive(Default, Clone, Copy)] pub struct DiskFragments { pub allocated_clusters: u64, @@ -288,9 +349,16 @@ impl Default for BlockProperty { } } +pub enum BlockAllocStatus { + DATA, + ZERO, +} + pub trait BlockDriverOps: Send { fn create_image(&mut self, options: &CreateOptions) -> Result; + fn query_image(&mut self, image_info: &mut ImageInfo) -> Result<()>; + fn check_image(&mut self, res: &mut CheckResult, quite: bool, fix: u64) -> Result<()>; fn disk_size(&mut self) -> Result; @@ -328,21 +396,55 @@ pub trait BlockDriverOps: Send { fn unregister_io_event(&mut self) -> Result<()>; fn get_status(&mut self) -> Arc>; + + // Get a continuous address with the same allocation status starting from `offset`. + // Returns this continuous address's allocation status(data or zero) and size. + fn get_address_alloc_status( + &mut self, + _offset: u64, + _bytes: u64, + ) -> Result<(BlockAllocStatus, u64)> { + bail!("Not support!"); + } } pub fn create_block_backend( - file: File, + file: Arc, aio: Aio, prop: BlockProperty, ) -> Result>>> { + let cloned_drive_id = prop.id.clone(); + // NOTE: we can drain request when request in io thread. + let drain = prop.iothread.is_some(); match prop.format { DiskFormat::Raw => { let mut raw_file = RawDriver::new(file, aio, prop.clone()); let file_size = raw_file.disk_size()?; - if file_size & (prop.req_align as u64 - 1) != 0 { + if file_size & (u64::from(prop.req_align) - 1) != 0 { bail!("The size of raw file is not aligned to {}.", prop.req_align); } - Ok(Arc::new(Mutex::new(raw_file))) + let new_raw = Arc::new(Mutex::new(raw_file)); + + let cloned_raw = Arc::downgrade(&new_raw); + let exit_notifier = Arc::new(move || { + if let Some(raw) = cloned_raw.upgrade() { + info!("clean up raw {:?} resources.", cloned_drive_id); + if drain { + info!("Drain the inflight IO for drive \"{}\"", cloned_drive_id); + let incomplete = raw.lock().unwrap().get_inflight(); + while incomplete.load(Ordering::SeqCst) != 0 { + yield_now(); + } + } + info!( + "Drain the inflight IO for drive \"{}\" ends.", + cloned_drive_id + ); + } + }) as Arc; + TempCleaner::add_exit_notifier(prop.id, exit_notifier); + + Ok(new_raw) } DiskFormat::Qcow2 => { let mut qcow2 = Qcow2Driver::new(file, aio, prop.clone()) @@ -352,7 +454,7 @@ pub fn create_block_backend( .with_context(|| "Failed to load metadata")?; let file_size = qcow2.disk_size()?; - if file_size & (prop.req_align as u64 - 1) != 0 { + if file_size & (u64::from(prop.req_align) - 1) != 0 { bail!( "The size of qcow2 file is not aligned to {}.", prop.req_align @@ -363,11 +465,8 @@ pub fn create_block_backend( .lock() .unwrap() .insert(prop.id.clone(), new_qcow2.clone()); - let cloned_qcow2 = Arc::downgrade(&new_qcow2); - // NOTE: we can drain request when request in io thread. - let drain = prop.iothread.is_some(); - let cloned_drive_id = prop.id.clone(); + let cloned_qcow2 = Arc::downgrade(&new_qcow2); let exit_notifier = Arc::new(move || { if let Some(qcow2) = cloned_qcow2.upgrade() { info!("clean up qcow2 {:?} resources.", cloned_drive_id); @@ -381,6 +480,7 @@ pub fn create_block_backend( if let Err(e) = qcow2.lock().unwrap().flush() { error!("Failed to flush qcow2 {:?}", e); } + info!("Flush qcow2 {} metadata success.", cloned_drive_id); } }) as Arc; TempCleaner::add_exit_notifier(prop.id.clone(), exit_notifier); diff --git a/block_backend/src/qcow2/cache.rs b/block_backend/src/qcow2/cache.rs index 445ddc31e996af41dbb9876c4ff7d579a6ed90fa..e54739f81ebaa6f8ccf76a0843aa87e13efa59db 100644 --- a/block_backend/src/qcow2/cache.rs +++ b/block_backend/src/qcow2/cache.rs @@ -85,7 +85,7 @@ impl CacheTable { bail!("Invalid idx {}", idx); } let v = match self.entry_size { - ENTRY_SIZE_U16 => BigEndian::read_u16(&self.table_data[start..end]) as u64, + ENTRY_SIZE_U16 => u64::from(BigEndian::read_u16(&self.table_data[start..end])), ENTRY_SIZE_U64 => BigEndian::read_u64(&self.table_data[start..end]), _ => bail!("Unsupported entry size {}", self.entry_size), }; @@ -190,7 +190,7 @@ impl Qcow2Cache { ) -> Option>> { let mut replaced_entry: Option>> = None; let mut lru_count = u64::MAX; - let mut target_idx = 0; + let mut target_idx: u64 = 0; self.check_refcount(); entry.borrow_mut().lru_count = self.lru_count; self.lru_count += 1; @@ -262,7 +262,7 @@ mod test { for i in 0..buf.len() { vec.append(&mut buf[i].to_be_bytes().to_vec()); } - let mut entry = CacheTable::new(0x00 as u64, vec, ENTRY_SIZE_U64).unwrap(); + let mut entry = CacheTable::new(0x00_u64, vec, ENTRY_SIZE_U64).unwrap(); assert_eq!(entry.get_entry_map(0).unwrap(), 0x00); assert_eq!(entry.get_entry_map(3).unwrap(), 0x03); assert_eq!(entry.get_entry_map(4).unwrap(), 0x04); @@ -279,19 +279,19 @@ mod test { vec.append(&mut buf[i].to_be_bytes().to_vec()); } let entry_0 = Rc::new(RefCell::new( - CacheTable::new(0x00 as u64, vec.clone(), ENTRY_SIZE_U64).unwrap(), + CacheTable::new(0x00_u64, vec.clone(), ENTRY_SIZE_U64).unwrap(), )); entry_0.borrow_mut().lru_count = 0; let entry_1 = Rc::new(RefCell::new( - CacheTable::new(0x00 as u64, vec.clone(), ENTRY_SIZE_U64).unwrap(), + CacheTable::new(0x00_u64, vec.clone(), ENTRY_SIZE_U64).unwrap(), )); entry_1.borrow_mut().lru_count = 1; let entry_2 = Rc::new(RefCell::new( - CacheTable::new(0x00 as u64, vec.clone(), ENTRY_SIZE_U64).unwrap(), + CacheTable::new(0x00_u64, vec.clone(), ENTRY_SIZE_U64).unwrap(), )); entry_2.borrow_mut().lru_count = 2; let entry_3 = Rc::new(RefCell::new( - CacheTable::new(0x00 as u64, vec.clone(), ENTRY_SIZE_U64).unwrap(), + CacheTable::new(0x00_u64, vec.clone(), ENTRY_SIZE_U64).unwrap(), )); entry_3.borrow_mut().lru_count = 3; @@ -315,7 +315,7 @@ mod test { )); let mut qcow2_cache = Qcow2Cache::new(2); - qcow2_cache.lru_replace(addr, entry.clone()); + qcow2_cache.lru_replace(addr, entry); qcow2_cache.lru_count = u64::MAX - cnt / 2; // Not in cache. assert!(qcow2_cache.get(0).is_none()); diff --git a/block_backend/src/qcow2/check.rs b/block_backend/src/qcow2/check.rs index 2f6bb9f7055be21756b7a9f2ba81d430520293bf..4313111a7a58483c455c56df9f8b60ab736ecd51 100644 --- a/block_backend/src/qcow2/check.rs +++ b/block_backend/src/qcow2/check.rs @@ -105,7 +105,7 @@ impl RefcountBlock { let start_bytes = idx * self.entry_bytes; let end_bytes = start_bytes + self.entry_bytes; let value = match self.entry_bytes { - ENTRY_SIZE_U16 => BigEndian::read_u16(&self.data[start_bytes..end_bytes]) as u64, + ENTRY_SIZE_U16 => u64::from(BigEndian::read_u16(&self.data[start_bytes..end_bytes])), ENTRY_SIZE_U64 => BigEndian::read_u64(&self.data[start_bytes..end_bytes]), _ => bail!("Entry size is unsupported"), }; @@ -221,7 +221,7 @@ impl Qcow2Driver { let snapshot_table_length = size_of::() as u64; let snapshot_table_offset = self.header.snapshots_offset; // Validate snapshot table. - if (u64::MAX - nb_snapshots as u64 * snapshot_table_length) < snapshot_table_offset + if (u64::MAX - u64::from(nb_snapshots) * snapshot_table_length) < snapshot_table_offset || !is_aligned(self.header.cluster_size(), snapshot_table_offset) { res.err_num += 1; @@ -300,7 +300,7 @@ impl Qcow2Driver { /// Rebuild a new refcount table according to metadata, including active l1 table, active l2 table, /// snapshot table, refcount table and refcount block. pub(crate) fn check_refcounts(&mut self, check: &mut Qcow2Check) -> Result<()> { - let cluster_bits = self.header.cluster_bits as u64; + let cluster_bits = u64::from(self.header.cluster_bits); let cluster_size = 1 << cluster_bits; let virtual_size = self.header.size; check.res.disk_frag.total_clusters = div_round_up(virtual_size, cluster_size).unwrap(); @@ -321,7 +321,7 @@ impl Qcow2Driver { if check.res.need_rebuild && check.fix & FIX_ERRORS != 0 { let old_res = check.res; - let mut fresh_leak = 0; + let mut fresh_leak: i32 = 0; output_msg!(check.quite, "Rebuilding refcount structure"); self.rebuild_refcount_structure(check)?; @@ -380,14 +380,14 @@ impl Qcow2Driver { 0, self.header.cluster_size(), file_len, - self.header.cluster_bits as u64, + u64::from(self.header.cluster_bits), check, )?; // Increase the refcount of active l1 table. let active_l1_offset = self.header.l1_table_offset; let active_l1_size = self.header.l1_size; - self.check_refcounts_l1(active_l1_offset, active_l1_size as u64, true, check)?; + self.check_refcounts_l1(active_l1_offset, u64::from(active_l1_size), true, check)?; // Increase the refcount of snapshot table. for idx in 0..self.header.nb_snapshots { @@ -404,7 +404,7 @@ impl Qcow2Driver { continue; } - if snap_l1_size as u64 > QCOW2_MAX_L1_SIZE / ENTRY_SIZE { + if u64::from(snap_l1_size) > QCOW2_MAX_L1_SIZE / ENTRY_SIZE { output_msg!( check.quite, "ERROR snapshot {:?}({:?}) l1_size={:?} l1 table is too large; snapshot table entry courropted", @@ -414,7 +414,7 @@ impl Qcow2Driver { continue; } - self.check_refcounts_l1(snap_l1_offset, snap_l1_size as u64, false, check)?; + self.check_refcounts_l1(snap_l1_offset, u64::from(snap_l1_size), false, check)?; } let snap_table_offset = self.header.snapshots_offset; @@ -424,19 +424,19 @@ impl Qcow2Driver { snap_table_offset, snap_table_size, file_len, - self.header.cluster_bits as u64, + u64::from(self.header.cluster_bits), check, )?; } let reftable_offset = self.header.refcount_table_offset; let reftable_bytes = - self.header.refcount_table_clusters as u64 * self.header.cluster_size(); + u64::from(self.header.refcount_table_clusters) * self.header.cluster_size(); self.increase_refcounts( reftable_offset, reftable_bytes, file_len, - self.header.cluster_bits as u64, + u64::from(self.header.cluster_bits), check, )?; @@ -463,7 +463,7 @@ impl Qcow2Driver { l1_offset, l1_size_bytes, file_len, - self.header.cluster_bits as u64, + u64::from(self.header.cluster_bits), check, )?; let l1_table = self @@ -497,7 +497,7 @@ impl Qcow2Driver { l2_offset, self.header.cluster_size(), file_len, - self.header.cluster_bits as u64, + u64::from(self.header.cluster_bits), check, )?; @@ -529,7 +529,7 @@ impl Qcow2Driver { file_len: u64, check: &mut Qcow2Check, ) -> Result<()> { - let cluster_bits = self.header.cluster_bits as u64; + let cluster_bits = u64::from(self.header.cluster_bits); let cluster_size = 1 << cluster_bits; let l2_size = cluster_size >> ENTRY_BITS; @@ -661,7 +661,7 @@ impl Qcow2Driver { } fn check_refcount_block(&mut self, check: &mut Qcow2Check) -> Result<()> { - let cluster_bits = self.header.cluster_bits as u64; + let cluster_bits = u64::from(self.header.cluster_bits); let cluster_size = 1 << cluster_bits; let file_len = self.driver.disk_size()?; let nb_clusters = bytes_to_clusters(file_len, cluster_size)?; @@ -794,7 +794,7 @@ impl Qcow2Driver { ); if need_fixed { - let added = rc_value_2 as i32 - rc_value_1 as i32; + let added = i32::from(rc_value_2) - i32::from(rc_value_1); let cluster_offset = cluster_idx << cluster_bits; self.refcount.update_refcount( cluster_offset, @@ -874,7 +874,7 @@ impl Qcow2Driver { } } - let mut num_repaired = 0; + let mut num_repaired: i32 = 0; let l2_buf = self.load_cluster(l2_offset)?; let l2_table = Rc::new(RefCell::new(CacheTable::new( l2_offset, @@ -964,8 +964,8 @@ impl Qcow2Driver { let mut reftable_offset: u64 = 0; let mut new_reftable: Vec = Vec::new(); let mut reftable_clusters: u64 = 0; - let cluster_bits = self.header.cluster_bits as u64; - let refblock_bits: u64 = cluster_bits + 3 - self.header.refcount_order as u64; + let cluster_bits = u64::from(self.header.cluster_bits); + let refblock_bits: u64 = cluster_bits + 3 - u64::from(self.header.refcount_order); let refblock_size: u64 = 1 << refblock_bits; // self.refblock.nb_clusters means the maximum number of clusters that can be represented by @@ -1124,8 +1124,8 @@ mod test { assert!(refblock.set_refcount(9, 9).is_ok()); // Get inner dat - let mut vec_1 = (1 as u16).to_be_bytes().to_vec(); - let mut vec_2 = (7 as u16).to_be_bytes().to_vec(); + let mut vec_1 = 1_u16.to_be_bytes().to_vec(); + let mut vec_2 = 7_u16.to_be_bytes().to_vec(); vec_1.append(&mut vec_2); let buf = refblock.get_data(0, 2); assert_eq!(buf, vec_1); diff --git a/block_backend/src/qcow2/header.rs b/block_backend/src/qcow2/header.rs index d2b9f8f7845edeb520e5a0e03875177795c6ac2c..fa3681d9cc3f014e7b23285f61df5f20a91d68db 100644 --- a/block_backend/src/qcow2/header.rs +++ b/block_backend/src/qcow2/header.rs @@ -138,7 +138,7 @@ impl QcowHeader { if !(MIN_CLUSTER_BIT..=MAX_CLUSTER_BIT).contains(&self.cluster_bits) { bail!("Invalid cluster bits {}", self.cluster_bits); } - if self.header_length as u64 > self.cluster_size() { + if u64::from(self.header_length) > self.cluster_size() { bail!( "Header length {} over cluster size {}", self.header_length, @@ -168,7 +168,7 @@ impl QcowHeader { if self.refcount_table_clusters == 0 { bail!("Refcount table clusters is zero"); } - if self.refcount_table_clusters as u64 > MAX_REFTABLE_SIZE / self.cluster_size() { + if u64::from(self.refcount_table_clusters) > MAX_REFTABLE_SIZE / self.cluster_size() { bail!( "Refcount table size over limit {}", self.refcount_table_clusters @@ -181,7 +181,7 @@ impl QcowHeader { ); } self.refcount_table_offset - .checked_add(self.refcount_table_clusters as u64 * self.cluster_size()) + .checked_add(u64::from(self.refcount_table_clusters) * self.cluster_size()) .with_context(|| { format!( "Invalid offset {} or refcount table clusters {}", @@ -192,7 +192,7 @@ impl QcowHeader { } fn check_l1_table(&self) -> Result<()> { - if self.l1_size as u64 > MAX_L1TABLE_SIZE / ENTRY_SIZE { + if u64::from(self.l1_size) > MAX_L1TABLE_SIZE / ENTRY_SIZE { bail!("L1 table size over limit {}", self.l1_size); } if !self.cluster_aligned(self.l1_table_offset) { @@ -201,7 +201,7 @@ impl QcowHeader { let size_per_l1_entry = self.cluster_size() * self.cluster_size() / ENTRY_SIZE; let l1_need_sz = div_round_up(self.size, size_per_l1_entry).with_context(|| "Failed to get l1 size")?; - if (self.l1_size as u64) < l1_need_sz { + if u64::from(self.l1_size) < l1_need_sz { bail!( "L1 table is too small, l1 size {} expect {}", self.l1_size, @@ -209,7 +209,7 @@ impl QcowHeader { ); } self.l1_table_offset - .checked_add(self.l1_size as u64 * ENTRY_SIZE) + .checked_add(u64::from(self.l1_size) * ENTRY_SIZE) .with_context(|| { format!( "Invalid offset {} or entry size {}", @@ -320,21 +320,21 @@ mod test { fn invalid_header_list() -> Vec<(Vec, String)> { let mut list = Vec::new(); // Invalid buffer length. - list.push((vec![0_u8; 16], format!("Invalid header len"))); + list.push((vec![0_u8; 16], "Invalid header len".to_string())); // Invalid buffer length for v3. let buf = valid_header_v3(); list.push(( buf[0..90].to_vec(), - format!("Invalid header len for version 3"), + "Invalid header len for version 3".to_string(), )); // Invalid magic. let mut buf = valid_header_v2(); BigEndian::write_u32(&mut buf[0..4], 1234); - list.push((buf, format!("Invalid format"))); + list.push((buf, "Invalid format".to_string())); // Invalid version. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[4..8], 1); - list.push((buf, format!("Invalid version"))); + list.push((buf, "Invalid version".to_string())); // Large header length. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[100..104], 0x10000000_u32); @@ -345,23 +345,23 @@ mod test { // Small cluster bit. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[20..24], 0); - list.push((buf, format!("Invalid cluster bit"))); + list.push((buf, "Invalid cluster bit".to_string())); // Large cluster bit. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[20..24], 65); - list.push((buf, format!("Invalid cluster bit"))); + list.push((buf, "Invalid cluster bit".to_string())); // Invalid backing file offset. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[8..16], 0x2000); - list.push((buf, format!("Don't support backing file offset"))); + list.push((buf, "Don't support backing file offset".to_string())); // Invalid refcount order. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[96..100], 5); - list.push((buf, format!("Invalid refcount order"))); + list.push((buf, "Invalid refcount order".to_string())); // Refcount table offset is not aligned. let mut buf = valid_header_v3(); BigEndian::write_u64(&mut buf[48..56], 0x1234); - list.push((buf, format!("Refcount table offset not aligned"))); + list.push((buf, "Refcount table offset not aligned".to_string())); // Refcount table offset is large. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[36..40], 4 * 1024 * 1024); @@ -377,15 +377,15 @@ mod test { // Invalid refcount table cluster. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[56..60], 256); - list.push((buf, format!("Refcount table size over limit"))); + list.push((buf, "Refcount table size over limit".to_string())); // Refcount table cluster is 0. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[56..60], 0); - list.push((buf, format!("Refcount table clusters is zero"))); + list.push((buf, "Refcount table clusters is zero".to_string())); // L1 table offset is not aligned. let mut buf = valid_header_v3(); BigEndian::write_u64(&mut buf[40..48], 0x123456); - list.push((buf, format!("L1 table offset not aligned"))); + list.push((buf, "L1 table offset not aligned".to_string())); // L1 table offset is large. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[36..40], 4 * 1024 * 1024); @@ -401,12 +401,12 @@ mod test { // Invalid l1 table size. let mut buf = valid_header_v3(); BigEndian::write_u32(&mut buf[36..40], 0xffff_0000_u32); - list.push((buf, format!("L1 table size over limit"))); + list.push((buf, "L1 table size over limit".to_string())); // File size is large than l1 table size. let mut buf = valid_header_v3(); BigEndian::write_u64(&mut buf[24..32], 0xffff_ffff_ffff_0000_u64); BigEndian::write_u32(&mut buf[36..40], 10); - list.push((buf, format!("L1 table is too small"))); + list.push((buf, "L1 table is too small".to_string())); list } diff --git a/block_backend/src/qcow2/mod.rs b/block_backend/src/qcow2/mod.rs index 2dd1e4ac504335f20c9bcc882595a59e4cea585b..1151798d7a9c0daaa67a09cf88803a359ca117ba 100644 --- a/block_backend/src/qcow2/mod.rs +++ b/block_backend/src/qcow2/mod.rs @@ -49,8 +49,8 @@ use crate::{ snapshot::{InternalSnapshot, QcowSnapshot, QcowSnapshotExtraData, QCOW2_MAX_SNAPSHOTS}, table::{Qcow2ClusterType, Qcow2Table}, }, - BlockDriverOps, BlockIoErrorCallback, BlockProperty, BlockStatus, CheckResult, CreateOptions, - SECTOR_SIZE, + BlockAllocStatus, BlockDriverOps, BlockIoErrorCallback, BlockProperty, BlockStatus, + CheckResult, CreateOptions, ImageInfo, SECTOR_SIZE, }; use machine_manager::event_loop::EventLoop; use machine_manager::qmp::qmp_schema::SnapshotInfo; @@ -77,7 +77,7 @@ pub const QCOW2_OFLAG_ZERO: u64 = 1 << 0; const QCOW2_OFFSET_COMPRESSED: u64 = 1 << 62; pub const QCOW2_OFFSET_COPIED: u64 = 1 << 63; const MAX_L1_SIZE: u64 = 32 * (1 << 20); -const DEFAULT_SECTOR_SIZE: u64 = 512; +pub(crate) const DEFAULT_SECTOR_SIZE: u64 = 512; pub(crate) const QCOW2_MAX_L1_SIZE: u64 = 1 << 25; // The default flush interval is 30s. @@ -269,7 +269,7 @@ pub fn qcow2_flush_metadata( } impl Qcow2Driver { - pub fn new(file: File, aio: Aio, conf: BlockProperty) -> Result { + pub fn new(file: Arc, aio: Aio, conf: BlockProperty) -> Result { let fd = file.as_raw_fd(); let sync_aio = Rc::new(RefCell::new(SyncAioInfo::new(fd, conf.clone())?)); Ok(Self { @@ -335,12 +335,19 @@ impl Qcow2Driver { } pub fn load_refcount_table(&mut self) -> Result<()> { - let sz = - self.header.refcount_table_clusters as u64 * (self.header.cluster_size() / ENTRY_SIZE); + let sz = u64::from(self.header.refcount_table_clusters) + * (self.header.cluster_size() / ENTRY_SIZE); self.refcount.refcount_table = self .sync_aio .borrow_mut() .read_ctrl_cluster(self.header.refcount_table_offset, sz)?; + for block_offset in &self.refcount.refcount_table { + if *block_offset == 0 { + continue; + } + let rfb_offset = block_offset & REFCOUNT_TABLE_OFFSET_MASK; + self.refcount.refcount_table_map.insert(rfb_offset, 1); + } Ok(()) } @@ -383,9 +390,9 @@ impl Qcow2Driver { expect_len ); } - let mut host_start = 0; + let mut host_start: u64 = 0; let mut first_cluster_type = Qcow2ClusterType::Unallocated; - let mut cnt = 0; + let mut cnt: u64 = 0; while cnt < clusters { let offset = cnt * self.header.cluster_size(); let l2_entry = self.get_l2_entry(begin + offset)?; @@ -435,7 +442,7 @@ impl Qcow2Driver { l2_entry &= !QCOW2_OFLAG_ZERO; let mut cluster_addr = l2_entry & L2_TABLE_OFFSET_MASK; if cluster_addr == 0 { - let new_addr = self.alloc_cluster(1, true)?; + let new_addr = self.alloc_cluster(1, false)?; l2_entry = new_addr | QCOW2_OFFSET_COPIED; cluster_addr = new_addr & L2_TABLE_OFFSET_MASK; } else if l2_entry & QCOW2_OFFSET_COPIED == 0 { @@ -472,7 +479,7 @@ impl Qcow2Driver { /// Extend the l1 table. pub fn grow_l1_table(&mut self, new_l1_size: u64) -> Result<()> { - let old_l1_size = self.header.l1_size as u64; + let old_l1_size = u64::from(self.header.l1_size); if new_l1_size <= old_l1_size { return Ok(()); } @@ -535,7 +542,7 @@ impl Qcow2Driver { /// Output: target entry. pub fn get_table_cluster(&mut self, guest_offset: u64) -> Result>> { let l1_index = self.table.get_l1_table_index(guest_offset); - if l1_index >= self.header.l1_size as u64 { + if l1_index >= u64::from(self.header.l1_size) { bail!("Need to grow l1 table size."); } @@ -821,11 +828,11 @@ impl Qcow2Driver { let snap = self.snapshot.snapshots[snap_id as usize].clone(); // Validate snapshot table - if snap.l1_size as u64 > MAX_L1_SIZE / ENTRY_SIZE { + if u64::from(snap.l1_size) > MAX_L1_SIZE / ENTRY_SIZE { bail!("Snapshot L1 table too large"); } - if i64::MAX as u64 - snap.l1_size as u64 * ENTRY_SIZE < snap.l1_table_offset + if i64::MAX as u64 - u64::from(snap.l1_size) * ENTRY_SIZE < snap.l1_table_offset || !is_aligned(self.header.cluster_size(), snap.l1_table_offset) { bail!("Snapshot L1 table offset invalid"); @@ -835,12 +842,12 @@ impl Qcow2Driver { let mut snap_l1_table = self .sync_aio .borrow_mut() - .read_ctrl_cluster(snap.l1_table_offset, snap.l1_size as u64)?; + .read_ctrl_cluster(snap.l1_table_offset, u64::from(snap.l1_size))?; // SAFETY: Upper limit of l1_size is decided by disk virtual size. snap_l1_table.resize(snap.l1_size as usize, 0); let cluster_size = self.header.cluster_size(); - let snap_l1_table_bytes = snap.l1_size as u64 * ENTRY_SIZE; + let snap_l1_table_bytes = u64::from(snap.l1_size) * ENTRY_SIZE; let snap_l1_table_clusters = bytes_to_clusters(snap_l1_table_bytes, cluster_size).unwrap(); let new_l1_table_offset = self.alloc_cluster(snap_l1_table_clusters, true)?; @@ -865,12 +872,17 @@ impl Qcow2Driver { self.table.l1_table_offset = new_l1_table_offset; self.table.l1_size = snap.l1_size; self.table.l1_table = snap_l1_table; + self.table.l1_table_map.clear(); + for l1_entry in self.table.l1_table.iter() { + let addr = l1_entry & L1_TABLE_OFFSET_MASK; + self.table.l1_table_map.insert(addr, 1); + } self.qcow2_update_snapshot_refcount(old_l1_table_offset, old_l1_size as usize, -1)?; // Free the snaphshot L1 table. let old_l1_table_clusters = - bytes_to_clusters(old_l1_size as u64 * ENTRY_SIZE, cluster_size).unwrap(); + bytes_to_clusters(u64::from(old_l1_size) * ENTRY_SIZE, cluster_size).unwrap(); self.refcount.update_refcount( old_l1_table_offset, old_l1_table_clusters, @@ -901,18 +913,22 @@ impl Qcow2Driver { bail!("Snapshot with name {} does not exist", name); } + // Record the old snapshot table size which will be used to free these old snapshot table clusters. + let cluster_size = self.header.cluster_size(); + let old_snapshot_table_clusters = + bytes_to_clusters(self.snapshot.snapshot_size, cluster_size).unwrap(); + // Delete snapshot information in memory. let snap = self.snapshot.del_snapshot(snapshot_idx as usize); // Alloc new cluster to save snapshots(except the deleted one) to disk. - let cluster_size = self.header.cluster_size(); let mut new_snapshots_offset = 0_u64; - let snapshot_table_clusters = + let new_snapshot_table_clusters = bytes_to_clusters(self.snapshot.snapshot_size, cluster_size).unwrap(); if self.snapshot.snapshots_number() > 0 { - new_snapshots_offset = self.alloc_cluster(snapshot_table_clusters, true)?; + new_snapshots_offset = self.alloc_cluster(new_snapshot_table_clusters, true)?; self.snapshot - .save_snapshot_table(new_snapshots_offset, &snap, false)?; + .save_snapshot_table(new_snapshots_offset, Some(&snap), false)?; } self.snapshot.snapshot_table_offset = new_snapshots_offset; @@ -921,7 +937,7 @@ impl Qcow2Driver { // Free the snaphshot L1 table. let l1_table_clusters = - bytes_to_clusters(snap.l1_size as u64 * ENTRY_SIZE, cluster_size).unwrap(); + bytes_to_clusters(u64::from(snap.l1_size) * ENTRY_SIZE, cluster_size).unwrap(); self.refcount.update_refcount( snap.l1_table_offset, l1_table_clusters, @@ -940,7 +956,7 @@ impl Qcow2Driver { // Free the cluster of the old snapshot table. self.refcount.update_refcount( self.header.snapshots_offset, - snapshot_table_clusters, + old_snapshot_table_clusters, -1, false, &Qcow2DiscardType::Snapshot, @@ -952,7 +968,7 @@ impl Qcow2Driver { self.table.save_l1_table()?; // Update the snapshot information in qcow2 header. - self.update_snapshot_info_in_header(new_snapshots_offset, false)?; + self.update_snapshot_info_in_header(new_snapshots_offset, -1)?; // Discard unused clusters. self.refcount.sync_process_discards(OpCode::Discard); @@ -960,7 +976,7 @@ impl Qcow2Driver { Ok(SnapshotInfo { id: snap.id.to_string(), name: snap.name.clone(), - vm_state_size: snap.vm_state_size as u64, + vm_state_size: u64::from(snap.vm_state_size), date_sec: snap.date_sec, date_nsec: snap.date_nsec, vm_clock_nsec: snap.vm_clock_nsec, @@ -981,7 +997,7 @@ impl Qcow2Driver { // Alloc cluster and copy L1 table for snapshot. let cluster_size = self.header.cluster_size(); - let l1_table_len = self.header.l1_size as u64 * ENTRY_SIZE; + let l1_table_len = u64::from(self.header.l1_size) * ENTRY_SIZE; let l1_table_clusters = bytes_to_clusters(l1_table_len, cluster_size).unwrap(); let new_l1_table_offset = self.alloc_cluster(l1_table_clusters, true)?; self.sync_aio @@ -997,6 +1013,17 @@ impl Qcow2Driver { // Alloc new snapshot table. let (date_sec, date_nsec) = gettime()?; + // Note: The `Snapshots` chapter in Qcow2 spec states: + // Snapshot table entry: + // Byte 16 - 19: Time at which the snapshot was taken in seconds since the + // Epoch + // Byte 20 - 23: Subsecond part of the time at which the snapshot was taken + // in nanoseconds + // + // 32 bits of seconds can represent a range of approximately 136 years since 1970. + // It's enough for current use. If an incorrect host time is used to inject error, + // there may be an issue of inaccurate creation time in the snapshot description. + // Considering compatibility, this issue of inaccurate time is acceptable. let snap = QcowSnapshot { l1_table_offset: new_l1_table_offset, l1_size: self.header.l1_size, @@ -1004,8 +1031,8 @@ impl Qcow2Driver { name, disk_size: self.virtual_disk_size(), vm_state_size: 0, - date_sec, - date_nsec, + date_sec: date_sec as u32, + date_nsec: date_nsec as u32, vm_clock_nsec, icount: u64::MAX, extra_data_size: size_of::() as u32, @@ -1021,7 +1048,7 @@ impl Qcow2Driver { // Append the new snapshot to the snapshot table and write new snapshot table to file. self.snapshot - .save_snapshot_table(new_snapshots_offset, &snap, true)?; + .save_snapshot_table(new_snapshots_offset, Some(&snap), true)?; // Free the old snapshot table cluster if snapshot exists. if self.header.snapshots_offset != 0 { @@ -1041,7 +1068,7 @@ impl Qcow2Driver { self.table.save_l1_table()?; // Update snapshot offset and num in qcow2 header. - self.update_snapshot_info_in_header(new_snapshots_offset, true)?; + self.update_snapshot_info_in_header(new_snapshots_offset, 1)?; // Add and update snapshot information in memory. self.snapshot.add_snapshot(snap); @@ -1053,14 +1080,10 @@ impl Qcow2Driver { Ok(()) } - fn update_snapshot_info_in_header(&mut self, snapshot_offset: u64, add: bool) -> Result<()> { + fn update_snapshot_info_in_header(&mut self, snapshot_offset: u64, add: i32) -> Result<()> { let mut new_header = self.header.clone(); new_header.snapshots_offset = snapshot_offset; - if add { - new_header.nb_snapshots += 1; - } else { - new_header.nb_snapshots -= 1; - } + new_header.nb_snapshots = (new_header.nb_snapshots as i32 + add) as u32; self.sync_aio .borrow_mut() .write_buffer(0, &new_header.to_vec())?; @@ -1260,7 +1283,7 @@ impl Qcow2Driver { _ => snap.icount.to_string(), }; - let date = get_format_time(snap.date_sec as i64); + let date = get_format_time(i64::from(snap.date_sec)); let date_str = format!( "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", date[0], date[1], date[2], date[3], date[4], date[5] @@ -1291,13 +1314,14 @@ impl Qcow2Driver { return 0; } - if check & METADATA_OVERLAP_CHECK_MAINHEADER != 0 && offset < self.header.cluster_size() { + let cluster_size = self.header.cluster_size(); + if check & METADATA_OVERLAP_CHECK_MAINHEADER != 0 && offset < cluster_size { return METADATA_OVERLAP_CHECK_MAINHEADER as i64; } let size = round_up( self.refcount.offset_into_cluster(offset) + size, - self.header.cluster_size(), + cluster_size, ) .unwrap() as usize; let offset = self.refcount.start_of_cluster(offset) as usize; @@ -1321,15 +1345,10 @@ impl Qcow2Driver { } if check & METADATA_OVERLAP_CHECK_ACTIVEL2 != 0 { - for l1_entry in &self.table.l1_table { - if ranges_overlap( - offset, - size, - (l1_entry & L1_TABLE_OFFSET_MASK) as usize, - self.header.cluster_size() as usize, - ) - .unwrap() - { + let num = size as u64 / cluster_size; + for i in 0..num { + let addr = offset as u64 + i * cluster_size; + if self.table.l1_table_map.contains_key(&addr) { return METADATA_OVERLAP_CHECK_ACTIVEL2 as i64; } } @@ -1340,7 +1359,7 @@ impl Qcow2Driver { offset, size, self.header.refcount_table_offset as usize, - self.header.refcount_table_clusters as usize * self.header.cluster_size() as usize, + self.header.refcount_table_clusters as usize * cluster_size as usize, ) .unwrap() { @@ -1348,15 +1367,10 @@ impl Qcow2Driver { } if check & METADATA_OVERLAP_CHECK_REFCOUNTBLOCK != 0 { - for block_offset in &self.refcount.refcount_table { - if ranges_overlap( - offset, - size, - (block_offset & REFCOUNT_TABLE_OFFSET_MASK) as usize, - self.header.cluster_size() as usize, - ) - .unwrap() - { + let num = size as u64 / cluster_size; + for i in 0..num { + let addr = offset as u64 + i * cluster_size; + if self.refcount.refcount_table_map.contains_key(&addr) { return METADATA_OVERLAP_CHECK_REFCOUNTBLOCK as i64; } } @@ -1404,6 +1418,11 @@ pub trait InternalSnapshotOps: Send + Sync { fn apply_snapshot(&mut self, name: String) -> Result<()>; fn list_snapshots(&self) -> String; fn get_status(&self) -> Arc>; + fn rename_snapshot( + &mut self, + old_snapshot_name: String, + new_snapshot_name: String, + ) -> Result<()>; } impl InternalSnapshotOps for Qcow2Driver { @@ -1443,6 +1462,52 @@ impl InternalSnapshotOps for Qcow2Driver { fn get_status(&self) -> Arc> { self.status.clone() } + + fn rename_snapshot( + &mut self, + old_snapshot_name: String, + new_snapshot_name: String, + ) -> Result<()> { + if self.get_snapshot_by_name(&new_snapshot_name) != -1 { + bail!("New snapshot name {} exits!", new_snapshot_name); + } + + let snap_id = self.get_snapshot_by_name(&old_snapshot_name); + if snap_id < 0 { + bail!("Snapshot name {} doesn't exit!", old_snapshot_name); + } + + // Update snapshot info in memory. Note: Stratovirt-img will exit if next actions fail. + // And these modified snapshot information in memory will not affect. + self.snapshot.snapshots[snap_id as usize].name = new_snapshot_name; + + // Write new snapshot info to new snapshot table. + let old_snapshots_offset = self.header.snapshots_offset; + let cluster_size = self.header.cluster_size(); + let snapshot_table_len = self.snapshot.snapshot_size; + let snapshot_table_clusters = bytes_to_clusters(snapshot_table_len, cluster_size).unwrap(); + let new_snapshots_offset = self.alloc_cluster(snapshot_table_clusters, true)?; + self.snapshot + .save_snapshot_table(new_snapshots_offset, None, true)?; + + // Update the snapshot information in qcow2 header. + self.update_snapshot_info_in_header(new_snapshots_offset, 0)?; + + // Delete old snapshot: Free the cluster of the old snapshot table. + self.refcount.update_refcount( + old_snapshots_offset, + snapshot_table_clusters, + -1, + false, + &Qcow2DiscardType::Snapshot, + )?; + self.flush()?; + + // Discard unused clusters. + self.refcount.sync_process_discards(OpCode::Discard); + + Ok(()) + } } // SAFETY: Send and Sync is not auto-implemented for raw pointer type in Aio. @@ -1611,19 +1676,25 @@ impl BlockDriverOps for Qcow2Driver { // Write zero. for i in 0..3 { let offset = i * cluster_size; - self.driver.file.seek(SeekFrom::Start(offset))?; - self.driver.file.write_all(&zero_buf.to_vec())? + self.driver.file.as_ref().seek(SeekFrom::Start(offset))?; + self.driver.file.as_ref().write_all(&zero_buf.to_vec())? } - self.driver.file.rewind()?; - self.driver.file.write_all(&self.header.to_vec())?; + self.driver.file.as_ref().rewind()?; + self.driver.file.as_ref().write_all(&self.header.to_vec())?; // Refcount table. - self.driver.file.seek(SeekFrom::Start(cluster_size))?; - self.driver.file.write_all(&rc_table)?; + self.driver + .file + .as_ref() + .seek(SeekFrom::Start(cluster_size))?; + self.driver.file.as_ref().write_all(&rc_table)?; // Refcount block table. - self.driver.file.seek(SeekFrom::Start(cluster_size * 2))?; - self.driver.file.write_all(&rc_block)?; + self.driver + .file + .as_ref() + .seek(SeekFrom::Start(cluster_size * 2))?; + self.driver.file.as_ref().write_all(&rc_block)?; // Create qcow2 driver. self.load_refcount_table()?; @@ -1640,6 +1711,18 @@ impl BlockDriverOps for Qcow2Driver { Ok(image_info) } + fn query_image(&mut self, info: &mut ImageInfo) -> Result<()> { + info.format = "qcow2".to_string(); + info.virtual_size = self.disk_size()?; + info.actual_size = self.driver.actual_size()?; + info.cluster_size = Some(self.header.cluster_size()); + + if !self.snapshot.snapshots.is_empty() { + info.snap_lists = Some(self.qcow2_list_snapshots()); + } + Ok(()) + } + fn check_image(&mut self, res: &mut CheckResult, quite: bool, fix: u64) -> Result<()> { let cluster_size = self.header.cluster_size(); let refcount_order = self.header.refcount_order; @@ -1669,11 +1752,11 @@ impl BlockDriverOps for Qcow2Driver { let mut left = iovec; let mut req_list: Vec = Vec::new(); - let mut copied = 0; + let mut copied: u64 = 0; while copied < nbytes { let pos = offset as u64 + copied; - match self.host_offset_for_read(pos, nbytes - copied)? { - HostRange::DataAddress(host_offset, cnt) => { + match self.host_offset_for_read(pos, nbytes - copied) { + Ok(HostRange::DataAddress(host_offset, cnt)) => { let (begin, end) = iovecs_split(left, cnt); left = end; req_list.push(CombineRequest { @@ -1683,12 +1766,17 @@ impl BlockDriverOps for Qcow2Driver { }); copied += cnt; } - HostRange::DataNotInit(cnt) => { + Ok(HostRange::DataNotInit(cnt)) => { let (begin, end) = iovecs_split(left, cnt); left = end; - iovec_write_zero(&begin); + // SAFETY: iovecs is generated by address_space. + unsafe { iovec_write_zero(&begin) }; copied += cnt; } + Err(e) => { + error!("Failed to read vectored: {:?}", e); + return self.driver.complete_request(OpCode::Preadv, -1, completecb); + } } } @@ -1702,11 +1790,19 @@ impl BlockDriverOps for Qcow2Driver { trace::block_write_vectored(&self.driver.block_prop.id, offset, nbytes); let mut req_list: Vec = Vec::new(); - let mut copied = 0; + let mut copied: u64 = 0; while copied < nbytes { let pos = offset as u64 + copied; let count = self.cluster_aligned_bytes(pos, nbytes - copied); - let host_offset = self.host_offset_for_write(pos, count)?; + let host_offset = match self.host_offset_for_write(pos, count) { + Ok(host_offset) => host_offset, + Err(e) => { + error!("Failed to write vectored: {:?}", e); + return self + .driver + .complete_request(OpCode::Pwritev, -1, completecb); + } + }; if let Some(end) = req_list.last_mut() { if end.offset + end.nbytes == host_offset { end.nbytes += count; @@ -1886,6 +1982,17 @@ impl BlockDriverOps for Qcow2Driver { fn get_status(&mut self) -> Arc> { self.status.clone() } + + fn get_address_alloc_status( + &mut self, + offset: u64, + bytes: u64, + ) -> Result<(BlockAllocStatus, u64)> { + match self.host_offset_for_read(offset, bytes)? { + HostRange::DataNotInit(size) => Ok((BlockAllocStatus::ZERO, size)), + HostRange::DataAddress(_, size) => Ok((BlockAllocStatus::DATA, size)), + } + } } pub fn is_aligned(cluster_sz: u64, offset: u64) -> bool { @@ -1952,19 +2059,19 @@ mod test { .custom_flags(libc::O_CREAT | libc::O_TRUNC) .open(path) .unwrap(); - file.set_len(cluster_sz * 3 + header.l1_size as u64 * ENTRY_SIZE) + file.set_len(cluster_sz * 3 + u64::from(header.l1_size) * ENTRY_SIZE) .unwrap(); let zero_buf = - vec![0_u8; (cluster_sz * 3 + header.l1_size as u64 * ENTRY_SIZE) as usize]; + vec![0_u8; (cluster_sz * 3 + u64::from(header.l1_size) * ENTRY_SIZE) as usize]; file.write_all(&zero_buf).unwrap(); file.seek(SeekFrom::Start(0)).unwrap(); file.write_all(&header.to_vec()).unwrap(); // Cluster 1 is the refcount table. - assert_eq!(header.refcount_table_offset, cluster_sz * 1); + assert_eq!(header.refcount_table_offset, cluster_sz); let mut refcount_table = [0_u8; ENTRY_SIZE as usize]; BigEndian::write_u64(&mut refcount_table, cluster_sz * 2); - file.seek(SeekFrom::Start(cluster_sz * 1)).unwrap(); + file.seek(SeekFrom::Start(cluster_sz)).unwrap(); file.write_all(&refcount_table).unwrap(); // Clusters which has been allocated. @@ -1988,11 +2095,13 @@ mod test { } fn create_qcow2_driver(&self, conf: BlockProperty) -> Qcow2Driver<()> { - let file = std::fs::OpenOptions::new() - .read(true) - .write(true) - .open(&self.path) - .unwrap(); + let file = Arc::new( + std::fs::OpenOptions::new() + .read(true) + .write(true) + .open(&self.path) + .unwrap(), + ); let aio = Aio::new( Arc::new(SyncAioInfo::complete_func), util::aio::AioEngine::Off, @@ -2063,11 +2172,13 @@ mod test { pub fn create_qcow2(path: &str) -> (TestImage, Qcow2Driver<()>) { let mut image = TestImage::new(path, 30, 16); - let file = std::fs::OpenOptions::new() - .read(true) - .write(true) - .open(path) - .unwrap(); + let file = Arc::new( + std::fs::OpenOptions::new() + .read(true) + .write(true) + .open(path) + .unwrap(), + ); let aio = Aio::new( Arc::new(SyncAioInfo::complete_func), util::aio::AioEngine::Off, @@ -2276,8 +2387,10 @@ mod test { let mut wbuf = vec![0; case.sz as usize]; let mut rbuf = vec![0; case.sz as usize]; - let wsz = iov_to_buf_direct(&case.wiovec, 0, &mut wbuf).unwrap(); - let rsz = iov_to_buf_direct(&case.riovec, 0, &mut rbuf).unwrap(); + // SAFETY: wiovec is valid. + let wsz = unsafe { iov_to_buf_direct(&case.wiovec, 0, &mut wbuf).unwrap() }; + // SAFETY: riovec is valid. + let rsz = unsafe { iov_to_buf_direct(&case.riovec, 0, &mut rbuf).unwrap() }; assert_eq!(wsz, case.sz as usize); assert_eq!(rsz, case.sz as usize); assert_eq!(wbuf, rbuf); @@ -2320,7 +2433,7 @@ mod test { riovec, wiovec, data: test_data, - offset: 1 * CLUSTER_SIZE as usize, + offset: CLUSTER_SIZE as usize, sz: CLUSTER_SIZE, }); let test_data = vec![TestData::new(5, CLUSTER_SIZE as usize)]; @@ -2351,8 +2464,10 @@ mod test { let mut wbuf = vec![0; case.sz as usize]; let mut rbuf = vec![0; case.sz as usize]; - let wsz = iov_to_buf_direct(&case.wiovec, 0, &mut wbuf).unwrap(); - let rsz = iov_to_buf_direct(&case.riovec, 0, &mut rbuf).unwrap(); + // SAFETY: wiovec is valid. + let wsz = unsafe { iov_to_buf_direct(&case.wiovec, 0, &mut wbuf).unwrap() }; + // SAFETY: riovec is valid. + let rsz = unsafe { iov_to_buf_direct(&case.riovec, 0, &mut rbuf).unwrap() }; assert_eq!(wsz, case.sz as usize); assert_eq!(rsz, case.sz as usize); assert_eq!(wbuf, rbuf); @@ -2706,9 +2821,8 @@ mod test { .borrow_mut() .get_entry_map(l2_index as usize) .unwrap(); - let host_offset = l2_entry & L2_TABLE_OFFSET_MASK; - host_offset + l2_entry & L2_TABLE_OFFSET_MASK } // Change snapshot table offset to unaligned address which will lead to error in refcount update diff --git a/block_backend/src/qcow2/refcount.rs b/block_backend/src/qcow2/refcount.rs index 1a3dfdddcd367fd8780abe06b3e8fc5b2b48fe92..1c7ff10e2b6b40ef19e4b2e3a829f9e7917e2f29 100644 --- a/block_backend/src/qcow2/refcount.rs +++ b/block_backend/src/qcow2/refcount.rs @@ -10,7 +10,7 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; use anyhow::{bail, Context, Result}; use log::{error, info}; @@ -32,6 +32,9 @@ use util::{ // The max refcount table size default is 4 clusters; const MAX_REFTABLE_NUM: u64 = 4; +// Default refcount table map length, which can describe 512GiB data for 64Kib cluster. +const REFCOUNT_TABLE_MAP_LEN: usize = 256; + #[derive(Eq, PartialEq, Clone)] pub enum Qcow2DiscardType { Never, @@ -64,6 +67,7 @@ impl DiscardTask { #[derive(Clone)] pub struct RefCount { pub refcount_table: Vec, + pub refcount_table_map: HashMap, sync_aio: Rc>, pub(crate) refcount_blk_cache: Qcow2Cache, pub discard_list: Vec, @@ -87,6 +91,7 @@ impl RefCount { pub fn new(sync_aio: Rc>) -> Self { RefCount { refcount_table: Vec::new(), + refcount_table_map: HashMap::with_capacity(REFCOUNT_TABLE_MAP_LEN), sync_aio, refcount_blk_cache: Qcow2Cache::default(), discard_list: Vec::new(), @@ -114,7 +119,7 @@ impl RefCount { self.refcount_table_offset = header.refcount_table_offset; self.refcount_table_clusters = header.refcount_table_clusters; self.refcount_table_size = - header.refcount_table_clusters as u64 * header.cluster_size() / ENTRY_SIZE; + u64::from(header.refcount_table_clusters) * header.cluster_size() / ENTRY_SIZE; self.refcount_blk_bits = header.cluster_bits + 3 - header.refcount_order; self.refcount_blk_size = 1 << self.refcount_blk_bits; self.cluster_bits = header.cluster_bits; @@ -136,7 +141,7 @@ impl RefCount { } fn cluster_in_rc_block(&self, cluster_index: u64) -> u64 { - cluster_index & (self.refcount_blk_size - 1) as u64 + cluster_index & u64::from(self.refcount_blk_size - 1) } /// Allocate a continuous space that is not referenced by existing refcount table @@ -146,7 +151,7 @@ impl RefCount { } let nb_clusters = bytes_to_clusters(size, self.cluster_size).unwrap(); - let mut free_clusters = 0; + let mut free_clusters: u64 = 0; while free_clusters < nb_clusters { let offset = self.free_cluster_index << self.cluster_bits; self.free_cluster_index += 1; @@ -177,7 +182,7 @@ impl RefCount { let (table, blocks) = refcount_metadata_size( clusters, self.cluster_size, - header.refcount_order as u64, + u64::from(header.refcount_order), true, )?; self.extend_refcount_table(header, start_idx, table, blocks)?; @@ -217,9 +222,11 @@ impl RefCount { new_table.resize(new_table_size as usize, 0); let start_offset = start_idx * self.cluster_size; let mut table_offset = start_offset; + let mut added_rb = Vec::new(); for i in 0..new_block_clusters { if new_table[i as usize] == 0 { new_table[i as usize] = table_offset; + added_rb.push(table_offset & REFCOUNT_TABLE_OFFSET_MASK); table_offset += self.cluster_size; } } @@ -247,6 +254,9 @@ impl RefCount { let old_table_offset = self.refcount_table_offset; let old_table_clusters = self.refcount_table_clusters; self.refcount_table = new_table; + for rb_offset in added_rb.iter() { + self.refcount_table_map.insert(*rb_offset, 1); + } self.refcount_table_offset = header.refcount_table_offset; self.refcount_table_clusters = header.refcount_table_clusters; self.refcount_table_size = new_table_size; @@ -264,7 +274,7 @@ impl RefCount { // Free the old cluster of refcount table. self.update_refcount( old_table_offset, - old_table_clusters as u64, + u64::from(old_table_clusters), -1, true, &Qcow2DiscardType::Other, @@ -316,8 +326,8 @@ impl RefCount { bail!("Failed to update refcount, offset is not aligned to cluster"); } let first_cluster = bytes_to_clusters(offset, self.cluster_size).unwrap(); - let mut rc_vec = Vec::new(); - let mut i = 0; + let mut rc_vec: Vec<(u64, u64, usize)> = Vec::with_capacity(clusters as usize); + let mut i: u64 = 0; while i < clusters { let rt_idx = (first_cluster + i) >> self.refcount_blk_bits; if rt_idx >= self.refcount_table_size { @@ -381,6 +391,22 @@ impl RefCount { self.refcount_blk_cache.flush(self.sync_aio.clone()) } + fn get_refcount_block_cache(&mut self, rt_idx: u64) -> Result>> { + let entry = self.refcount_blk_cache.get(rt_idx); + let cache_entry = if let Some(entry) = entry { + entry.clone() + } else { + self.load_refcount_block(rt_idx).with_context(|| { + format!("Failed to get refcount block cache, index is {}", rt_idx) + })?; + self.refcount_blk_cache + .get(rt_idx) + .with_context(|| format!("Not found refcount block cache, index is {}", rt_idx))? + .clone() + }; + Ok(cache_entry) + } + fn set_refcount( &mut self, rt_idx: u64, @@ -391,18 +417,10 @@ impl RefCount { ) -> Result<()> { let is_add = added > 0; let added_value = added.unsigned_abs() as u16; - if !self.refcount_blk_cache.contains_keys(rt_idx) { - self.load_refcount_block(rt_idx).with_context(|| { - format!("Failed to get refcount block cache, index is {}", rt_idx) - })?; - } let cache_entry = self - .refcount_blk_cache - .get(rt_idx) - .with_context(|| format!("Not found refcount block cache, index is {}", rt_idx))? - .clone(); - - let mut rb_vec = Vec::new(); + .get_refcount_block_cache(rt_idx) + .with_context(|| "Get refcount block cache failed")?; + let mut rb_vec: Vec = Vec::with_capacity(clusters); let mut borrowed_entry = cache_entry.borrow_mut(); let is_dirty = borrowed_entry.dirty_info.is_dirty; for i in 0..clusters { @@ -425,7 +443,7 @@ impl RefCount { ) })? }; - let cluster_idx = rt_idx * self.refcount_blk_size as u64 + rb_idx + i as u64; + let cluster_idx = rt_idx * u64::from(self.refcount_blk_size) + rb_idx + i as u64; if rc_value == 0 { if self.discard_passthrough.contains(discard_type) { // update refcount discard. @@ -442,7 +460,7 @@ impl RefCount { } for (idx, rc_value) in rb_vec.iter().enumerate() { - borrowed_entry.set_entry_map(rb_idx as usize + idx, *rc_value as u64)?; + borrowed_entry.set_entry_map(rb_idx as usize + idx, u64::from(*rc_value))?; } if !is_dirty { self.refcount_blk_cache.add_dirty_table(cache_entry.clone()); @@ -471,17 +489,9 @@ impl RefCount { ); } - if !self.refcount_blk_cache.contains_keys(rt_idx) { - self.load_refcount_block(rt_idx).with_context(|| { - format!("Failed to get refcount block cache, index is {}", rt_idx) - })?; - } let cache_entry = self - .refcount_blk_cache - .get(rt_idx) - .with_context(|| format!("Not found refcount block cache, index is {}", rt_idx))? - .clone(); - + .get_refcount_block_cache(rt_idx) + .with_context(|| "Get refcount block cache failed")?; let rb_idx = self.cluster_in_rc_block(cluster) as usize; let rc_value = cache_entry.borrow_mut().get_entry_map(rb_idx).unwrap(); @@ -490,18 +500,8 @@ impl RefCount { /// Add discard task to the list. fn update_discard_list(&mut self, offset: u64, nbytes: u64) -> Result<()> { - let mut discard_task = DiscardTask { offset, nbytes }; - let len = self.discard_list.len(); - let mut discard_list: Vec = Vec::with_capacity(len + 1); - for task in self.discard_list.iter() { - if discard_task.is_overlap(task) { - discard_task.merge_task(task); - } else { - discard_list.push(task.clone()); - } - } - discard_list.push(discard_task); - self.discard_list = discard_list; + let discard_task = DiscardTask { offset, nbytes }; + self.discard_list.push(discard_task); Ok(()) } @@ -542,6 +542,8 @@ impl RefCount { // Update refcount table. self.refcount_table[rt_idx as usize] = alloc_offset; + let rb_offset = alloc_offset & REFCOUNT_TABLE_OFFSET_MASK; + self.refcount_table_map.insert(rb_offset, 1); let rc_block = vec![0_u8; self.cluster_size as usize]; let cache_entry = Rc::new(RefCell::new(CacheTable::new( alloc_offset, @@ -657,11 +659,11 @@ pub fn refcount_metadata_size( ) -> Result<(u64, u64)> { let reftable_entries = cluster_size / ENTRY_SIZE; let refblock_entries = cluster_size * 8 / (1 << refcount_order); - let mut table = 0; - let mut blocks = 0; + let mut table: u64 = 0; + let mut blocks: u64 = 0; let mut clusters = nb_clusters; let mut last_clusters; - let mut total_clusters = 0; + let mut total_clusters: u64 = 0; loop { last_clusters = total_clusters; @@ -779,8 +781,8 @@ mod test { path: &str, img_bits: u32, cluster_bits: u32, - ) -> (Qcow2Driver<()>, File) { - let file = image_create(path, img_bits, cluster_bits); + ) -> (Qcow2Driver<()>, Arc) { + let file = Arc::new(image_create(path, img_bits, cluster_bits)); let aio = Aio::new( Arc::new(SyncAioInfo::complete_func), util::aio::AioEngine::Off, @@ -799,10 +801,9 @@ mod test { l2_cache_size: None, refcount_cache_size: None, }; - let cloned_file = file.try_clone().unwrap(); - let mut qcow2_driver = Qcow2Driver::new(file, aio, conf.clone()).unwrap(); + let mut qcow2_driver = Qcow2Driver::new(file.clone(), aio, conf.clone()).unwrap(); qcow2_driver.load_metadata(conf).unwrap(); - (qcow2_driver, cloned_file) + (qcow2_driver, file) } #[test] @@ -818,14 +819,15 @@ mod test { let free_cluster_index = 3 + ((header.l1_size * ENTRY_SIZE as u32 + cluster_sz as u32 - 1) >> cluster_bits); let addr = qcow2.alloc_cluster(1, true).unwrap(); - assert_eq!(addr, cluster_sz * free_cluster_index as u64); + assert_eq!(addr, cluster_sz * u64::from(free_cluster_index)); qcow2.flush().unwrap(); // Check if the refcount of the cluster is updated to the disk. let mut rc_value = [0_u8; 2]; cloned_file + .as_ref() .read_at( &mut rc_value, - cluster_sz * 2 + 2 * free_cluster_index as u64, + cluster_sz * 2 + 2 * u64::from(free_cluster_index), ) .unwrap(); assert_eq!(1, BigEndian::read_u16(&rc_value)); @@ -861,6 +863,7 @@ mod test { let table_size = div_round_up(image_size, block_size * cluster_size).unwrap(); let mut refcount_table = vec![0_u8; table_size as usize * ENTRY_SIZE as usize]; assert!(cloned_file + .as_ref() .read_at(&mut refcount_table, table_offset) .is_ok()); for i in 0..table_size { @@ -877,7 +880,7 @@ mod test { for j in (i + 1)..len { let addr2 = res_data[j].0 as usize; let size2 = res_data[j].1 as usize; - assert_eq!(ranges_overlap(addr1, size1, addr2, size2).unwrap(), false); + assert!(!ranges_overlap(addr1, size1, addr2, size2).unwrap()); } } @@ -916,10 +919,14 @@ mod test { let old_rct_size = cluster_sz as usize * rct_clusters as usize; let new_rct_size = cluster_sz as usize * new_rct_clusters as usize; let mut old_rc_table = vec![0_u8; old_rct_size]; - cloned_file.read_at(&mut old_rc_table, rct_offset).unwrap(); + cloned_file + .as_ref() + .read_at(&mut old_rc_table, rct_offset) + .unwrap(); let mut new_rc_table = vec![0_u8; new_rct_size]; cloned_file - .read_at(&mut new_rc_table, new_rct_offset as u64) + .as_ref() + .read_at(&mut new_rc_table, new_rct_offset) .unwrap(); for i in 0..old_rct_size { assert_eq!(old_rc_table[i], new_rc_table[i]); @@ -975,7 +982,7 @@ mod test { &Qcow2DiscardType::Never, ); if let Err(err) = ret { - let err_msg = format!("Invalid refcount block address 0x0, index is 2"); + let err_msg = "Invalid refcount block address 0x0, index is 2".to_string(); assert_eq!(err.to_string(), err_msg); } else { assert!(false); @@ -999,7 +1006,7 @@ mod test { // Test refcount overflow. let ret = refcount.set_refcount(0, 0, 1, 65535, &Qcow2DiscardType::Never); if let Err(err) = ret { - let err_msg = format!("Refcount 2 add 65535 cause overflows, index is 0"); + let err_msg = "Refcount 2 add 65535 cause overflows, index is 0".to_string(); assert_eq!(err.to_string(), err_msg); } else { assert!(false); @@ -1008,7 +1015,7 @@ mod test { // Test refcount underflow. let ret = refcount.set_refcount(0, 0, 1, -65535, &Qcow2DiscardType::Never); if let Err(err) = ret { - let err_msg = format!("Refcount 2 sub 65535 cause overflows, index is 0"); + let err_msg = "Refcount 2 sub 65535 cause overflows, index is 0".to_string(); assert_eq!(err.to_string(), err_msg); } else { assert!(false); diff --git a/block_backend/src/qcow2/snapshot.rs b/block_backend/src/qcow2/snapshot.rs index b5ee4ba37c797e3aebe9c1c26a937f0f6f80daf8..f0b638f2a345c727a9cfdeba1afd3d5b731dd246 100644 --- a/block_backend/src/qcow2/snapshot.rs +++ b/block_backend/src/qcow2/snapshot.rs @@ -85,7 +85,7 @@ impl InternalSnapshot { } pub fn find_new_snapshot_id(&self) -> u64 { - let mut id_max = 0; + let mut id_max: u64 = 0; for snap in &self.snapshots { if id_max < snap.id { id_max = snap.id; @@ -98,18 +98,20 @@ impl InternalSnapshot { pub fn save_snapshot_table( &self, addr: u64, - extra_snap: &QcowSnapshot, + extra_snap: Option<&QcowSnapshot>, attach: bool, ) -> Result<()> { let mut buf = Vec::new(); for snap in &self.snapshots { - if !attach && snap.id == extra_snap.id { + if !attach && extra_snap.is_some() && snap.id == extra_snap.unwrap().id { continue; } buf.append(&mut snap.gen_snapshot_table_entry()); } if attach { - buf.append(&mut extra_snap.gen_snapshot_table_entry()); + if let Some(extra) = extra_snap { + buf.append(&mut extra.gen_snapshot_table_entry()); + } } self.sync_aio.borrow_mut().write_buffer(addr, &buf) } @@ -139,7 +141,7 @@ impl InternalSnapshot { for i in 0..nb_snapshots { let offset = addr + self.snapshot_size; - let mut pos = 0; + let mut pos: usize = 0; let header_size = size_of::(); let mut header_buf = vec![0_u8; header_size]; self.sync_aio @@ -281,7 +283,7 @@ impl QcowSnapshot { // Snapshot Extra data. // vm_state_size_large is used for vm snapshot. // It's equal to vm_state_size which is also 0 in disk snapshot. - BigEndian::write_u64(&mut buf[40..48], self.vm_state_size as u64); + BigEndian::write_u64(&mut buf[40..48], u64::from(self.vm_state_size)); BigEndian::write_u64(&mut buf[48..56], self.disk_size); if self.extra_data_size == SNAPSHOT_EXTRA_DATA_LEN_24 as u32 { BigEndian::write_u64(&mut buf[56..64], self.icount); diff --git a/block_backend/src/qcow2/table.rs b/block_backend/src/qcow2/table.rs index 5886becb4971142f2035e4be691aa70891392a04..3abd1fe4978d6e81eb13525427ef35ec177dc3e7 100644 --- a/block_backend/src/qcow2/table.rs +++ b/block_backend/src/qcow2/table.rs @@ -10,7 +10,7 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; use anyhow::{Context, Result}; use log::info; @@ -28,6 +28,9 @@ use crate::{ use machine_manager::config::MAX_L2_CACHE_SIZE; use util::num_ops::div_round_up; +// Default l1 table map length, which can describe 512GiB data for 64KiB cluster. +const L1_TABLE_MAP_LEN: usize = 1024; + #[derive(PartialEq, Eq, Debug)] pub enum Qcow2ClusterType { /// Cluster is unallocated. @@ -81,6 +84,7 @@ pub struct Qcow2Table { cluster_bits: u64, cluster_size: u64, pub l1_table: Vec, + pub l1_table_map: HashMap, pub l1_table_offset: u64, pub l1_size: u32, pub l2_table_cache: Qcow2Cache, @@ -96,6 +100,7 @@ impl Qcow2Table { cluster_bits: 0, cluster_size: 0, l1_table: Vec::new(), + l1_table_map: HashMap::with_capacity(L1_TABLE_MAP_LEN), l1_table_offset: 0, l1_size: 0, l2_table_cache: Qcow2Cache::default(), @@ -128,9 +133,9 @@ impl Qcow2Table { }; info!("Driver {} l2 cache size {}", conf.id, cache_size); let l2_table_cache: Qcow2Cache = Qcow2Cache::new(cache_size as usize); - self.cluster_bits = header.cluster_bits as u64; + self.cluster_bits = u64::from(header.cluster_bits); self.cluster_size = header.cluster_size(); - self.l2_bits = header.cluster_bits as u64 - ENTRY_BITS; + self.l2_bits = u64::from(header.cluster_bits) - ENTRY_BITS; self.l2_size = header.cluster_size() / ENTRY_SIZE; self.l2_table_cache = l2_table_cache; self.l1_table_offset = header.l1_table_offset; @@ -142,7 +147,11 @@ impl Qcow2Table { self.l1_table = self .sync_aio .borrow_mut() - .read_ctrl_cluster(self.l1_table_offset, self.l1_size as u64)?; + .read_ctrl_cluster(self.l1_table_offset, u64::from(self.l1_size))?; + for l1_entry in &self.l1_table { + let l1_entry_addr = l1_entry & L1_TABLE_OFFSET_MASK; + self.l1_table_map.insert(l1_entry_addr, 1); + } Ok(()) } @@ -185,7 +194,11 @@ impl Qcow2Table { } pub fn update_l1_table(&mut self, l1_index: usize, l2_address: u64) { + let old_addr = self.l1_table[l1_index] & L1_TABLE_OFFSET_MASK; + let new_addr = l2_address & L1_TABLE_OFFSET_MASK; self.l1_table[l1_index] = l2_address; + self.l1_table_map.remove(&old_addr); + self.l1_table_map.insert(new_addr, 1); } pub fn update_l2_table( @@ -256,7 +269,7 @@ mod test { let addr = qcow2.alloc_cluster(1, true).unwrap(); let l2_cluster: Vec = vec![0_u8; cluster_size]; let l2_table = Rc::new(RefCell::new( - CacheTable::new(addr, l2_cluster.clone(), ENTRY_SIZE_U64).unwrap(), + CacheTable::new(addr, l2_cluster, ENTRY_SIZE_U64).unwrap(), )); qcow2.table.cache_l2_table(l2_table.clone()).unwrap(); diff --git a/block_backend/src/raw.rs b/block_backend/src/raw.rs index c051b3f100470bdf93fa37745074349b69042ef6..d49578641b72eebabb14bb8f99e80d409babc724 100644 --- a/block_backend/src/raw.rs +++ b/block_backend/src/raw.rs @@ -25,7 +25,7 @@ use crate::{ file::{CombineRequest, FileDriver}, qcow2::is_aligned, BlockDriverOps, BlockIoErrorCallback, BlockProperty, BlockStatus, CheckResult, CreateOptions, - SECTOR_SIZE, + ImageInfo, SECTOR_SIZE, }; use util::{ aio::{get_iov_size, raw_write, Aio, Iovec}, @@ -45,7 +45,7 @@ unsafe impl Send for RawDriver {} unsafe impl Sync for RawDriver {} impl RawDriver { - pub fn new(file: File, aio: Aio, prop: BlockProperty) -> Self { + pub fn new(file: Arc, aio: Aio, prop: BlockProperty) -> Self { Self { driver: FileDriver::new(file, aio, prop), status: Arc::new(Mutex::new(BlockStatus::Init)), @@ -56,24 +56,27 @@ impl RawDriver { // get_file_alignment() detects the alignment length by submitting IO to the first sector. // If this area is fallocated, misaligned IO will also return success, so we pre fill this area. pub fn alloc_first_block(&mut self, new_size: u64) -> Result<()> { - let write_size = if new_size < MAX_FILE_ALIGN as u64 { + let write_size = if new_size < u64::from(MAX_FILE_ALIGN) { SECTOR_SIZE } else { - MAX_FILE_ALIGN as u64 + u64::from(MAX_FILE_ALIGN) }; - let max_align = std::cmp::max(MAX_FILE_ALIGN as u64, host_page_size()) as usize; + let max_align = std::cmp::max(u64::from(MAX_FILE_ALIGN), host_page_size()) as usize; // SAFETY: allocate aligned memory and free it later. let align_buf = unsafe { libc::memalign(max_align, write_size as usize) }; if align_buf.is_null() { bail!("Failed to alloc memory for write."); } - let ret = raw_write( - self.driver.file.as_raw_fd(), - align_buf as u64, - write_size as usize, - 0, - ); + // SAFETY: align_buf is valid and large enough. + let ret = unsafe { + raw_write( + self.driver.file.as_raw_fd(), + align_buf as u64, + write_size as usize, + 0, + ) + }; // SAFETY: the memory is allocated in this function. unsafe { libc::free(align_buf) }; @@ -92,6 +95,13 @@ impl BlockDriverOps for RawDriver { Ok(image_info) } + fn query_image(&mut self, info: &mut ImageInfo) -> Result<()> { + info.format = "raw".to_string(); + info.virtual_size = self.disk_size()?; + info.actual_size = self.driver.actual_size()?; + Ok(()) + } + fn check_image(&mut self, _res: &mut CheckResult, _quite: bool, _fix: u64) -> Result<()> { bail!("This image format does not support checks"); } diff --git a/boot_loader/Cargo.toml b/boot_loader/Cargo.toml index d04c4ceaf87d5eb8eaa340476c7635ae27cd3a62..e2c9c45f69aac4c55dd13eb2fe95eda0a51e27f0 100644 --- a/boot_loader/Cargo.toml +++ b/boot_loader/Cargo.toml @@ -8,7 +8,7 @@ license = "Mulan PSL v2" [dependencies] thiserror = "1.0" anyhow = "1.0" -kvm-bindings = { version = "0.6.0", features = ["fam-wrappers"] } +kvm-bindings = { version = "0.7.0", features = ["fam-wrappers"] } log = "0.4" address_space = { path = "../address_space" } devices = { path = "../devices" } diff --git a/boot_loader/src/aarch64/mod.rs b/boot_loader/src/aarch64/mod.rs index 20d564c7da4cfe78a166f7f37c7ef488302ae575..d06a95af65a57a31a6884132f1a0eb99efcd7f84 100644 --- a/boot_loader/src/aarch64/mod.rs +++ b/boot_loader/src/aarch64/mod.rs @@ -19,7 +19,7 @@ use anyhow::{anyhow, Context, Result}; use log::info; use crate::error::BootLoaderError; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use devices::legacy::{error::LegacyError as FwcfgErrorKind, FwCfgEntryType, FwCfgOps}; use util::byte_code::ByteCode; @@ -85,7 +85,12 @@ fn load_kernel( ))); } sys_mem - .write(&mut kernel_image, GuestAddress(kernel_start), kernel_size) + .write( + &mut kernel_image, + GuestAddress(kernel_start), + kernel_size, + AddressAttr::Ram, + ) .with_context(|| "Fail to write kernel to guest memory")?; } Ok(kernel_end) @@ -129,7 +134,12 @@ fn load_initrd( .with_context(|| FwcfgErrorKind::AddEntryErr("InitrdData".to_string()))?; } else { sys_mem - .write(&mut initrd_image, GuestAddress(initrd_start), initrd_size) + .write( + &mut initrd_image, + GuestAddress(initrd_start), + initrd_size, + AddressAttr::Ram, + ) .with_context(|| "Fail to write initrd to guest memory")?; } diff --git a/boot_loader/src/x86_64/direct_boot/gdt.rs b/boot_loader/src/x86_64/direct_boot/gdt.rs index 62c70be683c848f61f2a115a0097331d6d691bf6..07995a061ccaefae082c856131c72d4012ac0743 100644 --- a/boot_loader/src/x86_64/direct_boot/gdt.rs +++ b/boot_loader/src/x86_64/direct_boot/gdt.rs @@ -19,7 +19,7 @@ use super::super::BootGdtSegment; use super::super::{ BOOT_GDT_MAX, BOOT_GDT_OFFSET, BOOT_IDT_OFFSET, GDT_ENTRY_BOOT_CS, GDT_ENTRY_BOOT_DS, }; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; // /* // * Constructor for a conventional segment GDT (or LDT) entry. @@ -94,7 +94,7 @@ fn write_gdt_table(table: &[u64], guest_mem: &Arc) -> Result<()> { let mut boot_gdt_addr = BOOT_GDT_OFFSET; for (_, entry) in table.iter().enumerate() { guest_mem - .write_object(entry, GuestAddress(boot_gdt_addr)) + .write_object(entry, GuestAddress(boot_gdt_addr), AddressAttr::Ram) .with_context(|| format!("Failed to load gdt to 0x{:x}", boot_gdt_addr))?; boot_gdt_addr += 8; } @@ -104,7 +104,7 @@ fn write_gdt_table(table: &[u64], guest_mem: &Arc) -> Result<()> { fn write_idt_value(val: u64, guest_mem: &Arc) -> Result<()> { let boot_idt_addr = BOOT_IDT_OFFSET; guest_mem - .write_object(&val, GuestAddress(boot_idt_addr)) + .write_object(&val, GuestAddress(boot_idt_addr), AddressAttr::Ram) .with_context(|| format!("Failed to load gdt to 0x{:x}", boot_idt_addr))?; Ok(()) @@ -119,9 +119,9 @@ pub fn setup_gdt(guest_mem: &Arc) -> Result { ]; let mut code_seg: kvm_segment = GdtEntry(gdt_table[GDT_ENTRY_BOOT_CS as usize]).into(); - code_seg.selector = GDT_ENTRY_BOOT_CS as u16 * 8; + code_seg.selector = u16::from(GDT_ENTRY_BOOT_CS) * 8; let mut data_seg: kvm_segment = GdtEntry(gdt_table[GDT_ENTRY_BOOT_DS as usize]).into(); - data_seg.selector = GDT_ENTRY_BOOT_DS as u16 * 8; + data_seg.selector = u16::from(GDT_ENTRY_BOOT_DS) * 8; write_gdt_table(&gdt_table[..], guest_mem)?; write_idt_value(0, guest_mem)?; diff --git a/boot_loader/src/x86_64/direct_boot/mod.rs b/boot_loader/src/x86_64/direct_boot/mod.rs index ddeede393860f661e97d694c66bfed3cb4aeb9ab..c910d8225ec772480d8a6cf81d275d321c18ee5a 100644 --- a/boot_loader/src/x86_64/direct_boot/mod.rs +++ b/boot_loader/src/x86_64/direct_boot/mod.rs @@ -29,7 +29,7 @@ use super::{ INITRD_ADDR_MAX, PDE_START, PDPTE_START, PML4_START, VMLINUX_STARTUP, ZERO_PAGE_START, }; use crate::error::BootLoaderError; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use util::byte_code::ByteCode; /// Load bzImage linux kernel to Guest Memory. @@ -66,7 +66,7 @@ fn load_bzimage(kernel_image: &mut File) -> Result { return Err(e); } - let mut setup_size = boot_hdr.setup_sects as u64; + let mut setup_size = u64::from(boot_hdr.setup_sects); if setup_size == 0 { setup_size = 4; } @@ -91,7 +91,12 @@ fn load_image(image: &mut File, start_addr: u64, sys_mem: &Arc) -> let len = image.seek(SeekFrom::End(0))?; image.seek(SeekFrom::Start(curr_loc))?; - sys_mem.write(image, GuestAddress(start_addr), len - curr_loc)?; + sys_mem.write( + image, + GuestAddress(start_addr), + len - curr_loc, + AddressAttr::Ram, + )?; Ok(()) } @@ -107,8 +112,8 @@ fn load_kernel_image( let (boot_hdr, kernel_start, vmlinux_start) = if let Ok(hdr) = load_bzimage(&mut kernel_image) { ( hdr, - hdr.code32_start as u64 + BZIMAGE_BOOT_OFFSET, - hdr.code32_start as u64, + u64::from(hdr.code32_start) + BZIMAGE_BOOT_OFFSET, + u64::from(hdr.code32_start), ) } else { ( @@ -163,13 +168,13 @@ fn setup_page_table(sys_mem: &Arc) -> Result { // Entry covering VA [0..512GB) let pdpte = boot_pdpte_addr | 0x03; sys_mem - .write_object(&pdpte, GuestAddress(boot_pml4_addr)) + .write_object(&pdpte, GuestAddress(boot_pml4_addr), AddressAttr::Ram) .with_context(|| format!("Failed to load PD PTE to 0x{:x}", boot_pml4_addr))?; // Entry covering VA [0..1GB) let pde = boot_pde_addr | 0x03; sys_mem - .write_object(&pde, GuestAddress(boot_pdpte_addr)) + .write_object(&pde, GuestAddress(boot_pdpte_addr), AddressAttr::Ram) .with_context(|| format!("Failed to load PDE to 0x{:x}", boot_pdpte_addr))?; // 512 2MB entries together covering VA [0..1GB). Note we are assuming @@ -177,7 +182,7 @@ fn setup_page_table(sys_mem: &Arc) -> Result { for i in 0..512u64 { let pde = (i << 21) + 0x83u64; sys_mem - .write_object(&pde, GuestAddress(boot_pde_addr + i * 8)) + .write_object(&pde, GuestAddress(boot_pde_addr + i * 8), AddressAttr::Ram) .with_context(|| format!("Failed to load PDE to 0x{:x}", boot_pde_addr + i * 8))?; } @@ -192,7 +197,11 @@ fn setup_boot_params( let mut boot_params = BootParams::new(*boot_hdr); boot_params.setup_e820_entries(config, sys_mem); sys_mem - .write_object(&boot_params, GuestAddress(ZERO_PAGE_START)) + .write_object( + &boot_params, + GuestAddress(ZERO_PAGE_START), + AddressAttr::Ram, + ) .with_context(|| format!("Failed to load zero page to 0x{:x}", ZERO_PAGE_START))?; Ok(()) @@ -209,7 +218,8 @@ fn setup_kernel_cmdline( sys_mem.write( &mut config.kernel_cmdline.as_bytes(), GuestAddress(CMDLINE_START), - cmdline_len as u64, + u64::from(cmdline_len), + AddressAttr::Ram, )?; Ok(()) @@ -305,18 +315,24 @@ mod test { .unwrap(); assert_eq!(setup_page_table(&space).unwrap(), 0x0000_9000); assert_eq!( - space.read_object::(GuestAddress(0x0000_9000)).unwrap(), + space + .read_object::(GuestAddress(0x0000_9000), AddressAttr::Ram) + .unwrap(), 0x0000_a003 ); assert_eq!( - space.read_object::(GuestAddress(0x0000_a000)).unwrap(), + space + .read_object::(GuestAddress(0x0000_a000), AddressAttr::Ram) + .unwrap(), 0x0000_b003 ); let mut page_addr: u64 = 0x0000_b000; let mut tmp_value: u64 = 0x83; for _ in 0..512u64 { assert_eq!( - space.read_object::(GuestAddress(page_addr)).unwrap(), + space + .read_object::(GuestAddress(page_addr), AddressAttr::Ram) + .unwrap(), tmp_value ); page_addr += 8; @@ -378,7 +394,11 @@ mod test { let mut arr: Vec = Vec::new(); let mut boot_addr: u64 = 0x500; for _ in 0..BOOT_GDT_MAX { - arr.push(space.read_object(GuestAddress(boot_addr)).unwrap()); + arr.push( + space + .read_object(GuestAddress(boot_addr), AddressAttr::Ram) + .unwrap(), + ); boot_addr += 8; } assert_eq!(arr[0], 0); @@ -395,6 +415,7 @@ mod test { &mut read_buffer.as_mut(), GuestAddress(0x0002_0000), cmd_len, + AddressAttr::Ram, ) .unwrap(); let s = String::from_utf8(read_buffer.to_vec()).unwrap(); diff --git a/boot_loader/src/x86_64/direct_boot/mptable.rs b/boot_loader/src/x86_64/direct_boot/mptable.rs index 8eee6d1c1529961e75ec40c3e7f1cd0514b3393e..8ea1ce2d224f26eff1a656c16bd49000012289e2 100644 --- a/boot_loader/src/x86_64/direct_boot/mptable.rs +++ b/boot_loader/src/x86_64/direct_boot/mptable.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use anyhow::{anyhow, Result}; use crate::error::BootLoaderError; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use util::byte_code::ByteCode; use util::checksum::obj_checksum; @@ -267,7 +267,7 @@ impl LocalInterruptEntry { macro_rules! write_entry { ( $d:expr, $t:ty, $m:expr, $o:expr, $s:expr ) => { let entry = $d; - $m.write_object(&entry, GuestAddress($o))?; + $m.write_object(&entry, GuestAddress($o), AddressAttr::Ram)?; $o += std::mem::size_of::<$t>() as u64; $s = $s.wrapping_add(obj_checksum(&entry)); }; @@ -294,6 +294,7 @@ pub fn setup_isa_mptable( sys_mem.write_object( &FloatingPointer::new(header as u32), GuestAddress(start_addr), + AddressAttr::Ram, )?; let mut offset = header + std::mem::size_of::() as u64; @@ -345,6 +346,7 @@ pub fn setup_isa_mptable( sys_mem.write_object( &ConfigTableHeader::new((offset - header) as u16, sum, lapic_addr), GuestAddress(header), + AddressAttr::Ram, )?; Ok(()) diff --git a/boot_loader/src/x86_64/standard_boot/elf.rs b/boot_loader/src/x86_64/standard_boot/elf.rs index 2817010ae86ed396ce842c04867548760f769377..9158d4c1c43aca747097ff5301cc285a57cb8cf6 100644 --- a/boot_loader/src/x86_64/standard_boot/elf.rs +++ b/boot_loader/src/x86_64/standard_boot/elf.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use anyhow::{bail, Context, Result}; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use devices::legacy::{FwCfgEntryType, FwCfgOps}; use util::byte_code::ByteCode; use util::num_ops::round_up; @@ -163,7 +163,12 @@ pub fn load_elf_kernel( if ph.p_type == PT_LOAD { kernel_image.seek(SeekFrom::Start(ph.p_offset))?; - sys_mem.write(kernel_image, GuestAddress(ph.p_paddr), ph.p_filesz)?; + sys_mem.write( + kernel_image, + GuestAddress(ph.p_paddr), + ph.p_filesz, + AddressAttr::Ram, + )?; addr_low = std::cmp::min(addr_low, ph.p_paddr); addr_max = std::cmp::max(addr_max, ph.p_paddr); @@ -181,10 +186,11 @@ pub fn load_elf_kernel( let p_align = ph.p_align; let aligned_namesz = - round_up(note_hdr.namesz as u64, p_align).with_context(|| { + round_up(u64::from(note_hdr.namesz), p_align).with_context(|| { format!( "Overflows when align up: num 0x{:x}, alignment 0x{:x}", - note_hdr.namesz as u64, p_align, + u64::from(note_hdr.namesz), + p_align, ) })?; if note_hdr.type_ == XEN_ELFNOTE_PHYS32_ENTRY { @@ -195,11 +201,12 @@ pub fn load_elf_kernel( pvh_start_addr = Some(entry_addr); break; } else { - let aligned_descsz = - round_up(note_hdr.descsz as u64, p_align).with_context(|| { + let aligned_descsz = round_up(u64::from(note_hdr.descsz), p_align) + .with_context(|| { format!( "Overflows when align up, num 0x{:x}, alignment 0x{:x}", - note_hdr.descsz as u64, p_align, + u64::from(note_hdr.descsz), + p_align, ) })?; let tail_size = aligned_namesz + aligned_descsz; diff --git a/boot_loader/src/x86_64/standard_boot/mod.rs b/boot_loader/src/x86_64/standard_boot/mod.rs index 0c17349cee0fc9764501736e9f648b88f507a41e..5ad697bf769c4401d6a93f8084e965e0f626505d 100644 --- a/boot_loader/src/x86_64/standard_boot/mod.rs +++ b/boot_loader/src/x86_64/standard_boot/mod.rs @@ -59,7 +59,7 @@ fn load_kernel_image( header: &RealModeKernelHeader, fwcfg: &mut dyn FwCfgOps, ) -> Result> { - let mut setup_size = header.setup_sects as u64; + let mut setup_size = u64::from(header.setup_sects); if setup_size == 0 { setup_size = 4; } diff --git a/build.rs b/build.rs index 96f77ffd45080de6df817a5eee400bc0bf54dc94..13b5a89850ba2c00a0087c8f80ff47b88c170594 100644 --- a/build.rs +++ b/build.rs @@ -16,6 +16,9 @@ fn ohos_env_configure() { println!("cargo:rustc-link-arg=--verbose"); println!("cargo:rustc-link-arg=--sysroot={}/sysroot", ohos_sdk_path); println!("cargo:rustc-link-arg=-lpixman_static"); + if cfg!(feature = "usb_host") { + println!("cargo:rustc-link-arg=-lusb-1.0"); + } println!( "cargo:rustc-link-search={}/sysroot/usr/lib/aarch64-linux-ohos", ohos_sdk_path diff --git a/chardev_backend/Cargo.toml b/chardev_backend/Cargo.toml index 83b47b3379d07e9c08cd51a709bb89f02364a04f..8a1b15358ac4ed7dc68ddd2dc2ebb6f4339eb7c7 100644 --- a/chardev_backend/Cargo.toml +++ b/chardev_backend/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" license = "Mulan PSL v2" [dependencies] -vmm-sys-util = "0.11.0" +vmm-sys-util = "0.12.1" anyhow = "1.0" log = "0.4" libc = "0.2" diff --git a/chardev_backend/src/chardev.rs b/chardev_backend/src/chardev.rs index 7a07a78ac17bb433824aa0951122c117cc14ac3d..3f9e58b062e579508fa76061f87d001ffc044bfa 100644 --- a/chardev_backend/src/chardev.rs +++ b/chardev_backend/src/chardev.rs @@ -10,8 +10,9 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::collections::VecDeque; use std::fs::{read_link, File, OpenOptions}; -use std::io::{Stdin, Stdout}; +use std::io::{ErrorKind, Stdin, Stdout}; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::path::PathBuf; use std::rc::Rc; @@ -24,11 +25,12 @@ use nix::fcntl::{fcntl, FcntlArg, OFlag}; use nix::pty::openpty; use nix::sys::termios::{cfmakeraw, tcgetattr, tcsetattr, SetArg, Termios}; use vmm_sys_util::epoll::EventSet; +use vmm_sys_util::eventfd::EventFd; use machine_manager::event_loop::EventLoop; use machine_manager::machine::{PathInfo, PTY_PATH}; use machine_manager::{ - config::{ChardevConfig, ChardevType}, + config::{ChardevConfig, ChardevType, SocketType}, temp_cleaner::TempCleaner, }; use util::file::clear_file; @@ -39,6 +41,8 @@ use util::set_termi_raw_mode; use util::socket::{SocketListener, SocketStream}; use util::unix::limit_permission; +const BUF_QUEUE_SIZE: usize = 128; + /// Provide the trait that helps handle the input data. pub trait InputReceiver: Send { /// Handle the input data and trigger interrupt if necessary. @@ -85,13 +89,17 @@ pub struct Chardev { /// Scheduled DPC to unpause input stream. /// Unpause must be done inside event-loop unpause_timer: Option, + /// output listener to notify when output stream fd can be written + output_listener_fd: Option>, + /// output buffer queue + outbuf: VecDeque>, } impl Chardev { pub fn new(chardev_cfg: ChardevConfig) -> Self { Chardev { - id: chardev_cfg.id, - backend: chardev_cfg.backend, + id: chardev_cfg.id(), + backend: chardev_cfg.classtype, listener: None, input: None, output: None, @@ -100,17 +108,19 @@ impl Chardev { dev: None, wait_port: false, unpause_timer: None, + output_listener_fd: None, + outbuf: VecDeque::with_capacity(BUF_QUEUE_SIZE), } } pub fn realize(&mut self) -> Result<()> { match &self.backend { - ChardevType::Stdio => { + ChardevType::Stdio { .. } => { set_termi_raw_mode().with_context(|| "Failed to set terminal to raw mode")?; self.input = Some(Arc::new(Mutex::new(std::io::stdin()))); self.output = Some(Arc::new(Mutex::new(std::io::stdout()))); } - ChardevType::Pty => { + ChardevType::Pty { .. } => { let (master, path) = set_pty_raw_mode().with_context(|| "Failed to set pty to raw mode")?; info!("Pty path is: {:?}", path); @@ -125,58 +135,43 @@ impl Chardev { self.input = Some(master_arc.clone()); self.output = Some(master_arc); } - ChardevType::UnixSocket { - path, - server, - nowait, - } => { + ChardevType::Socket { server, nowait, .. } => { if !*server || !*nowait { bail!( "Argument \'server\' and \'nowait\' are both required for chardev \'{}\'", &self.id ); } - - clear_file(path.clone())?; - let listener = SocketListener::bind_by_uds(path).with_context(|| { - format!( - "Failed to bind socket for chardev \'{}\', path: {}", - &self.id, path - ) - })?; - self.listener = Some(listener); - - // add file to temporary pool, so it could be cleaned when vm exit. - TempCleaner::add_path(path.clone()); - limit_permission(path).with_context(|| { - format!( - "Failed to change file permission for chardev \'{}\', path: {}", - &self.id, path - ) - })?; - } - ChardevType::TcpSocket { - host, - port, - server, - nowait, - } => { - if !*server || !*nowait { - bail!( - "Argument \'server\' and \'nowait\' are both required for chardev \'{}\'", - &self.id - ); + let socket_type = self.backend.socket_type()?; + if let SocketType::Tcp { host, port } = socket_type { + let listener = SocketListener::bind_by_tcp(&host, port).with_context(|| { + format!( + "Failed to bind socket for chardev \'{}\', address: {}:{}", + &self.id, host, port + ) + })?; + self.listener = Some(listener); + } else if let SocketType::Unix { path } = socket_type { + clear_file(path.clone())?; + let listener = SocketListener::bind_by_uds(&path).with_context(|| { + format!( + "Failed to bind socket for chardev \'{}\', path: {}", + &self.id, path + ) + })?; + self.listener = Some(listener); + + // add file to temporary pool, so it could be cleaned when vm exit. + TempCleaner::add_path(path.clone()); + limit_permission(&path).with_context(|| { + format!( + "Failed to change file permission for chardev \'{}\', path: {}", + &self.id, path + ) + })?; } - - let listener = SocketListener::bind_by_tcp(host, *port).with_context(|| { - format!( - "Failed to bind socket for chardev \'{}\', address: {}:{}", - &self.id, host, port - ) - })?; - self.listener = Some(listener); } - ChardevType::File(path) => { + ChardevType::File { path, .. } => { let file = Arc::new(Mutex::new( OpenOptions::new() .read(true) @@ -237,7 +232,7 @@ impl Chardev { let unpause_fn = Box::new(move || { let res = EventLoop::update_event( vec![EventNotifier::new( - NotifierOperation::Modify, + NotifierOperation::AddEvents, input_fd, None, EventSet::IN | EventSet::HANG_UP, @@ -261,6 +256,113 @@ impl Chardev { self.unpause_timer = None; } } + + fn clear_outbuf(&mut self) { + self.outbuf.clear(); + } + + pub fn outbuf_is_full(&self) -> bool { + self.outbuf.len() == self.outbuf.capacity() + } + + pub fn fill_outbuf(&mut self, buf: Vec, listener_fd: Option>) -> Result<()> { + match self.backend { + ChardevType::File { .. } | ChardevType::Pty { .. } | ChardevType::Stdio { .. } => { + if self.output.is_none() { + bail!("chardev has no output"); + } + return write_buffer_sync(self.output.as_ref().unwrap().clone(), buf); + } + ChardevType::Socket { .. } => { + if self.output.is_none() { + return Ok(()); + } + if listener_fd.is_none() { + return write_buffer_sync(self.output.as_ref().unwrap().clone(), buf); + } + } + } + + if self.outbuf_is_full() { + bail!("Failed to append buffer because output buffer queue is full"); + } + self.outbuf.push_back(buf); + self.output_listener_fd = listener_fd; + + let event_notifier = EventNotifier::new( + NotifierOperation::AddEvents, + self.stream_fd.unwrap(), + None, + EventSet::OUT, + Vec::new(), + ); + EventLoop::update_event(vec![event_notifier], None)?; + Ok(()) + } + + fn consume_outbuf(&mut self) -> Result<()> { + if self.output.is_none() { + bail!("no output interface"); + } + let output = self.output.as_ref().unwrap(); + while !self.outbuf.is_empty() { + if write_buffer_async(output.clone(), self.outbuf.front_mut().unwrap())? { + break; + } + self.outbuf.pop_front(); + } + Ok(()) + } +} + +fn write_buffer_sync(writer: Arc>, buf: Vec) -> Result<()> { + let len = buf.len(); + let mut written = 0_usize; + let mut locked_writer = writer.lock().unwrap(); + + while written < len { + match locked_writer.write(&buf[written..len]) { + Ok(n) => written += n, + Err(e) => bail!("chardev failed to write file with error {:?}", e), + } + } + locked_writer + .flush() + .with_context(|| "chardev failed to flush")?; + Ok(()) +} + +// If write is blocked, return true. Otherwise return false. +fn write_buffer_async( + writer: Arc>, + buf: &mut Vec, +) -> Result { + let len = buf.len(); + let mut locked_writer = writer.lock().unwrap(); + let mut written = 0_usize; + + while written < len { + match locked_writer.write(&buf[written..len]) { + Ok(0) => break, + Ok(n) => written += n, + Err(e) => { + let err_type = e.kind(); + if err_type != ErrorKind::WouldBlock && err_type != ErrorKind::Interrupted { + bail!("chardev failed to write data with error {:?}", e); + } + break; + } + } + } + locked_writer + .flush() + .with_context(|| "chardev failed to flush")?; + + if written == len { + return Ok(false); + } + buf.drain(0..written); + Ok(true) } fn set_pty_raw_mode() -> Result<(i32, PathBuf)> { @@ -453,10 +555,10 @@ fn get_socket_notifier(chardev: Arc>) -> Option { locked_receiver.set_paused(); return Some(vec![EventNotifier::new( - NotifierOperation::Modify, + NotifierOperation::DeleteEvents, stream_fd, None, - EventSet::HANG_UP, + EventSet::IN, vec![], )]); } @@ -486,12 +588,53 @@ fn get_socket_notifier(chardev: Arc>) -> Option { None }); + let handling_chardev = cloned_chardev.clone(); + let output_handler = Rc::new(move |event, fd| { + if event & EventSet::OUT != EventSet::OUT { + return None; + } + + let mut locked_cdev = handling_chardev.lock().unwrap(); + if let Err(e) = locked_cdev.consume_outbuf() { + error!("Failed to consume outbuf with error {:?}", e); + locked_cdev.clear_outbuf(); + return Some(vec![EventNotifier::new( + NotifierOperation::DeleteEvents, + fd, + None, + EventSet::OUT, + Vec::new(), + )]); + } + + if locked_cdev.output_listener_fd.is_some() { + let fd = locked_cdev.output_listener_fd.as_ref().unwrap(); + if let Err(e) = fd.write(1) { + error!("Failed to write eventfd with error {:?}", e); + return None; + } + locked_cdev.output_listener_fd = None; + } + + if locked_cdev.outbuf.is_empty() { + Some(vec![EventNotifier::new( + NotifierOperation::DeleteEvents, + fd, + None, + EventSet::OUT, + Vec::new(), + )]) + } else { + None + } + }); + Some(vec![EventNotifier::new( NotifierOperation::AddShared, stream_fd, Some(listener_fd), EventSet::IN | EventSet::HANG_UP, - vec![input_handler], + vec![input_handler, output_handler], )]) }); @@ -510,11 +653,10 @@ impl EventNotifierHelper for Chardev { let notifier = { let backend = chardev.lock().unwrap().backend.clone(); match backend { - ChardevType::Stdio => get_terminal_notifier(chardev), - ChardevType::Pty => get_terminal_notifier(chardev), - ChardevType::UnixSocket { .. } => get_socket_notifier(chardev), - ChardevType::TcpSocket { .. } => get_socket_notifier(chardev), - ChardevType::File(_) => None, + ChardevType::Stdio { .. } => get_terminal_notifier(chardev), + ChardevType::Pty { .. } => get_terminal_notifier(chardev), + ChardevType::Socket { .. } => get_socket_notifier(chardev), + ChardevType::File { .. } => None, } }; notifier.map_or(Vec::new(), |value| vec![value]) diff --git a/cpu/Cargo.toml b/cpu/Cargo.toml index a665edecbe023195464e9b1f4a61cd7cf099208b..57e4d7e26099f38205c0feb8cca8b735cc0f6092 100644 --- a/cpu/Cargo.toml +++ b/cpu/Cargo.toml @@ -9,11 +9,11 @@ description = "CPU emulation" [dependencies] thiserror = "1.0" anyhow = "1.0" -kvm-bindings = { version = "0.6.0", features = ["fam-wrappers"] } +kvm-bindings = { version = "0.7.0", features = ["fam-wrappers"] } nix = { version = "0.26.2", default-features = false, features = ["fs", "feature"] } log = "0.4" libc = "0.2" -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" machine_manager = { path = "../machine_manager" } migration = { path = "../migration" } migration_derive = { path = "../migration/migration_derive" } diff --git a/cpu/src/aarch64/mod.rs b/cpu/src/aarch64/mod.rs index 3aebcbc0ae2f7a757b8414d18ae77a8dd4a4f1a0..9d89b48979feca2b5b75316efb37d9748d83b8b1 100644 --- a/cpu/src/aarch64/mod.rs +++ b/cpu/src/aarch64/mod.rs @@ -108,6 +108,8 @@ pub struct ArmCPUState { pub features: ArmCPUFeatures, /// Virtual timer count. pub vtimer_cnt: u64, + /// Virtual timer count valid. + pub vtimer_cnt_valid: bool, } impl ArmCPUState { diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index 7a1162951d2e4e5141fc22555b0c79d7443aa5b3..21f6bb0144ad836a929824f3abedaf10093fc802 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -65,14 +65,16 @@ use std::cell::RefCell; use std::sync::atomic::{fence, AtomicBool, Ordering}; use std::sync::{Arc, Barrier, Condvar, Mutex, Weak}; use std::thread; +use std::time::Duration; +use std::time::Instant; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use log::{error, info, warn}; use nix::unistd::gettid; use machine_manager::config::ShutdownAction::{ShutdownActionPause, ShutdownActionPoweroff}; use machine_manager::event; -use machine_manager::machine::{HypervisorType, MachineInterface}; +use machine_manager::machine::{HypervisorType, MachineInterface, VmState}; use machine_manager::qmp::{qmp_channel::QmpChannel, qmp_schema}; // SIGRTMIN = 34 (GNU, in MUSL is 35) and SIGRTMAX = 64 in linux, VCPU signal @@ -118,7 +120,7 @@ pub trait CPUInterface { /// Realize `CPU` structure, set registers value for `CPU`. fn realize( &self, - boot: &CPUBootConfig, + boot: &Option, topology: &CPUTopology, #[cfg(target_arch = "aarch64")] features: &CPUFeatures, ) -> Result<()>; @@ -283,6 +285,7 @@ impl CPU { /// Set thread id for `CPU`. pub fn set_tid(&self, tid: Option) { if tid.is_none() { + // Cast is safe as tid is not negative. *self.tid.lock().unwrap() = Some(gettid().as_raw() as u64); } else { *self.tid.lock().unwrap() = tid; @@ -310,7 +313,7 @@ impl CPU { impl CPUInterface for CPU { fn realize( &self, - boot: &CPUBootConfig, + boot: &Option, topology: &CPUTopology, #[cfg(target_arch = "aarch64")] config: &CPUFeatures, ) -> Result<()> { @@ -323,14 +326,16 @@ impl CPUInterface for CPU { )))); } - self.hypervisor_cpu - .set_boot_config( - self.arch_cpu.clone(), - boot, - #[cfg(target_arch = "aarch64")] - config, - ) - .with_context(|| "Failed to realize arch cpu")?; + if let Some(boot) = boot { + self.hypervisor_cpu + .set_boot_config( + self.arch_cpu.clone(), + boot, + #[cfg(target_arch = "aarch64")] + config, + ) + .with_context(|| "Failed to realize arch cpu")?; + } self.arch_cpu .lock() @@ -412,7 +417,17 @@ impl CPUInterface for CPU { vm.lock().unwrap().destroy(); } ShutdownActionPause => { - vm.lock().unwrap().pause(); + let now = Instant::now(); + while !vm.lock().unwrap().pause() { + thread::sleep(Duration::from_millis(5)); + if now.elapsed() > Duration::from_secs(2) { + // Not use resume() to avoid unnecessary qmp event. + vm.lock() + .unwrap() + .notify_lifecycle(VmState::Paused, VmState::Running); + bail!("Failed to pause VM"); + } + } } } } else { @@ -450,7 +465,7 @@ pub struct CPUThreadWorker { } impl CPUThreadWorker { - thread_local!(static LOCAL_THREAD_VCPU: RefCell> = RefCell::new(None)); + thread_local!(static LOCAL_THREAD_VCPU: RefCell> = const { RefCell::new(None) }); /// Allocates a new `CPUThreadWorker`. fn new(thread_cpu: Arc) -> Self { @@ -589,16 +604,18 @@ impl CpuTopology { /// # Arguments /// /// * `vcpu_id` - ID of vcpu. - fn get_topo_item(&self, vcpu_id: usize) -> (u8, u8, u8, u8, u8) { - let socketid: u8 = vcpu_id as u8 / (self.dies * self.clusters * self.cores * self.threads); - let dieid: u8 = (vcpu_id as u8 / (self.clusters * self.cores * self.threads)) % self.dies; - let clusterid: u8 = (vcpu_id as u8 / (self.cores * self.threads)) % self.clusters; - let coreid: u8 = (vcpu_id as u8 / self.threads) % self.cores; - let threadid: u8 = vcpu_id as u8 % self.threads; + fn get_topo_item(&self, vcpu_id: u8) -> (u8, u8, u8, u8, u8) { + // nr_cpus is no more than u8::MAX, multiply will not overflow. + // nr_xxx is no less than 1, div and mod operations will not panic. + let socketid: u8 = vcpu_id / (self.dies * self.clusters * self.cores * self.threads); + let dieid: u8 = (vcpu_id / (self.clusters * self.cores * self.threads)) % self.dies; + let clusterid: u8 = (vcpu_id / (self.cores * self.threads)) % self.clusters; + let coreid: u8 = (vcpu_id / self.threads) % self.cores; + let threadid: u8 = vcpu_id % self.threads; (socketid, dieid, clusterid, coreid, threadid) } - pub fn get_topo_instance_for_qmp(&self, cpu_index: usize) -> qmp_schema::CpuInstanceProperties { + pub fn get_topo_instance_for_qmp(&self, cpu_index: u8) -> qmp_schema::CpuInstanceProperties { let (socketid, _dieid, _clusterid, coreid, threadid) = self.get_topo_item(cpu_index); qmp_schema::CpuInstanceProperties { node_id: None, diff --git a/cpu/src/x86_64/cpuid.rs b/cpu/src/x86_64/cpuid.rs index 58ecfbd8d9c70c634c76bcdc5dc9c714372baf6f..f7b8d7525bb9d0027ac6d3e4b9adf5dd5c0b5359 100644 --- a/cpu/src/x86_64/cpuid.rs +++ b/cpu/src/x86_64/cpuid.rs @@ -12,21 +12,17 @@ use core::arch::x86_64::__cpuid_count; -pub fn host_cpuid( +pub unsafe fn host_cpuid( leaf: u32, subleaf: u32, - eax: *mut u32, - ebx: *mut u32, - ecx: *mut u32, - edx: *mut u32, + eax: &mut u32, + ebx: &mut u32, + ecx: &mut u32, + edx: &mut u32, ) { - // SAFETY: cpuid is created in get_supported_cpuid(). - unsafe { - let cpuid = __cpuid_count(leaf, subleaf); - - *eax = cpuid.eax; - *ebx = cpuid.ebx; - *ecx = cpuid.ecx; - *edx = cpuid.edx; - } + let cpuid = __cpuid_count(leaf, subleaf); + *eax = cpuid.eax; + *ebx = cpuid.ebx; + *ecx = cpuid.ecx; + *edx = cpuid.edx; } diff --git a/cpu/src/x86_64/mod.rs b/cpu/src/x86_64/mod.rs index 0a8ad16905d85d16ef2bc278e0c78337e727a60f..06b2cc4eba75b046b074a93ed4fff983709f0f11 100644 --- a/cpu/src/x86_64/mod.rs +++ b/cpu/src/x86_64/mod.rs @@ -75,7 +75,7 @@ pub enum X86RegsIndex { /// X86 CPU booting configure information #[allow(clippy::upper_case_acronyms)] -#[derive(Default, Clone, Debug)] +#[derive(Default, Clone, Debug, Copy)] pub struct X86CPUBootConfig { pub prot64_mode: bool, /// Register %rip value @@ -120,14 +120,14 @@ impl X86CPUTopology { /// The state of vCPU's register. #[allow(clippy::upper_case_acronyms)] #[repr(C)] -#[derive(Copy, Clone, Desc, ByteCode)] +#[derive(Desc, ByteCode)] #[desc_version(compat_version = "0.1.0")] pub struct X86CPUState { - max_vcpus: u32, - nr_threads: u32, - nr_cores: u32, - nr_dies: u32, - nr_sockets: u32, + max_vcpus: u8, + nr_threads: u8, + nr_cores: u8, + nr_dies: u8, + nr_sockets: u8, pub apic_id: u32, pub regs: Regs, pub sregs: Sregs, @@ -142,6 +142,34 @@ pub struct X86CPUState { pub debugregs: DebugRegs, } +impl Clone for X86CPUState { + fn clone(&self) -> Self { + let mut xsave: Xsave = Default::default(); + // we just clone xsave.region, because xsave.extra does not save + // valid values and it is not allowed to be cloned. + xsave.region = self.xsave.region; + Self { + max_vcpus: self.max_vcpus, + nr_threads: self.nr_threads, + nr_cores: self.nr_cores, + nr_dies: self.nr_dies, + nr_sockets: self.nr_sockets, + apic_id: self.apic_id, + regs: self.regs, + sregs: self.sregs, + fpu: self.fpu, + mp_state: self.mp_state, + lapic: self.lapic, + msr_len: self.msr_len, + msr_list: self.msr_list, + cpu_events: self.cpu_events, + xsave, + xcrs: self.xcrs, + debugregs: self.debugregs, + } + } +} + impl X86CPUState { /// Allocates a new `X86CPUState`. /// @@ -149,7 +177,7 @@ impl X86CPUState { /// /// * `vcpu_id` - ID of this `CPU`. /// * `max_vcpus` - Number of vcpus. - pub fn new(vcpu_id: u32, max_vcpus: u32) -> Self { + pub fn new(vcpu_id: u32, max_vcpus: u8) -> Self { let mp_state = MpState { mp_state: if vcpu_id == 0 { MP_STATE_RUNNABLE @@ -181,7 +209,8 @@ impl X86CPUState { self.msr_len = locked_cpu_state.msr_len; self.msr_list = locked_cpu_state.msr_list; self.cpu_events = locked_cpu_state.cpu_events; - self.xsave = locked_cpu_state.xsave; + self.xsave = Default::default(); + self.xsave.region = locked_cpu_state.xsave.region; self.xcrs = locked_cpu_state.xcrs; self.debugregs = locked_cpu_state.debugregs; } @@ -192,9 +221,9 @@ impl X86CPUState { /// /// * `topology` - X86 CPU Topology pub fn set_cpu_topology(&mut self, topology: &X86CPUTopology) -> Result<()> { - self.nr_threads = topology.threads as u32; - self.nr_cores = topology.cores as u32; - self.nr_dies = topology.dies as u32; + self.nr_threads = topology.threads; + self.nr_cores = topology.cores; + self.nr_dies = topology.dies; Ok(()) } @@ -210,19 +239,19 @@ impl X86CPUState { self.lapic = lapic; // SAFETY: The member regs in struct LapicState is a u8 array with 1024 entries, - // so it's saft to cast u8 pointer to u32 at position APIC_LVT0 and APIC_LVT1. + // so it's safe to cast u8 pointer to u32 at position APIC_LVT0 and APIC_LVT1. // Safe because all value in this unsafe block is certain. unsafe { let apic_lvt_lint0 = &mut self.lapic.regs[APIC_LVT0..] as *mut [i8] as *mut u32; - *apic_lvt_lint0 &= !0x700; - *apic_lvt_lint0 |= APIC_MODE_EXTINT << 8; + let modified = (apic_lvt_lint0.read_unaligned() & !0x700) | (APIC_MODE_EXTINT << 8); + apic_lvt_lint0.write_unaligned(modified); let apic_lvt_lint1 = &mut self.lapic.regs[APIC_LVT1..] as *mut [i8] as *mut u32; - *apic_lvt_lint1 &= !0x700; - *apic_lvt_lint1 |= APIC_MODE_NMI << 8; + let modified = (apic_lvt_lint1.read_unaligned() & !0x700) | (APIC_MODE_NMI << 8); + apic_lvt_lint1.write_unaligned(modified); let apic_id = &mut self.lapic.regs[APIC_ID..] as *mut [i8] as *mut u32; - *apic_id = self.apic_id << 24; + apic_id.write_unaligned(self.apic_id << 24); } Ok(()) @@ -242,17 +271,17 @@ impl X86CPUState { pub fn setup_sregs(&mut self, sregs: Sregs, boot_config: &X86CPUBootConfig) -> Result<()> { self.sregs = sregs; - self.sregs.cs.base = (boot_config.boot_selector as u64) << 4; + self.sregs.cs.base = u64::from(boot_config.boot_selector) << 4; self.sregs.cs.selector = boot_config.boot_selector; - self.sregs.ds.base = (boot_config.boot_selector as u64) << 4; + self.sregs.ds.base = u64::from(boot_config.boot_selector) << 4; self.sregs.ds.selector = boot_config.boot_selector; - self.sregs.es.base = (boot_config.boot_selector as u64) << 4; + self.sregs.es.base = u64::from(boot_config.boot_selector) << 4; self.sregs.es.selector = boot_config.boot_selector; - self.sregs.fs.base = (boot_config.boot_selector as u64) << 4; + self.sregs.fs.base = u64::from(boot_config.boot_selector) << 4; self.sregs.fs.selector = boot_config.boot_selector; - self.sregs.gs.base = (boot_config.boot_selector as u64) << 4; + self.sregs.gs.base = u64::from(boot_config.boot_selector) << 4; self.sregs.gs.selector = boot_config.boot_selector; - self.sregs.ss.base = (boot_config.boot_selector as u64) << 4; + self.sregs.ss.base = u64::from(boot_config.boot_selector) << 4; self.sregs.ss.selector = boot_config.boot_selector; if boot_config.prot64_mode { @@ -330,6 +359,7 @@ impl X86CPUState { data, ..Default::default() }; + // usize is enough for storing msr len. self.msr_len += 1; } } @@ -363,6 +393,7 @@ impl X86CPUState { } pub fn setup_cpuid(&self, cpuid: &mut CpuId) -> Result<()> { + // nr_xx is no less than 1. let core_offset = 32u32 - (self.nr_threads - 1).leading_zeros(); let die_offset = (32u32 - (self.nr_cores - 1).leading_zeros()) + core_offset; let pkg_offset = (32u32 - (self.nr_dies - 1).leading_zeros()) + die_offset; @@ -379,29 +410,36 @@ impl X86CPUState { } } 2 => { - host_cpuid( - 2, - 0, - &mut entry.eax, - &mut entry.ebx, - &mut entry.ecx, - &mut entry.edx, - ); + // SAFETY: entry is from KVM_GET_SUPPORTED_CPUID. + unsafe { + host_cpuid( + 2, + 0, + &mut entry.eax, + &mut entry.ebx, + &mut entry.ecx, + &mut entry.edx, + ); + } } 4 => { // cache info: needed for Pentium Pro compatibility // Passthrough host cache info directly to guest - host_cpuid( - 4, - entry.index, - &mut entry.eax, - &mut entry.ebx, - &mut entry.ecx, - &mut entry.edx, - ); + // SAFETY: entry is from KVM_GET_SUPPORTED_CPUID. + unsafe { + host_cpuid( + 4, + entry.index, + &mut entry.eax, + &mut entry.ebx, + &mut entry.ecx, + &mut entry.edx, + ); + } entry.eax &= !0xfc00_0000; if entry.eax & 0x0001_ffff != 0 && self.max_vcpus > 1 { - entry.eax |= (self.max_vcpus - 1) << 26; + // max_vcpus is no less than 1. + entry.eax |= (u32::from(self.max_vcpus) - 1) << 26; } } 6 => { @@ -423,12 +461,13 @@ impl X86CPUState { match entry.index { 0 => { entry.eax = core_offset; - entry.ebx = self.nr_threads; + entry.ebx = u32::from(self.nr_threads); entry.ecx |= ECX_THREAD; } 1 => { entry.eax = pkg_offset; - entry.ebx = self.nr_threads * self.nr_cores; + // nr_cpus is no more than u8::MAX, multiply will not overflow. + entry.ebx = u32::from(self.nr_threads * self.nr_cores); entry.ecx |= ECX_CORE; } _ => { @@ -454,17 +493,19 @@ impl X86CPUState { match entry.index { 0 => { entry.eax = core_offset; - entry.ebx = self.nr_threads; + entry.ebx = u32::from(self.nr_threads); entry.ecx |= ECX_THREAD; } 1 => { entry.eax = die_offset; - entry.ebx = self.nr_cores * self.nr_threads; + // nr_cpus is no more than u8::MAX, multiply will not overflow. + entry.ebx = u32::from(self.nr_cores * self.nr_threads); entry.ecx |= ECX_CORE; } 2 => { entry.eax = pkg_offset; - entry.ebx = self.nr_dies * self.nr_cores * self.nr_threads; + // nr_cpus is no more than u8::MAX, multiply will not overflow. + entry.ebx = u32::from(self.nr_dies * self.nr_cores * self.nr_threads); entry.ecx |= ECX_DIE; } _ => { @@ -476,14 +517,17 @@ impl X86CPUState { } 0x8000_0002..=0x8000_0004 => { // Passthrough host cpu model name directly to guest - host_cpuid( - entry.function, - entry.index, - &mut entry.eax, - &mut entry.ebx, - &mut entry.ecx, - &mut entry.edx, - ); + // SAFETY: entry is from KVM_GET_SUPPORTED_CPUID. + unsafe { + host_cpuid( + entry.function, + entry.index, + &mut entry.eax, + &mut entry.ebx, + &mut entry.ecx, + &mut entry.edx, + ); + } } _ => (), } @@ -512,11 +556,11 @@ impl StateTransfer for CPU { } fn set_state(&self, state: &[u8]) -> Result<()> { - let cpu_state = *X86CPUState::from_bytes(state) + let cpu_state = X86CPUState::from_bytes(state) .with_context(|| MigrationError::FromBytesError("CPU"))?; let mut cpu_state_locked = self.arch_cpu.lock().unwrap(); - *cpu_state_locked = cpu_state; + *cpu_state_locked = cpu_state.clone(); Ok(()) } diff --git a/devices/Cargo.toml b/devices/Cargo.toml index e19dea9b06c6f7d421c4edd03b80935435fbda24..3dca2b4185028a50057e79935c404deccc6da0ce 100644 --- a/devices/Cargo.toml +++ b/devices/Cargo.toml @@ -12,7 +12,9 @@ anyhow = "1.0" libc = "0.2" log = "0.4" serde = { version = "1.0", features = ["derive"] } -vmm-sys-util = "0.11.1" +strum = "0.24.1" +strum_macros = "0.24.3" +vmm-sys-util = "0.12.1" byteorder = "1.4.3" drm-fourcc = ">=2.2.0" once_cell = "1.18.0" @@ -33,9 +35,10 @@ pulse = { version = "2.27", package = "libpulse-binding", optional = true } psimple = { version = "2.27", package = "libpulse-simple-binding", optional = true } alsa = { version = "0.7.0", optional = true } rusb = { version = "0.9", optional = true } -libusb1-sys = { version = "0.6.4", optional = true } +libusb1-sys = { version = "0.6.5", optional = true } trace = { path = "../trace" } clap = { version = "=4.1.4", default-features = false, features = ["std", "derive"] } +hisysevent = { path = "../hisysevent" } [features] default = [] @@ -45,8 +48,12 @@ scream_pulseaudio = ["scream", "dep:pulse", "dep:psimple", "machine_manager/scre scream_ohaudio = ["scream", "machine_manager/scream_ohaudio", "util/scream_ohaudio"] pvpanic = ["machine_manager/pvpanic"] demo_device = ["machine_manager/demo_device", "ui/console", "util/pixman"] -usb_host = ["dep:libusb1-sys", "dep:rusb", "machine_manager/usb_host"] +usb_host = ["dep:libusb1-sys", "dep:rusb", "machine_manager/usb_host", "util/usb_host"] usb_camera = ["machine_manager/usb_camera"] usb_camera_v4l2 = ["usb_camera", "dep:v4l2-sys-mit", "machine_manager/usb_camera_v4l2", "util/usb_camera_v4l2"] usb_camera_oh = ["usb_camera", "machine_manager/usb_camera_oh", "util/usb_camera_oh"] ramfb = ["ui/console", "util/pixman"] +usb_uas = [] +trace_to_logger = [] +trace_to_ftrace = [] +trace_to_hitrace = [] diff --git a/devices/src/acpi/cpu_controller.rs b/devices/src/acpi/cpu_controller.rs index 1259e8d2d4e1c94cee2693a08036fcbd01c9d2a0..3f85e288da849b4c207080f2cc7813fb3665ff85 100644 --- a/devices/src/acpi/cpu_controller.rs +++ b/devices/src/acpi/cpu_controller.rs @@ -19,8 +19,8 @@ use anyhow::{bail, Context, Result}; use log::{error, info}; use vmm_sys_util::eventfd::EventFd; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::{ AcpiError, AcpiLocalApic, AmlAcquire, AmlAddressSpaceType, AmlArg, AmlBuffer, AmlBuilder, AmlCallWithArgs1, AmlCallWithArgs2, AmlCallWithArgs4, AmlDevice, AmlEisaId, AmlEqual, AmlField, @@ -32,6 +32,7 @@ use acpi::{ use address_space::GuestAddress; use cpu::{CPUBootConfig, CPUInterface, CPUTopology, CpuLifecycleState, CPU}; use migration::MigrationManager; +use util::gen_base_func; const CPU_ENABLE_FLAG: u8 = 1; const CPU_INSERTING_FLAG: u8 = 2; @@ -87,24 +88,28 @@ pub struct CpuController { } impl CpuController { - pub fn realize( - mut self, - sysbus: &mut SysBus, + pub fn new( max_cpus: u8, + sysbus: &Arc>, region_base: u64, region_size: u64, cpu_config: CpuConfig, hotplug_cpu_req: Arc, - ) -> Result>> { - self.max_cpus = max_cpus; - self.cpu_config = Some(cpu_config); - self.hotplug_cpu_req = Some(hotplug_cpu_req); - self.set_sys_resource(sysbus, region_base, region_size) + boot_vcpus: Vec>, + ) -> Result { + let mut cpu_controller = CpuController { + max_cpus, + cpu_config: Some(cpu_config), + hotplug_cpu_req: Some(hotplug_cpu_req), + ..Default::default() + }; + cpu_controller + .set_sys_resource(sysbus, region_base, region_size, "CPUController") .with_context(|| AcpiError::Alignment(region_size.try_into().unwrap()))?; - let dev = Arc::new(Mutex::new(self)); - let ret_dev = dev.clone(); - sysbus.attach_device(&dev, region_base, region_size, "CPUController")?; - Ok(ret_dev) + cpu_controller.set_boot_vcpu(boot_vcpus)?; + cpu_controller.set_parent_bus(sysbus.clone()); + + Ok(cpu_controller) } fn eject_cpu(&mut self, vcpu_id: u8) -> Result<()> { @@ -157,8 +162,8 @@ impl CpuController { None } - pub fn get_boot_config(&self) -> &CPUBootConfig { - &self.cpu_config.as_ref().unwrap().boot_config + pub fn get_boot_config(&self) -> CPUBootConfig { + self.cpu_config.as_ref().unwrap().boot_config } pub fn get_hotplug_cpu_info(&self) -> (String, u8) { @@ -242,23 +247,19 @@ impl CpuController { } impl Device for CpuController { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn realize(self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let dev = Arc::new(Mutex::new(self)); + sysbus.attach_device(&dev)?; + Ok(dev) } } impl SysBusDevOps for CpuController { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { data[0] = 0; @@ -329,15 +330,11 @@ impl SysBusDevOps for CpuController { } true } - - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } } impl AmlBuilder for CpuController { fn aml_bytes(&self) -> Vec { - let res = self.base.res; + let res = self.base.res.clone(); let mut cpu_hotplug_controller = AmlDevice::new("PRES"); cpu_hotplug_controller.append_child(AmlNameDecl::new("_HID", AmlEisaId::new("PNP0A06"))); cpu_hotplug_controller.append_child(AmlNameDecl::new( diff --git a/devices/src/acpi/ged.rs b/devices/src/acpi/ged.rs index f50e1b11805cba2065d8ed39a1dfd0bf2471ca1b..76a15032c40bb50016719b46daa43fb281d45266 100644 --- a/devices/src/acpi/ged.rs +++ b/devices/src/acpi/ged.rs @@ -19,8 +19,8 @@ use anyhow::{Context, Result}; use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::EventFd; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::{ AcpiError, AmlActiveLevel, AmlAddressSpaceType, AmlAnd, AmlBuilder, AmlDevice, AmlEdgeLevel, AmlEqual, AmlExtendedInterrupt, AmlField, AmlFieldAccessType, AmlFieldLockRule, AmlFieldUnit, @@ -35,8 +35,11 @@ use address_space::GuestAddress; use machine_manager::event; use machine_manager::event_loop::EventLoop; use machine_manager::qmp::qmp_channel::QmpChannel; -use util::loop_context::{read_fd, EventNotifier, NotifierOperation}; -use util::{loop_context::NotifierCallback, num_ops::write_data_u32}; +use util::gen_base_func; +use util::loop_context::{ + create_new_eventfd, read_fd, EventNotifier, NotifierCallback, NotifierOperation, +}; +use util::num_ops::write_data_u32; #[derive(Clone, Copy)] pub enum AcpiEvent { @@ -51,6 +54,7 @@ pub enum AcpiEvent { const AML_GED_EVT_REG: &str = "EREG"; const AML_GED_EVT_SEL: &str = "ESEL"; +#[derive(Clone)] pub struct GedEvent { power_button: Arc, #[cfg(target_arch = "x86_64")] @@ -75,42 +79,29 @@ pub struct Ged { base: SysBusDevBase, notification_type: Arc, battery_present: bool, -} - -impl Default for Ged { - fn default() -> Self { - Self { - base: SysBusDevBase::default(), - notification_type: Arc::new(AtomicU32::new(AcpiEvent::Nothing as u32)), - battery_present: false, - } - } + ged_event: GedEvent, } impl Ged { - pub fn realize( - mut self, - sysbus: &mut SysBus, - ged_event: GedEvent, + pub fn new( battery_present: bool, + sysbus: &Arc>, region_base: u64, region_size: u64, - ) -> Result>> { - self.base.interrupt_evt = Some(Arc::new(EventFd::new(libc::EFD_NONBLOCK)?)); - self.set_sys_resource(sysbus, region_base, region_size) + ged_event: GedEvent, + ) -> Result { + let mut ged = Self { + base: SysBusDevBase::default(), + notification_type: Arc::new(AtomicU32::new(AcpiEvent::Nothing as u32)), + battery_present, + ged_event, + }; + ged.base.interrupt_evt = Some(Arc::new(create_new_eventfd()?)); + ged.set_sys_resource(sysbus, region_base, region_size, "Ged") .with_context(|| AcpiError::Alignment(region_size as u32))?; - self.battery_present = battery_present; - - let dev = Arc::new(Mutex::new(self)); - sysbus.attach_device(&dev, region_base, region_size, "Ged")?; + ged.set_parent_bus(sysbus.clone()); - let ged = dev.lock().unwrap(); - ged.register_acpi_powerdown_event(ged_event.power_button) - .with_context(|| "Failed to register ACPI powerdown event.")?; - #[cfg(target_arch = "x86_64")] - ged.register_acpi_cpu_resize_event(ged_event.cpu_resize) - .with_context(|| "Failed to register ACPI cpu resize event.")?; - Ok(dev.clone()) + Ok(ged) } fn register_acpi_powerdown_event(&self, power_button: Arc) -> Result<()> { @@ -120,8 +111,9 @@ impl Ged { read_fd(power_down_fd); ged_clone .notification_type - .store(AcpiEvent::PowerDown as u32, Ordering::SeqCst); + .fetch_or(AcpiEvent::PowerDown as u32, Ordering::SeqCst); ged_clone.inject_interrupt(); + trace::ged_inject_acpi_event(AcpiEvent::PowerDown as u32); if QmpChannel::is_connected() { event!(Powerdown); } @@ -149,8 +141,9 @@ impl Ged { read_fd(cpu_resize_fd); clone_ged .notification_type - .store(AcpiEvent::CpuResize as u32, Ordering::SeqCst); + .fetch_or(AcpiEvent::CpuResize as u32, Ordering::SeqCst); clone_ged.inject_interrupt(); + trace::ged_inject_acpi_event(AcpiEvent::CpuResize as u32); if QmpChannel::is_connected() { event!(CpuResize); } @@ -174,27 +167,32 @@ impl Ged { self.notification_type .fetch_or(evt as u32, Ordering::SeqCst); self.inject_interrupt(); + trace::ged_inject_acpi_event(evt as u32); } } impl Device for Ged { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn realize(self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let ged_event = self.ged_event.clone(); + let dev = Arc::new(Mutex::new(self)); + sysbus.attach_device(&dev)?; + + let ged = dev.lock().unwrap(); + ged.register_acpi_powerdown_event(ged_event.power_button) + .with_context(|| "Failed to register ACPI powerdown event.")?; + #[cfg(target_arch = "x86_64")] + ged.register_acpi_cpu_resize_event(ged_event.cpu_resize) + .with_context(|| "Failed to register ACPI cpu resize event.")?; + Ok(dev.clone()) } } impl SysBusDevOps for Ged { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { if offset != 0 { @@ -203,16 +201,13 @@ impl SysBusDevOps for Ged { let value = self .notification_type .swap(AcpiEvent::Nothing as u32, Ordering::SeqCst); + trace::ged_read(value); write_data_u32(data, value) } fn write(&mut self, _data: &[u8], _base: GuestAddress, _offset: u64) -> bool { true } - - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } } impl AmlBuilder for Ged { diff --git a/devices/src/acpi/power.rs b/devices/src/acpi/power.rs index a51071d8b5bccafb70fd23162fb175fd58eb4b90..ae3d0bf61c004e56ea90f7d80f3943ca052c5d87 100644 --- a/devices/src/acpi/power.rs +++ b/devices/src/acpi/power.rs @@ -18,8 +18,8 @@ use anyhow::{Context, Result}; use log::info; use crate::acpi::ged::{AcpiEvent, Ged}; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::{ AcpiError, AmlAddressSpaceType, AmlBuilder, AmlDevice, AmlField, AmlFieldAccessType, AmlFieldLockRule, AmlFieldUnit, AmlFieldUpdateRule, AmlIndex, AmlInteger, AmlMethod, AmlName, @@ -30,6 +30,7 @@ use machine_manager::event_loop::EventLoop; use migration::{DeviceStateDesc, FieldDesc, MigrationHook, MigrationManager, StateTransfer}; use migration_derive::{ByteCode, Desc}; use util::byte_code::ByteCode; +use util::gen_base_func; use util::num_ops::write_data_u32; const AML_ACAD_REG: &str = "ADPM"; @@ -80,8 +81,13 @@ pub struct PowerDev { } impl PowerDev { - pub fn new(ged_dev: Arc>) -> Self { - Self { + pub fn new( + ged_dev: Arc>, + sysbus: &Arc>, + region_base: u64, + region_size: u64, + ) -> Result { + let mut pdev = Self { base: SysBusDevBase::default(), regs: vec![0; POWERDEV_REGS_SIZE], state: PowerDevState { @@ -90,7 +96,11 @@ impl PowerDev { last_bat_lvl: 0xffffffff, }, ged: ged_dev, - } + }; + pdev.set_sys_resource(sysbus, region_base, region_size, "PowerDev") + .with_context(|| AcpiError::Alignment(region_size as u32))?; + pdev.set_parent_bus(sysbus.clone()); + Ok(pdev) } fn read_sysfs_power_props( @@ -154,6 +164,8 @@ impl PowerDev { // unit: mW self.regs[REG_IDX_BAT_PRATE] = (self.regs[REG_IDX_BAT_PRATE] * self.regs[REG_IDX_BAT_PVOLT]) / 1000; + + trace::power_status_read(&self.regs); Ok(()) } @@ -174,37 +186,6 @@ impl PowerDev { } } -impl PowerDev { - pub fn realize( - mut self, - sysbus: &mut SysBus, - region_base: u64, - region_size: u64, - ) -> Result<()> { - self.set_sys_resource(sysbus, region_base, region_size) - .with_context(|| AcpiError::Alignment(region_size as u32))?; - - let dev = Arc::new(Mutex::new(self)); - sysbus.attach_device(&dev, region_base, region_size, "PowerDev")?; - - let pdev_available: bool; - { - let mut pdev = dev.lock().unwrap(); - pdev_available = pdev.power_battery_init_info().is_ok(); - if pdev_available { - pdev.send_power_event(AcpiEvent::BatteryInf); - } - } - if pdev_available { - power_status_update(&dev.clone()); - } else { - let mut pdev = dev.lock().unwrap(); - pdev.power_load_static_status(); - } - Ok(()) - } -} - impl StateTransfer for PowerDev { fn get_state_vec(&self) -> Result> { Ok(self.state.as_bytes().to_vec()) @@ -229,23 +210,35 @@ impl MigrationHook for PowerDev { } impl Device for PowerDev { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn realize(self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let dev = Arc::new(Mutex::new(self)); + sysbus.attach_device(&dev)?; + + let pdev_available: bool; + { + let mut pdev = dev.lock().unwrap(); + pdev_available = pdev.power_battery_init_info().is_ok(); + if pdev_available { + pdev.send_power_event(AcpiEvent::BatteryInf); + } + } + if pdev_available { + power_status_update(&dev); + } else { + let mut pdev = dev.lock().unwrap(); + pdev.power_load_static_status(); + } - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + Ok(dev) } } impl SysBusDevOps for PowerDev { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { let reg_idx: u64 = offset / 4; @@ -253,16 +246,13 @@ impl SysBusDevOps for PowerDev { return false; } let value = self.regs[reg_idx as usize]; + trace::power_read(reg_idx, value); write_data_u32(data, value) } fn write(&mut self, _data: &[u8], _base: GuestAddress, _offset: u64) -> bool { true } - - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } } impl AmlBuilder for PowerDev { @@ -366,7 +356,7 @@ impl AmlBuilder for PowerDev { acpi_bat_dev.append_child(method); let mut bst_pkg = AmlPackage::new(4); - bst_pkg.append_child(AmlInteger(ACPI_BATTERY_STATE_CHARGING as u64)); + bst_pkg.append_child(AmlInteger(u64::from(ACPI_BATTERY_STATE_CHARGING))); bst_pkg.append_child(AmlInteger(0xFFFFFFFF)); bst_pkg.append_child(AmlInteger(0xFFFFFFFF)); bst_pkg.append_child(AmlInteger(0xFFFFFFFF)); @@ -395,7 +385,7 @@ impl AmlBuilder for PowerDev { acpi_acad_dev .aml_bytes() .into_iter() - .chain(acpi_bat_dev.aml_bytes().into_iter()) + .chain(acpi_bat_dev.aml_bytes()) .collect() } } diff --git a/devices/src/camera_backend/demo.rs b/devices/src/camera_backend/demo.rs index 0640c781224677a0dac3d4b0ded3cc6b84c1be2a..725d3de84251e164b8a1c7116062021c08c0267f 100644 --- a/devices/src/camera_backend/demo.rs +++ b/devices/src/camera_backend/demo.rs @@ -204,7 +204,7 @@ impl DemoCameraBackend { notify(); } let interval = if locked_fmt.fps != 0 { - 1000 / locked_fmt.fps as u64 + 1000 / u64::from(locked_fmt.fps) } else { 20 }; @@ -467,7 +467,8 @@ impl CameraBackend for DemoCameraBackend { let start = frame_offset + copied; let end = start + cnt; let tmp = &locked_frame.image[start..end]; - mem_from_buf(tmp, iov.iov_base) + // SAFETY: iovecs is generated by address_space and len is not less than tmp's. + unsafe { mem_from_buf(tmp, iov.iov_base) } .with_context(|| format!("Failed to write data to {:x}", iov.iov_base))?; copied += cnt; } @@ -544,19 +545,19 @@ fn convert_to_nv12(source: &[u8], width: u32, height: u32) -> Vec { for i in 0..len { let idx = (i * pixel) as usize; let (b, g, r) = ( - source[idx] as f32, - source[idx + 1] as f32, - source[idx + 2] as f32, + f32::from(source[idx]), + f32::from(source[idx + 1]), + f32::from(source[idx + 2]), ); let y = (0.299 * r + 0.587 * g + 0.114 * b) as u8; - img_nv12.push(y as u8); + img_nv12.push(y); } for i in 0..(width * height / 2) { let idx = (i * 2 * pixel) as usize; let (b, g, r) = ( - source[idx] as f32, - source[idx + 1] as f32, - source[idx + 2] as f32, + f32::from(source[idx]), + f32::from(source[idx + 1]), + f32::from(source[idx + 2]), ); let u = (-0.147 * r - 0.289 * g + 0.436 * b + 128_f32) as u8; let v = (0.615 * r - 0.515 * g - 0.100 * b + 128_f32) as u8; diff --git a/devices/src/camera_backend/mod.rs b/devices/src/camera_backend/mod.rs index 00723e5e2fffb5dc5803df7925af94e7a3f45ea0..a02f492185d5d19e7c1a05228ffd7b7f5c196ae4 100644 --- a/devices/src/camera_backend/mod.rs +++ b/devices/src/camera_backend/mod.rs @@ -54,7 +54,7 @@ impl CamBasicFmt { } } -#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Default)] +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, PartialOrd, Default)] pub enum FmtType { #[default] Yuy2 = 0, @@ -123,9 +123,9 @@ pub fn get_video_frame_size(width: u32, height: u32, fmt: FmtType) -> Result Result { let fm_size = get_video_frame_size(width, height, fmt)?; - let size_in_bit = fm_size as u64 * INTERVALS_PER_SEC as u64 * 8; + let size_in_bit = u64::from(fm_size) * u64::from(INTERVALS_PER_SEC) * 8; let rate = size_in_bit - .checked_div(interval as u64) + .checked_div(u64::from(interval)) .with_context(|| format!("Invalid size {} or interval {}", size_in_bit, interval))?; Ok(rate as u32) } @@ -184,12 +184,16 @@ pub trait CameraBackend: Send + Sync { /// Register broken callback which is called when backend is broken. fn register_broken_cb(&mut self, cb: CameraBrokenCallback); + + /// Pause/resume stream. + fn pause(&mut self, _paused: bool) {} } #[allow(unused_variables)] pub fn create_cam_backend( config: UsbCameraConfig, cameradev: CameraDevConfig, + _tokenid: u64, ) -> Result>> { let cam: Arc> = match cameradev.backend { #[cfg(feature = "usb_camera_v4l2")] @@ -202,6 +206,7 @@ pub fn create_cam_backend( CamBackendType::OhCamera => Arc::new(Mutex::new(OhCameraBackend::new( cameradev.id, cameradev.path, + _tokenid, )?)), CamBackendType::Demo => Arc::new(Mutex::new(DemoCameraBackend::new( config.id, diff --git a/devices/src/camera_backend/ohcam.rs b/devices/src/camera_backend/ohcam.rs index 2bf3253f577922786ddc8f937d6efa2d4868f212..51d7d0493ec929955a32ca82fed7832c30cb39f3 100755 --- a/devices/src/camera_backend/ohcam.rs +++ b/devices/src/camera_backend/ohcam.rs @@ -10,6 +10,8 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::collections::HashMap; +use std::ffi::CStr; use std::sync::RwLock; use anyhow::{bail, Context, Result}; @@ -20,21 +22,24 @@ use crate::camera_backend::{ CamBasicFmt, CameraBackend, CameraBrokenCallback, CameraFormatList, CameraFrame, CameraNotifyCallback, FmtType, }; +#[cfg(any( + feature = "trace_to_logger", + feature = "trace_to_ftrace", + all(target_env = "ohos", feature = "trace_to_hitrace") +))] +use trace::trace_scope::Scope; use util::aio::Iovec; use util::ohos_binding::camera::*; +use util::ohos_binding::misc::bound_tokenid; -type OhCamCB = RwLock; -static OHCAM_CALLBACK: Lazy = Lazy::new(|| RwLock::new(OhCamCallBack::default())); +type OhCamCB = RwLock>; +static OHCAM_CALLBACKS: Lazy = Lazy::new(|| RwLock::new(HashMap::new())); // In UVC, interval's unit is 100ns. // So, fps * interval / 10_000_000 == 1. const FPS_INTERVAL_TRANS: u32 = 10_000_000; const RESOLUTION_WHITELIST: [(i32, i32); 2] = [(640, 480), (1280, 720)]; -const FRAME_FORMAT_WHITELIST: [i32; 3] = [ - CAMERA_FORMAT_YUV420SP, - CAMERA_FORMAT_YUYV422, - CAMERA_FORMAT_NV12, -]; +const FRAME_FORMAT_WHITELIST: [i32; 2] = [CAMERA_FORMAT_YUYV422, CAMERA_FORMAT_NV12]; const FPS_WHITELIST: [i32; 1] = [30]; #[derive(Default)] @@ -83,14 +88,52 @@ impl OhCamCallBack { } } +#[cfg(any( + feature = "trace_to_logger", + feature = "trace_to_ftrace", + all(target_env = "ohos", feature = "trace_to_hitrace") +))] +#[derive(Clone, Default)] +struct OhCameraAsyncScope { + next_frame_id: u64, + async_scope: Option, +} + +#[cfg(any( + feature = "trace_to_logger", + feature = "trace_to_ftrace", + all(target_env = "ohos", feature = "trace_to_hitrace") +))] +impl OhCameraAsyncScope { + fn start(&mut self) { + self.async_scope = Some(trace::ohcam_next_frame(true, self.next_frame_id)); + self.next_frame_id += 1; + } + + fn stop(&mut self) { + self.async_scope = None; + } +} + #[derive(Clone)] pub struct OhCameraBackend { + // ID for this OhCameraBackend. id: String, - camidx: u8, + // ID of OH camera device. + camid: String, profile_cnt: u8, ctx: OhCamera, fmt_list: Vec, selected_profile: u8, + stream_on: bool, + paused: bool, + #[cfg(any( + feature = "trace_to_logger", + feature = "trace_to_ftrace", + all(target_env = "ohos", feature = "trace_to_hitrace") + ))] + async_scope: Box, + tokenid: u64, } // SAFETY: Send and Sync is not auto-implemented for raw pointer type. @@ -111,20 +154,32 @@ fn cam_fmt_from_oh(t: i32) -> Result { Ok(fmt) } -impl OhCameraBackend { - pub fn new(id: String, camid: String) -> Result { - let idx = camid.parse::().with_context(|| "Invalid PATH format")?; - let ctx = OhCamera::new(idx as i32)?; +impl Drop for OhCameraBackend { + fn drop(&mut self) { + OHCAM_CALLBACKS.write().unwrap().remove_entry(&self.camid); + } +} - let profile_cnt = ctx.get_fmt_nums(idx as i32)? as u8; +impl OhCameraBackend { + pub fn new(id: String, cam_name: String, tokenid: u64) -> Result { + let (ctx, profile_cnt) = OhCamera::new(cam_name.clone())?; Ok(OhCameraBackend { id, - camidx: idx, - profile_cnt, + camid: cam_name, + profile_cnt: profile_cnt as u8, ctx, fmt_list: vec![], selected_profile: 0, + stream_on: false, + paused: false, + #[cfg(any( + feature = "trace_to_logger", + feature = "trace_to_ftrace", + all(target_env = "ohos", feature = "trace_to_hitrace") + ))] + async_scope: Box::::default(), + tokenid, }) } } @@ -148,8 +203,7 @@ impl CameraBackend for OhCameraBackend { } self.selected_profile = fmt.fmt_index - 1; - self.ctx - .set_fmt(self.camidx as i32, self.selected_profile as i32)?; + self.ctx.set_fmt(i32::from(self.selected_profile))?; return Ok(()); } } @@ -161,12 +215,26 @@ impl CameraBackend for OhCameraBackend { } fn video_stream_on(&mut self) -> Result<()> { - self.ctx.start_stream(on_buffer_available, on_broken) + if self.tokenid != 0 { + bound_tokenid(self.tokenid)?; + } + self.ctx.start_stream(on_buffer_available, on_broken)?; + self.stream_on = true; + Ok(()) } fn video_stream_off(&mut self) -> Result<()> { self.ctx.stop_stream(); - OHCAM_CALLBACK.write().unwrap().clear_buffer(); + if let Some(cb) = OHCAM_CALLBACKS.write().unwrap().get_mut(&self.camid) { + cb.clear_buffer(); + } + self.stream_on = false; + #[cfg(any( + feature = "trace_to_logger", + feature = "trace_to_ftrace", + all(target_env = "ohos", feature = "trace_to_hitrace") + ))] + self.async_scope.stop(); Ok(()) } @@ -174,7 +242,7 @@ impl CameraBackend for OhCameraBackend { let mut fmt_list: Vec = Vec::new(); for idx in 0..self.profile_cnt { - match self.ctx.get_profile(self.camidx as i32, idx as i32) { + match self.ctx.get_profile(i32::from(idx)) { Ok((fmt, width, height, fps)) => { if !FRAME_FORMAT_WHITELIST.iter().any(|&x| x == fmt) || !RESOLUTION_WHITELIST.iter().any(|&x| x == (width, height)) @@ -192,19 +260,42 @@ impl CameraBackend for OhCameraBackend { fmt_list.push(CameraFormatList { format: cam_fmt_from_oh(fmt)?, frame: vec![frame], - fmt_index: (idx) + 1, + fmt_index: idx.checked_add(1).unwrap_or_else(|| { + error!("list_format: too much profile ID"); + u8::MAX + }), }); } Err(e) => error!("{:?}", e), } } + + // Just for APP ToDesk, This stupid APP uses the format reported first + // to realize camera-related functions. It doesn't support NV12, so + // we put YUY2 forward.s + fmt_list.sort_by(|a, b| a.format.partial_cmp(&b.format).unwrap()); self.fmt_list = fmt_list.clone(); Ok(fmt_list) } fn reset(&mut self) { - OHCAM_CALLBACK.write().unwrap().clear_buffer(); - self.ctx.reset_camera(); + if let Some(cb) = OHCAM_CALLBACKS.write().unwrap().get_mut(&self.camid) { + cb.clear_buffer(); + } + if self.stream_on { + self.video_stream_off().unwrap_or_else(|e| { + error!("OHCAM: stream off failed: {:?}", e); + }); + } + if let Err(e) = self.ctx.reset_camera(self.camid.clone()) { + error!("OHCAM: reset failed, err: {e}"); + } + #[cfg(any( + feature = "trace_to_logger", + feature = "trace_to_ftrace", + all(target_env = "ohos", feature = "trace_to_hitrace") + ))] + self.async_scope.stop(); } fn get_format_by_index(&self, format_index: u8, frame_index: u8) -> Result { @@ -240,26 +331,45 @@ impl CameraBackend for OhCameraBackend { } fn get_frame_size(&self) -> usize { - OHCAM_CALLBACK.read().unwrap().get_buffer().1 as usize + if let Some(cb) = OHCAM_CALLBACKS.read().unwrap().get(&self.camid) { + return cb.get_buffer().1 as usize; + } + 0 } fn next_frame(&mut self) -> Result<()> { + #[cfg(any( + feature = "trace_to_logger", + feature = "trace_to_ftrace", + all(target_env = "ohos", feature = "trace_to_hitrace") + ))] + self.async_scope.start(); self.ctx.next_frame(); - OHCAM_CALLBACK.write().unwrap().clear_buffer(); + if let Some(cb) = OHCAM_CALLBACKS.write().unwrap().get_mut(&self.camid) { + cb.clear_buffer(); + } Ok(()) } fn get_frame(&self, iovecs: &[Iovec], frame_offset: usize, len: usize) -> Result { - let (src, src_len) = OHCAM_CALLBACK.read().unwrap().get_buffer(); - if src_len == 0 { - bail!("Invalid frame src_len {}", src_len); + let (src, src_len) = OHCAM_CALLBACKS + .read() + .unwrap() + .get(&self.camid) + .with_context(|| "Invalid camid in callback table")? + .get_buffer(); + + if src.is_none() || src.unwrap() == 0 { + bail!("Invalid frame src") } - if frame_offset + len > src_len as usize { - bail!("Invalid frame offset {} or len {}", frame_offset, len); + if src_len == 0_u64 { + bail!("Invalid frame src_len {}", src_len); } - let mut copied = 0; + trace::trace_scope_start!(ohcam_get_frame, args = (frame_offset, len)); + + let mut copied = 0_usize; for iov in iovecs { if len == copied { break; @@ -276,24 +386,81 @@ impl CameraBackend for OhCameraBackend { } fn register_notify_cb(&mut self, cb: CameraNotifyCallback) { - OHCAM_CALLBACK.write().unwrap().set_notify_cb(cb); + OHCAM_CALLBACKS + .write() + .unwrap() + .entry(self.camid.clone()) + .or_insert(OhCamCallBack::default()) + .set_notify_cb(cb); } fn register_broken_cb(&mut self, cb: CameraBrokenCallback) { - OHCAM_CALLBACK.write().unwrap().set_broken_cb(cb); + OHCAM_CALLBACKS + .write() + .unwrap() + .entry(self.camid.clone()) + .or_insert(OhCamCallBack::default()) + .set_broken_cb(cb); } + + fn pause(&mut self, paused: bool) { + if self.paused == paused { + return; + } + + if paused { + // If stream is off, we don't need to set self.paused. + // Because it's not required to re-open stream while + // vm is resuming. + if !self.stream_on { + return; + } + self.paused = true; + self.video_stream_off().unwrap_or_else(|e| { + error!("ohcam pause: failed to pause stream {:?}", e); + }); + } else { + self.paused = false; + self.video_stream_on().unwrap_or_else(|e| { + error!("ohcam resume: failed to resume stream {:?}", e); + }) + } + } +} + +fn cstr_to_string(src: *const u8) -> Result { + if src.is_null() { + bail!("cstr_to_string: src is null"); + } + // SAFETY: we promise that 'src' ends with "null" symbol. + let src_cstr = unsafe { CStr::from_ptr(src) }; + let target_string = src_cstr + .to_str() + .with_context(|| "cstr_to_string: failed to transfer camid")? + .to_owned(); + + Ok(target_string) } // SAFETY: use RW lock to ensure the security of resources. -unsafe extern "C" fn on_buffer_available(src_buffer: u64, length: i32) { - OHCAM_CALLBACK - .write() - .unwrap() - .set_buffer(src_buffer, length); - OHCAM_CALLBACK.read().unwrap().notify(); +unsafe extern "C" fn on_buffer_available(src_buffer: u64, length: i32, camid: *const u8) { + let cam = cstr_to_string(camid).unwrap_or_else(|e| { + error!("{e}"); + "".to_string() + }); + if let Some(cb) = OHCAM_CALLBACKS.write().unwrap().get_mut(&cam) { + cb.set_buffer(src_buffer, length); + cb.notify(); + } } // SAFETY: use RW lock to ensure the security of resources. -unsafe extern "C" fn on_broken() { - OHCAM_CALLBACK.read().unwrap().broken(); +unsafe extern "C" fn on_broken(camid: *const u8) { + let cam = cstr_to_string(camid).unwrap_or_else(|e| { + error!("{e}"); + "".to_string() + }); + if let Some(cb) = OHCAM_CALLBACKS.read().unwrap().get(&cam) { + cb.broken(); + } } diff --git a/devices/src/camera_backend/v4l2.rs b/devices/src/camera_backend/v4l2.rs index 0885d1da78846fe6bc34ede94239307b2601012e..f86b13ea8b31856e286b7178e6673c5a34db37ee 100644 --- a/devices/src/camera_backend/v4l2.rs +++ b/devices/src/camera_backend/v4l2.rs @@ -225,8 +225,8 @@ impl V4l2CameraBackend { ); continue; } - let interval = - (numerator as u64 * INTERVALS_PER_SEC as u64 / denominator as u64) as u32; + let interval = (u64::from(numerator) * u64::from(INTERVALS_PER_SEC) + / u64::from(denominator)) as u32; list.push(interval); } Ok(list) @@ -498,7 +498,7 @@ impl V4l2IoHandler { let iov = locked_buf .get(buf.index as usize) .with_context(|| "Buffer index overflow")?; - if buf.bytesused as u64 > iov.iov_len { + if u64::from(buf.bytesused) > iov.iov_len { bail!( "Buffer overflow, bytesused {} iov len {}", buf.bytesused, @@ -506,7 +506,7 @@ impl V4l2IoHandler { ); } locked_sample.addr = iov.iov_base; - locked_sample.used_len = buf.bytesused as u64; + locked_sample.used_len = u64::from(buf.bytesused); locked_sample.buf_index = buf.index; drop(locked_sample); // Notify the camera to deal with request. diff --git a/devices/src/interrupt_controller/aarch64/state.rs b/devices/src/interrupt_controller/aarch64/state.rs index 926096103ea228cbd51479f1fbaf3c71638a9a1e..a099790bd2925265859ffd00ad4ac2479c8e81bc 100644 --- a/devices/src/interrupt_controller/aarch64/state.rs +++ b/devices/src/interrupt_controller/aarch64/state.rs @@ -280,7 +280,7 @@ impl GICv3 { ..Default::default() }; - let offset = dist.irq_base / (GIC_IRQ_INTERNAL as u64 / REGISTER_SIZE); + let offset = dist.irq_base / (u64::from(GIC_IRQ_INTERNAL) / REGISTER_SIZE); self.access_gic_distributor(GICD_IGROUPR + offset, &mut dist.gicd_igroupr, false)?; self.access_gic_distributor(GICD_ISENABLER + offset, &mut dist.gicd_isenabler, false)?; self.access_gic_distributor(dist.irq_base, &mut dist.line_level, false)?; @@ -290,7 +290,7 @@ impl GICv3 { // edge trigger for i in 0..NR_GICD_ICFGR { if ((i * GIC_IRQ_INTERNAL as usize / NR_GICD_ICFGR) as u64 + dist.irq_base) - > self.nr_irqs as u64 + > u64::from(self.nr_irqs) { break; } @@ -299,7 +299,7 @@ impl GICv3 { } for i in 0..NR_GICD_IPRIORITYR { - if (i as u64 * REGISTER_SIZE + dist.irq_base) > self.nr_irqs as u64 { + if (i as u64 * REGISTER_SIZE + dist.irq_base) > u64::from(self.nr_irqs) { break; } let offset = dist.irq_base + REGISTER_SIZE * i as u64; @@ -311,7 +311,7 @@ impl GICv3 { } for i in 0..NR_GICD_IROUTER { - if (i as u64 + dist.irq_base) > self.nr_irqs as u64 { + if (i as u64 + dist.irq_base) > u64::from(self.nr_irqs) { break; } let offset = dist.irq_base + i as u64; @@ -328,12 +328,12 @@ impl GICv3 { } fn set_dist(&self, mut dist: GICv3DistState) -> Result<()> { - let offset = dist.irq_base / (GIC_IRQ_INTERNAL as u64 / REGISTER_SIZE); + let offset = dist.irq_base / (u64::from(GIC_IRQ_INTERNAL) / REGISTER_SIZE); self.access_gic_distributor(GICD_ISENABLER + offset, &mut dist.gicd_isenabler, true)?; self.access_gic_distributor(GICD_IGROUPR + offset, &mut dist.gicd_igroupr, true)?; for i in 0..NR_GICD_IROUTER { - if (i as u64 + dist.irq_base) > self.nr_irqs as u64 { + if (i as u64 + dist.irq_base) > u64::from(self.nr_irqs) { break; } let offset = dist.irq_base + i as u64; @@ -349,7 +349,7 @@ impl GICv3 { // edge trigger for i in 0..NR_GICD_ICFGR { if ((i * GIC_IRQ_INTERNAL as usize / NR_GICD_ICFGR) as u64 + dist.irq_base) - > self.nr_irqs as u64 + > u64::from(self.nr_irqs) { break; } @@ -362,7 +362,7 @@ impl GICv3 { self.access_gic_distributor(GICD_ISACTIVER + offset, &mut dist.gicd_isactiver, true)?; for i in 0..NR_GICD_IPRIORITYR { - if (i as u64 * REGISTER_SIZE + dist.irq_base) > self.nr_irqs as u64 { + if (i as u64 * REGISTER_SIZE + dist.irq_base) > u64::from(self.nr_irqs) { break; } let offset = dist.irq_base + REGISTER_SIZE * i as u64; @@ -634,7 +634,7 @@ impl StateTransfer for GICv3 { .map_err(|e| MigrationError::GetGicRegsError("gicd_statusr", e.to_string()))?; for irq in (GIC_IRQ_INTERNAL..self.nr_irqs).step_by(32) { state.irq_dist[state.dist_len] = self - .get_dist(irq as u64) + .get_dist(u64::from(irq)) .map_err(|e| MigrationError::GetGicRegsError("dist", e.to_string()))?; state.dist_len += 1; } diff --git a/devices/src/interrupt_controller/mod.rs b/devices/src/interrupt_controller/mod.rs index 4a0453c8fdea6d2784d056b84c5a1a3d408b0050..d1b04a168168cbed9602a34006654090a406dd22 100644 --- a/devices/src/interrupt_controller/mod.rs +++ b/devices/src/interrupt_controller/mod.rs @@ -81,6 +81,8 @@ pub trait LineIrqManager: Send + Sync { } pub trait MsiIrqManager: Send + Sync { + fn irqfd_enable(&self) -> bool; + fn allocate_irq(&self, _vector: MsiVector) -> Result { Ok(0) } @@ -140,6 +142,10 @@ impl IrqState { } pub fn register_irq(&mut self) -> Result<()> { + if self.irq_handler.is_none() { + return Ok(()); + } + let irq_handler = self.irq_handler.as_ref().unwrap(); if !irq_handler.irqfd_enable() { self.irq_fd = None; diff --git a/devices/src/legacy/fwcfg.rs b/devices/src/legacy/fwcfg.rs index 2abf4368e03b72230c6fc00ea83b2394584787ec..59b2d250da187e39f2f53c876fdc041fb089177d 100644 --- a/devices/src/legacy/fwcfg.rs +++ b/devices/src/legacy/fwcfg.rs @@ -19,8 +19,8 @@ use byteorder::{BigEndian, ByteOrder}; use log::{error, warn}; use crate::legacy::error::LegacyError; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::{ AmlBuilder, AmlDevice, AmlInteger, AmlNameDecl, AmlResTemplate, AmlScopeBuilder, AmlString, }; @@ -28,10 +28,10 @@ use acpi::{ use acpi::{AmlIoDecode, AmlIoResource}; #[cfg(target_arch = "aarch64")] use acpi::{AmlMemory32Fixed, AmlReadAndWrite}; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use util::byte_code::ByteCode; use util::num_ops::extract_u64; -use util::offset_of; +use util::{gen_base_func, offset_of}; #[cfg(target_arch = "x86_64")] const FW_CFG_IO_BASE: u64 = 0x510; @@ -243,12 +243,14 @@ fn write_dma_memory( mut buf: &[u8], len: u64, ) -> Result<()> { - addr_space.write(&mut buf, addr, len).with_context(|| { - format!( - "Failed to write dma memory of fwcfg at gpa=0x{:x} len=0x{:x}", - addr.0, len - ) - })?; + addr_space + .write(&mut buf, addr, len, AddressAttr::Ram) + .with_context(|| { + format!( + "Failed to write dma memory of fwcfg at gpa=0x{:x} len=0x{:x}", + addr.0, len + ) + })?; Ok(()) } @@ -260,12 +262,14 @@ fn read_dma_memory( mut buf: &mut [u8], len: u64, ) -> Result<()> { - addr_space.read(&mut buf, addr, len).with_context(|| { - format!( - "Failed to read dma memory of fwcfg at gpa=0x{:x} len=0x{:x}", - addr.0, len - ) - })?; + addr_space + .read(&mut buf, addr, len, AddressAttr::Ram) + .with_context(|| { + format!( + "Failed to read dma memory of fwcfg at gpa=0x{:x} len=0x{:x}", + addr.0, len + ) + })?; Ok(()) } @@ -361,11 +365,14 @@ impl FwCfgCommon { /// Select the entry by the key specified fn select_entry(&mut self, key: u16) { + let ret; self.cur_offset = 0; if (key & FW_CFG_ENTRY_MASK) >= self.max_entry() { self.cur_entry = FW_CFG_INVALID; + ret = 0; } else { self.cur_entry = key; + ret = 1; // unwrap() is safe because we have checked the range of `key`. let selected_entry = self.get_entry_mut().unwrap(); @@ -373,6 +380,8 @@ impl FwCfgCommon { cb.select_callback(); } } + + trace::fwcfg_select_entry(key, get_key_name(key as usize), ret); } fn add_entry( @@ -404,11 +413,12 @@ impl FwCfgCommon { warn!("Entry not empty, will override"); } - entry.data = data; + entry.data = data.clone(); entry.select_cb = select_cb; entry.allow_write = allow_write; entry.write_cb = write_cb; + trace::fwcfg_add_entry(key, get_key_name(key as usize), data); Ok(()) } @@ -467,11 +477,8 @@ impl FwCfgCommon { } } - let file = FwCfgFile::new( - data.len() as u32, - FW_CFG_FILE_FIRST + index as u16, - filename, - ); + let data_len = data.len(); + let file = FwCfgFile::new(data_len as u32, FW_CFG_FILE_FIRST + index as u16, filename); self.files.insert(index, file); self.files.iter_mut().skip(index + 1).for_each(|f| { f.select += 1; @@ -489,6 +496,8 @@ impl FwCfgCommon { FW_CFG_FILE_FIRST as usize + index, FwCfgEntry::new(data, select_cb, write_cb, allow_write), ); + + trace::fwcfg_add_file(index, filename, data_len); Ok(()) } @@ -591,7 +600,7 @@ impl FwCfgCommon { &mem_space, GuestAddress(dma.address), data.as_slice(), - len as u64, + u64::from(len), ) .is_err() { @@ -614,7 +623,7 @@ impl FwCfgCommon { &mem_space, GuestAddress(dma.address), &entry.data[offset as usize..], - len as u64, + u64::from(len), ) .is_err() { @@ -624,7 +633,7 @@ impl FwCfgCommon { if is_write { let mut dma_read_error = false; let data = &mut entry.data[offset as usize..]; - if read_dma_memory(&mem_space, GuestAddress(dma.address), data, len as u64) + if read_dma_memory(&mem_space, GuestAddress(dma.address), data, u64::from(len)) .is_err() { dma_read_error = true; @@ -636,7 +645,7 @@ impl FwCfgCommon { if let Some(cb) = &entry.write_cb { cb.lock().unwrap().write_callback( data.to_vec(), - offset as u64, + u64::from(offset), len as usize, ); } @@ -645,11 +654,13 @@ impl FwCfgCommon { offset += len; } dma.length -= len; - dma.address += len as u64 + dma.address += u64::from(len) } self.cur_offset = offset; write_dma_result(&self.mem_space, dma_addr, dma.control)?; + + trace::fwcfg_read_data(0); Ok(()) } @@ -698,7 +709,7 @@ impl FwCfgCommon { fn dma_mem_read(&self, addr: u64, size: u32) -> Result { extract_u64( FW_CFG_DMA_SIGNATURE as u64, - ((8 - addr - size as u64) * 8) as u32, + ((8 - addr - u64::from(size)) * 8) as u32, size * 8, ) .with_context(|| "Failed to extract bits from u64") @@ -732,7 +743,7 @@ impl FwCfgCommon { && cur_offset < entry.data.len() as u32 { loop { - value = (value << 8) | entry.data[cur_offset as usize] as u64; + value = (value << 8) | u64::from(entry.data[cur_offset as usize]); cur_offset += 1; size -= 1; @@ -740,9 +751,11 @@ impl FwCfgCommon { break; } } - value <<= 8 * size as u64; + value <<= 8 * u64::from(size); } self.cur_offset = cur_offset; + + trace::fwcfg_read_data(value); Ok(value) } @@ -833,28 +846,22 @@ pub struct FwCfgMem { #[cfg(target_arch = "aarch64")] impl FwCfgMem { - pub fn new(sys_mem: Arc) -> Self { - FwCfgMem { - base: SysBusDevBase::new(SysBusDevType::FwCfg), - fwcfg: FwCfgCommon::new(sys_mem), - } - } - - pub fn realize( - mut self, - sysbus: &mut SysBus, + pub fn new( + sys_mem: Arc, + sysbus: &Arc>, region_base: u64, region_size: u64, - ) -> Result>> { - self.fwcfg.common_realize()?; - self.set_sys_resource(sysbus, region_base, region_size) + ) -> Result { + let mut fwcfgmem = FwCfgMem { + base: SysBusDevBase::new(SysBusDevType::FwCfg), + fwcfg: FwCfgCommon::new(sys_mem), + }; + fwcfgmem + .set_sys_resource(sysbus, region_base, region_size, "FwCfgMem") .with_context(|| "Failed to allocate system resource for FwCfg.")?; + fwcfgmem.set_parent_bus(sysbus.clone()); - let dev = Arc::new(Mutex::new(self)); - sysbus - .attach_device(&dev, region_base, region_size, "FwCfgMem") - .with_context(|| "Failed to attach FwCfg device to system bus.")?; - Ok(dev) + Ok(fwcfgmem) } } @@ -922,24 +929,28 @@ impl FwCfgOps for FwCfgMem { #[cfg(target_arch = "aarch64")] impl Device for FwCfgMem { - fn device_base(&self) -> &DeviceBase { - &self.base.base + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + self.fwcfg.select_entry(FwCfgEntryType::Signature as u16); + Ok(()) } - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn realize(mut self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + self.fwcfg.common_realize()?; + let dev = Arc::new(Mutex::new(self)); + sysbus + .attach_device(&dev) + .with_context(|| "Failed to attach FwCfg device to system bus.")?; + Ok(dev) } } #[cfg(target_arch = "aarch64")] impl SysBusDevOps for FwCfgMem { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], base: GuestAddress, offset: u64) -> bool { common_read(self, data, base, offset) @@ -948,9 +959,9 @@ impl SysBusDevOps for FwCfgMem { fn write(&mut self, data: &[u8], _base: GuestAddress, offset: u64) -> bool { let size = data.len() as u32; let value = match size { - 1 => data[0] as u64, - 2 => BigEndian::read_u16(data) as u64, - 4 => BigEndian::read_u32(data) as u64, + 1 => u64::from(data[0]), + 2 => u64::from(BigEndian::read_u16(data)), + 4 => u64::from(BigEndian::read_u32(data)), 8 => BigEndian::read_u64(data), _ => 0, }; @@ -963,12 +974,8 @@ impl SysBusDevOps for FwCfgMem { self.fwcfg.select_entry(value as u16); } 16..=23 => { - if self - .fwcfg - .dma_mem_write(offset - 0x10, value, size) - .is_err() - { - error!("Failed to write dma at offset=0x{:x}.", offset); + if let Err(e) = self.fwcfg.dma_mem_write(offset - 0x10, value, size) { + error!("Failed to write dma at offset=0x{:x} {:?}.", offset, e); return false; } } @@ -980,24 +987,15 @@ impl SysBusDevOps for FwCfgMem { true } - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } - fn set_sys_resource( &mut self, - _sysbus: &mut SysBus, + _sysbus: &Arc>, region_base: u64, region_size: u64, + region_name: &str, ) -> Result<()> { - let res = self.get_sys_resource_mut().unwrap(); - res.region_base = region_base; - res.region_size = region_size; - Ok(()) - } - - fn reset(&mut self) -> Result<()> { - self.fwcfg.select_entry(FwCfgEntryType::Signature as u16); + self.sysbusdev_base_mut() + .set_sys(-1, region_base, region_size, region_name); Ok(()) } } @@ -1011,34 +1009,17 @@ pub struct FwCfgIO { #[cfg(target_arch = "x86_64")] impl FwCfgIO { - pub fn new(sys_mem: Arc) -> Self { - FwCfgIO { - base: SysBusDevBase { - base: DeviceBase::default(), - dev_type: SysBusDevType::FwCfg, - res: SysRes { - region_base: FW_CFG_IO_BASE, - region_size: FW_CFG_IO_SIZE, - irq: -1, - }, - ..Default::default() - }, + pub fn new(sys_mem: Arc, sysbus: &Arc>) -> Result { + let mut fwcfg = FwCfgIO { + base: SysBusDevBase::new(SysBusDevType::FwCfg), fwcfg: FwCfgCommon::new(sys_mem), - } - } - - pub fn realize(mut self, sysbus: &mut SysBus) -> Result>> { - self.fwcfg.common_realize()?; - let region_base = self.base.res.region_base; - let region_size = self.base.res.region_size; - self.set_sys_resource(sysbus, region_base, region_size) + }; + fwcfg + .set_sys_resource(sysbus, FW_CFG_IO_BASE, FW_CFG_IO_SIZE, "FwCfgIO") .with_context(|| "Failed to allocate system resource for FwCfg.")?; + fwcfg.set_parent_bus(sysbus.clone()); - let dev = Arc::new(Mutex::new(self)); - sysbus - .attach_device(&dev, region_base, region_size, "FwCfgIO") - .with_context(|| "Failed to attach FwCfg device to system bus.")?; - Ok(dev) + Ok(fwcfg) } } @@ -1104,24 +1085,28 @@ impl FwCfgOps for FwCfgIO { #[cfg(target_arch = "x86_64")] impl Device for FwCfgIO { - fn device_base(&self) -> &DeviceBase { - &self.base.base + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + self.fwcfg.select_entry(FwCfgEntryType::Signature as u16); + Ok(()) } - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn realize(mut self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + self.fwcfg.common_realize()?; + let dev = Arc::new(Mutex::new(self)); + sysbus + .attach_device(&dev) + .with_context(|| "Failed to attach FwCfg device to system bus.")?; + Ok(dev) } } #[cfg(target_arch = "x86_64")] impl SysBusDevOps for FwCfgIO { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], base: GuestAddress, offset: u64) -> bool { common_read(self, data, base, offset) @@ -1142,9 +1127,9 @@ impl SysBusDevOps for FwCfgIO { } 4..=11 => { let value = match size { - 1 => data[0] as u64, - 2 => BigEndian::read_u16(data) as u64, - 4 => BigEndian::read_u32(data) as u64, + 1 => u64::from(data[0]), + 2 => u64::from(BigEndian::read_u16(data)), + 4 => u64::from(BigEndian::read_u32(data)), 8 => BigEndian::read_u64(data), _ => 0, }; @@ -1165,24 +1150,15 @@ impl SysBusDevOps for FwCfgIO { true } - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } - fn set_sys_resource( &mut self, - _sysbus: &mut SysBus, + _sysbus: &Arc>, region_base: u64, region_size: u64, + region_name: &str, ) -> Result<()> { - let res = self.get_sys_resource_mut().unwrap(); - res.region_base = region_base; - res.region_size = region_size; - Ok(()) - } - - fn reset(&mut self) -> Result<()> { - self.fwcfg.select_entry(FwCfgEntryType::Signature as u16); + self.sysbusdev_base_mut() + .set_sys(-1, region_base, region_size, region_name); Ok(()) } } @@ -1302,58 +1278,8 @@ impl AmlBuilder for FwCfgIO { #[cfg(test)] mod test { use super::*; - use crate::sysbus::{IRQ_BASE, IRQ_MAX}; - use address_space::{AddressSpace, HostMemMapping, Region}; - - fn sysbus_init() -> SysBus { - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sys_mem"), - "sys_mem", - None, - ) - .unwrap(); - #[cfg(target_arch = "x86_64")] - let sys_io = AddressSpace::new( - Region::init_container_region(1 << 16, "sys_io"), - "sys_io", - None, - ) - .unwrap(); - let free_irqs: (i32, i32) = (IRQ_BASE, IRQ_MAX); - let mmio_region: (u64, u64) = (0x0A00_0000, 0x1000_0000); - SysBus::new( - #[cfg(target_arch = "x86_64")] - &sys_io, - &sys_mem, - free_irqs, - mmio_region, - ) - } - - fn address_space_init() -> Arc { - let root = Region::init_container_region(1 << 36, "root"); - let sys_space = AddressSpace::new(root, "sys_space", None).unwrap(); - let host_mmap = Arc::new( - HostMemMapping::new( - GuestAddress(0), - None, - 0x1000_0000, - None, - false, - false, - false, - ) - .unwrap(), - ); - sys_space - .root() - .add_subregion( - Region::init_ram_region(host_mmap.clone(), "region_1"), - host_mmap.start_address().raw_value(), - ) - .unwrap(); - sys_space - } + use crate::sysbus::sysbus_init; + use crate::test::address_space_init; #[test] fn test_entry_functions() { @@ -1506,7 +1432,12 @@ mod test { let addr = GuestAddress(0x0000); fwcfg_common .mem_space - .write(&mut dma_request.as_ref(), addr, dma_request.len() as u64) + .write( + &mut dma_request.as_ref(), + addr, + dma_request.len() as u64, + AddressAttr::Ram, + ) .unwrap(); // [2]set dma addr. @@ -1516,7 +1447,13 @@ mod test { assert_eq!(fwcfg_common.handle_dma_request().is_ok(), true); // [4]check dma response. - assert_eq!(fwcfg_common.mem_space.read_object::(addr).unwrap(), 0); + assert_eq!( + fwcfg_common + .mem_space + .read_object::(addr, AddressAttr::Ram) + .unwrap(), + 0 + ); // [5]check dma write result. let mut read_dma_buf = Vec::new(); @@ -1524,7 +1461,12 @@ mod test { let len = sig_entry_data.len(); fwcfg_common .mem_space - .read(&mut read_dma_buf, GuestAddress(0xffff), len as u64) + .read( + &mut read_dma_buf, + GuestAddress(0xffff), + len as u64, + AddressAttr::Ram, + ) .unwrap(); assert_eq!(read_dma_buf, sig_entry_data); @@ -1537,7 +1479,12 @@ mod test { let addr = GuestAddress(0x0000); fwcfg_common .mem_space - .write(&mut dma_request.as_ref(), addr, dma_request.len() as u64) + .write( + &mut dma_request.as_ref(), + addr, + dma_request.len() as u64, + AddressAttr::Ram, + ) .unwrap(); fwcfg_common.dma_addr = addr; @@ -1545,14 +1492,25 @@ mod test { assert_eq!(fwcfg_common.handle_dma_request().is_ok(), true); // Result should be all zero. - assert_eq!(fwcfg_common.mem_space.read_object::(addr).unwrap(), 0); + assert_eq!( + fwcfg_common + .mem_space + .read_object::(addr, AddressAttr::Ram) + .unwrap(), + 0 + ); let mut read_dma_buf = Vec::new(); let all_zero = vec![0x0_u8; 4]; let len = all_zero.len(); fwcfg_common .mem_space - .read(&mut read_dma_buf, GuestAddress(0xffff), len as u64) + .read( + &mut read_dma_buf, + GuestAddress(0xffff), + len as u64, + AddressAttr::Ram, + ) .unwrap(); assert_eq!(read_dma_buf, all_zero); } @@ -1562,9 +1520,9 @@ mod test { fn test_read_write_aarch64() { let mut sys_bus = sysbus_init(); let sys_mem = address_space_init(); - let fwcfg = FwCfgMem::new(sys_mem); + let fwcfg = FwCfgMem::new(sys_mem, &mut sys_bus, 0x0902_0000, 0x0000_0018).unwrap(); - let fwcfg_dev = FwCfgMem::realize(fwcfg, &mut sys_bus, 0x0902_0000, 0x0000_0018).unwrap(); + let fwcfg_dev = fwcfg.realize().unwrap(); // Read FW_CFG_DMA_SIGNATURE entry. let base = GuestAddress(0x0000); let mut read_data = vec![0xff_u8, 0xff, 0xff, 0xff]; @@ -1602,9 +1560,9 @@ mod test { fn test_read_write_x86_64() { let mut sys_bus = sysbus_init(); let sys_mem = address_space_init(); - let fwcfg = FwCfgIO::new(sys_mem); + let fwcfg = FwCfgIO::new(sys_mem, &mut sys_bus).unwrap(); - let fwcfg_dev = FwCfgIO::realize(fwcfg, &mut sys_bus).unwrap(); + let fwcfg_dev = fwcfg.realize().unwrap(); // Read FW_CFG_DMA_SIGNATURE entry. let base = GuestAddress(0x0000); let mut read_data = vec![0xff_u8, 0xff, 0xff, 0xff]; diff --git a/devices/src/legacy/mod.rs b/devices/src/legacy/mod.rs index 00632feb418b69cc4f9aaabc8636480b1f58980b..74cc2e4a7a21094610148cd2764db1b5065c702a 100644 --- a/devices/src/legacy/mod.rs +++ b/devices/src/legacy/mod.rs @@ -53,5 +53,5 @@ pub use pl011::PL011; #[cfg(target_arch = "aarch64")] pub use pl031::{PL031, RTC_CR, RTC_DR, RTC_IMSC, RTC_LR}; #[cfg(all(feature = "ramfb", target_arch = "aarch64"))] -pub use ramfb::Ramfb; +pub use ramfb::{Ramfb, RamfbConfig}; pub use serial::{Serial, SERIAL_ADDR}; diff --git a/devices/src/legacy/pflash.rs b/devices/src/legacy/pflash.rs index e10039228582fd6e8ebadaa132793a83bda043f4..58e752f1d8ea824120d82a969b7879b66b28ca8f 100644 --- a/devices/src/legacy/pflash.rs +++ b/devices/src/legacy/pflash.rs @@ -18,10 +18,11 @@ use anyhow::{anyhow, bail, Context, Result}; use log::{error, warn}; use super::error::LegacyError; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::AmlBuilder; -use address_space::{FileBackend, GuestAddress, HostMemMapping, Region}; +use address_space::{AddressAttr, FileBackend, GuestAddress, HostMemMapping, Region}; +use util::gen_base_func; use util::num_ops::{deposit_u32, extract_u32, read_data_u32, round_up, write_data_u32}; use util::unix::host_page_size; @@ -57,20 +58,21 @@ pub struct PFlash { write_blk_size: u32, /// ROM region of PFlash. rom: Option, + /// backend: Option, + host_mmap: Arc, } impl PFlash { fn flash_region_size( region_max_size: u64, - backend: &Option, + backend: &Option>, read_only: bool, ) -> Result { // We don't have to occupy the whole memory region. - // If flash is read-only, expose just real data size, - // rounded up to page_size + // Expose just real data size, rounded up to page_size. if let Some(fd) = backend.as_ref() { - let len = fd.metadata().unwrap().len(); - if len > region_max_size || len == 0 || (!read_only && len != region_max_size) { + let len = fd.as_ref().metadata().unwrap().len(); + if len > region_max_size || len == 0 || (!read_only && len % host_page_size() != 0) { bail!( "Invalid flash file: Region size 0x{region_max_size:X}, file size 0x{len:X}; read_only {read_only}" ); @@ -98,18 +100,21 @@ impl PFlash { /// * block-length is zero. /// * PFlash size is zero. /// * flash is writable and file size is smaller than region_max_size. + #[allow(clippy::too_many_arguments)] pub fn new( region_max_size: u64, - backend: &Option, + backend: Option>, block_len: u32, bank_width: u32, device_width: u32, read_only: bool, + sysbus: &Arc>, + region_base: u64, ) -> Result { if block_len == 0 { bail!("PFlash: block-length is zero which is invalid."); } - let size = Self::flash_region_size(region_max_size, backend, read_only)?; + let size = Self::flash_region_size(region_max_size, &backend, read_only)?; let blocks_per_device: u32 = size as u32 / block_len; if blocks_per_device == 0 { bail!("PFlash: num-blocks is zero which is invalid."); @@ -186,9 +191,21 @@ impl PFlash { // Number of protection fields. cfi_table[0x3f] = 0x01; - Ok(PFlash { + let has_backend = backend.is_some(); + let region_size = Self::flash_region_size(region_max_size, &backend, read_only)?; + let host_mmap = Arc::new(HostMemMapping::new( + GuestAddress(region_base), + None, + region_size, + backend.map(FileBackend::new_common), + false, + true, + read_only, + )?); + + let mut pflash = PFlash { base: SysBusDevBase::new(SysBusDevType::Flash), - has_backend: backend.is_some(), + has_backend, block_len, bank_width, // device id for Intel PFlash. @@ -203,43 +220,14 @@ impl PFlash { counter: 0, write_blk_size, rom: None, - }) - } + host_mmap, + }; - pub fn realize( - mut self, - sysbus: &mut SysBus, - region_base: u64, - region_max_size: u64, - backend: Option, - ) -> Result<()> { - let region_size = Self::flash_region_size(region_max_size, &backend, self.read_only)?; - self.set_sys_resource(sysbus, region_base, region_size) + pflash + .set_sys_resource(sysbus, region_base, region_size, "PflashRom") .with_context(|| "Failed to allocate system resource for PFlash.")?; - - let host_mmap = Arc::new(HostMemMapping::new( - GuestAddress(region_base), - None, - region_size, - backend.map(FileBackend::new_common), - false, - true, - self.read_only, - )?); - - let dev = Arc::new(Mutex::new(self)); - let region_ops = sysbus.build_region_ops(&dev); - - let rom_region = Region::init_rom_device_region(host_mmap, region_ops, "PflashRom"); - dev.lock().unwrap().rom = Some(rom_region.clone()); - sysbus - .sys_mem - .root() - .add_subregion(rom_region, region_base) - .with_context(|| "Failed to attach PFlash to system bus")?; - sysbus.devices.push(dev); - - Ok(()) + pflash.set_parent_bus(sysbus.clone()); + Ok(pflash) } fn set_read_array_mode(&mut self, is_illegal_cmd: bool) -> Result<()> { @@ -318,7 +306,7 @@ impl PFlash { } // Repeat data for PFlash device which supports x16-mode but works in x8-mode. for i in 1..self.max_device_width { - resp = deposit_u32(resp, 8 * i, 8, self.cfi_table[index as usize] as u32) + resp = deposit_u32(resp, 8 * i, 8, u32::from(self.cfi_table[index as usize])) .with_context(|| "Failed to deposit bits to u32")?; } } @@ -341,17 +329,23 @@ impl PFlash { } // Unwrap is safe, because after realize function, rom isn't none. let mr = self.rom.as_ref().unwrap(); - if offset + size as u64 > mr.size() { + if offset + .checked_add(size as u64) + .map(|sum| sum > mr.size()) + .unwrap_or(true) + { return Err(anyhow!(LegacyError::PFlashWriteOverflow( mr.size(), offset, - size as u64 + u64::from(size) ))); } - let addr: u64 = mr - .get_host_address() - .with_context(|| "Failed to get host address.")?; + // SAFETY: size has been checked. + let addr: u64 = unsafe { + mr.get_host_address(AddressAttr::RomDevice) + .with_context(|| "Failed to get host address.") + }?; let ret = // SAFETY: addr and size are valid. unsafe { @@ -371,14 +365,19 @@ impl PFlash { fn read_data(&mut self, data: &mut [u8], offset: u64) -> Result<()> { // Unwrap is safe, because after realize function, rom isn't none. let mr = self.rom.as_ref().unwrap(); - if offset + data.len() as u64 > mr.size() { + if offset + .checked_add(data.len() as u64) + .map(|sum| sum > mr.size()) + .unwrap_or(true) + { return Err(anyhow!(LegacyError::PFlashReadOverflow( mr.size(), offset, data.len() as u64 ))); } - let host_addr = mr.get_host_address().unwrap(); + // SAFETY: size has been checked. + let host_addr = unsafe { mr.get_host_address(AddressAttr::RomDevice).unwrap() }; let src = // SAFETY: host_addr of the region is local allocated and sanity has been checked. unsafe { std::slice::from_raw_parts_mut((host_addr + offset) as *mut u8, data.len()) }; @@ -399,14 +398,19 @@ impl PFlash { ); // Unwrap is safe, because after realize function, rom isn't none. let mr = self.rom.as_ref().unwrap(); - if offset + data.len() as u64 > mr.size() { + if offset + .checked_add(data.len() as u64) + .map(|sum| sum > mr.size()) + .unwrap_or(true) + { return Err(anyhow!(LegacyError::PFlashWriteOverflow( mr.size(), offset, data.len() as u64 ))); } - let host_addr = mr.get_host_address().unwrap(); + // SAFETY: size has been checked. + let host_addr = unsafe { mr.get_host_address(AddressAttr::RomDevice).unwrap() }; let mut dst = // SAFETY: host_addr of the region is local allocated and sanity has been checked. unsafe { std::slice::from_raw_parts_mut((host_addr + offset) as *mut u8, data.len()) }; @@ -434,7 +438,7 @@ impl PFlash { trace::pflash_write("single byte program (0)".to_string(), cmd); } 0x20 => { - let offset_mask = offset & !(self.block_len as u64 - 1); + let offset_mask = offset & !(u64::from(self.block_len) - 1); trace::pflash_write_block_erase(offset, self.block_len); if !self.read_only { let all_one = vec![0xff_u8; self.block_len as usize]; @@ -630,7 +634,7 @@ impl PFlash { } self.status |= 0x80; if self.counter == 0 { - let mask: u64 = !(self.write_blk_size as u64 - 1); + let mask: u64 = !(u64::from(self.write_blk_size) - 1); trace::pflash_write("block write finished".to_string(), self.cmd); self.write_cycle = self.write_cycle.wrapping_add(1); if !self.read_only { @@ -691,23 +695,42 @@ impl PFlash { } impl Device for PFlash { - fn device_base(&self) -> &DeviceBase { - &self.base.base + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + self.rom + .as_ref() + .unwrap() + .set_rom_device_romd(true) + .with_context(|| "Fail to set PFlash rom region read only")?; + self.cmd = 0x00; + self.write_cycle = 0; + self.status = 0x80; + Ok(()) } - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn realize(self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let region_base = self.base.res.region_base; + let host_mmap = self.host_mmap.clone(); + let dev = Arc::new(Mutex::new(self)); + let region_ops = sysbus.build_region_ops(&dev); + let rom_region = Region::init_rom_device_region(host_mmap, region_ops, "PflashRom"); + dev.lock().unwrap().rom = Some(rom_region.clone()); + sysbus + .sys_mem + .root() + .add_subregion(rom_region, region_base) + .with_context(|| "Failed to attach PFlash to system bus")?; + sysbus.sysbus_attach_child(dev.clone())?; + + Ok(dev) } } impl SysBusDevOps for PFlash { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { let mut index: u64; @@ -720,8 +743,8 @@ impl SysBusDevOps for PFlash { // - cmd 0x98 represents PFlash CFI query. match self.cmd { 0x00 => { - if self.read_data(data, offset).is_err() { - error!("Failed to read data from PFlash."); + if let Err(e) = self.read_data(data, offset) { + error!("Failed to read data from PFlash {:?}.", e); } return true; } @@ -733,15 +756,15 @@ impl SysBusDevOps for PFlash { // 0x70: Status Register. // 0xe8: Write block. // Just read status register, return every device status in bank. - ret = self.status as u32; + ret = u32::from(self.status); if self.device_width != 0 && data_len > self.device_width { let mut shift: u32 = self.device_width * 8; while shift + self.device_width * 8 <= data_len * 8 { - ret |= (self.status as u32) << shift; + ret |= u32::from(self.status) << shift; shift += self.device_width * 8; } } else if self.device_width == 0 && data_len > 2 { - ret |= (self.status as u32) << 16; + ret |= u32::from(self.status) << 16; } trace::pflash_read_status(ret); } @@ -775,7 +798,7 @@ impl SysBusDevOps for PFlash { // combine serval queries into one response. let mut i: u32 = 0; while i < data_len { - match self.query_devid(offset + (i * self.bank_width) as u64) { + match self.query_devid(offset + u64::from(i * self.bank_width)) { Err(e) => { error!("Failed to query devid {:?}", e); break; @@ -815,7 +838,7 @@ impl SysBusDevOps for PFlash { } else { let mut i: u32 = 0; while i < data_len { - match self.query_cfi(offset + (i * self.bank_width) as u64) { + match self.query_cfi(offset + u64::from(i * self.bank_width)) { Err(e) => { error!("Failed to query devid, {:?}", e); break; @@ -854,7 +877,7 @@ impl SysBusDevOps for PFlash { } fn write(&mut self, data: &[u8], _base: GuestAddress, offset: u64) -> bool { - let mut value = 0; + let mut value = 0_u32; if !read_data_u32(data, &mut value) { return false; } @@ -862,15 +885,10 @@ impl SysBusDevOps for PFlash { let data_len: u8 = data.len() as u8; trace::pflash_io_write(offset, data_len, value, self.write_cycle); - if self.write_cycle == 0 - && self - .rom - .as_ref() - .unwrap() - .set_rom_device_romd(false) - .is_err() - { - error!("Failed PFlash to set device to read array mode."); + if self.write_cycle == 0 { + if let Err(e) = self.rom.as_ref().unwrap().set_rom_device_romd(false) { + error!("Failed PFlash to set device to read array mode {:?}.", e); + } } // Write: @@ -893,32 +911,15 @@ impl SysBusDevOps for PFlash { } } - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } - fn set_sys_resource( &mut self, - _sysbus: &mut SysBus, + _sysbus: &Arc>, region_base: u64, region_size: u64, + region_name: &str, ) -> Result<()> { - let res = self.get_sys_resource_mut().unwrap(); - res.region_base = region_base; - res.region_size = region_size; - res.irq = 0; - Ok(()) - } - - fn reset(&mut self) -> Result<()> { - self.rom - .as_ref() - .unwrap() - .set_rom_device_romd(true) - .with_context(|| "Fail to set PFlash rom region read only")?; - self.cmd = 0x00; - self.write_cycle = 0; - self.status = 0x80; + self.sysbusdev_base_mut() + .set_sys(0, region_base, region_size, region_name); Ok(()) } } @@ -935,33 +936,7 @@ mod test { use std::fs::File; use super::*; - use crate::sysbus::{IRQ_BASE, IRQ_MAX}; - use address_space::AddressSpace; - - fn sysbus_init() -> SysBus { - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sys_mem"), - "sys_mem", - None, - ) - .unwrap(); - #[cfg(target_arch = "x86_64")] - let sys_io = AddressSpace::new( - Region::init_container_region(1 << 16, "sys_io"), - "sys_io", - None, - ) - .unwrap(); - let free_irqs: (i32, i32) = (IRQ_BASE, IRQ_MAX); - let mmio_region: (u64, u64) = (0x0A00_0000, 0x1000_0000); - SysBus::new( - #[cfg(target_arch = "x86_64")] - &sys_io, - &sys_mem, - free_irqs, - mmio_region, - ) - } + use crate::sysbus::sysbus_init; fn pflash_dev_init(file_name: &str) -> Arc> { let sector_len: u32 = 0x40_000; @@ -973,32 +948,20 @@ mod test { fd.set_len(flash_size).unwrap(); drop(fd); - let fd = Some( + let fd = Some(Arc::new( std::fs::OpenOptions::new() .read(true) .write(true) .open(file_name) .unwrap(), - ); - let pflash = PFlash::new(flash_size, &fd, sector_len, 4, 2, read_only).unwrap(); + )); let sysbus = sysbus_init(); - let dev = Arc::new(Mutex::new(pflash)); - let region_ops = sysbus.build_region_ops(&dev); - let host_mmap = Arc::new( - HostMemMapping::new( - GuestAddress(flash_base), - None, - flash_size, - fd.map(FileBackend::new_common), - false, - true, - false, - ) - .unwrap(), - ); + let pflash = PFlash::new( + flash_size, fd, sector_len, 4, 2, read_only, &sysbus, flash_base, + ) + .unwrap(); + let dev = pflash.realize().unwrap(); - let rom_region = Region::init_rom_device_region(host_mmap, region_ops, "pflash-dev"); - dev.lock().unwrap().rom = Some(rom_region); dev } diff --git a/devices/src/legacy/pl011.rs b/devices/src/legacy/pl011.rs index 24a34ef8ebc2439282402430a71ebf802ccf73fc..465173ba8eea13e76f721e8fb7e4f7041f4b1579 100644 --- a/devices/src/legacy/pl011.rs +++ b/devices/src/legacy/pl011.rs @@ -13,12 +13,11 @@ use std::sync::{Arc, Mutex}; use anyhow::{Context, Result}; -use log::{debug, error}; -use vmm_sys_util::eventfd::EventFd; +use log::error; use super::error::LegacyError; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::{ AmlActiveLevel, AmlBuilder, AmlDevice, AmlEdgeLevel, AmlExtendedInterrupt, AmlIntShare, AmlInteger, AmlMemory32Fixed, AmlNameDecl, AmlReadAndWrite, AmlResTemplate, AmlResourceUsage, @@ -26,17 +25,16 @@ use acpi::{ }; use address_space::GuestAddress; use chardev_backend::chardev::{Chardev, InputReceiver}; -use machine_manager::{ - config::{BootSource, Param, SerialConfig}, - event_loop::EventLoop, -}; +use machine_manager::config::SerialConfig; +use machine_manager::event_loop::EventLoop; use migration::{ snapshot::PL011_SNAPSHOT_ID, DeviceStateDesc, FieldDesc, MigrationError, MigrationHook, MigrationManager, StateTransfer, }; use migration_derive::{ByteCode, Desc}; use util::byte_code::ByteCode; -use util::loop_context::EventNotifierHelper; +use util::gen_base_func; +use util::loop_context::{create_new_eventfd, EventNotifierHelper}; use util::num_ops::read_data_u32; const PL011_FLAG_TXFE: u8 = 0x80; @@ -99,7 +97,7 @@ impl PL011State { fn new() -> Self { PL011State { rfifo: [0; PL011_FIFO_SIZE], - flags: (PL011_FLAG_TXFE | PL011_FLAG_RXFE) as u32, + flags: u32::from(PL011_FLAG_TXFE | PL011_FLAG_RXFE), lcr: 0, rsr: 0, cr: 0x300, @@ -131,17 +129,28 @@ pub struct PL011 { impl PL011 { /// Create a new `PL011` instance with default parameters. - pub fn new(cfg: SerialConfig) -> Result { - Ok(PL011 { + pub fn new( + cfg: SerialConfig, + sysbus: &Arc>, + region_base: u64, + region_size: u64, + ) -> Result { + let mut pl011 = PL011 { base: SysBusDevBase { dev_type: SysBusDevType::PL011, - interrupt_evt: Some(Arc::new(EventFd::new(libc::EFD_NONBLOCK)?)), + interrupt_evt: Some(Arc::new(create_new_eventfd()?)), ..Default::default() }, paused: false, state: PL011State::new(), chardev: Arc::new(Mutex::new(Chardev::new(cfg.chardev))), - }) + }; + pl011 + .set_sys_resource(sysbus, region_base, region_size, "PL011") + .with_context(|| "Failed to set system resource for PL011.")?; + pl011.set_parent_bus(sysbus.clone()); + + Ok(pl011) } fn interrupt(&mut self) { @@ -154,45 +163,6 @@ impl PL011 { } } - pub fn realize( - mut self, - sysbus: &mut SysBus, - region_base: u64, - region_size: u64, - bs: &Arc>, - ) -> Result<()> { - self.chardev - .lock() - .unwrap() - .realize() - .with_context(|| "Failed to realize chardev")?; - self.set_sys_resource(sysbus, region_base, region_size) - .with_context(|| "Failed to set system resource for PL011.")?; - - let dev = Arc::new(Mutex::new(self)); - sysbus - .attach_device(&dev, region_base, region_size, "PL011") - .with_context(|| "Failed to attach PL011 to system bus.")?; - - bs.lock().unwrap().kernel_cmdline.push(Param { - param_type: "earlycon".to_string(), - value: format!("pl011,mmio,0x{:08x}", region_base), - }); - MigrationManager::register_device_instance( - PL011State::descriptor(), - dev.clone(), - PL011_SNAPSHOT_ID, - ); - let locked_dev = dev.lock().unwrap(); - locked_dev.chardev.lock().unwrap().set_receiver(&dev); - EventLoop::update_event( - EventNotifierHelper::internal_notifiers(locked_dev.chardev.clone()), - None, - ) - .with_context(|| LegacyError::RegNotifierErr)?; - Ok(()) - } - fn unpause_rx(&mut self) { if self.paused { trace::pl011_unpause_rx(); @@ -204,20 +174,20 @@ impl PL011 { impl InputReceiver for PL011 { fn receive(&mut self, data: &[u8]) { - self.state.flags &= !PL011_FLAG_RXFE as u32; + self.state.flags &= u32::from(!PL011_FLAG_RXFE); for val in data { let mut slot = (self.state.read_pos + self.state.read_count) as usize; if slot >= PL011_FIFO_SIZE { slot -= PL011_FIFO_SIZE; } - self.state.rfifo[slot] = *val as u32; + self.state.rfifo[slot] = u32::from(*val); self.state.read_count += 1; trace::pl011_receive(self.state.rfifo[slot], self.state.read_count); } // If in character-mode, or in FIFO-mode and FIFO is full, trigger the interrupt. if ((self.state.lcr & 0x10) == 0) || (self.state.read_count as usize == PL011_FIFO_SIZE) { - self.state.flags |= PL011_FLAG_RXFF as u32; + self.state.flags |= u32::from(PL011_FLAG_RXFF); trace::pl011_receive_full(); } if self.state.read_count >= self.state.read_trigger { @@ -237,23 +207,40 @@ impl InputReceiver for PL011 { } impl Device for PL011 { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn realize(self) -> Result>> { + self.chardev + .lock() + .unwrap() + .realize() + .with_context(|| "Failed to realize chardev")?; + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let dev = Arc::new(Mutex::new(self)); + sysbus + .attach_device(&dev) + .with_context(|| "Failed to attach PL011 to system bus.")?; + drop(locked_bus); + MigrationManager::register_device_instance( + PL011State::descriptor(), + dev.clone(), + PL011_SNAPSHOT_ID, + ); + let locked_dev = dev.lock().unwrap(); + locked_dev.chardev.lock().unwrap().set_receiver(&dev); + EventLoop::update_event( + EventNotifierHelper::internal_notifiers(locked_dev.chardev.clone()), + None, + ) + .with_context(|| LegacyError::RegNotifierErr)?; + drop(locked_dev); + Ok(dev) } } impl SysBusDevOps for PL011 { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { if data.len() > 4 { @@ -267,7 +254,7 @@ impl SysBusDevOps for PL011 { // Data register. self.unpause_rx(); - self.state.flags &= !(PL011_FLAG_RXFF as u32); + self.state.flags &= !u32::from(PL011_FLAG_RXFF); let c = self.state.rfifo[self.state.read_pos as usize]; if self.state.read_count > 0 { @@ -278,7 +265,7 @@ impl SysBusDevOps for PL011 { } } if self.state.read_count == 0 { - self.state.flags |= PL011_FLAG_RXFE as u32; + self.state.flags |= u32::from(PL011_FLAG_RXFE); } if self.state.read_count == self.state.read_trigger - 1 { self.state.int_level &= !INT_RX; @@ -330,7 +317,7 @@ impl SysBusDevOps for PL011 { 0x3f8..=0x400 => { // Register 0xFE0~0xFFC is UART Peripheral Identification Registers // and PrimeCell Identification Registers. - ret = *self.state.id.get(((offset - 0xfe0) >> 2) as usize).unwrap() as u32; + ret = u32::from(*self.state.id.get(((offset - 0xfe0) >> 2) as usize).unwrap()); } _ => { error!("Failed to read pl011: Invalid offset 0x{:x}", offset); @@ -353,20 +340,10 @@ impl SysBusDevOps for PL011 { match offset >> 2 { 0 => { let ch = value as u8; - - if let Some(output) = &mut self.chardev.lock().unwrap().output { - let mut locked_output = output.lock().unwrap(); - if let Err(e) = locked_output.write_all(&[ch]) { - debug!("Failed to write to pl011 output fd, error is {:?}", e); - } - if let Err(e) = locked_output.flush() { - debug!("Failed to flush pl011, error is {:?}", e); - } - } else { - debug!("Failed to get output fd"); + if let Err(e) = self.chardev.lock().unwrap().fill_outbuf(vec![ch], None) { + error!("Failed to append pl011 data to outbuf of chardev, {:?}", e); return false; } - self.state.int_level |= INT_TX; self.interrupt(); } @@ -425,10 +402,6 @@ impl SysBusDevOps for PL011 { true } - - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } } impl StateTransfer for PL011 { @@ -481,18 +454,21 @@ impl AmlBuilder for PL011 { #[cfg(test)] mod test { use super::*; + use crate::sysbus::sysbus_init; use machine_manager::config::{ChardevConfig, ChardevType}; #[test] fn test_receive() { let chardev_cfg = ChardevConfig { - id: "chardev".to_string(), - backend: ChardevType::Stdio, + classtype: ChardevType::Stdio { + id: "chardev".to_string(), + }, }; - let mut pl011_dev = PL011::new(SerialConfig { + let config = SerialConfig { chardev: chardev_cfg, - }) - .unwrap(); + }; + let sysbus = sysbus_init(); + let mut pl011_dev = PL011::new(config, &sysbus, 0x0900_0000, 0x0000_1000).unwrap(); assert_eq!(pl011_dev.state.rfifo, [0; PL011_FIFO_SIZE]); assert_eq!(pl011_dev.state.flags, 0x90); assert_eq!(pl011_dev.state.lcr, 0); @@ -517,7 +493,7 @@ mod test { pl011_dev.receive(&data); assert_eq!(pl011_dev.state.read_count, data.len() as u32); for i in 0..data.len() { - assert_eq!(pl011_dev.state.rfifo[i], data[i] as u32); + assert_eq!(pl011_dev.state.rfifo[i], u32::from(data[i])); } assert_eq!(pl011_dev.state.flags, 0xC0); assert_eq!(pl011_dev.state.int_level, INT_RX); diff --git a/devices/src/legacy/pl031.rs b/devices/src/legacy/pl031.rs index a5790dd6d9c303d83d4b0be36c35129687e3e0f6..d6a9792205e0b0ddb0be20286250127bb8968eff 100644 --- a/devices/src/legacy/pl031.rs +++ b/devices/src/legacy/pl031.rs @@ -15,11 +15,10 @@ use std::time::{Instant, SystemTime, UNIX_EPOCH}; use anyhow::{Context, Result}; use byteorder::{ByteOrder, LittleEndian}; -use vmm_sys_util::eventfd::EventFd; use super::error::LegacyError; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::AmlBuilder; use address_space::GuestAddress; use migration::{ @@ -28,6 +27,8 @@ use migration::{ }; use migration_derive::{ByteCode, Desc}; use util::byte_code::ByteCode; +use util::gen_base_func; +use util::loop_context::create_new_eventfd; use util::num_ops::write_data_u32; /// Registers for pl031 from ARM PrimeCell Real Time Clock Technical Reference Manual. @@ -78,9 +79,9 @@ pub struct PL031 { base_time: Instant, } -impl Default for PL031 { - fn default() -> Self { - Self { +impl PL031 { + pub fn new(sysbus: &Arc>, region_base: u64, region_size: u64) -> Result { + let mut pl031 = Self { base: SysBusDevBase::new(SysBusDevType::Rtc), state: PL031State::default(), // since 1970-01-01 00:00:00,it never cause overflow. @@ -89,57 +90,43 @@ impl Default for PL031 { .expect("time wrong") .as_secs() as u32, base_time: Instant::now(), - } - } -} - -impl PL031 { - pub fn realize( - mut self, - sysbus: &mut SysBus, - region_base: u64, - region_size: u64, - ) -> Result<()> { - self.base.interrupt_evt = Some(Arc::new(EventFd::new(libc::EFD_NONBLOCK)?)); - self.set_sys_resource(sysbus, region_base, region_size) + }; + pl031.base.interrupt_evt = Some(Arc::new(create_new_eventfd()?)); + pl031 + .set_sys_resource(sysbus, region_base, region_size, "PL031") .with_context(|| LegacyError::SetSysResErr)?; + pl031.set_parent_bus(sysbus.clone()); - let dev = Arc::new(Mutex::new(self)); - sysbus.attach_device(&dev, region_base, region_size, "PL031")?; - - MigrationManager::register_device_instance( - PL031State::descriptor(), - dev, - PL031_SNAPSHOT_ID, - ); - - Ok(()) + Ok(pl031) } /// Get current clock value. fn get_current_value(&self) -> u32 { - (self.base_time.elapsed().as_secs() as u128 + self.tick_offset as u128) as u32 + (u128::from(self.base_time.elapsed().as_secs()) + u128::from(self.tick_offset)) as u32 } } impl Device for PL031 { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn realize(self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let dev = Arc::new(Mutex::new(self)); + sysbus.attach_device(&dev)?; - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + MigrationManager::register_device_instance( + PL031State::descriptor(), + dev.clone(), + PL031_SNAPSHOT_ID, + ); + + Ok(dev) } } impl SysBusDevOps for PL031 { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); /// Read data from registers by guest. fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { @@ -198,10 +185,6 @@ impl SysBusDevOps for PL031 { true } - - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } } impl AmlBuilder for PL031 { @@ -234,13 +217,15 @@ impl MigrationHook for PL031 {} #[cfg(test)] mod test { use super::*; + use crate::sysbus::sysbus_init; use util::time::mktime64; const WIGGLE: u32 = 2; #[test] fn test_set_year_20xx() { - let mut rtc = PL031::default(); + let sysbus = sysbus_init(); + let mut rtc = PL031::new(&sysbus, 0x0901_0000, 0x0000_1000).unwrap(); // Set rtc time: 2013-11-13 02:04:56. let mut wtick = mktime64(2013, 11, 13, 2, 4, 56) as u32; let mut data = [0; 4]; @@ -266,7 +251,8 @@ mod test { #[test] fn test_set_year_1970() { - let mut rtc = PL031::default(); + let sysbus = sysbus_init(); + let mut rtc = PL031::new(&sysbus, 0x0901_0000, 0x0000_1000).unwrap(); // Set rtc time (min): 1970-01-01 00:00:00. let wtick = mktime64(1970, 1, 1, 0, 0, 0) as u32; let mut data = [0; 4]; diff --git a/devices/src/legacy/ramfb.rs b/devices/src/legacy/ramfb.rs index 470fcd796137a87bf05be065afe175061e4a315c..b0a1dc98b1f751ac04e43ea647c10c8d15b588fd 100644 --- a/devices/src/legacy/ramfb.rs +++ b/devices/src/legacy/ramfb.rs @@ -16,20 +16,23 @@ use std::sync::{Arc, Mutex, Weak}; use std::time::Duration; use anyhow::{Context, Result}; +use clap::{ArgAction, Parser}; use drm_fourcc::DrmFourcc; use log::error; use super::fwcfg::{FwCfgOps, FwCfgWriteCallback}; use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType}; -use crate::{Device, DeviceBase}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::AmlBuilder; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; +use machine_manager::config::valid_id; use machine_manager::event_loop::EventLoop; use ui::console::{ console_init, display_graphic_update, display_replace_surface, ConsoleType, DisplayConsole, DisplaySurface, HardWareOperations, }; use ui::input::{key_event, KEYCODE_RET}; +use util::gen_base_func; use util::pixman::{pixman_format_bpp, pixman_format_code_t, pixman_image_create_bits}; const BYTES_PER_PIXELS: u32 = 8; @@ -39,6 +42,17 @@ const INSTALL_CHECK_INTERVEL_MS: u64 = 500; const INSTALL_RELEASE_INTERVEL_MS: u64 = 200; const INSTALL_PRESS_INTERVEL_MS: u64 = 100; +#[derive(Parser, Debug, Clone)] +#[command(no_binary_name(true))] +pub struct RamfbConfig { + #[arg(long, value_parser = ["ramfb"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long, default_value = "false", action = ArgAction::Append)] + pub install: bool, +} + #[repr(packed)] struct RamfbCfg { _addr: u64, @@ -107,13 +121,17 @@ impl RamfbState { } if stride == 0 { - let linesize = width * pixman_format_bpp(format as u32) as u32 / BYTES_PER_PIXELS; + let linesize = width * u32::from(pixman_format_bpp(format as u32)) / BYTES_PER_PIXELS; stride = linesize; } - let fb_addr = match self.sys_mem.addr_cache_init(GuestAddress(addr)) { + let fb_addr = match self + .sys_mem + .addr_cache_init(GuestAddress(addr), AddressAttr::Ram) + { Some((hva, len)) => { - if len < stride as u64 { + let sf_len = u64::from(stride) * u64::from(height); + if len < sf_len { error!("Insufficient contiguous memory length"); return; } @@ -230,38 +248,35 @@ pub struct Ramfb { } impl Ramfb { - pub fn new(sys_mem: Arc, install: bool) -> Self { - Ramfb { + pub fn new(sys_mem: Arc, sysbus: &Arc>, install: bool) -> Self { + let mut ramfb = Ramfb { base: SysBusDevBase::new(SysBusDevType::Ramfb), ramfb_state: RamfbState::new(sys_mem, install), - } - } - - pub fn realize(self, sysbus: &mut SysBus) -> Result<()> { - let dev = Arc::new(Mutex::new(self)); - sysbus.attach_dynamic_device(&dev)?; - Ok(()) + }; + ramfb.set_parent_bus(sysbus.clone()); + ramfb } } impl Device for Ramfb { - fn device_base(&self) -> &DeviceBase { - &self.base.base + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + self.ramfb_state.reset_ramfb_state(); + Ok(()) } - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn realize(self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let dev = Arc::new(Mutex::new(self)); + sysbus.attach_device(&dev)?; + Ok(dev) } } impl SysBusDevOps for Ramfb { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, _data: &mut [u8], _base: GuestAddress, _offset: u64) -> bool { error!("Ramfb can not be read!"); @@ -272,11 +287,6 @@ impl SysBusDevOps for Ramfb { error!("Ramfb can not be written!"); false } - - fn reset(&mut self) -> Result<()> { - self.ramfb_state.reset_ramfb_state(); - Ok(()) - } } impl AmlBuilder for Ramfb { @@ -317,3 +327,25 @@ fn set_press_event(install: Arc, data: *const u8) { install.store(false, Ordering::Release); } } + +#[cfg(test)] +mod tests { + use super::*; + use machine_manager::config::str_slip_to_clap; + + #[test] + fn test_ramfb_config_cmdline_parser() { + // Test1: install. + let ramfb_cmd1 = "ramfb,id=ramfb0,install=true"; + let ramfb_config = + RamfbConfig::try_parse_from(str_slip_to_clap(ramfb_cmd1, true, false)).unwrap(); + assert_eq!(ramfb_config.id, "ramfb0"); + assert_eq!(ramfb_config.install, true); + + // Test2: Default. + let ramfb_cmd2 = "ramfb,id=ramfb0"; + let ramfb_config = + RamfbConfig::try_parse_from(str_slip_to_clap(ramfb_cmd2, true, false)).unwrap(); + assert_eq!(ramfb_config.install, false); + } +} diff --git a/devices/src/legacy/rtc.rs b/devices/src/legacy/rtc.rs index 8334c2d19ba8216e3623eab17dc715dd4e761e06..4c324bf5e6c22acf464cb1032bb45f914632251e 100644 --- a/devices/src/legacy/rtc.rs +++ b/devices/src/legacy/rtc.rs @@ -15,15 +15,16 @@ use std::time::{Instant, SystemTime, UNIX_EPOCH}; use anyhow::Result; use log::{debug, error, warn}; -use vmm_sys_util::eventfd::EventFd; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::{ AmlBuilder, AmlDevice, AmlEisaId, AmlIoDecode, AmlIoResource, AmlIrqNoFlags, AmlNameDecl, AmlResTemplate, AmlScopeBuilder, }; use address_space::GuestAddress; +use util::gen_base_func; +use util::loop_context::create_new_eventfd; use util::time::{mktime64, NANOSECONDS_PER_SECOND}; /// IO port of RTC device to select Register to read/write. @@ -94,7 +95,7 @@ fn bcd_to_bin(src: u8) -> u64 { return 0_u64; } - (((src >> 4) * 10) + (src & 0x0f)) as u64 + u64::from(((src >> 4) * 10) + (src & 0x0f)) } #[allow(clippy::upper_case_acronyms)] @@ -117,16 +118,11 @@ pub struct RTC { impl RTC { /// Construct function of RTC device. - pub fn new() -> Result { + pub fn new(sysbus: &Arc>) -> Result { let mut rtc = RTC { base: SysBusDevBase { dev_type: SysBusDevType::Rtc, - res: SysRes { - region_base: RTC_PORT_INDEX, - region_size: 8, - irq: -1, - }, - interrupt_evt: Some(Arc::new(EventFd::new(libc::EFD_NONBLOCK)?)), + interrupt_evt: Some(Arc::new(create_new_eventfd()?)), ..Default::default() }, cmos_data: [0_u8; 128], @@ -146,6 +142,9 @@ impl RTC { rtc.init_rtc_reg(); + rtc.set_sys_resource(sysbus, RTC_PORT_INDEX, 8, "RTC")?; + rtc.set_parent_bus(sysbus.clone()); + Ok(rtc) } @@ -266,19 +265,9 @@ impl RTC { true } - pub fn realize(mut self, sysbus: &mut SysBus) -> Result<()> { - let region_base = self.base.res.region_base; - let region_size = self.base.res.region_size; - self.set_sys_resource(sysbus, region_base, region_size)?; - - let dev = Arc::new(Mutex::new(self)); - sysbus.attach_device(&dev, region_base, region_size, "RTC")?; - Ok(()) - } - /// Get current clock value. fn get_current_value(&self) -> i64 { - (self.base_time.elapsed().as_secs() as i128 + self.tick_offset as i128) as i64 + (i128::from(self.base_time.elapsed().as_secs()) + i128::from(self.tick_offset)) as i64 } fn set_rtc_cmos(&mut self, tm: libc::tm) { @@ -332,7 +321,13 @@ impl RTC { + bcd_to_bin(self.cmos_data[RTC_CENTURY_BCD as usize]) * 100; // Check rtc time is valid to prevent tick_offset overflow. - if year < 1970 || !(1..=12).contains(&mon) || !(1..=31).contains(&day) { + if year < 1970 + || !(1..=12).contains(&mon) + || !(1..=31).contains(&day) + || !(0..=24).contains(&hour) + || !(0..=60).contains(&min) + || !(0..=60).contains(&sec) + { warn!( "RTC: the updated rtc time {}-{}-{} may be invalid.", year, mon, day @@ -351,23 +346,26 @@ impl RTC { } impl Device for RTC { - fn device_base(&self) -> &DeviceBase { - &self.base.base + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + self.cmos_data.fill(0); + self.init_rtc_reg(); + self.set_memory(self.mem_size, self.gap_start); + Ok(()) } - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn realize(self) -> Result>> { + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let dev = Arc::new(Mutex::new(self)); + sysbus.attach_device(&dev)?; + Ok(dev) } } impl SysBusDevOps for RTC { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], base: GuestAddress, offset: u64) -> bool { if offset == 0 { @@ -390,17 +388,6 @@ impl SysBusDevOps for RTC { self.write_data(data) } } - - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } - - fn reset(&mut self) -> Result<()> { - self.cmos_data.fill(0); - self.init_rtc_reg(); - self.set_memory(self.mem_size, self.gap_start); - Ok(()) - } } impl AmlBuilder for RTC { @@ -428,6 +415,7 @@ mod test { use anyhow::Context; use super::*; + use crate::sysbus::sysbus_init; use address_space::GuestAddress; const WIGGLE: u8 = 2; @@ -448,7 +436,8 @@ mod test { #[test] fn test_set_year_20xx() -> Result<()> { - let mut rtc = RTC::new().with_context(|| "Failed to create RTC device")?; + let sysbus = sysbus_init(); + let mut rtc = RTC::new(&sysbus).with_context(|| "Failed to create RTC device")?; // Set rtc time: 2013-11-13 02:04:56 cmos_write(&mut rtc, RTC_CENTURY_BCD, 0x20); cmos_write(&mut rtc, RTC_YEAR, 0x13); @@ -482,7 +471,8 @@ mod test { #[test] fn test_set_year_1970() -> Result<()> { - let mut rtc = RTC::new().with_context(|| "Failed to create RTC device")?; + let sysbus = sysbus_init(); + let mut rtc = RTC::new(&sysbus).with_context(|| "Failed to create RTC device")?; // Set rtc time (min): 1970-01-01 00:00:00 cmos_write(&mut rtc, RTC_CENTURY_BCD, 0x19); cmos_write(&mut rtc, RTC_YEAR, 0x70); @@ -505,7 +495,8 @@ mod test { #[test] fn test_invalid_rtc_time() -> Result<()> { - let mut rtc = RTC::new().with_context(|| "Failed to create RTC device")?; + let sysbus = sysbus_init(); + let mut rtc = RTC::new(&sysbus).with_context(|| "Failed to create RTC device")?; // Set rtc year: 1969 cmos_write(&mut rtc, RTC_CENTURY_BCD, 0x19); cmos_write(&mut rtc, RTC_YEAR, 0x69); diff --git a/devices/src/legacy/serial.rs b/devices/src/legacy/serial.rs index 7c5b90729e29adbf11a4eb88745f57278c99e12e..e21112b3cf214dcfa7beb85c9e34c253af509b2b 100644 --- a/devices/src/legacy/serial.rs +++ b/devices/src/legacy/serial.rs @@ -15,11 +15,10 @@ use std::sync::{Arc, Mutex}; use anyhow::{bail, Context, Result}; use log::{debug, error}; -use vmm_sys_util::eventfd::EventFd; use super::error::LegacyError; -use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType, SysRes}; -use crate::{Device, DeviceBase}; +use crate::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType}; +use crate::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use acpi::{ AmlActiveLevel, AmlBuilder, AmlDevice, AmlEdgeLevel, AmlEisaId, AmlExtendedInterrupt, AmlIntShare, AmlInteger, AmlIoDecode, AmlIoResource, AmlNameDecl, AmlResTemplate, @@ -34,7 +33,8 @@ use migration::{ }; use migration_derive::{ByteCode, Desc}; use util::byte_code::ByteCode; -use util::loop_context::EventNotifierHelper; +use util::gen_base_func; +use util::loop_context::{create_new_eventfd, EventNotifierHelper}; pub const SERIAL_ADDR: u64 = 0x3f8; @@ -124,46 +124,25 @@ pub struct Serial { } impl Serial { - pub fn new(cfg: SerialConfig) -> Self { - Serial { + pub fn new( + cfg: SerialConfig, + sysbus: &Arc>, + region_base: u64, + region_size: u64, + ) -> Result { + let mut serial = Serial { base: SysBusDevBase::new(SysBusDevType::Serial), paused: false, rbr: VecDeque::new(), state: SerialState::new(), chardev: Arc::new(Mutex::new(Chardev::new(cfg.chardev))), - } - } - pub fn realize( - mut self, - sysbus: &mut SysBus, - region_base: u64, - region_size: u64, - ) -> Result<()> { - self.chardev - .lock() - .unwrap() - .realize() - .with_context(|| "Failed to realize chardev")?; - self.base.interrupt_evt = Some(Arc::new(EventFd::new(libc::EFD_NONBLOCK)?)); - self.set_sys_resource(sysbus, region_base, region_size) + }; + serial.base.interrupt_evt = Some(Arc::new(create_new_eventfd()?)); + serial + .set_sys_resource(sysbus, region_base, region_size, "Serial") .with_context(|| LegacyError::SetSysResErr)?; - - let dev = Arc::new(Mutex::new(self)); - sysbus.attach_device(&dev, region_base, region_size, "Serial")?; - - MigrationManager::register_device_instance( - SerialState::descriptor(), - dev.clone(), - SERIAL_SNAPSHOT_ID, - ); - let locked_dev = dev.lock().unwrap(); - locked_dev.chardev.lock().unwrap().set_receiver(&dev); - EventLoop::update_event( - EventNotifierHelper::internal_notifiers(locked_dev.chardev.clone()), - None, - ) - .with_context(|| LegacyError::RegNotifierErr)?; - Ok(()) + serial.set_parent_bus(sysbus.clone()); + Ok(serial) } fn unpause_rx(&mut self) { @@ -297,19 +276,10 @@ impl Serial { self.rbr.push_back(data); self.state.lsr |= UART_LSR_DR; - } else { - let output = self.chardev.lock().unwrap().output.clone(); - if output.is_none() { - self.update_iir(); - bail!("serial: failed to get output fd."); - } - let mut locked_output = output.as_ref().unwrap().lock().unwrap(); - locked_output - .write_all(&[data]) - .with_context(|| "serial: failed to write.")?; - locked_output - .flush() - .with_context(|| "serial: failed to flush.")?; + } else if let Err(e) = + self.chardev.lock().unwrap().fill_outbuf(vec![data], None) + { + bail!("Failed to append data to output buffer of chardev, {:?}", e); } self.update_iir(); @@ -381,23 +351,38 @@ impl InputReceiver for Serial { } impl Device for Serial { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn realize(self) -> Result>> { + self.chardev + .lock() + .unwrap() + .realize() + .with_context(|| "Failed to realize chardev")?; + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let dev = Arc::new(Mutex::new(self)); + sysbus.attach_device(&dev)?; - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + MigrationManager::register_device_instance( + SerialState::descriptor(), + dev.clone(), + SERIAL_SNAPSHOT_ID, + ); + let locked_dev = dev.lock().unwrap(); + locked_dev.chardev.lock().unwrap().set_receiver(&dev); + EventLoop::update_event( + EventNotifierHelper::internal_notifiers(locked_dev.chardev.clone()), + None, + ) + .with_context(|| LegacyError::RegNotifierErr)?; + drop(locked_dev); + Ok(dev) } } impl SysBusDevOps for Serial { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { data[0] = self.read_internal(offset); @@ -416,10 +401,6 @@ impl SysBusDevOps for Serial { fn get_irq(&self, _sysbus: &mut SysBus) -> Result { Ok(UART_IRQ) } - - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } } impl AmlBuilder for Serial { @@ -484,18 +465,22 @@ impl MigrationHook for Serial {} #[cfg(test)] mod test { use super::*; + use crate::sysbus::sysbus_init; use machine_manager::config::{ChardevConfig, ChardevType}; #[test] fn test_methods_of_serial() { // test new method let chardev_cfg = ChardevConfig { - id: "chardev".to_string(), - backend: ChardevType::Stdio, + classtype: ChardevType::Stdio { + id: "chardev".to_string(), + }, }; - let mut usart = Serial::new(SerialConfig { + let sysbus = sysbus_init(); + let config = SerialConfig { chardev: chardev_cfg.clone(), - }); + }; + let mut usart = Serial::new(config, &sysbus, SERIAL_ADDR, 8).unwrap(); assert_eq!(usart.state.ier, 0); assert_eq!(usart.state.iir, 1); assert_eq!(usart.state.lcr, 3); @@ -545,12 +530,15 @@ mod test { #[test] fn test_serial_migration_interface() { let chardev_cfg = ChardevConfig { - id: "chardev".to_string(), - backend: ChardevType::Stdio, + classtype: ChardevType::Stdio { + id: "chardev".to_string(), + }, }; - let mut usart = Serial::new(SerialConfig { + let config = SerialConfig { chardev: chardev_cfg, - }); + }; + let sysbus = sysbus_init(); + let mut usart = Serial::new(config, &sysbus, SERIAL_ADDR, 8).unwrap(); // Get state vector for usart let serial_state_result = usart.get_state_vec(); assert!(serial_state_result.is_ok()); diff --git a/devices/src/lib.rs b/devices/src/lib.rs index a155ff64306f1558c214c2b8b36528492ab66283..d335d1a7c2b933e0fbecc1dd899cb2c3986fb3dd 100644 --- a/devices/src/lib.rs +++ b/devices/src/lib.rs @@ -39,25 +39,53 @@ pub use legacy::error::LegacyError as LegacyErrs; pub use scsi::bus as ScsiBus; pub use scsi::disk as ScsiDisk; +use std::any::Any; +use std::any::TypeId; +use std::collections::BTreeMap; +use std::sync::{Arc, Mutex, Weak}; + +use anyhow::{bail, Context, Result}; +use util::AsAny; + #[derive(Clone, Default)] pub struct DeviceBase { /// Name of this device pub id: String, /// Whether it supports hot-plug/hot-unplug. pub hotpluggable: bool, + /// parent bus. + pub parent: Option>>, + /// Child bus. + pub child: Option>>, } impl DeviceBase { - pub fn new(id: String, hotpluggable: bool) -> Self { - DeviceBase { id, hotpluggable } + pub fn new(id: String, hotpluggable: bool, parent: Option>>) -> Self { + DeviceBase { + id, + hotpluggable, + parent, + child: None, + } } } -pub trait Device { +pub trait Device: Any + AsAny + Send + Sync { fn device_base(&self) -> &DeviceBase; fn device_base_mut(&mut self) -> &mut DeviceBase; + /// `Any` trait requires a `'static` lifecycle. Error "argument requires that `device` is borrowed for `'static`" + /// will be reported when using `as_any` directly for local variables which don't have `'static` lifecycle. + /// Encapsulation of `as_any` can solve this problem. + fn device_as_any(&mut self) -> &mut dyn Any { + self.as_any_mut() + } + + fn device_type_id(&self) -> TypeId { + self.type_id() + } + /// Get device name. fn name(&self) -> String { self.device_base().id.clone() @@ -67,4 +95,235 @@ pub trait Device { fn hotpluggable(&self) -> bool { self.device_base().hotpluggable } + + /// Get the bus which this device is mounted on. + fn parent_bus(&self) -> Option>> { + self.device_base().parent.clone() + } + + fn set_parent_bus(&mut self, bus: Arc>) { + self.device_base_mut().parent = Some(Arc::downgrade(&bus)); + } + + /// Get the bus which this device has. + fn child_bus(&self) -> Option>> { + self.device_base().child.clone() + } + + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + Ok(()) + } + + /// Realize device. + fn realize(self) -> Result>> + where + Self: Sized, + { + // Note: Only PciHost does not have its own realization logic, + // but it will not be called. + bail!("Realize of the device {} is not implemented", self.name()); + } + + /// Unrealize device. + fn unrealize(&mut self) -> Result<()> { + bail!("Unrealize of the device {} is not implemented", self.name()); + } +} + +/// Macro `convert_device_ref!`: Convert from Arc> to &$device_type. +/// +/// # Arguments +/// +/// * `$trait_device` - Variable defined as Arc>. +/// * `$lock_device` - Variable used to get MutexGuard<'_, dyn Device>. +/// * `$struct_device` - Variable used to get &$device_type. +/// * `$device_type` - Struct corresponding to device type. +#[macro_export] +macro_rules! convert_device_ref { + ($trait_device:expr, $lock_device: ident, $struct_device: ident, $device_type: ident) => { + let mut $lock_device = $trait_device.lock().unwrap(); + let $struct_device = $lock_device + .device_as_any() + .downcast_ref::<$device_type>() + .unwrap(); + }; +} + +/// Macro `convert_device_mut!`: Convert from Arc> to &mut $device_type. +/// +/// # Arguments +/// +/// * `$trait_device` - Variable defined as Arc>. +/// * `$lock_device` - Variable used to get MutexGuard<'_, dyn Device>. +/// * `$struct_device` - Variable used to get &mut $device_type. +/// * `$device_type` - Struct corresponding to device type. +#[macro_export] +macro_rules! convert_device_mut { + ($trait_device:expr, $lock_device: ident, $struct_device: ident, $device_type: ident) => { + let mut $lock_device = $trait_device.lock().unwrap(); + let $struct_device = $lock_device + .device_as_any() + .downcast_mut::<$device_type>() + .unwrap(); + }; +} + +#[derive(Default)] +pub struct BusBase { + /// Name of this bus. + pub name: String, + /// Parent device. + pub parent: Option>>, + /// Children devices. + /// + /// Note: + /// 1. The construction of FDT table needs to strictly follow the order of sysbus, + /// so `BTreemap` needs to be used. + /// 2. every device has a unique address on the bus. Using `u64` is sufficient for we can + /// convert it to u8(devfn) for PCI bus and convert it to (u8, u16)(target, lun) for SCSI bus. + /// SysBus doesn't need this unique `u64` address, so we will incrementally fill in a useless number. + pub children: BTreeMap>>, +} + +impl BusBase { + fn new(name: String) -> BusBase { + Self { + name, + ..Default::default() + } + } +} + +pub trait Bus: Any + AsAny + Send + Sync { + fn bus_base(&self) -> &BusBase; + + fn bus_base_mut(&mut self) -> &mut BusBase; + + /// `Any` trait requires a `'static` lifecycle. Error "argument requires that `bus` is borrowed for `'static`" + /// will be reported when using `as_any` directly for local variables which don't have `'static` lifecycle. + /// Encapsulation of `as_any` can solve this problem. + fn bus_as_any(&mut self) -> &mut dyn Any { + self.as_any_mut() + } + + /// Get the name of this bus. + fn name(&self) -> String { + self.bus_base().name.clone() + } + + /// Get the device that owns this bus. + fn parent_device(&self) -> Option>> { + self.bus_base().parent.clone() + } + + /// Get the devices mounted on this bus. + fn child_devices(&self) -> BTreeMap>> { + self.bus_base().children.clone() + } + + /// Get the specific device mounted on this bus. + fn child_dev(&self, key: u64) -> Option<&Arc>> { + self.bus_base().children.get(&key) + } + + /// Attach device to this bus. + fn attach_child(&mut self, key: u64, dev: Arc>) -> Result<()> { + let children = &mut self.bus_base_mut().children; + if children.get(&key).is_some() { + bail!( + "Location of the device {} is same as one of the bus {}", + dev.lock().unwrap().name(), + self.name() + ); + } + children.insert(key, dev); + + Ok(()) + } + + /// Detach device from this bus. + fn detach_child(&mut self, key: u64) -> Result<()> { + self.bus_base_mut() + .children + .remove(&key) + .with_context(|| format!("No such device using key {} in bus {}.", key, self.name()))?; + + Ok(()) + } + + /// Bus reset means that all devices attached to this bus should reset. + fn reset(&self) -> Result<()> { + for dev in self.child_devices().values() { + let mut locked_dev = dev.lock().unwrap(); + locked_dev + .reset(true) + .with_context(|| format!("Failed to reset device {}", locked_dev.name()))?; + } + + Ok(()) + } +} + +/// Macro `convert_bus_ref!`: Convert from Arc> to &$bus_type. +/// +/// # Arguments +/// +/// * `$trait_bus` - Variable defined as Arc>. +/// * `$lock_bus` - Variable used to get MutexGuard<'_, dyn Bus>. +/// * `$struct_bus` - Variable used to get &$bus_type. +/// * `$bus_type` - Struct corresponding to bus type. +#[macro_export] +macro_rules! convert_bus_ref { + ($trait_bus:expr, $lock_bus: ident, $struct_bus: ident, $bus_type: ident) => { + let mut $lock_bus = $trait_bus.lock().unwrap(); + let $struct_bus = $lock_bus.bus_as_any().downcast_ref::<$bus_type>().unwrap(); + }; +} + +/// Macro `convert_bus_mut!`: Convert from Arc> to &mut $bus_type. +/// +/// # Arguments +/// +/// * `$trait_bus` - Variable defined as Arc>. +/// * `$lock_bus` - Variable used to get MutexGuard<'_, dyn Bus>. +/// * `$struct_bus` - Variable used to get &mut $bus_type. +/// * `$bus_type` - Struct corresponding to bus type. +#[macro_export] +macro_rules! convert_bus_mut { + ($trait_bus:expr, $lock_bus: ident, $struct_bus: ident, $bus_type: ident) => { + let mut $lock_bus = $trait_bus.lock().unwrap(); + let $struct_bus = $lock_bus.bus_as_any().downcast_mut::<$bus_type>().unwrap(); + }; +} + +#[cfg(test)] +pub mod test { + use std::sync::Arc; + + use address_space::{AddressSpace, GuestAddress, HostMemMapping, Region}; + + pub fn address_space_init() -> Arc { + let root = Region::init_container_region(1 << 36, "root"); + let sys_space = AddressSpace::new(root, "sys_space", None).unwrap(); + let host_mmap = Arc::new( + HostMemMapping::new( + GuestAddress(0), + None, + 0x1000_0000, + None, + false, + false, + false, + ) + .unwrap(), + ); + sys_space + .root() + .add_subregion( + Region::init_ram_region(host_mmap.clone(), "region_1"), + host_mmap.start_address().raw_value(), + ) + .unwrap(); + sys_space + } } diff --git a/devices/src/misc/ivshmem.rs b/devices/src/misc/ivshmem.rs index e5c7cd74fb888d2bd821304c61b2cdb1472211b2..86b92381757430a0ffe321a822787fcbeb0ec6d6 100644 --- a/devices/src/misc/ivshmem.rs +++ b/devices/src/misc/ivshmem.rs @@ -11,21 +11,22 @@ // See the Mulan PSL v2 for more details. use std::sync::{ - atomic::{AtomicU16, Ordering}, - Arc, Mutex, Weak, + atomic::{AtomicBool, AtomicU16, Ordering}, + Arc, Mutex, RwLock, Weak, }; -use anyhow::{bail, Result}; +use anyhow::Result; +use log::error; -use crate::pci::{ - config::{ - PciConfig, RegionType, DEVICE_ID, PCI_CLASS_MEMORY_RAM, PCI_CONFIG_SPACE_SIZE, - PCI_VENDOR_ID_REDHAT_QUMRANET, REVISION_ID, SUB_CLASS_CODE, VENDOR_ID, - }, - le_write_u16, PciBus, PciDevBase, PciDevOps, +use crate::pci::config::{ + PciConfig, RegionType, DEVICE_ID, PCI_CLASS_MEMORY_RAM, PCI_CONFIG_SPACE_SIZE, + PCI_VENDOR_ID_REDHAT_QUMRANET, REVISION_ID, SUB_CLASS_CODE, VENDOR_ID, }; -use crate::{Device, DeviceBase}; +use crate::pci::msix::init_msix; +use crate::pci::{le_write_u16, PciBus, PciDevBase, PciDevOps}; +use crate::{convert_bus_ref, Bus, Device, DeviceBase, PCI_BUS}; use address_space::{GuestAddress, Region, RegionOps}; +use util::gen_base_func; const PCI_VENDOR_ID_IVSHMEM: u16 = PCI_VENDOR_ID_REDHAT_QUMRANET; const PCI_DEVICE_ID_IVSHMEM: u16 = 0x1110; @@ -35,51 +36,116 @@ const PCI_BAR_MAX_IVSHMEM: u8 = 3; const IVSHMEM_REG_BAR_SIZE: u64 = 0x100; +const IVSHMEM_BAR0_IRQ_MASK: u64 = 0; +const IVSHMEM_BAR0_IRQ_STATUS: u64 = 4; +const IVSHMEM_BAR0_IVPOSITION: u64 = 8; +const IVSHMEM_BAR0_DOORBELL: u64 = 12; + +type Bar0Write = dyn Fn(&[u8], u64) -> bool + Send + Sync; +type Bar0Read = dyn Fn(&mut [u8], u64) -> bool + Send + Sync; + +#[derive(Default)] +struct Bar0Ops { + write: Option>, + read: Option>, +} + /// Intel-VM shared memory device structure. pub struct Ivshmem { base: PciDevBase, dev_id: Arc, ram_mem_region: Region, + vector_nr: u32, + bar0_ops: Arc>, + reset_cb: Option>, } impl Ivshmem { pub fn new( name: String, devfn: u8, - parent_bus: Weak>, + parent_bus: Weak>, ram_mem_region: Region, + vector_nr: u32, ) -> Self { Self { base: PciDevBase { - base: DeviceBase::new(name, false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, PCI_BAR_MAX_IVSHMEM), + base: DeviceBase::new(name, false, Some(parent_bus)), + config: PciConfig::new(devfn, PCI_CONFIG_SPACE_SIZE, PCI_BAR_MAX_IVSHMEM), devfn, - parent_bus, + bme: Arc::new(AtomicBool::new(false)), }, dev_id: Arc::new(AtomicU16::new(0)), ram_mem_region, + vector_nr, + bar0_ops: Arc::new(RwLock::new(Bar0Ops::default())), + reset_cb: None, } } fn register_bars(&mut self) -> Result<()> { - // Currently, ivshmem uses only the shared memory and does not use interrupt. - // Therefore, bar0 read and write callback is not implemented. - let reg_read = move |_: &mut [u8], _: GuestAddress, _: u64| -> bool { true }; - let reg_write = move |_: &[u8], _: GuestAddress, _: u64| -> bool { true }; + // Currently, ivshmem does not support intx interrupt, ivposition and doorbell. + let bar0_ops = self.bar0_ops.clone(); + let reg_read = move |data: &mut [u8], _: GuestAddress, offset: u64| -> bool { + if offset >= IVSHMEM_REG_BAR_SIZE { + error!("ivshmem: read offset {} exceeds bar0 size", offset); + return true; + } + match offset { + IVSHMEM_BAR0_IRQ_MASK | IVSHMEM_BAR0_IRQ_STATUS | IVSHMEM_BAR0_IVPOSITION => {} + _ => { + if let Some(rcb) = bar0_ops.read().unwrap().read.as_ref() { + return rcb(data, offset); + } + } + } + true + }; + let bar0_ops = self.bar0_ops.clone(); + let reg_write = move |data: &[u8], _: GuestAddress, offset: u64| -> bool { + if offset >= IVSHMEM_REG_BAR_SIZE { + error!("ivshmem: write offset {} exceeds bar0 size", offset); + return true; + } + match offset { + IVSHMEM_BAR0_IRQ_MASK | IVSHMEM_BAR0_IRQ_STATUS | IVSHMEM_BAR0_DOORBELL => {} + _ => { + if let Some(wcb) = bar0_ops.read().unwrap().write.as_ref() { + return wcb(data, offset); + } + } + } + true + }; let reg_region_ops = RegionOps { read: Arc::new(reg_read), write: Arc::new(reg_write), }; // bar0: mmio register + let mut bar0_region = + Region::init_io_region(IVSHMEM_REG_BAR_SIZE, reg_region_ops, "IvshmemIo"); + bar0_region.set_access_size(4); self.base.config.register_bar( 0, - Region::init_io_region(IVSHMEM_REG_BAR_SIZE, reg_region_ops, "IvshmemIo"), - RegionType::Mem64Bit, + bar0_region, + RegionType::Mem32Bit, false, IVSHMEM_REG_BAR_SIZE, )?; + // bar1: msix + if self.vector_nr > 0 { + init_msix( + &mut self.base, + 1, + self.vector_nr, + self.dev_id.clone(), + None, + None, + )?; + } + // bar2: ram self.base.config.register_bar( 2, @@ -89,28 +155,32 @@ impl Ivshmem { self.ram_mem_region.size(), ) } -} -impl Device for Ivshmem { - fn device_base(&self) -> &DeviceBase { - &self.base.base + pub fn trigger_msix(&self, vector: u16) { + if self.vector_nr == 0 { + return; + } + if let Some(msix) = self.base.config.msix.as_ref() { + msix.lock() + .unwrap() + .notify(vector, self.dev_id.load(Ordering::Acquire)); + } } - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + pub fn set_bar0_ops(&mut self, bar0_ops: (Arc, Arc)) { + self.bar0_ops.write().unwrap().write = Some(bar0_ops.0); + self.bar0_ops.write().unwrap().read = Some(bar0_ops.1); } -} -impl PciDevOps for Ivshmem { - fn pci_base(&self) -> &PciDevBase { - &self.base + pub fn register_reset_callback(&mut self, cb: Box) { + self.reset_cb = Some(cb); } +} - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base - } +impl Device for Ivshmem { + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { self.init_write_mask(false)?; self.init_write_clear_mask(false)?; le_write_u16( @@ -134,33 +204,37 @@ impl PciDevOps for Ivshmem { self.register_bars()?; // Attach to the PCI bus. - let pci_bus = self.base.parent_bus.upgrade().unwrap(); - let mut locked_pci_bus = pci_bus.lock().unwrap(); - let pci_device = locked_pci_bus.devices.get(&self.base.devfn); - match pci_device { - Some(device) => bail!( - "Devfn {:?} has been used by {:?}", - &self.base.devfn, - device.lock().unwrap().name() - ), - None => locked_pci_bus - .devices - .insert(self.base.devfn, Arc::new(Mutex::new(self))), - }; + let bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); + self.dev_id + .store(pci_bus.generate_dev_id(self.base.devfn), Ordering::Release); + let dev = Arc::new(Mutex::new(self)); + locked_bus.attach_child(u64::from(dev.lock().unwrap().base.devfn), dev.clone())?; + Ok(dev) + } + + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + if let Some(cb) = &self.reset_cb { + cb(); + } Ok(()) } +} + +impl PciDevOps for Ivshmem { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); fn write_config(&mut self, offset: usize, data: &[u8]) { - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let locked_parent_bus = parent_bus.lock().unwrap(); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(parent_bus, locked_bus, pci_bus); self.base.config.write( offset, data, self.dev_id.load(Ordering::Acquire), #[cfg(target_arch = "x86_64")] - Some(&locked_parent_bus.io_region), - Some(&locked_parent_bus.mem_region), + Some(&pci_bus.io_region), + Some(&pci_bus.mem_region), ); } } diff --git a/devices/src/misc/mod.rs b/devices/src/misc/mod.rs index 36c2d9c5b73f8d930149e841e0b3b1d87c340b88..0e0c015a4a70e58c01414969100122fbcb707292 100644 --- a/devices/src/misc/mod.rs +++ b/devices/src/misc/mod.rs @@ -14,7 +14,7 @@ pub mod scream; #[cfg(feature = "scream")] -mod ivshmem; +pub mod ivshmem; #[cfg(feature = "pvpanic")] pub mod pvpanic; diff --git a/devices/src/misc/pvpanic.rs b/devices/src/misc/pvpanic.rs index e2d29dd12609b566754c0e23b016daecf270c248..08cbebd02d28a14d81bcabcd63fdaac6ad8d7a77 100644 --- a/devices/src/misc/pvpanic.rs +++ b/devices/src/misc/pvpanic.rs @@ -11,25 +11,26 @@ // See the Mulan PSL v2 for more details. use std::sync::{ - atomic::{AtomicU16, Ordering}, + atomic::{AtomicBool, AtomicU16, Ordering}, Arc, Mutex, Weak, }; use anyhow::{bail, Context, Result}; +use clap::Parser; use log::{debug, error, info}; +use serde::{Deserialize, Serialize}; -use crate::pci::{ - config::{ - PciConfig, RegionType, CLASS_PI, DEVICE_ID, HEADER_TYPE, PCI_CLASS_SYSTEM_OTHER, - PCI_CONFIG_SPACE_SIZE, PCI_DEVICE_ID_REDHAT_PVPANIC, PCI_SUBDEVICE_ID_QEMU, - PCI_VENDOR_ID_REDHAT, PCI_VENDOR_ID_REDHAT_QUMRANET, REVISION_ID, SUBSYSTEM_ID, - SUBSYSTEM_VENDOR_ID, SUB_CLASS_CODE, VENDOR_ID, - }, - le_write_u16, PciBus, PciDevBase, PciDevOps, +use crate::pci::config::{ + PciConfig, RegionType, CLASS_PI, DEVICE_ID, HEADER_TYPE, PCI_CLASS_SYSTEM_OTHER, + PCI_CONFIG_SPACE_SIZE, PCI_DEVICE_ID_REDHAT_PVPANIC, PCI_SUBDEVICE_ID_QEMU, + PCI_VENDOR_ID_REDHAT, PCI_VENDOR_ID_REDHAT_QUMRANET, REVISION_ID, SUBSYSTEM_ID, + SUBSYSTEM_VENDOR_ID, SUB_CLASS_CODE, VENDOR_ID, }; -use crate::{Device, DeviceBase}; +use crate::pci::{le_write_u16, PciBus, PciDevBase, PciDevOps}; +use crate::{convert_bus_mut, convert_bus_ref, Bus, Device, DeviceBase, MUT_PCI_BUS, PCI_BUS}; use address_space::{GuestAddress, Region, RegionOps}; -use machine_manager::config::{PvpanicDevConfig, PVPANIC_CRASHLOADED, PVPANIC_PANICKED}; +use machine_manager::config::{get_pci_df, valid_id}; +use util::gen_base_func; const PVPANIC_PCI_REVISION_ID: u8 = 1; const PVPANIC_PCI_VENDOR_ID: u16 = PCI_VENDOR_ID_REDHAT_QUMRANET; @@ -40,6 +41,33 @@ const PVPANIC_REG_BAR_SIZE: u64 = 0x4; #[cfg(target_arch = "x86_64")] const PVPANIC_REG_BAR_SIZE: u64 = 0x1; +pub const PVPANIC_PANICKED: u32 = 1 << 0; +pub const PVPANIC_CRASHLOADED: u32 = 1 << 1; + +#[derive(Parser, Debug, Clone, Serialize, Deserialize)] +#[command(no_binary_name(true))] +pub struct PvpanicDevConfig { + #[arg(long, value_parser = ["pvpanic"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: String, + #[arg(long, value_parser = get_pci_df)] + pub addr: (u8, u8), + #[arg(long, alias = "supported-features", default_value = "3", value_parser = valid_supported_features)] + pub supported_features: u32, +} + +fn valid_supported_features(f: &str) -> Result { + let features = f.parse::()?; + let supported_features = match features & !(PVPANIC_PANICKED | PVPANIC_CRASHLOADED) { + 0 => features, + _ => bail!("Unsupported pvpanic device features {}", features), + }; + Ok(supported_features) +} + #[derive(Copy, Clone)] pub struct PvPanicState { supported_features: u32, @@ -59,12 +87,14 @@ impl PvPanicState { if (event & PVPANIC_PANICKED) == PVPANIC_PANICKED && (self.supported_features & PVPANIC_PANICKED) == PVPANIC_PANICKED { + hisysevent::STRATOVIRT_PVPANIC("PANICKED".to_string()); info!("pvpanic: panicked event"); } if (event & PVPANIC_CRASHLOADED) == PVPANIC_CRASHLOADED && (self.supported_features & PVPANIC_CRASHLOADED) == PVPANIC_CRASHLOADED { + hisysevent::STRATOVIRT_PVPANIC("CRASHLOADED".to_string()); info!("pvpanic: crashloaded event"); } @@ -79,13 +109,13 @@ pub struct PvPanicPci { } impl PvPanicPci { - pub fn new(config: &PvpanicDevConfig, devfn: u8, parent_bus: Weak>) -> Self { + pub fn new(config: &PvpanicDevConfig, devfn: u8, parent_bus: Weak>) -> Self { Self { base: PciDevBase { - base: DeviceBase::new(config.id.clone(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 1), + base: DeviceBase::new(config.id.clone(), false, Some(parent_bus)), + config: PciConfig::new(devfn, PCI_CONFIG_SPACE_SIZE, 1), devfn, - parent_bus, + bme: Arc::new(AtomicBool::new(false)), }, dev_id: AtomicU16::new(0), pvpanic: Arc::new(PvPanicState::new(config.supported_features)), @@ -114,7 +144,7 @@ impl PvPanicPci { } }); - matches!(cloned_pvpanic_write.handle_event(val as u32), Ok(())) + matches!(cloned_pvpanic_write.handle_event(u32::from(val)), Ok(())) }); let bar0_region_ops = RegionOps { @@ -137,25 +167,9 @@ impl PvPanicPci { } impl Device for PvPanicPci { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); -impl PciDevOps for PvPanicPci { - fn pci_base(&self) -> &PciDevBase { - &self.base - } - - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base - } - - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { self.init_write_mask(false)?; self.init_write_clear_mask(false)?; le_write_u16( @@ -200,42 +214,37 @@ impl PciDevOps for PvPanicPci { // Attach to the PCI bus. let devfn = self.base.devfn; let dev = Arc::new(Mutex::new(self)); - let pci_bus = dev.lock().unwrap().base.parent_bus.upgrade().unwrap(); - let mut locked_pci_bus = pci_bus.lock().unwrap(); - let device_id = locked_pci_bus.generate_dev_id(devfn); + let bus = dev.lock().unwrap().parent_bus().unwrap().upgrade().unwrap(); + MUT_PCI_BUS!(bus, locked_bus, pci_bus); + let device_id = pci_bus.generate_dev_id(devfn); dev.lock() .unwrap() .dev_id .store(device_id, Ordering::Release); - let pci_device = locked_pci_bus.devices.get(&devfn); - if pci_device.is_none() { - locked_pci_bus.devices.insert(devfn, dev); - } else { - bail!( - "pvpanic: Devfn {:?} has been used by {:?}", - &devfn, - pci_device.unwrap().lock().unwrap().name() - ); - } + locked_bus.attach_child(u64::from(devfn), dev.clone())?; - Ok(()) + Ok(dev) } fn unrealize(&mut self) -> Result<()> { Ok(()) } +} + +impl PciDevOps for PvPanicPci { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); fn write_config(&mut self, offset: usize, data: &[u8]) { - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let locked_parent_bus = parent_bus.lock().unwrap(); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(parent_bus, locked_bus, pci_bus); self.base.config.write( offset, data, self.dev_id.load(Ordering::Acquire), #[cfg(target_arch = "x86_64")] - Some(&locked_parent_bus.io_region), - Some(&locked_parent_bus.mem_region), + Some(&pci_bus.io_region), + Some(&pci_bus.mem_region), ); } } @@ -244,113 +253,118 @@ impl PciDevOps for PvPanicPci { mod tests { use super::*; use crate::pci::{host::tests::create_pci_host, le_read_u16, PciHost}; + use crate::{convert_bus_ref, convert_device_mut, PCI_BUS}; + use machine_manager::config::str_slip_to_clap; + + /// Convert from Arc> to &mut PvPanicPci. + #[macro_export] + macro_rules! MUT_PVPANIC_PCI { + ($trait_device:expr, $lock_device: ident, $struct_device: ident) => { + convert_device_mut!($trait_device, $lock_device, $struct_device, PvPanicPci); + }; + } fn init_pvpanic_dev(devfn: u8, supported_features: u32, dev_id: &str) -> Arc> { let pci_host = create_pci_host(); let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); + let root_bus = Arc::downgrade(&locked_pci_host.child_bus().unwrap()); let config = PvpanicDevConfig { id: dev_id.to_string(), supported_features, + classtype: "".to_string(), + bus: "pcie.0".to_string(), + addr: (3, 0), }; - let pvpanic_dev = PvPanicPci::new(&config, devfn, root_bus.clone()); + let pvpanic_dev = PvPanicPci::new(&config, devfn, root_bus); assert_eq!(pvpanic_dev.base.base.id, "pvpanic_test".to_string()); pvpanic_dev.realize().unwrap(); - drop(root_bus); drop(locked_pci_host); pci_host } + #[test] + fn test_pvpanic_cmdline_parser() { + // Test1: Right. + let cmdline = "pvpanic,id=pvpanic0,bus=pcie.0,addr=0x7,supported-features=0"; + let result = PvpanicDevConfig::try_parse_from(str_slip_to_clap(cmdline, true, false)); + assert_eq!(result.unwrap().supported_features, 0); + + // Test2: Default value. + let cmdline = "pvpanic,id=pvpanic0,bus=pcie.0,addr=0x7"; + let result = PvpanicDevConfig::try_parse_from(str_slip_to_clap(cmdline, true, false)); + assert_eq!(result.unwrap().supported_features, 3); + + // Test3: Illegal value. + let cmdline = "pvpanic,id=pvpanic0,bus=pcie.0,addr=0x7,supported-features=4"; + let result = PvpanicDevConfig::try_parse_from(str_slip_to_clap(cmdline, true, false)); + assert!(result.is_err()); + } + #[test] fn test_pvpanic_attached() { let pci_host = init_pvpanic_dev(7, PVPANIC_PANICKED | PVPANIC_CRASHLOADED, "pvpanic_test"); - let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); - - let pvpanic_dev = root_bus.upgrade().unwrap().lock().unwrap().get_device(0, 7); + let root_bus = pci_host.lock().unwrap().child_bus().unwrap(); + PCI_BUS!(root_bus, locked_bus, pci_bus); + let pvpanic_dev = pci_bus.get_device(0, 7); + drop(locked_bus); assert!(pvpanic_dev.is_some()); assert_eq!( - pvpanic_dev.unwrap().lock().unwrap().pci_base().base.id, + pvpanic_dev.unwrap().lock().unwrap().name(), "pvpanic_test".to_string() ); - let info = PciBus::find_attached_bus(&locked_pci_host.root_bus, "pvpanic_test"); + let info = PciBus::find_attached_bus(&root_bus, "pvpanic_test"); assert!(info.is_some()); let (bus, dev) = info.unwrap(); - assert_eq!(bus.lock().unwrap().name, "pcie.0"); + assert_eq!(bus.lock().unwrap().name(), "pcie.0"); assert_eq!(dev.lock().unwrap().name(), "pvpanic_test"); } #[test] fn test_pvpanic_config() { let pci_host = init_pvpanic_dev(7, PVPANIC_PANICKED | PVPANIC_CRASHLOADED, "pvpanic_test"); - let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); - - let pvpanic_dev = root_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .get_device(0, 7) - .unwrap(); - - let info = le_read_u16( - &pvpanic_dev.lock().unwrap().pci_base_mut().config.config, - VENDOR_ID as usize, - ) - .unwrap_or_else(|_| 0); + let root_bus = pci_host.lock().unwrap().child_bus().unwrap(); + PCI_BUS!(root_bus, locked_bus, pci_bus); + let pvpanic_dev = pci_bus.get_device(0, 7).unwrap(); + MUT_PVPANIC_PCI!(pvpanic_dev, locked_dev, pvpanic); + let info = le_read_u16(&pvpanic.pci_base_mut().config.config, VENDOR_ID as usize) + .unwrap_or_else(|_| 0); assert_eq!(info, PCI_VENDOR_ID_REDHAT); - let info = le_read_u16( - &pvpanic_dev.lock().unwrap().pci_base_mut().config.config, - DEVICE_ID as usize, - ) - .unwrap_or_else(|_| 0); + let info = le_read_u16(&pvpanic.pci_base_mut().config.config, DEVICE_ID as usize) + .unwrap_or_else(|_| 0); assert_eq!(info, PCI_DEVICE_ID_REDHAT_PVPANIC); let info = le_read_u16( - &pvpanic_dev.lock().unwrap().pci_base_mut().config.config, + &pvpanic.pci_base_mut().config.config, SUB_CLASS_CODE as usize, ) .unwrap_or_else(|_| 0); assert_eq!(info, PCI_CLASS_SYSTEM_OTHER); - let info = le_read_u16( - &pvpanic_dev.lock().unwrap().pci_base_mut().config.config, - SUBSYSTEM_VENDOR_ID, - ) - .unwrap_or_else(|_| 0); + let info = le_read_u16(&pvpanic.pci_base_mut().config.config, SUBSYSTEM_VENDOR_ID) + .unwrap_or_else(|_| 0); assert_eq!(info, PVPANIC_PCI_VENDOR_ID); - let info = le_read_u16( - &pvpanic_dev.lock().unwrap().pci_base_mut().config.config, - SUBSYSTEM_ID, - ) - .unwrap_or_else(|_| 0); + let info = + le_read_u16(&pvpanic.pci_base_mut().config.config, SUBSYSTEM_ID).unwrap_or_else(|_| 0); assert_eq!(info, PCI_SUBDEVICE_ID_QEMU); } #[test] fn test_pvpanic_read_features() { let pci_host = init_pvpanic_dev(7, PVPANIC_PANICKED | PVPANIC_CRASHLOADED, "pvpanic_test"); - let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); - - let pvpanic_dev = root_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .get_device(0, 7) - .unwrap(); + let root_bus = pci_host.lock().unwrap().child_bus().unwrap(); + PCI_BUS!(root_bus, locked_bus, pci_bus); + let pvpanic_dev = pci_bus.get_device(0, 7).unwrap(); + MUT_PVPANIC_PCI!(pvpanic_dev, locked_dev, pvpanic); // test read supported_features let mut data_read = [0xffu8; 1]; - let result = &pvpanic_dev.lock().unwrap().pci_base_mut().config.bars[0] + let result = &pvpanic.pci_base_mut().config.bars[0] .region .as_ref() .unwrap() @@ -365,21 +379,15 @@ mod tests { #[test] fn test_pvpanic_write_panicked() { let pci_host = init_pvpanic_dev(7, PVPANIC_PANICKED | PVPANIC_CRASHLOADED, "pvpanic_test"); - let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); - - let pvpanic_dev = root_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .get_device(0, 7) - .unwrap(); + let root_bus = pci_host.lock().unwrap().child_bus().unwrap(); + PCI_BUS!(root_bus, locked_bus, pci_bus); + let pvpanic_dev = pci_bus.get_device(0, 7).unwrap(); + MUT_PVPANIC_PCI!(pvpanic_dev, locked_dev, pvpanic); // test write panicked event let data_write = [PVPANIC_PANICKED as u8; 1]; let count = data_write.len() as u64; - let result = &pvpanic_dev.lock().unwrap().pci_base_mut().config.bars[0] + let result = &pvpanic.pci_base_mut().config.bars[0] .region .as_ref() .unwrap() @@ -390,21 +398,15 @@ mod tests { #[test] fn test_pvpanic_write_crashload() { let pci_host = init_pvpanic_dev(7, PVPANIC_PANICKED | PVPANIC_CRASHLOADED, "pvpanic_test"); - let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); - - let pvpanic_dev = root_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .get_device(0, 7) - .unwrap(); + let root_bus = pci_host.lock().unwrap().child_bus().unwrap(); + PCI_BUS!(root_bus, locked_bus, pci_bus); + let pvpanic_dev = pci_bus.get_device(0, 7).unwrap(); + MUT_PVPANIC_PCI!(pvpanic_dev, locked_dev, pvpanic); // test write crashload event let data_write = [PVPANIC_CRASHLOADED as u8; 1]; let count = data_write.len() as u64; - let result = &pvpanic_dev.lock().unwrap().pci_base_mut().config.bars[0] + let result = &pvpanic.pci_base_mut().config.bars[0] .region .as_ref() .unwrap() @@ -416,20 +418,15 @@ mod tests { fn test_pvpanic_write_unknown() { let pci_host = init_pvpanic_dev(7, PVPANIC_PANICKED | PVPANIC_CRASHLOADED, "pvpanic_test"); let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); - - let pvpanic_dev = root_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .get_device(0, 7) - .unwrap(); + let root_bus = locked_pci_host.child_bus().unwrap(); + PCI_BUS!(root_bus, locked_bus, pci_bus); + let pvpanic_dev = pci_bus.get_device(0, 7).unwrap(); + MUT_PVPANIC_PCI!(pvpanic_dev, locked_dev, pvpanic); // test write unknown event let data_write = [100u8; 1]; let count = data_write.len() as u64; - let result = &pvpanic_dev.lock().unwrap().pci_base_mut().config.bars[0] + let result = &pvpanic.pci_base_mut().config.bars[0] .region .as_ref() .unwrap() diff --git a/devices/src/misc/scream/alsa.rs b/devices/src/misc/scream/alsa.rs index e3ca324cb22e8bd79170fd395a8a7d2a91439e75..f6a8689eaa6e0d0dca190ceea534b53ffe138faf 100644 --- a/devices/src/misc/scream/alsa.rs +++ b/devices/src/misc/scream/alsa.rs @@ -24,8 +24,8 @@ use anyhow::Result; use log::{debug, error, warn}; use super::{ - AudioInterface, ScreamDirection, ShmemStreamFmt, StreamData, AUDIO_SAMPLE_RATE_44KHZ, - TARGET_LATENCY_MS, + AudioInterface, AudioStatus, ScreamDirection, ShmemStreamFmt, StreamData, + AUDIO_SAMPLE_RATE_44KHZ, TARGET_LATENCY_MS, }; const MAX_CHANNELS: u8 = 8; @@ -81,7 +81,7 @@ impl AlsaStreamData { hwp.set_rate_resample(true)?; hwp.set_access(Access::RWInterleaved)?; hwp.set_format(self.format)?; - hwp.set_channels(channels as u32)?; + hwp.set_channels(u32::from(channels))?; hwp.set_rate(self.rate, ValueOr::Nearest)?; // Set the latency in microseconds. hwp.set_buffer_time_near(self.latency * 1000, ValueOr::Nearest)?; @@ -168,7 +168,7 @@ impl AudioInterface for AlsaStreamData { return; } - let mut frames = 0; + let mut frames = 0_u32; let mut io = self.pcm.as_ref().unwrap().io_bytes(); // Make sure audio read does not bypass chunk_idx read. @@ -184,12 +184,14 @@ impl AudioInterface for AlsaStreamData { }; let samples = - recv_data.audio_size / (self.bytes_per_sample * recv_data.fmt.channels as u32); + recv_data.audio_size / (self.bytes_per_sample * u32::from(recv_data.fmt.channels)); while frames < samples { let send_frame_num = min(samples - frames, MAX_FRAME_NUM); - let offset = (frames * self.bytes_per_sample * recv_data.fmt.channels as u32) as usize; + let offset = + (frames * self.bytes_per_sample * u32::from(recv_data.fmt.channels)) as usize; let end = offset - + (send_frame_num * self.bytes_per_sample * recv_data.fmt.channels as u32) as usize; + + (send_frame_num * self.bytes_per_sample * u32::from(recv_data.fmt.channels)) + as usize; match io.write(&data[offset..end]) { Err(e) => { debug!("Failed to write data to ALSA buffer: {:?}", e); @@ -203,7 +205,8 @@ impl AudioInterface for AlsaStreamData { } Ok(n) => { trace::scream_alsa_send_frames(frames, offset, end); - frames += n as u32 / (self.bytes_per_sample * recv_data.fmt.channels as u32); + frames += + n as u32 / (self.bytes_per_sample * u32::from(recv_data.fmt.channels)); } } } @@ -215,7 +218,7 @@ impl AudioInterface for AlsaStreamData { return 0; } - let mut frames = 0; + let mut frames = 0_u32; let mut io = self.pcm.as_ref().unwrap().io_bytes(); // Make sure audio read does not bypass chunk_idx read. @@ -231,11 +234,12 @@ impl AudioInterface for AlsaStreamData { }; let samples = - recv_data.audio_size / (self.bytes_per_sample * recv_data.fmt.channels as u32); + recv_data.audio_size / (self.bytes_per_sample * u32::from(recv_data.fmt.channels)); while frames < samples { - let offset = (frames * self.bytes_per_sample * recv_data.fmt.channels as u32) as usize; + let offset = + (frames * self.bytes_per_sample * u32::from(recv_data.fmt.channels)) as usize; let end = offset - + ((samples - frames) * self.bytes_per_sample * recv_data.fmt.channels as u32) + + ((samples - frames) * self.bytes_per_sample * u32::from(recv_data.fmt.channels)) as usize; match io.read(&mut data[offset..end]) { Err(e) => { @@ -250,7 +254,8 @@ impl AudioInterface for AlsaStreamData { } Ok(n) => { trace::scream_alsa_receive_frames(frames, offset, end); - frames += n as u32 / (self.bytes_per_sample * recv_data.fmt.channels as u32); + frames += + n as u32 / (self.bytes_per_sample * u32::from(recv_data.fmt.channels)); // During the host headset switchover, io.read is blocked for a long time. // As a result, the VM recording delay exceeds 1s. Thereforce, check whether @@ -259,7 +264,7 @@ impl AudioInterface for AlsaStreamData { warn!("Scream alsa can't get frames delay: {e:?}"); 0 }); - if delay > self.rate as i64 >> 1 { + if delay > i64::from(self.rate) >> 1 { warn!("Scream alsa read audio blocked too long, delay {delay} frames, init again!"); self.init = false; } @@ -283,4 +288,12 @@ impl AudioInterface for AlsaStreamData { self.init = false; } + + fn get_status(&self) -> AudioStatus { + if self.init { + AudioStatus::Started + } else { + AudioStatus::Ready + } + } } diff --git a/devices/src/misc/scream/audio_demo.rs b/devices/src/misc/scream/audio_demo.rs index d22b7bc325374c119f8b6badb238e9022006c651..c709c21644d32d19e799c9d797ad581f3f47db16 100644 --- a/devices/src/misc/scream/audio_demo.rs +++ b/devices/src/misc/scream/audio_demo.rs @@ -10,6 +10,7 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::sync::{Arc, Mutex, RwLock}; use std::{ fs::{File, OpenOptions}, io::{Read, Write}, @@ -19,7 +20,50 @@ use std::{ use core::time; use log::error; -use super::{AudioInterface, ScreamDirection, StreamData}; +use super::{AudioExtension, AudioInterface, AudioStatus, ScreamDirection, StreamData}; +use crate::misc::ivshmem::Ivshmem; + +pub const INITIAL_VOLUME_VAL: u32 = 0xaa; +const IVSHMEM_VOLUME_SYNC_VECTOR: u16 = 0; + +pub struct DemoAudioVolume { + shm_dev: Arc>, + vol: RwLock, +} + +// SAFETY: all fields are protected by lock +unsafe impl Send for DemoAudioVolume {} +// SAFETY: all fields are protected by lock +unsafe impl Sync for DemoAudioVolume {} + +impl AudioExtension for DemoAudioVolume { + fn get_host_volume(&self) -> u32 { + *self.vol.read().unwrap() + } + + fn set_host_volume(&self, vol: u32) { + *self.vol.write().unwrap() = vol; + } +} + +impl DemoAudioVolume { + pub fn new(shm_dev: Arc>) -> Arc { + let vol = Arc::new(Self { + shm_dev, + vol: RwLock::new(0), + }); + vol.notify(INITIAL_VOLUME_VAL); + vol + } + + fn notify(&self, vol: u32) { + *self.vol.write().unwrap() = vol; + self.shm_dev + .lock() + .unwrap() + .trigger_msix(IVSHMEM_VOLUME_SYNC_VECTOR); + } +} pub struct AudioDemo { file: File, @@ -86,4 +130,8 @@ impl AudioInterface for AudioDemo { } fn destroy(&mut self) {} + + fn get_status(&self) -> AudioStatus { + AudioStatus::Started + } } diff --git a/devices/src/misc/scream/mod.rs b/devices/src/misc/scream/mod.rs index 13659f4b1f3e8cf4968cfa4e0c01cab4e46f9527..9dc1b3966407b8def4024064bc31ea24006d8031 100644 --- a/devices/src/misc/scream/mod.rs +++ b/devices/src/misc/scream/mod.rs @@ -12,56 +12,101 @@ #[cfg(feature = "scream_alsa")] mod alsa; -mod audio_demo; +pub mod audio_demo; #[cfg(all(target_env = "ohos", feature = "scream_ohaudio"))] mod ohaudio; #[cfg(feature = "scream_pulseaudio")] mod pulseaudio; -use std::{ - mem, - str::FromStr, - sync::{ - atomic::{fence, Ordering}, - Arc, Mutex, RwLock, Weak, - }, - thread, -}; +use std::str::FromStr; +use std::sync::atomic::{fence, Ordering}; +use std::sync::{Arc, Condvar, Mutex, RwLock, Weak}; +use std::{mem, thread}; use anyhow::{anyhow, bail, Context, Result}; use clap::{ArgAction, Parser}; use core::time; -use log::{error, warn}; +use log::{error, info, warn}; +use once_cell::sync::Lazy; #[cfg(feature = "scream_alsa")] use self::alsa::AlsaStreamData; -use self::audio_demo::AudioDemo; +use self::audio_demo::{AudioDemo, DemoAudioVolume}; use super::ivshmem::Ivshmem; -use crate::pci::{PciBus, PciDevOps}; +use crate::pci::{le_read_u32, le_write_u32}; +use crate::{Bus, Device}; use address_space::{GuestAddress, HostMemMapping, Region}; use machine_manager::config::{get_pci_df, parse_bool, valid_id}; +use machine_manager::notifier::register_vm_pause_notifier; +use machine_manager::state_query::register_state_query_callback; #[cfg(all(target_env = "ohos", feature = "scream_ohaudio"))] -use ohaudio::OhAudio; +use ohaudio::{OhAudio, OhAudioVolume}; #[cfg(feature = "scream_pulseaudio")] use pulseaudio::PulseStreamData; #[cfg(all(target_env = "ohos", feature = "scream_ohaudio"))] -use util::ohos_binding::misc::{get_firstcaller_tokenid, set_firstcaller_tokenid}; +use util::ohos_binding::misc::bound_tokenid; pub const AUDIO_SAMPLE_RATE_44KHZ: u32 = 44100; pub const AUDIO_SAMPLE_RATE_48KHZ: u32 = 48000; pub const WINDOWS_SAMPLE_BASE_RATE: u8 = 128; -pub const TARGET_LATENCY_MS: u32 = 50; +pub const TARGET_LATENCY_MS: u32 = 20; -// A frame of back-end audio data is 50ms, and the next frame of audio data needs -// to be trained in polling within 50ms. Theoretically, the shorter the polling time, +#[cfg(all(target_env = "ohos", feature = "scream_ohaudio"))] +const IVSHMEM_VOLUME_SYNC_VECTOR: u16 = 0; +const IVSHMEM_STATUS_CHANGE_VECTOR: u16 = 1; +const IVSHMEM_VECTORS_NR: u32 = 2; +pub const IVSHMEM_BAR0_VOLUME: u64 = 240; +pub const IVSHMEM_BAR0_STATUS: u64 = 244; + +pub const STATUS_PLAY_BIT: u32 = 0x1; +pub const STATUS_START_BIT: u32 = 0x2; +const STATUS_MIC_AVAIL_BIT: u32 = 0x4; + +// A frame of back-end audio data is 20ms, and the next frame of audio data needs +// to be trained in polling within 20ms. Theoretically, the shorter the polling time, // the better. However, if the value is too small, the overhead is high. So take a -// compromise: 50 * 1000 / 8 us. +// compromise: 20 * 1000 / 8 us. const POLL_DELAY_US: u64 = (TARGET_LATENCY_MS as u64) * 1000 / 8; pub const SCREAM_MAGIC: u64 = 0x02032023; +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd)] +pub enum AudioStatus { + // Processor is ready and waiting for play/capture. + #[default] + Ready, + // Processor is started and doing job. + Started, + // OH audio framework error occurred. + Error, + // OH audio stream is interrupted. + Intr, + // OH audio stream interruption ends. + IntrResume, +} + +type AuthorityNotify = dyn Fn() + Send + Sync; + +#[derive(Clone)] +pub struct AuthorityInformation { + state: bool, + notify: Option>, +} + +impl AuthorityInformation { + const fn default() -> AuthorityInformation { + AuthorityInformation { + state: true, + notify: None, + } + } +} + +type AuthInfo = RwLock; +static AUTH_INFO: Lazy = Lazy::new(|| RwLock::new(AuthorityInformation::default())); + /// The scream device defines the audio directions. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ScreamDirection { @@ -88,25 +133,24 @@ pub struct ShmemStreamHeader { pub fmt: ShmemStreamFmt, } -fn record_authority_rw(auth: bool, write: bool) -> bool { - static AUTH: RwLock = RwLock::new(true); - if write { - *AUTH.write().unwrap() = auth; +pub fn set_record_authority(auth: bool) { + AUTH_INFO.write().unwrap().state = auth; + if let Some(auth_notify) = &AUTH_INFO.read().unwrap().notify { + auth_notify(); } - *AUTH.read().unwrap() } -pub fn set_record_authority(auth: bool) { - record_authority_rw(auth, true); +pub fn set_authority_notify(notify: Option>) { + AUTH_INFO.write().unwrap().notify = notify; } -fn get_record_authority() -> bool { - record_authority_rw(false, false) +pub fn get_record_authority() -> bool { + AUTH_INFO.read().unwrap().state } impl ShmemStreamHeader { pub fn check(&self, last_end: u64) -> bool { - if (self.offset as u64) < last_end { + if u64::from(self.offset) < last_end { warn!( "Guest set bad offset {} exceeds last stream buffer end {}", self.offset, last_end @@ -115,11 +159,12 @@ impl ShmemStreamHeader { if self.chunk_idx > self.max_chunks { error!( - "The chunk index of stream {} exceeds the maximum number of chunks {}", - self.chunk_idx, self.max_chunks + "Invalid max_chunks: {} or chunk_idx: {}", + self.max_chunks, self.chunk_idx ); return false; } + if self.fmt.channels == 0 || self.fmt.channel_map == 0 { error!( "The fmt channels {} or channel_map {} is invalid", @@ -181,34 +226,118 @@ impl ShmemStreamFmt { } else { AUDIO_SAMPLE_RATE_48KHZ }; - sample_rate * (self.rate % WINDOWS_SAMPLE_BASE_RATE) as u32 + sample_rate * u32::from(self.rate % WINDOWS_SAMPLE_BASE_RATE) + } +} + +struct ScreamCond { + cond: Condvar, + paused: Mutex, +} + +impl ScreamCond { + const STREAM_PAUSE_BIT: u8 = 0x1; + const VM_PAUSE_BIT: u8 = 0x2; + + fn new() -> Arc { + Arc::new(Self { + cond: Condvar::default(), + paused: Mutex::new(Self::STREAM_PAUSE_BIT), + }) + } + + fn wait_if_paused(&self, interface: Arc>) { + let mut destroy_thread_handle = None; + let mut locked_pause = self.paused.lock().unwrap(); + while *locked_pause != 0 { + if destroy_thread_handle.is_none() { + let cloned_interface = interface.clone(); + destroy_thread_handle = Some( + thread::Builder::new() + .name("scream destroy".to_string()) + .spawn(move || { + cloned_interface.lock().unwrap().destroy(); + }) + .unwrap(), + ); + } + locked_pause = self.cond.wait(locked_pause).unwrap(); + } + drop(locked_pause); + if let Some(handle) = destroy_thread_handle { + if let Err(e) = handle.join() { + error!("failed to join destroy thread, {:?}", e); + } + } + } + + fn set_value(&self, bv: u8, set: bool) { + let mut locked_pause = self.paused.lock().unwrap(); + let old_val = *locked_pause; + match set { + true => *locked_pause = old_val | bv, + false => *locked_pause = old_val & !bv, + } + if *locked_pause == 0 { + self.cond.notify_all(); + } + } + + fn set_vm_pause(&self, paused: bool) { + self.set_value(Self::VM_PAUSE_BIT, paused); + } + + fn set_stream_pause(&self, paused: bool) { + self.set_value(Self::STREAM_PAUSE_BIT, paused); + } + + fn stream_paused(&self) -> bool { + *self.paused.lock().unwrap() != 0 } } /// Audio stream data structure. -#[derive(Default)] +#[derive(Debug, Default)] pub struct StreamData { pub fmt: ShmemStreamFmt, + max_chunks: u16, chunk_idx: u16, + /// Start address of header implies. + start_addr: u64, + /// Length of total data which header implies. + data_shm_len: u64, /// Size of the data to be played or recorded. pub audio_size: u32, /// Location of the played or recorded audio data in the shared memory. pub audio_base: u64, + /// VM pause notifier id. + pause_notifier_id: u64, } impl StreamData { - fn init(&mut self, header: &ShmemStreamHeader) { + fn init(&mut self, header: &ShmemStreamHeader, hva: u64) { fence(Ordering::Acquire); self.fmt = header.fmt; self.chunk_idx = header.chunk_idx; + self.max_chunks = header.max_chunks; + self.data_shm_len = u64::from(header.chunk_size) * u64::from(self.max_chunks); + self.start_addr = hva + u64::from(header.offset); + self.audio_size = header.chunk_size; + } + + fn register_pause_notifier(&mut self, cond: Arc) { + let pause_notify = Arc::new(move |paused: bool| { + cond.set_vm_pause(paused); + }); + self.pause_notifier_id = register_vm_pause_notifier(pause_notify); } fn wait_for_ready( &mut self, interface: Arc>, dir: ScreamDirection, - poll_delay_us: u64, hva: u64, + cond: Arc, ) { // SAFETY: hva is the shared memory base address. It already verifies the validity // of the address range during the scream realize. @@ -218,43 +347,38 @@ impl StreamData { ScreamDirection::Playback => &header.play, ScreamDirection::Record => &header.capt, }; - trace::scream_init(&dir, &stream_header); loop { - if header.magic != SCREAM_MAGIC || stream_header.is_started == 0 { + let mut locked_paused = cond.paused.lock().unwrap(); + while *locked_paused != 0 { interface.lock().unwrap().destroy(); - while header.magic != SCREAM_MAGIC || stream_header.is_started == 0 { - thread::sleep(time::Duration::from_millis(10)); - header = - // SAFETY: hva is allocated by libc:::mmap, it can be guaranteed to be legal. - &unsafe { std::slice::from_raw_parts(hva as *const ShmemHeader, 1) }[0]; - } - self.init(stream_header); + locked_paused = cond.cond.wait(locked_paused).unwrap(); } - // Audio playback requires waiting for the guest to play audio data. - if dir == ScreamDirection::Playback && self.chunk_idx == stream_header.chunk_idx { - thread::sleep(time::Duration::from_micros(poll_delay_us)); + if header.magic != SCREAM_MAGIC || stream_header.is_started == 0 { + *locked_paused |= ScreamCond::STREAM_PAUSE_BIT; continue; } - let mut last_end = 0; + header = + // SAFETY: hva is allocated by libc:::mmap, it can be guaranteed to be legal. + &unsafe { std::slice::from_raw_parts(hva as *const ShmemHeader, 1) }[0]; + self.init(stream_header, hva); + + let mut last_end = 0_u64; // The recording buffer is behind the playback buffer. Thereforce, the end position of // the playback buffer must be calculted to determine whether the two buffers overlap. if dir == ScreamDirection::Record && header.play.is_started != 0 { - last_end = header.play.offset as u64 - + header.play.chunk_size as u64 * header.play.max_chunks as u64; + last_end = u64::from(header.play.offset) + + u64::from(header.play.chunk_size) * u64::from(header.play.max_chunks); } if !stream_header.check(last_end) { + *locked_paused |= ScreamCond::STREAM_PAUSE_BIT; continue; } - // Guest reformats the audio, and the scream device also needs to be init. - if self.fmt != stream_header.fmt { - self.init(stream_header); - continue; - } + trace::scream_init(&dir, &stream_header); return; } @@ -266,18 +390,14 @@ impl StreamData { shmem_size: u64, stream_header: &ShmemStreamHeader, ) -> bool { - self.audio_size = stream_header.chunk_size; - self.audio_base = hva - + stream_header.offset as u64 - + (stream_header.chunk_size as u64) * (self.chunk_idx as u64); - - if (self.audio_base + self.audio_size as u64) > (hva + shmem_size) { + self.audio_base = self + .start_addr + .saturating_add(u64::from(self.audio_size) * u64::from(self.chunk_idx)); + let buf_end = hva + shmem_size; + if self.audio_base.saturating_add(u64::from(self.audio_size)) > buf_end { error!( "Scream: wrong header: offset {} chunk_idx {} chunk_size {} max_chunks {}", - stream_header.offset, - stream_header.chunk_idx, - stream_header.chunk_size, - stream_header.max_chunks, + stream_header.offset, stream_header.chunk_idx, self.audio_size, self.max_chunks, ); return false; } @@ -289,6 +409,7 @@ impl StreamData { hva: u64, shmem_size: u64, interface: Arc>, + cond: Arc, ) { // SAFETY: hva is the shared memory base address. It already verifies the validity // of the address range during the header check. @@ -296,6 +417,8 @@ impl StreamData { let play = &header.play; loop { + cond.wait_if_paused(interface.clone()); + if play.fmt.fmt_generation != self.fmt.fmt_generation { break; } @@ -308,15 +431,15 @@ impl StreamData { // slow and the backward data is skipped. if play .chunk_idx - .wrapping_add(play.max_chunks) + .wrapping_add(self.max_chunks) .wrapping_sub(self.chunk_idx) - % play.max_chunks + % self.max_chunks > 4 { self.chunk_idx = - play.chunk_idx.wrapping_add(play.max_chunks).wrapping_sub(1) % play.max_chunks; + play.chunk_idx.wrapping_add(self.max_chunks).wrapping_sub(1) % self.max_chunks; } else { - self.chunk_idx = (self.chunk_idx + 1) % play.max_chunks; + self.chunk_idx = (self.chunk_idx + 1) % self.max_chunks; } if !self.update_buffer_by_chunk_idx(hva, shmem_size, play) { @@ -331,48 +454,53 @@ impl StreamData { hva: u64, shmem_size: u64, interface: Arc>, + cond: Arc, ) { // SAFETY: hva is the shared memory base address. It already verifies the validity // of the address range during the header check. let header = &mut unsafe { std::slice::from_raw_parts_mut(hva as *mut ShmemHeader, 1) }[0]; let capt = &mut header.capt; - let addr = hva + capt.offset as u64; - let mut locked_interface = interface.lock().unwrap(); - locked_interface.pre_receive(addr, capt); while capt.is_started != 0 { + cond.wait_if_paused(interface.clone()); + + if capt.fmt.fmt_generation != self.fmt.fmt_generation { + return; + } + if !self.update_buffer_by_chunk_idx(hva, shmem_size, capt) { return; } let recv_chunks_cnt: i32 = if get_record_authority() { - locked_interface.receive(self) + interface.lock().unwrap().receive(self) } else { - locked_interface.destroy(); + interface.lock().unwrap().destroy(); 0 }; - if recv_chunks_cnt > 0 { - self.chunk_idx = (self.chunk_idx + recv_chunks_cnt as u16) % capt.max_chunks; - - // Make sure chunk_idx write does not bypass audio chunk write. - fence(Ordering::SeqCst); - capt.chunk_idx = self.chunk_idx; + match recv_chunks_cnt.cmp(&0) { + std::cmp::Ordering::Less => thread::sleep(time::Duration::from_millis(100)), + std::cmp::Ordering::Greater => { + self.chunk_idx = match (self.chunk_idx + recv_chunks_cnt as u16) + .checked_rem(capt.max_chunks) + { + Some(idx) => idx, + None => { + warn!("Scream: capture header might be cleared by driver"); + return; + } + }; + // Make sure chunk_idx write does not bypass audio chunk write. + fence(Ordering::SeqCst); + capt.chunk_idx = self.chunk_idx; + } + std::cmp::Ordering::Equal => continue, } } } } -#[cfg(all(target_env = "ohos", feature = "scream_ohaudio"))] -fn bound_tokenid(token_id: u64) -> Result<()> { - if token_id == 0 { - bail!("UI token ID not passed."); - } else if token_id != get_firstcaller_tokenid()? { - set_firstcaller_tokenid(token_id)?; - } - Ok(()) -} - #[derive(Clone, Debug)] enum ScreamInterface { #[cfg(feature = "scream_alsa")] @@ -402,8 +530,10 @@ impl FromStr for ScreamInterface { } #[derive(Parser, Debug, Clone)] -#[command(name = "ivshmem_scream")] +#[command(no_binary_name(true))] pub struct ScreamConfig { + #[arg(long)] + pub classtype: String, #[arg(long, value_parser = valid_id)] id: String, #[arg(long)] @@ -428,17 +558,31 @@ pub struct Scream { size: u64, config: ScreamConfig, token_id: Option>>, + interface_resource: Vec>>, } impl Scream { - pub fn new(size: u64, config: ScreamConfig, token_id: Option>>) -> Self { + pub fn new( + size: u64, + config: ScreamConfig, + token_id: Option>>, + ) -> Result { set_record_authority(config.record_auth); - Self { + let header_size = mem::size_of::() as u64; + if size < header_size { + bail!( + "The size {} of the shared memory is smaller than audio header {}", + size, + header_size + ); + } + Ok(Self { hva: 0, size, config, token_id, - } + interface_resource: Vec::new(), + }) } #[allow(unused_variables)] @@ -458,48 +602,59 @@ impl Scream { } } - fn start_play_thread_fn(&self) -> Result<()> { + fn start_play_thread_fn(&mut self, cond: Arc) -> Result<()> { let hva = self.hva; let shmem_size = self.size; let interface = self.interface_init("ScreamPlay", ScreamDirection::Playback); + self.interface_resource.push(interface.clone()); + self.register_state_query("scream-play".to_string(), cond.clone()); thread::Builder::new() .name("scream audio play worker".to_string()) .spawn(move || { let clone_interface = interface.clone(); let mut play_data = StreamData::default(); + play_data.register_pause_notifier(cond.clone()); loop { play_data.wait_for_ready( clone_interface.clone(), ScreamDirection::Playback, - POLL_DELAY_US, hva, + cond.clone(), ); - play_data.playback_trans(hva, shmem_size, clone_interface.clone()); + play_data.playback_trans( + hva, + shmem_size, + clone_interface.clone(), + cond.clone(), + ); } }) .with_context(|| "Failed to create thread scream")?; Ok(()) } - fn start_record_thread_fn(&self) -> Result<()> { + fn start_record_thread_fn(&mut self, cond: Arc) -> Result<()> { let hva = self.hva; let shmem_size = self.size; let interface = self.interface_init("ScreamCapt", ScreamDirection::Record); let _ti = self.token_id.clone(); + self.interface_resource.push(interface.clone()); + self.register_state_query("scream-record".to_string(), cond.clone()); thread::Builder::new() .name("scream audio capt worker".to_string()) .spawn(move || { let clone_interface = interface.clone(); let mut capt_data = StreamData::default(); + capt_data.register_pause_notifier(cond.clone()); loop { capt_data.wait_for_ready( clone_interface.clone(), ScreamDirection::Record, - POLL_DELAY_US, hva, + cond.clone(), ); #[cfg(all(target_env = "ohos", feature = "scream_ohaudio"))] @@ -507,23 +662,24 @@ impl Scream { bound_tokenid(*token_id.read().unwrap()) .unwrap_or_else(|e| error!("bound token ID failed: {}", e)); } - capt_data.capture_trans(hva, shmem_size, clone_interface.clone()); + capt_data.capture_trans(hva, shmem_size, clone_interface.clone(), cond.clone()); } }) .with_context(|| "Failed to create thread scream")?; Ok(()) } - pub fn realize(mut self, devfn: u8, parent_bus: Weak>) -> Result<()> { - let header_size = mem::size_of::() as u64; - if self.size < header_size { - bail!( - "The size {} of the shared memory is smaller then audio header {}", - self.size, - header_size - ); - } + fn register_state_query(&self, module: String, cond: Arc) { + register_state_query_callback( + module, + Arc::new(move || match cond.stream_paused() { + false => "On".to_string(), + true => "Off".to_string(), + }), + ); + } + pub fn realize(&mut self, parent_bus: Weak>) -> Result<()> { let host_mmap = Arc::new(HostMemMapping::new( GuestAddress(0), None, @@ -535,21 +691,124 @@ impl Scream { )?); self.hva = host_mmap.host_address(); + let devfn = (self.config.addr.0 << 3) + self.config.addr.1; let mem_region = Region::init_ram_region(host_mmap, "ivshmem_ram"); + let ivshmem = Ivshmem::new( + "ivshmem".to_string(), + devfn, + parent_bus, + mem_region, + IVSHMEM_VECTORS_NR, + ); + let ivshmem = ivshmem.realize()?; + let ivshmem_cloned = ivshmem.clone(); + + let play_cond = ScreamCond::new(); + let capt_cond = ScreamCond::new(); + self.set_ivshmem_ops(ivshmem, play_cond.clone(), capt_cond.clone()); + + let author_notify = Arc::new(move || { + ivshmem_cloned + .lock() + .unwrap() + .trigger_msix(IVSHMEM_STATUS_CHANGE_VECTOR); + }); + set_authority_notify(Some(author_notify)); + + self.start_play_thread_fn(play_cond)?; + self.start_record_thread_fn(capt_cond) + } - let ivshmem = Ivshmem::new("ivshmem".to_string(), devfn, parent_bus, mem_region); - ivshmem.realize()?; + fn set_ivshmem_ops( + &mut self, + ivshmem: Arc>, + play_cond: Arc, + capt_cond: Arc, + ) { + let cloned_play_cond = play_cond.clone(); + let cloned_capt_cond = capt_cond.clone(); + let cb = Box::new(move || { + info!("Scream: device is reset."); + cloned_play_cond.set_stream_pause(true); + cloned_capt_cond.set_stream_pause(true); + }); + ivshmem.lock().unwrap().register_reset_callback(cb); + + let interface = self.create_audio_extension(ivshmem.clone()); + let interface2 = interface.clone(); + let bar0_write = Arc::new(move |data: &[u8], offset: u64| { + match offset { + IVSHMEM_BAR0_VOLUME => { + interface.set_host_volume(le_read_u32(data, 0).unwrap()); + } + IVSHMEM_BAR0_STATUS => { + let val = le_read_u32(data, 0).unwrap(); + if val & STATUS_PLAY_BIT == STATUS_PLAY_BIT { + play_cond.set_stream_pause(val & STATUS_START_BIT != STATUS_START_BIT); + } else { + capt_cond.set_stream_pause(val & STATUS_START_BIT != STATUS_START_BIT); + } + } + _ => { + info!("ivshmem-scream: unsupported write: {offset}"); + } + } + true + }); + let bar0_read = Arc::new(move |data: &mut [u8], offset: u64| { + match offset { + IVSHMEM_BAR0_VOLUME => { + let _ = le_write_u32(data, 0, interface2.get_host_volume()); + } + IVSHMEM_BAR0_STATUS => { + let _ = le_write_u32(data, 0, interface2.get_status_register()); + } + _ => { + info!("ivshmem-scream: unsupported read: {offset}"); + } + } + true + }); + ivshmem + .lock() + .unwrap() + .set_bar0_ops((bar0_write, bar0_read)); + } - self.start_play_thread_fn()?; - self.start_record_thread_fn() + fn create_audio_extension(&self, _ivshmem: Arc>) -> Arc { + match self.config.interface { + #[cfg(all(target_env = "ohos", feature = "scream_ohaudio"))] + ScreamInterface::OhAudio => OhAudioVolume::new(_ivshmem), + ScreamInterface::Demo => DemoAudioVolume::new(_ivshmem), + #[allow(unreachable_patterns)] + _ => Arc::new(AudioExtensionDummy {}), + } } } pub trait AudioInterface: Send { fn send(&mut self, recv_data: &StreamData); - // For OHOS's audio task. It confirms shmem info. - #[allow(unused_variables)] - fn pre_receive(&mut self, start_addr: u64, sh_header: &ShmemStreamHeader) {} fn receive(&mut self, recv_data: &StreamData) -> i32; fn destroy(&mut self); + fn get_status(&self) -> AudioStatus; } + +pub trait AudioExtension: Send + Sync { + fn set_host_volume(&self, _vol: u32) {} + fn get_host_volume(&self) -> u32 { + 0 + } + fn get_status_register(&self) -> u32 { + match get_record_authority() { + true => STATUS_MIC_AVAIL_BIT, + false => 0, + } + } +} + +struct AudioExtensionDummy; +impl AudioExtension for AudioExtensionDummy {} +// SAFETY: it is a dummy +unsafe impl Send for AudioExtensionDummy {} +// SAFETY: it is a dummy +unsafe impl Sync for AudioExtensionDummy {} diff --git a/devices/src/misc/scream/ohaudio.rs b/devices/src/misc/scream/ohaudio.rs index bd7d89c293ec02bca43743374ebefd54fdc77bbd..315c6c337977cd39f4ab6d883962a40ebd3badc0 100755 --- a/devices/src/misc/scream/ohaudio.rs +++ b/devices/src/misc/scream/ohaudio.rs @@ -10,52 +10,174 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::collections::VecDeque; use std::os::raw::c_void; use std::sync::{ - atomic::{fence, AtomicBool, AtomicI32, Ordering}, - Arc, Mutex, + atomic::{fence, AtomicBool, Ordering}, + Arc, Condvar, Mutex, RwLock, }; -use std::{cmp, ptr, thread, time}; +use std::{cmp, io::Read, ptr, thread, time::Duration}; -use log::error; +use log::{error, info, warn}; -use crate::misc::scream::{AudioInterface, ScreamDirection, ShmemStreamHeader, StreamData}; +use crate::misc::ivshmem::Ivshmem; +use crate::misc::scream::{ + AudioExtension, AudioInterface, AudioStatus, ScreamDirection, StreamData, + IVSHMEM_VOLUME_SYNC_VECTOR, TARGET_LATENCY_MS, +}; +use machine_manager::event_loop::EventLoop; use util::ohos_binding::audio::*; +const STREAM_DATA_VEC_CAPACITY: usize = 15; +const FLUSH_DELAY_MS: u64 = 5; +const FLUSH_DELAY_CNT: u64 = 200; +const SCREAM_MAX_VOLUME: u32 = 110; +const CAPTURE_WAIT_TIMEOUT: u64 = 500; +const RENDER_WAIT_TIMEOUT: u64 = TARGET_LATENCY_MS as u64 * 2; +const MS_PER_SECOND: u64 = 1000; + trait OhAudioProcess { fn init(&mut self, stream: &StreamData) -> bool; fn destroy(&mut self); - fn preprocess(&mut self, _start_addr: u64, _sh_header: &ShmemStreamHeader) {} fn process(&mut self, recv_data: &StreamData) -> i32; + fn get_status(&self) -> AudioStatus; } #[derive(Debug, Clone, Copy)] struct StreamUnit { - pub addr: u64, - pub len: u64, + addr: usize, + len: usize, } -const STREAM_DATA_VEC_CAPACITY: usize = 30; -const FLUSH_DELAY_THRESHOLD_MS: u64 = 100; -const FLUSH_DELAY_MS: u64 = 5; -const FLUSH_DELAY_CNT: u64 = 200; +impl Read for StreamUnit { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let len = cmp::min(self.len, buf.len()); + // SAFETY: all the source data are in scream BAR. + unsafe { ptr::copy_nonoverlapping(self.addr as *const u8, buf.as_mut_ptr(), len) }; + self.len -= len; + self.addr += len; + Ok(len) + } +} + +impl StreamUnit { + #[inline] + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn new(addr: usize, len: usize) -> Self { + Self { addr, len } + } + + #[inline] + fn len(&self) -> usize { + self.len + } +} + +struct StreamQueue { + queue: VecDeque, + data_size: usize, +} + +impl Read for StreamQueue { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let len = buf.len(); + let mut ret = 0_usize; + while ret < len { + if self.queue.is_empty() { + break; + } + let unit = match self.queue.front_mut() { + Some(u) => u, + None => break, + }; + let rlen = unit.read(&mut buf[ret..len]).unwrap(); + ret += rlen; + self.data_size -= rlen; + if unit.is_empty() { + self.pop_front(); + } + } + Ok(ret) + } + + // If there's no enough data, let's fill the whole buffer with 0. + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + let len = buf.len(); + match self.read(buf) { + Ok(ret) => { + if ret < len { + self.read_zero(&mut buf[ret..len]); + } + Ok(()) + } + Err(e) => Err(e), + } + } +} + +impl StreamQueue { + fn new(capacity: usize) -> Self { + Self { + queue: VecDeque::with_capacity(capacity), + data_size: 0, + } + } + + fn clear(&mut self) { + self.queue.clear(); + } + + #[inline] + fn data_size(&self) -> usize { + self.data_size + } + + fn pop_front(&mut self) { + if let Some(elem) = self.queue.pop_front() { + self.data_size -= elem.len(); + } + } + + fn push_back(&mut self, unit: StreamUnit) { + // When audio data is not consumed in time, this buffer + // might be full. So let's keep the max size by dropping + // the old data. This can guarantee sound playing can't + // be delayed too much and the buffer won't become too + // large. + if self.queue.len() == self.queue.capacity() { + self.pop_front(); + } + self.data_size += unit.len; + self.queue.push_back(unit); + } + + fn read_zero(&mut self, buf: &mut [u8]) { + // SAFETY: the buffer is guaranteed by the caller. + unsafe { + ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()); + } + } +} struct OhAudioRender { ctx: Option, - stream_data: Arc>>, - data_size: AtomicI32, - start: bool, + stream_data: Arc>, flushing: AtomicBool, + status: AudioStatus, + cond: Condvar, } impl Default for OhAudioRender { fn default() -> OhAudioRender { OhAudioRender { ctx: None, - stream_data: Arc::new(Mutex::new(Vec::with_capacity(STREAM_DATA_VEC_CAPACITY))), - data_size: AtomicI32::new(0), - start: false, + stream_data: Arc::new(Mutex::new(StreamQueue::new(STREAM_DATA_VEC_CAPACITY))), flushing: AtomicBool::new(false), + status: AudioStatus::default(), + cond: Condvar::new(), } } } @@ -74,17 +196,30 @@ impl OhAudioRender { } fn flush(&mut self) { - self.flushing.store(true, Ordering::Release); - let mut cnt = 0; - while (cnt < FLUSH_DELAY_CNT) && (self.flushing.load(Ordering::Acquire)) { - thread::sleep(time::Duration::from_millis(FLUSH_DELAY_MS)); + self.set_flushing(true); + let mut cnt = 0_u64; + while cnt < FLUSH_DELAY_CNT { + thread::sleep(Duration::from_millis(FLUSH_DELAY_MS)); cnt += 1; + if self.stream_data.lock().unwrap().data_size() == 0 { + break; + } } - // We need to wait for 100ms to ensure the audio data has - // been flushed before stop renderer. - thread::sleep(time::Duration::from_millis(FLUSH_DELAY_THRESHOLD_MS)); + } + + fn flush_renderer(&self) { let _ = self.ctx.as_ref().unwrap().flush_renderer(); } + + #[inline(always)] + fn is_flushing(&self) -> bool { + self.flushing.load(Ordering::Acquire) + } + + #[inline(always)] + fn set_flushing(&mut self, flush: bool) { + self.flushing.store(flush, Ordering::Release); + } } impl OhAudioProcess for OhAudioRender { @@ -95,7 +230,7 @@ impl OhAudioProcess for OhAudioRender { stream.fmt.size, stream.fmt.get_rate(), stream.fmt.channels, - AudioProcessCb::RendererCb(Some(on_write_data_cb)), + AudioProcessCb::RendererCb(Some(on_write_data_cb), Some(render_on_interrupt_cb)), ptr::addr_of!(*self) as *mut c_void, ) { Ok(()) => self.ctx = Some(context), @@ -107,27 +242,33 @@ impl OhAudioProcess for OhAudioRender { } match self.ctx.as_ref().unwrap().start() { Ok(()) => { - self.start = true; + info!("Renderer start"); + self.status = AudioStatus::Started; + trace::oh_scream_render_init(&self.ctx); } Err(e) => { error!("failed to start oh audio renderer: {}", e); } } - self.start + self.status == AudioStatus::Started } fn destroy(&mut self) { - if self.ctx.is_some() { - if self.start { - self.flush(); - self.ctx.as_mut().unwrap().stop(); - self.start = false; + info!("Renderer destroy"); + match self.status { + AudioStatus::Error | AudioStatus::Intr => { + self.ctx = None; + self.status = AudioStatus::Ready; + return; } - self.ctx = None; + AudioStatus::Started => self.flush(), + _ => {} } - let mut locked_data = self.stream_data.lock().unwrap(); - locked_data.clear(); - self.data_size.store(0, Ordering::Relaxed); + self.ctx = None; + self.stream_data.lock().unwrap().clear(); + self.set_flushing(false); + self.status = AudioStatus::Ready; + trace::oh_scream_render_destroy(); } fn process(&mut self, recv_data: &StreamData) -> i32 { @@ -135,33 +276,93 @@ impl OhAudioProcess for OhAudioRender { fence(Ordering::Acquire); - let su = StreamUnit { - addr: recv_data.audio_base, - len: recv_data.audio_size as u64, - }; - let mut locked_data = self.stream_data.lock().unwrap(); - locked_data.push(su); - self.data_size - .fetch_add(recv_data.audio_size as i32, Ordering::Relaxed); - drop(locked_data); + trace::trace_scope_start!(ohaudio_render_process, args = (recv_data)); + + self.stream_data.lock().unwrap().push_back(StreamUnit::new( + recv_data.audio_base as usize, + recv_data.audio_size as usize, + )); + self.cond.notify_all(); - if !self.start && !self.init(recv_data) { + if self.status == AudioStatus::Error || self.status == AudioStatus::Intr { + error!( + "Audio server {:?} occurred. Destroy and reconnect it.", + self.status + ); + self.destroy(); + } + + if self.status == AudioStatus::Ready && !self.init(recv_data) { error!("failed to init oh audio"); self.destroy(); } 0 } + + fn get_status(&self) -> AudioStatus { + self.status + } +} + +struct CaptureStream { + cond: Condvar, + data: Mutex>, + expected: usize, +} + +impl Default for CaptureStream { + fn default() -> Self { + Self { + cond: Condvar::new(), + data: Mutex::new(Vec::with_capacity(1 << 20)), + expected: 0, + } + } +} + +impl CaptureStream { + fn wait_for_data(&mut self, buf: &mut [u8]) -> bool { + let mut locked_data = self.data.lock().unwrap(); + self.expected = buf.len(); + while locked_data.len() < self.expected { + let ret = self + .cond + .wait_timeout(locked_data, Duration::from_millis(CAPTURE_WAIT_TIMEOUT)) + .unwrap(); + if ret.1.timed_out() { + return false; + } + locked_data = ret.0; + } + buf.copy_from_slice(&locked_data[..self.expected]); + *locked_data = locked_data[self.expected..].to_vec(); + self.expected = 0; + true + } + + fn append_data(&mut self, buf: &[u8]) { + let mut locked_data = self.data.lock().unwrap(); + locked_data.extend_from_slice(buf); + if locked_data.len() > self.expected { + self.cond.notify_all(); + } + } + + fn reset(&mut self) { + let mut locked_data = self.data.lock().unwrap(); + locked_data.clear(); + self.expected = 0; + self.cond.notify_all(); + } } #[derive(Default)] struct OhAudioCapture { ctx: Option, - align: u32, - new_chunks: AtomicI32, - shm_addr: u64, - shm_len: u64, - cur_pos: u64, - start: bool, + status: AudioStatus, + stream: CaptureStream, + timer_start: bool, + data_size_per_second: u64, } impl OhAudioCapture { @@ -185,7 +386,7 @@ impl OhAudioProcess for OhAudioCapture { stream.fmt.size, stream.fmt.get_rate(), stream.fmt.channels, - AudioProcessCb::CapturerCb(Some(on_read_data_cb)), + AudioProcessCb::CapturerCb(Some(on_read_data_cb), Some(capture_on_interrupt_cb)), ptr::addr_of!(*self) as *mut c_void, ) { Ok(()) => self.ctx = Some(context), @@ -194,9 +395,14 @@ impl OhAudioProcess for OhAudioCapture { return false; } } + self.data_size_per_second = (stream.fmt.size as u64 >> 3) + * stream.fmt.get_rate() as u64 + * stream.fmt.channels as u64; match self.ctx.as_ref().unwrap().start() { Ok(()) => { - self.start = true; + info!("Capturer start"); + self.status = AudioStatus::Started; + trace::oh_scream_capture_init(&self.ctx); true } Err(e) => { @@ -207,90 +413,171 @@ impl OhAudioProcess for OhAudioCapture { } fn destroy(&mut self) { - if self.ctx.is_some() { - if self.start { - self.ctx.as_mut().unwrap().stop(); - self.start = false; - } - self.ctx = None; - } - } - - fn preprocess(&mut self, start_addr: u64, sh_header: &ShmemStreamHeader) { - self.align = sh_header.chunk_size; - self.new_chunks.store(0, Ordering::Release); - self.shm_addr = start_addr; - self.shm_len = sh_header.max_chunks as u64 * sh_header.chunk_size as u64; - self.cur_pos = start_addr + sh_header.chunk_idx as u64 * sh_header.chunk_size as u64; + info!("Capturer destroy"); + self.status = AudioStatus::Ready; + self.ctx = None; + self.stream.reset(); + trace::oh_scream_capture_destroy(); } fn process(&mut self, recv_data: &StreamData) -> i32 { self.check_fmt_update(recv_data); - if !self.start && !self.init(recv_data) { + + trace::trace_scope_start!(ohaudio_capturer_process, args = (recv_data)); + + // We expect capture stream can be resumed when another stream stops interrupting it, + // but OHOS does not resume it. We resume it manually. + if self.status == AudioStatus::Error || self.status == AudioStatus::IntrResume { self.destroy(); - return 0; } - self.new_chunks.store(0, Ordering::Release); - while self.new_chunks.load(Ordering::Acquire) == 0 { - thread::sleep(time::Duration::from_millis(10)); + + if self.status == AudioStatus::Ready && !self.init(recv_data) { + self.destroy(); + return -1; } + // SAFETY: the buffer is from ivshmem and the caller ensures its validation. + let buf = unsafe { + std::slice::from_raw_parts_mut( + recv_data.audio_base as *mut u8, + recv_data.audio_size as usize, + ) + }; + if self.status == AudioStatus::Intr { + // When capture stream is interrupted, we need to send mute data to front end. Without this, + // some applications may pop-up error window for no data received. + if !self.timer_start { + let period: u64 = MS_PER_SECOND * buf.len() as u64 / self.data_size_per_second; + mute_capture_data_gen(ptr::addr_of_mut!(*self), period, buf.len()); + return 0; + } + } + if !self.stream.wait_for_data(buf) && self.status != AudioStatus::Intr { + warn!("timed out to wait for capture audio data"); + self.status = AudioStatus::Error; + return 0; + } + 1 + } - self.new_chunks.load(Ordering::Acquire) + fn get_status(&self) -> AudioStatus { + self.status } } -extern "C" fn on_write_data_cb( +fn mute_capture_data_gen(capture: *mut OhAudioCapture, period: u64, len: usize) { + let buffer: Vec = vec![0; len]; + let buf = buffer.as_slice(); + + //SAFETY: we make sure that capture we passed is valid. + let capt = unsafe { capture.as_mut().unwrap_unchecked() }; + capt.stream.append_data(buf); + + let mute_capture_cb = Box::new(move || { + mute_capture_data_gen(capture, period, len); + }); + + if capt.status == AudioStatus::Intr { + EventLoop::get_ctx(None) + .unwrap() + .timer_add(mute_capture_cb, Duration::from_millis(period)); + capt.timer_start = true; + } else { + capt.timer_start = false; + } +} + +extern "C" fn render_on_interrupt_cb( _renderer: *mut OhAudioRenderer, user_data: *mut ::std::os::raw::c_void, - buffer: *mut ::std::os::raw::c_void, - length: i32, + source_type: capi::OHAudioInterruptSourceType, + hint: capi::OHAudioInterruptHint, ) -> i32 { + info!( + "Render interrupts, type is {}, hint is {}", + source_type, hint + ); // SAFETY: we make sure that it is OhAudioRender when register callback. let render = unsafe { (user_data as *mut OhAudioRender) .as_mut() .unwrap_unchecked() }; - - let data_size = render.data_size.load(Ordering::Relaxed); - if !render.flushing.load(Ordering::Acquire) && data_size < length { - // SAFETY: we checked len. - unsafe { ptr::write_bytes(buffer as *mut u8, 0, length as usize) }; - return 0; + if hint == capi::AUDIOSTREAM_INTERRUPT_HINT_PAUSE { + render.status = AudioStatus::Intr; } + 0 +} - // Copy stream data from shared memory to buffer. - let mut dst_addr = buffer as u64; - let mut left = length as u64; - let mut su_list = render.stream_data.lock().unwrap(); - while left > 0 && su_list.len() > 0 { - let su = &mut su_list[0]; - let len = cmp::min(left, su.len); - - // SAFETY: we checked len. - unsafe { - ptr::copy_nonoverlapping(su.addr as *const u8, dst_addr as *mut u8, len as usize) - }; +extern "C" fn capture_on_interrupt_cb( + _capturer: *mut OhAudioCapturer, + user_data: *mut ::std::os::raw::c_void, + source_type: capi::OHAudioInterruptSourceType, + hint: capi::OHAudioInterruptHint, +) -> i32 { + info!( + "Capture interrupts, type is {}, hint is {}", + source_type, hint + ); - dst_addr += len; - left -= len; - su.len -= len; - if su.len == 0 { - su_list.remove(0); - } else { - su.addr += len; - } + // SAFETY: we make sure that it is OhAudioCapture when register callback. + let capture = unsafe { + (user_data as *mut OhAudioCapture) + .as_mut() + .unwrap_unchecked() + }; + if hint == capi::AUDIOSTREAM_INTERRUPT_HINT_PAUSE { + capture.status = AudioStatus::Intr; + } else if hint == capi::AUDIOSTREAM_INTERRUPT_HINT_RESUME { + capture.status = AudioStatus::IntrResume; } - render - .data_size - .fetch_sub(length - left as i32, Ordering::Relaxed); + 0 +} - if left > 0 { - // SAFETY: we checked len. - unsafe { ptr::write_bytes(dst_addr as *mut u8, 0, left as usize) }; +extern "C" fn on_write_data_cb( + _renderer: *mut OhAudioRenderer, + user_data: *mut ::std::os::raw::c_void, + buffer: *mut ::std::os::raw::c_void, + length: i32, +) -> i32 { + if buffer.is_null() || user_data.is_null() { + error!("on_write_data_cb: Invalid input"); + return 0; } - if render.flushing.load(Ordering::Acquire) && su_list.is_empty() { - render.flushing.store(false, Ordering::Release); + + // SAFETY: we make sure that it is OhAudioRender when register callback. + let render = unsafe { + (user_data as *mut OhAudioRender) + .as_mut() + .unwrap_unchecked() + }; + + let len = length as usize; + // SAFETY: the buffer is guaranteed by OH audio framework. + let wbuf = unsafe { std::slice::from_raw_parts_mut(buffer as *mut u8, len) }; + + trace::oh_scream_on_write_data_cb(len); + trace::trace_scope_start!(ohaudio_write_cb, args = (len)); + let is_empty = render.stream_data.lock().unwrap().data_size() == 0; + let mut locked_stream_data = match is_empty { + true => { + render + .cond + .wait_timeout( + render.stream_data.lock().unwrap(), + Duration::from_millis(RENDER_WAIT_TIMEOUT), + ) + .unwrap() + .0 + } + false => render.stream_data.lock().unwrap(), + }; + match locked_stream_data.read_exact(wbuf) { + Ok(()) => { + if render.is_flushing() { + render.flush_renderer(); + } + } + Err(e) => error!("Failed to read stream data {:?}", e), } 0 } @@ -301,6 +588,11 @@ extern "C" fn on_read_data_cb( buffer: *mut ::std::os::raw::c_void, length: i32, ) -> i32 { + if buffer.is_null() || user_data.is_null() { + error!("on_read_data_cb: Invalid input"); + return 0; + } + // SAFETY: we make sure that it is OhAudioCapture when register callback. let capture = unsafe { (user_data as *mut OhAudioCapture) @@ -308,43 +600,15 @@ extern "C" fn on_read_data_cb( .unwrap_unchecked() }; - loop { - if !capture.start { - return 0; - } - if capture.new_chunks.load(Ordering::Acquire) == 0 { - break; - } - } - let old_pos = capture.cur_pos - ((capture.cur_pos - capture.shm_addr) % capture.align as u64); - let buf_end = capture.shm_addr + capture.shm_len; - let mut src_addr = buffer as u64; - let mut left = length as u64; - while left > 0 { - let len = cmp::min(left, buf_end - capture.cur_pos); - // SAFETY: we checked len. - unsafe { - ptr::copy_nonoverlapping( - src_addr as *const u8, - capture.cur_pos as *mut u8, - len as usize, - ) - }; - left -= len; - src_addr += len; - capture.cur_pos += len; - if capture.cur_pos == buf_end { - capture.cur_pos = capture.shm_addr; - } + trace::trace_scope_start!(ohaudio_read_cb, args = (length)); + + if capture.status != AudioStatus::Started { + return 0; } - let new_chunks = match capture.cur_pos <= old_pos { - true => (capture.shm_len - (old_pos - capture.cur_pos)) / capture.align as u64, - false => (capture.cur_pos - old_pos) / capture.align as u64, - }; - capture - .new_chunks - .store(new_chunks as i32, Ordering::Release); + // SAFETY: the buffer is checked above. + let buf = unsafe { std::slice::from_raw_parts(buffer as *mut u8, length as usize) }; + capture.stream.append_data(buf); 0 } @@ -373,10 +637,6 @@ impl AudioInterface for OhAudio { self.processor.process(recv_data); } - fn pre_receive(&mut self, start_addr: u64, sh_header: &ShmemStreamHeader) { - self.processor.preprocess(start_addr, sh_header); - } - fn receive(&mut self, recv_data: &StreamData) -> i32 { self.processor.process(recv_data) } @@ -384,6 +644,74 @@ impl AudioInterface for OhAudio { fn destroy(&mut self) { self.processor.destroy(); } + + fn get_status(&self) -> AudioStatus { + self.processor.get_status() + } +} + +pub struct OhAudioVolume { + shm_dev: Arc>, + ohos_vol: RwLock, + ohos_vol_max: u32, + ohos_vol_min: u32, +} + +// SAFETY: all unsafe fields are protected by lock +unsafe impl Send for OhAudioVolume {} +// SAFETY: all unsafe fields are protected by lock +unsafe impl Sync for OhAudioVolume {} + +impl GuestVolumeNotifier for OhAudioVolume { + fn notify(&self, vol: u32) { + *self.ohos_vol.write().unwrap() = self.to_guest_vol(vol); + self.shm_dev + .lock() + .unwrap() + .trigger_msix(IVSHMEM_VOLUME_SYNC_VECTOR); + } +} + +impl AudioExtension for OhAudioVolume { + fn get_host_volume(&self) -> u32 { + *self.ohos_vol.read().unwrap() + } + + fn set_host_volume(&self, vol: u32) { + set_ohos_volume(self.to_host_vol(vol)); + } +} + +impl OhAudioVolume { + pub fn new(shm_dev: Arc>) -> Arc { + let vol = Arc::new(Self { + shm_dev, + ohos_vol: RwLock::new(0), + ohos_vol_max: get_ohos_volume_max(), + ohos_vol_min: get_ohos_volume_min(), + }); + *vol.ohos_vol.write().unwrap() = vol.to_guest_vol(get_ohos_volume()); + register_guest_volume_notifier(vol.clone()); + vol + } + + fn to_guest_vol(&self, h_vol: u32) -> u32 { + if self.ohos_vol_max > self.ohos_vol_min { + return SCREAM_MAX_VOLUME * h_vol / (self.ohos_vol_max - self.ohos_vol_min); + } + 0 + } + + fn to_host_vol(&self, v_vol: u32) -> u32 { + if v_vol == 0 || self.ohos_vol_max <= self.ohos_vol_min { + return 0; + } + let res = (self.ohos_vol_max - self.ohos_vol_min) * v_vol / SCREAM_MAX_VOLUME + 1; + if res > self.ohos_vol_max { + return self.ohos_vol_max; + } + res + } } #[cfg(test)] diff --git a/devices/src/misc/scream/pulseaudio.rs b/devices/src/misc/scream/pulseaudio.rs index c42fabacd70c40ebeb2d0f212601507be41580ad..e0d1dc26553698553fe5c4954fd60bde3e6e955a 100644 --- a/devices/src/misc/scream/pulseaudio.rs +++ b/devices/src/misc/scream/pulseaudio.rs @@ -22,7 +22,7 @@ use pulse::{ time::MicroSeconds, }; -use super::{AudioInterface, AUDIO_SAMPLE_RATE_44KHZ}; +use super::{AudioInterface, AudioStatus, AUDIO_SAMPLE_RATE_44KHZ}; use crate::misc::scream::{ScreamDirection, ShmemStreamFmt, StreamData, TARGET_LATENCY_MS}; const MAX_LATENCY_MS: u32 = 100; @@ -84,8 +84,8 @@ impl PulseStreamData { // Set buffer size for requested latency. let buffer_attr = BufferAttr { - maxlength: ss.usec_to_bytes(MicroSeconds(MAX_LATENCY_MS as u64 * 1000)) as u32, - tlength: ss.usec_to_bytes(MicroSeconds(TARGET_LATENCY_MS as u64 * 1000)) as u32, + maxlength: ss.usec_to_bytes(MicroSeconds(u64::from(MAX_LATENCY_MS) * 1000)) as u32, + tlength: ss.usec_to_bytes(MicroSeconds(u64::from(TARGET_LATENCY_MS) * 1000)) as u32, prebuf: std::u32::MAX, minreq: std::u32::MAX, fragsize: std::u32::MAX, @@ -123,8 +123,15 @@ impl PulseStreamData { } } - fn transfer_channel_map(&mut self, format: &ShmemStreamFmt) { + fn transfer_channel_map(&mut self, format: &ShmemStreamFmt) -> bool { self.channel_map.init(); + let channels = format.channels; + if channels <= Map::CHANNELS_MAX { + self.channel_map.set_len(channels); + } else { + error!("invalid channels {}", channels); + return false; + } self.channel_map.set_len(format.channels); let map: &mut [Position] = self.channel_map.get_mut(); // In Windows, the channel mask shows as following figure. @@ -149,6 +156,7 @@ impl PulseStreamData { *item = Position::FrontCenter; } } + true } fn check_fmt_update(&mut self, recv_data: &StreamData) { @@ -181,8 +189,8 @@ impl PulseStreamData { self.channel_map.init_mono(); } else if recv_data.fmt.channels == 2 { self.channel_map.init_stereo(); - } else { - self.transfer_channel_map(&recv_data.fmt); + } else if !self.transfer_channel_map(&recv_data.fmt) { + return; } if !self.channel_map.is_valid() { @@ -198,9 +206,10 @@ impl PulseStreamData { if self.ss.rate > 0 { // Sample spec has changed, so the playback buffer size for the requested latency must // be recalculated as well. - self.buffer_attr.tlength = - self.ss - .usec_to_bytes(MicroSeconds(self.latency as u64 * 1000)) as u32; + self.buffer_attr.tlength = self + .ss + .usec_to_bytes(MicroSeconds(u64::from(self.latency) * 1000)) + as u32; self.simple = Simple::new( None, @@ -215,9 +224,9 @@ impl PulseStreamData { .map_or_else( |_| { warn!( - "Unable to open PulseAudio with sample rate {}, sample size {} and channels {}", - self.ss.rate, recv_data.fmt.size, recv_data.fmt.channels - ); + "Unable to open PulseAudio with sample rate {}, sample size {} and channels {}", + self.ss.rate, recv_data.fmt.size, recv_data.fmt.channels + ); None }, Some, @@ -289,6 +298,14 @@ impl AudioInterface for PulseStreamData { } self.simple = None; } + + fn get_status(&self) -> AudioStatus { + if self.simple.is_some() { + AudioStatus::Started + } else { + AudioStatus::Ready + } + } } #[cfg(test)] @@ -309,7 +326,8 @@ mod tests { // set 8: BC, 6: FLC, 4: BL, 2: FC, 0: FL test_data.fmt.channels = 5; test_data.fmt.channel_map = 0b1_0101_0101; - pulse.transfer_channel_map(&test_data.fmt); + let ret = pulse.transfer_channel_map(&test_data.fmt); + assert_eq!(ret, true); assert_eq!(pulse.channel_map.len(), 5); let map = pulse.channel_map.get_mut(); @@ -322,7 +340,8 @@ mod tests { // The first 12 bits are set to 1. test_data.fmt.channels = 12; test_data.fmt.channel_map = 0b1111_1111_1111; - pulse.transfer_channel_map(&test_data.fmt); + let ret = pulse.transfer_channel_map(&test_data.fmt); + assert_eq!(ret, true); assert_eq!(pulse.channel_map.len(), 12); let map = pulse.channel_map.get_mut(); diff --git a/devices/src/pci/bus.rs b/devices/src/pci/bus.rs index d493e181271db1ba0d42847743e79247accafc83..e4f730ab752dd2c083c04a0d2ebda8a3964f2d0f 100644 --- a/devices/src/pci/bus.rs +++ b/devices/src/pci/bus.rs @@ -10,11 +10,10 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::collections::HashMap; use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::{Arc, Mutex, Weak}; -use anyhow::{bail, Context, Result}; +use anyhow::{Context, Result}; use log::debug; use super::{ @@ -22,21 +21,19 @@ use super::{ hotplug::HotplugOps, PciDevOps, PciIntxState, }; -use crate::MsiIrqManager; +use crate::pci::{to_pcidevops, RootPort}; +use crate::{ + convert_bus_mut, convert_bus_ref, convert_device_mut, convert_device_ref, Bus, BusBase, Device, + MsiIrqManager, MUT_ROOT_PORT, PCI_BUS_DEVICE, ROOT_PORT, +}; use address_space::Region; +use util::gen_base_func; -type DeviceBusInfo = (Arc>, Arc>); +type DeviceBusInfo = (Arc>, Arc>); /// PCI bus structure. pub struct PciBus { - /// Bus name - pub name: String, - /// Devices attached to the bus. - pub devices: HashMap>>, - /// Child buses of the bus. - pub child_buses: Vec>>, - /// Pci bridge which the bus originates from. - pub parent_bridge: Option>>, + pub base: BusBase, /// IO region which the parent bridge manages. #[cfg(target_arch = "x86_64")] pub io_region: Region, @@ -49,6 +46,42 @@ pub struct PciBus { pub msi_irq_manager: Option>, } +/// Convert from Arc> to &mut PciBus. +#[macro_export] +macro_rules! MUT_PCI_BUS { + ($trait_bus:expr, $lock_bus: ident, $struct_bus: ident) => { + convert_bus_mut!($trait_bus, $lock_bus, $struct_bus, PciBus); + }; +} + +/// Convert from Arc> to &PciBus. +#[macro_export] +macro_rules! PCI_BUS { + ($trait_bus:expr, $lock_bus: ident, $struct_bus: ident) => { + convert_bus_ref!($trait_bus, $lock_bus, $struct_bus, PciBus); + }; +} + +impl Bus for PciBus { + gen_base_func!(bus_base, bus_base_mut, BusBase, base); + + fn reset(&self) -> Result<()> { + for dev in self.child_devices().values() { + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + pci_dev + .reset(false) + .with_context(|| format!("Fail to reset pci dev {}", pci_dev.name()))?; + + if let Some(bus) = pci_dev.child_bus() { + MUT_PCI_BUS!(bus, locked_bus, pci_bus); + pci_bus.reset().with_context(|| "Fail to reset child bus")?; + } + } + + Ok(()) + } +} + impl PciBus { /// Create new bus entity. /// @@ -63,10 +96,7 @@ impl PciBus { mem_region: Region, ) -> Self { Self { - name, - devices: HashMap::new(), - child_buses: Vec::new(), - parent_bridge: None, + base: BusBase::new(name), #[cfg(target_arch = "x86_64")] io_region, mem_region, @@ -95,9 +125,9 @@ impl PciBus { /// /// * `bus_num` - The bus number. /// * `devfn` - Slot number << 3 | Function number. - pub fn get_device(&self, bus_num: u8, devfn: u8) -> Option>> { - if let Some(dev) = self.devices.get(&devfn) { - return Some((*dev).clone()); + pub fn get_device(&self, bus_num: u8, devfn: u8) -> Option>> { + if let Some(dev) = self.child_dev(u64::from(devfn)) { + return Some(dev.clone()); } debug!("Can't find device {}:{}", bus_num, devfn); None @@ -122,15 +152,18 @@ impl PciBus { /// /// * `bus` - Bus to find from. /// * `bus_number` - The bus number. - pub fn find_bus_by_num(bus: &Arc>, bus_num: u8) -> Option>> { - let locked_bus = bus.lock().unwrap(); - if locked_bus.number(SECONDARY_BUS_NUM as usize) == bus_num { - return Some((*bus).clone()); + pub fn find_bus_by_num(bus: &Arc>, bus_num: u8) -> Option>> { + PCI_BUS!(bus, locked_bus, pci_bus); + if pci_bus.number(SECONDARY_BUS_NUM as usize) == bus_num { + return Some(bus.clone()); } - if locked_bus.in_range(bus_num) { - for sub_bus in &locked_bus.child_buses { - if let Some(b) = PciBus::find_bus_by_num(sub_bus, bus_num) { - return Some(b); + if pci_bus.in_range(bus_num) { + for dev in pci_bus.child_devices().values() { + let child_bus = dev.lock().unwrap().child_bus(); + if let Some(sub_bus) = child_bus { + if let Some(b) = PciBus::find_bus_by_num(&sub_bus, bus_num) { + return Some(b); + } } } } @@ -143,14 +176,20 @@ impl PciBus { /// /// * `bus` - Bus to find from. /// * `name` - Bus name. - pub fn find_bus_by_name(bus: &Arc>, bus_name: &str) -> Option>> { + pub fn find_bus_by_name( + bus: &Arc>, + bus_name: &str, + ) -> Option>> { let locked_bus = bus.lock().unwrap(); - if locked_bus.name.as_str() == bus_name { - return Some((*bus).clone()); + if locked_bus.name().as_str() == bus_name { + return Some(bus.clone()); } - for sub_bus in &locked_bus.child_buses { - if let Some(b) = PciBus::find_bus_by_name(sub_bus, bus_name) { - return Some(b); + for dev in locked_bus.child_devices().values() { + let child_bus = dev.lock().unwrap().child_bus(); + if let Some(sub_bus) = child_bus { + if let Some(b) = PciBus::find_bus_by_name(&sub_bus, bus_name) { + return Some(b); + } } } None @@ -160,20 +199,22 @@ impl PciBus { /// /// # Arguments /// - /// * `pci_bus` - On which bus to find. + /// * `bus` - On which bus to find. /// * `name` - Device name. - pub fn find_attached_bus(pci_bus: &Arc>, name: &str) -> Option { - // Device is attached in pci_bus. - let locked_bus = pci_bus.lock().unwrap(); - for dev in locked_bus.devices.values() { + pub fn find_attached_bus(bus: &Arc>, name: &str) -> Option { + // Device is attached in bus. + let locked_bus = bus.lock().unwrap(); + for dev in locked_bus.child_devices().values() { if dev.lock().unwrap().name() == name { - return Some((pci_bus.clone(), dev.clone())); + return Some((bus.clone(), dev.clone())); } - } - // Find in child bus. - for bus in &locked_bus.child_buses { - if let Some(found) = PciBus::find_attached_bus(bus, name) { - return Some(found); + + // Find in child bus. + let child_bus = dev.lock().unwrap().child_bus(); + if let Some(sub_bus) = child_bus { + if let Some(found) = PciBus::find_attached_bus(&sub_bus, name) { + return Some(found); + } } } None @@ -185,39 +226,17 @@ impl PciBus { /// /// * `bus` - Bus to detach from. /// * `dev` - Device attached to the bus. - pub fn detach_device(bus: &Arc>, dev: &Arc>) -> Result<()> { - let mut dev_locked = dev.lock().unwrap(); - dev_locked + pub fn detach_device(bus: &Arc>, dev: &Arc>) -> Result<()> { + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + pci_dev .unrealize() - .with_context(|| format!("Failed to unrealize device {}", dev_locked.name()))?; + .with_context(|| format!("Failed to unrealize device {}", pci_dev.name()))?; - let devfn = dev_locked.pci_base().devfn; + let devfn = u64::from(pci_dev.pci_base().devfn); let mut locked_bus = bus.lock().unwrap(); - if locked_bus.devices.get(&devfn).is_some() { - locked_bus.devices.remove(&devfn); - } else { - bail!("Device {} not found in the bus", dev_locked.name()); - } - - Ok(()) - } - - pub fn reset(&mut self) -> Result<()> { - for (_id, pci_dev) in self.devices.iter() { - pci_dev - .lock() - .unwrap() - .reset(false) - .with_context(|| "Fail to reset pci dev")?; - } - - for child_bus in self.child_buses.iter_mut() { - child_bus - .lock() - .unwrap() - .reset() - .with_context(|| "Fail to reset child bus")?; - } + locked_bus + .detach_child(devfn) + .with_context(|| format!("Device {} not found in the bus", pci_dev.name()))?; Ok(()) } @@ -232,23 +251,16 @@ impl PciBus { } fn get_bridge_control_reg(&self, offset: usize, data: &mut [u8]) { - if self.parent_bridge.is_none() { - return; + if let Some(parent_bridge) = self.parent_device() { + let bridge = parent_bridge.upgrade().unwrap(); + MUT_ROOT_PORT!(bridge, locked_bridge, rootport); + rootport.read_config(offset, data); } - - self.parent_bridge - .as_ref() - .unwrap() - .upgrade() - .unwrap() - .lock() - .unwrap() - .read_config(offset, data); } pub fn generate_dev_id(&self, devfn: u8) -> u16 { let bus_num = self.number(SECONDARY_BUS_NUM as usize); - ((bus_num as u16) << 8) | (devfn as u16) + (u16::from(bus_num) << 8) | u16::from(devfn) } pub fn update_dev_id(&self, devfn: u8, dev_id: &Arc) { @@ -256,11 +268,11 @@ impl PciBus { } pub fn get_msi_irq_manager(&self) -> Option> { - match &self.parent_bridge { + match self.parent_device().as_ref() { Some(parent_bridge) => { - let parent_bridge = parent_bridge.upgrade().unwrap(); - let locked_parent_bridge = parent_bridge.lock().unwrap(); - locked_parent_bridge.get_msi_irq_manager() + let bridge = parent_bridge.upgrade().unwrap(); + ROOT_PORT!(bridge, locked_bridge, rootport); + rootport.get_msi_irq_manager() } None => self.msi_irq_manager.clone(), } @@ -269,181 +281,83 @@ impl PciBus { #[cfg(test)] mod tests { - use anyhow::Result; - use super::*; use crate::pci::bus::PciBus; - use crate::pci::config::{PciConfig, PCI_CONFIG_SPACE_SIZE}; - use crate::pci::root_port::RootPort; - use crate::pci::{PciDevBase, PciHost}; - use crate::{Device, DeviceBase}; - use address_space::{AddressSpace, Region}; - - #[derive(Clone)] - struct PciDevice { - base: PciDevBase, - } - - impl Device for PciDevice { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } - } - - impl PciDevOps for PciDevice { - fn pci_base(&self) -> &PciDevBase { - &self.base - } - - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base - } - - fn write_config(&mut self, offset: usize, data: &[u8]) { - #[allow(unused_variables)] - self.base.config.write( - offset, - data, - 0, - #[cfg(target_arch = "x86_64")] - None, - None, - ); - } - - fn realize(mut self) -> Result<()> { - let devfn = self.base.devfn; - self.init_write_mask(false)?; - self.init_write_clear_mask(false)?; - - let dev = Arc::new(Mutex::new(self)); - dev.lock() - .unwrap() - .base - .parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .devices - .insert(devfn, dev.clone()); - Ok(()) - } - - fn unrealize(&mut self) -> Result<()> { - Ok(()) - } - } - - pub fn create_pci_host() -> Arc> { - #[cfg(target_arch = "x86_64")] - let sys_io = AddressSpace::new( - Region::init_container_region(1 << 16, "sysio"), - "sysio", - None, - ) - .unwrap(); - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sysmem"), - "sysmem", - None, - ) - .unwrap(); - Arc::new(Mutex::new(PciHost::new( - #[cfg(target_arch = "x86_64")] - &sys_io, - &sys_mem, - (0xB000_0000, 0x1000_0000), - (0xC000_0000, 0x3000_0000), - #[cfg(target_arch = "aarch64")] - (0xF000_0000, 0x1000_0000), - #[cfg(target_arch = "aarch64")] - (512 << 30, 512 << 30), - 16, - ))) - } + use crate::pci::host::tests::create_pci_host; + use crate::pci::root_port::{RootPort, RootPortConfig}; + use crate::pci::tests::TestPciDevice; + use crate::pci::{clean_pcidevops_type, register_pcidevops_type}; #[test] fn test_find_attached_bus() { let pci_host = create_pci_host(); let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); - - let root_port = RootPort::new("pcie.1".to_string(), 8, 0, root_bus.clone(), false); + let root_bus = Arc::downgrade(&locked_pci_host.child_bus().unwrap()); + let root_port_config = RootPortConfig { + addr: (1, 0), + id: "pcie.1".to_string(), + ..Default::default() + }; + let root_port = RootPort::new(root_port_config, root_bus.clone()); root_port.realize().unwrap(); // Test device is attached to the root bus. - let pci_dev = PciDevice { - base: PciDevBase { - base: DeviceBase::new("test1".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 0), - devfn: 10, - parent_bus: root_bus.clone(), - }, - }; + let pci_dev = TestPciDevice::new("test1", 10, root_bus); pci_dev.realize().unwrap(); // Test device is attached to the root port. - let bus = PciBus::find_bus_by_name(&locked_pci_host.root_bus, "pcie.1").unwrap(); - let pci_dev = PciDevice { - base: PciDevBase { - base: DeviceBase::new("test2".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 0), - devfn: 12, - parent_bus: Arc::downgrade(&bus), - }, - }; + let bus = + PciBus::find_bus_by_name(&locked_pci_host.child_bus().unwrap(), "pcie.1").unwrap(); + let pci_dev = TestPciDevice::new("test2", 12, Arc::downgrade(&bus)); pci_dev.realize().unwrap(); - let info = PciBus::find_attached_bus(&locked_pci_host.root_bus, "test0"); + let info = PciBus::find_attached_bus(&locked_pci_host.child_bus().unwrap(), "test0"); assert!(info.is_none()); - let info = PciBus::find_attached_bus(&locked_pci_host.root_bus, "test1"); + let info = PciBus::find_attached_bus(&locked_pci_host.child_bus().unwrap(), "test1"); assert!(info.is_some()); let (bus, dev) = info.unwrap(); - assert_eq!(bus.lock().unwrap().name, "pcie.0"); + assert_eq!(bus.lock().unwrap().name(), "pcie.0"); assert_eq!(dev.lock().unwrap().name(), "test1"); - let info = PciBus::find_attached_bus(&locked_pci_host.root_bus, "test2"); + let info = PciBus::find_attached_bus(&locked_pci_host.child_bus().unwrap(), "test2"); assert!(info.is_some()); let (bus, dev) = info.unwrap(); - assert_eq!(bus.lock().unwrap().name, "pcie.1"); + assert_eq!(bus.lock().unwrap().name(), "pcie.1"); assert_eq!(dev.lock().unwrap().name(), "test2"); } #[test] fn test_detach_device() { + register_pcidevops_type::().unwrap(); + let pci_host = create_pci_host(); let locked_pci_host = pci_host.lock().unwrap(); - let root_bus = Arc::downgrade(&locked_pci_host.root_bus); + let root_bus = Arc::downgrade(&locked_pci_host.child_bus().unwrap()); - let root_port = RootPort::new("pcie.1".to_string(), 8, 0, root_bus.clone(), false); + let root_port_config = RootPortConfig { + id: "pcie.1".to_string(), + addr: (1, 0), + ..Default::default() + }; + let root_port = RootPort::new(root_port_config, root_bus.clone()); root_port.realize().unwrap(); - let bus = PciBus::find_bus_by_name(&locked_pci_host.root_bus, "pcie.1").unwrap(); - let pci_dev = PciDevice { - base: PciDevBase { - base: DeviceBase::new("test1".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 0), - devfn: 0, - parent_bus: Arc::downgrade(&bus), - }, - }; - let dev = Arc::new(Mutex::new(pci_dev.clone())); - let dev_ops: Arc> = dev; + let bus = + PciBus::find_bus_by_name(&locked_pci_host.child_bus().unwrap(), "pcie.1").unwrap(); + let pci_dev = TestPciDevice::new("test1", 0, Arc::downgrade(&bus)); + let dev_ops: Arc> = Arc::new(Mutex::new(pci_dev.clone())); pci_dev.realize().unwrap(); - let info = PciBus::find_attached_bus(&locked_pci_host.root_bus, "test1"); + let info = PciBus::find_attached_bus(&locked_pci_host.child_bus().unwrap(), "test1"); assert!(info.is_some()); let res = PciBus::detach_device(&bus, &dev_ops); assert!(res.is_ok()); - let info = PciBus::find_attached_bus(&locked_pci_host.root_bus, "test1"); + let info = PciBus::find_attached_bus(&locked_pci_host.child_bus().unwrap(), "test1"); assert!(info.is_none()); + + clean_pcidevops_type(); } } diff --git a/devices/src/pci/config.rs b/devices/src/pci/config.rs index 16c9b6a53c052aa8ceb65d0840daf0d0e151c4c8..db84bef81cf0ac2bc8b275839308861a1cc8d0c5 100644 --- a/devices/src/pci/config.rs +++ b/devices/src/pci/config.rs @@ -14,7 +14,7 @@ use std::collections::HashSet; use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Context, Result}; -use log::{error, warn}; +use log::{error, info, warn}; use crate::pci::intx::Intx; use crate::pci::msix::{Msix, MSIX_TABLE_ENTRY_SIZE}; @@ -22,6 +22,7 @@ use crate::pci::{ le_read_u16, le_read_u32, le_read_u64, le_write_u16, le_write_u32, le_write_u64, pci_ext_cap_next, PciBus, PciError, BDF_FUNC_SHIFT, }; +use crate::{convert_bus_ref, Bus, PCI_BUS}; use address_space::Region; use util::num_ops::ranges_overlap; @@ -382,6 +383,8 @@ pub enum PcieDevType { /// Configuration space of PCI/PCIe device. #[derive(Clone)] pub struct PciConfig { + /// Device number and function number. + pub devfn: u8, /// Configuration space data. pub config: Vec, /// Mask of writable bits. @@ -411,7 +414,7 @@ impl PciConfig { /// /// * `config_size` - Configuration size in bytes. /// * `nr_bar` - Number of BARs. - pub fn new(config_size: usize, nr_bar: u8) -> Self { + pub fn new(devfn: u8, config_size: usize, nr_bar: u8) -> Self { let mut bars = Vec::new(); for _ in 0..nr_bar as usize { bars.push(Bar { @@ -425,15 +428,16 @@ impl PciConfig { } PciConfig { + devfn, config: vec![0; config_size], write_mask: vec![0; config_size], write_clear_mask: vec![0; config_size], bars, - last_cap_end: PCI_CONFIG_HEAD_END as u16, + last_cap_end: u16::from(PCI_CONFIG_HEAD_END), last_ext_cap_offset: 0, last_ext_cap_end: PCI_CONFIG_SPACE_SIZE as u16, msix: None, - pci_express_cap_offset: PCI_CONFIG_HEAD_END as u16, + pci_express_cap_offset: u16::from(PCI_CONFIG_HEAD_END), intx: None, } } @@ -733,7 +737,7 @@ impl PciConfig { return BAR_SPACE_UNMAPPED; } let bar_val = le_read_u32(&self.config, offset).unwrap(); - return (bar_val & IO_BASE_ADDR_MASK) as u64; + return u64::from(bar_val & IO_BASE_ADDR_MASK); } if command & COMMAND_MEMORY_SPACE == 0 { @@ -743,7 +747,7 @@ impl PciConfig { RegionType::Io => BAR_SPACE_UNMAPPED, RegionType::Mem32Bit => { let bar_val = le_read_u32(&self.config, offset).unwrap(); - (bar_val & MEM_BASE_ADDR_MASK as u32) as u64 + u64::from(bar_val & MEM_BASE_ADDR_MASK as u32) } RegionType::Mem64Bit => { let bar_val = le_read_u64(&self.config, offset).unwrap(); @@ -804,8 +808,8 @@ impl PciConfig { /// # Arguments /// /// * `bus` - The bus which region registered. - pub fn unregister_bars(&mut self, bus: &Arc>) -> Result<()> { - let locked_bus = bus.lock().unwrap(); + pub fn unregister_bars(&mut self, bus: &Arc>) -> Result<()> { + PCI_BUS!(bus, locked_bus, pci_bus); for bar in self.bars.iter_mut() { if bar.address == BAR_SPACE_UNMAPPED || bar.size == 0 { continue; @@ -815,7 +819,7 @@ impl PciConfig { { #[cfg(target_arch = "x86_64")] if let Some(region) = bar.region.as_ref() { - locked_bus + pci_bus .io_region .delete_subregion(region) .with_context(|| "Failed to unregister io bar")?; @@ -823,7 +827,7 @@ impl PciConfig { } _ => { if let Some(region) = bar.region.as_ref() { - locked_bus + pci_bus .mem_region .delete_subregion(region) .with_context(|| "Failed to unregister mem bar")?; @@ -904,7 +908,11 @@ impl PciConfig { } } - trace::pci_update_mappings_del(id, self.bars[id].address, self.bars[id].size); + info!( + "pci dev {} delete bar {} mapping: addr 0x{:X} size {}", + self.devfn, id, self.bars[id].address, self.bars[id].size + ); + self.bars[id].address = BAR_SPACE_UNMAPPED; } @@ -938,7 +946,11 @@ impl PciConfig { } } - trace::pci_update_mappings_add(id, self.bars[id].address, self.bars[id].size); + info!( + "pci dev {} update bar {} mapping: addr 0x{:X} size {}", + self.devfn, id, new_addr, self.bars[id].size + ); + self.bars[id].address = new_addr; } } @@ -1001,7 +1013,7 @@ impl PciConfig { le_write_u32( &mut self.config, offset, - id as u32 | (version << PCI_EXT_CAP_VER_SHIFT), + u32::from(id) | (version << PCI_EXT_CAP_VER_SHIFT), )?; if self.last_ext_cap_offset != 0 { let old_value = le_read_u32(&self.config, self.last_ext_cap_offset as usize)?; @@ -1028,7 +1040,7 @@ impl PciConfig { self.add_pci_cap(CapId::Pcie as u8, PCI_EXP_VER2_SIZEOF as usize)?; self.pci_express_cap_offset = cap_offset as u16; let mut offset: usize = cap_offset + PcieCap::CapReg as usize; - let pci_type = (dev_type << PCI_EXP_FLAGS_TYPE_SHIFT) as u16 & PCI_EXP_FLAGS_TYPE; + let pci_type = u16::from(dev_type << PCI_EXP_FLAGS_TYPE_SHIFT) & PCI_EXP_FLAGS_TYPE; le_write_u16( &mut self.config, offset, @@ -1053,7 +1065,7 @@ impl PciConfig { | PCI_EXP_LNKCAP_ASPMS_0S | PCI_EXP_LNKCAP_LBNC | PCI_EXP_LNKCAP_DLLLARC - | ((port_num as u32) << PCI_EXP_LNKCAP_PN_SHIFT), + | (u32::from(port_num) << PCI_EXP_LNKCAP_PN_SHIFT), )?; offset = cap_offset + PcieCap::LinkStat as usize; le_write_u16( @@ -1073,7 +1085,7 @@ impl PciConfig { | PCI_EXP_SLTCAP_PIP | PCI_EXP_SLTCAP_HPS | PCI_EXP_SLTCAP_HPC - | ((slot as u32) << PCI_EXP_SLTCAP_PSN_SHIFT), + | (u32::from(slot) << PCI_EXP_SLTCAP_PSN_SHIFT), )?; offset = cap_offset + PcieCap::SlotCtl as usize; le_write_u16( @@ -1172,8 +1184,8 @@ impl PciConfig { fn validate_bar_size(&self, bar_type: RegionType, size: u64) -> Result<()> { if !size.is_power_of_two() || (bar_type == RegionType::Io && size < MINIMUM_BAR_SIZE_FOR_PIO as u64) - || (bar_type == RegionType::Mem32Bit && size > u32::MAX as u64) - || (bar_type == RegionType::Io && size > u16::MAX as u64) + || (bar_type == RegionType::Mem32Bit && size > u64::from(u32::MAX)) + || (bar_type == RegionType::Io && size > u64::from(u16::MAX)) { return Err(anyhow!(PciError::InvalidConf( "Bar size of type ".to_string() + &bar_type.to_string(), @@ -1196,6 +1208,10 @@ impl PciConfig { let max_vector = table_len / MSIX_TABLE_ENTRY_SIZE as usize; vector_nr < max_vector as u32 } + + pub fn bus_maser_enable(&self) -> bool { + self.config[COMMAND as usize] as u16 & COMMAND_BUS_MASTER != 0 + } } #[cfg(test)] @@ -1208,7 +1224,7 @@ mod tests { #[test] fn test_find_pci_cap() { - let mut pci_config = PciConfig::new(PCI_CONFIG_SPACE_SIZE, 3); + let mut pci_config = PciConfig::new(0, PCI_CONFIG_SPACE_SIZE, 3); let offset = pci_config.find_pci_cap(MSIX_CAP_ID); assert_eq!(offset, 0xff); @@ -1238,7 +1254,7 @@ mod tests { write: Arc::new(write_ops), }; let region = Region::init_io_region(8192, region_ops.clone(), "io"); - let mut pci_config = PciConfig::new(PCI_CONFIG_SPACE_SIZE, 3); + let mut pci_config = PciConfig::new(0, PCI_CONFIG_SPACE_SIZE, 3); #[cfg(target_arch = "x86_64")] assert!(pci_config @@ -1264,7 +1280,7 @@ mod tests { le_write_u32( &mut pci_config.config, BAR_0 as usize, - IO_BASE_ADDR_MASK | BAR_IO_SPACE as u32, + IO_BASE_ADDR_MASK | u32::from(BAR_IO_SPACE), ) .unwrap(); le_write_u32( @@ -1276,7 +1292,7 @@ mod tests { le_write_u64( &mut pci_config.config, BAR_0 as usize + 2 * REG_SIZE, - MEM_BASE_ADDR_MASK | (BAR_MEM_64BIT | BAR_PREFETCH) as u64, + MEM_BASE_ADDR_MASK | u64::from(BAR_MEM_64BIT | BAR_PREFETCH), ) .unwrap(); @@ -1286,7 +1302,7 @@ mod tests { { // I/O space access is enabled. le_write_u16(&mut pci_config.config, COMMAND as usize, COMMAND_IO_SPACE).unwrap(); - assert_eq!(pci_config.get_bar_address(0), IO_BASE_ADDR_MASK as u64); + assert_eq!(pci_config.get_bar_address(0), u64::from(IO_BASE_ADDR_MASK)); } assert_eq!(pci_config.get_bar_address(1), BAR_SPACE_UNMAPPED); assert_eq!(pci_config.get_bar_address(2), BAR_SPACE_UNMAPPED); @@ -1301,7 +1317,7 @@ mod tests { assert_eq!(pci_config.get_bar_address(0), BAR_SPACE_UNMAPPED); assert_eq!( pci_config.get_bar_address(1), - (MEM_BASE_ADDR_MASK as u32) as u64 + u64::from(MEM_BASE_ADDR_MASK as u32) ); assert_eq!(pci_config.get_bar_address(2), MEM_BASE_ADDR_MASK); } @@ -1315,7 +1331,7 @@ mod tests { write: Arc::new(write_ops), }; let region = Region::init_io_region(8192, region_ops, "io"); - let mut pci_config = PciConfig::new(PCI_CONFIG_SPACE_SIZE, 6); + let mut pci_config = PciConfig::new(0, PCI_CONFIG_SPACE_SIZE, 6); #[cfg(target_arch = "x86_64")] assert!(pci_config @@ -1332,14 +1348,14 @@ mod tests { le_write_u32( &mut pci_config.config, BAR_0 as usize, - 2048_u32 | BAR_IO_SPACE as u32, + 2048_u32 | u32::from(BAR_IO_SPACE), ) .unwrap(); le_write_u32(&mut pci_config.config, BAR_0 as usize + REG_SIZE, 2048).unwrap(); le_write_u32( &mut pci_config.config, BAR_0 as usize + 2 * REG_SIZE, - 2048_u32 | BAR_MEM_64BIT as u32 | BAR_PREFETCH as u32, + 2048_u32 | u32::from(BAR_MEM_64BIT) | u32::from(BAR_PREFETCH), ) .unwrap(); le_write_u16( @@ -1389,14 +1405,14 @@ mod tests { le_write_u32( &mut pci_config.config, BAR_0 as usize, - 4096_u32 | BAR_IO_SPACE as u32, + 4096_u32 | u32::from(BAR_IO_SPACE), ) .unwrap(); le_write_u32(&mut pci_config.config, BAR_0 as usize + REG_SIZE, 4096).unwrap(); le_write_u32( &mut pci_config.config, BAR_0 as usize + 2 * REG_SIZE, - 4096_u32 | BAR_MEM_64BIT as u32 | BAR_PREFETCH as u32, + 4096_u32 | u32::from(BAR_MEM_64BIT) | u32::from(BAR_PREFETCH), ) .unwrap(); pci_config @@ -1412,7 +1428,7 @@ mod tests { #[test] fn test_add_pci_cap() { - let mut pci_config = PciConfig::new(PCI_CONFIG_SPACE_SIZE, 2); + let mut pci_config = PciConfig::new(0, PCI_CONFIG_SPACE_SIZE, 2); // Overflow. assert!(pci_config @@ -1424,12 +1440,12 @@ mod tests { // Capbility size is not multiple of DWORD. pci_config.add_pci_cap(0x12, 10).unwrap(); - assert_eq!(pci_config.last_cap_end, PCI_CONFIG_HEAD_END as u16 + 12); + assert_eq!(pci_config.last_cap_end, u16::from(PCI_CONFIG_HEAD_END) + 12); } #[test] fn test_add_pcie_ext_cap() { - let mut pci_config = PciConfig::new(PCIE_CONFIG_SPACE_SIZE, 2); + let mut pci_config = PciConfig::new(0, PCIE_CONFIG_SPACE_SIZE, 2); // Overflow. assert!(pci_config @@ -1450,7 +1466,7 @@ mod tests { #[test] fn test_get_ext_cap_size() { - let mut pcie_config = PciConfig::new(PCIE_CONFIG_SPACE_SIZE, 3); + let mut pcie_config = PciConfig::new(0, PCIE_CONFIG_SPACE_SIZE, 3); let offset1 = pcie_config.add_pcie_ext_cap(1, 0x10, 1).unwrap(); let offset2 = pcie_config.add_pcie_ext_cap(1, 0x40, 1).unwrap(); pcie_config.add_pcie_ext_cap(1, 0x20, 1).unwrap(); @@ -1463,7 +1479,7 @@ mod tests { #[test] fn test_reset_common_regs() { - let mut pcie_config = PciConfig::new(PCIE_CONFIG_SPACE_SIZE, 3); + let mut pcie_config = PciConfig::new(0, PCIE_CONFIG_SPACE_SIZE, 3); pcie_config.init_common_write_mask().unwrap(); pcie_config.init_common_write_clear_mask().unwrap(); @@ -1488,7 +1504,7 @@ mod tests { write: Arc::new(write_ops), }; let region = Region::init_io_region(4096, region_ops, "io"); - let mut pci_config = PciConfig::new(PCI_CONFIG_SPACE_SIZE, 3); + let mut pci_config = PciConfig::new(0, PCI_CONFIG_SPACE_SIZE, 3); // bar is unmapped #[cfg(target_arch = "x86_64")] @@ -1510,7 +1526,7 @@ mod tests { #[cfg(target_arch = "x86_64")] io_region.clone(), mem_region.clone(), - ))); + ))) as Arc>; assert!(pci_config.unregister_bars(&bus).is_ok()); @@ -1530,14 +1546,14 @@ mod tests { le_write_u32( &mut pci_config.config, BAR_0 as usize, - 2048 | BAR_IO_SPACE as u32, + 2048 | u32::from(BAR_IO_SPACE), ) .unwrap(); le_write_u32(&mut pci_config.config, BAR_0 as usize + REG_SIZE, 2048).unwrap(); le_write_u32( &mut pci_config.config, BAR_0 as usize + 2 * REG_SIZE, - 2048 | BAR_MEM_64BIT as u32 | BAR_PREFETCH as u32, + 2048 | u32::from(BAR_MEM_64BIT) | u32::from(BAR_PREFETCH), ) .unwrap(); le_write_u16( diff --git a/devices/src/pci/demo_device/dpy_device.rs b/devices/src/pci/demo_device/dpy_device.rs index bde922f680f1597bb1e7a5806e6c750cf06b68c3..8248b9dd3c78fc4302aa3f66671ff3f01d6e962d 100644 --- a/devices/src/pci/demo_device/dpy_device.rs +++ b/devices/src/pci/demo_device/dpy_device.rs @@ -28,7 +28,7 @@ use log::error; use once_cell::sync::Lazy; use super::DeviceTypeOperation; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use ui::{ console::{ register_display, DisplayChangeListener, DisplayChangeListenerOperations, DisplayMouse, @@ -128,8 +128,8 @@ impl DisplayChangeListenerOperations for DpyInterface { } let mut i = 0; - let mut offset = y * stride + x * bpp as i32 / 8; - let count = w * bpp as i32 / 8; + let mut offset = y * stride + x * i32::from(bpp) / 8; + let count = w * i32::from(bpp) / 8; while i < h { error!( "update from {} to {}, before is {}", @@ -227,6 +227,7 @@ impl DeviceTypeOperation for DemoDisplay { &mut buf.as_slice(), address_space::GuestAddress(mem_addr), buf.len() as u64, + AddressAttr::Ram, ); } diff --git a/devices/src/pci/demo_device/gpu_device.rs b/devices/src/pci/demo_device/gpu_device.rs index 7abef71259d06af6027fdbf86441ae9c0fadbf15..8b7c1775d7f7a24222f872ac521fd890ac9b5289 100644 --- a/devices/src/pci/demo_device/gpu_device.rs +++ b/devices/src/pci/demo_device/gpu_device.rs @@ -29,7 +29,7 @@ use byteorder::{ByteOrder, LittleEndian}; use log::info; use super::DeviceTypeOperation; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use ui::{ console::{ console_close, console_init, display_cursor_define, display_graphic_update, @@ -196,8 +196,12 @@ impl DeviceTypeOperation for DemoGpu { let mem_addr = LittleEndian::read_u64(data); // Event Type. let mut buf: Vec = vec![]; - self.sys_mem - .read(&mut buf, address_space::GuestAddress(mem_addr), 21)?; + self.sys_mem.read( + &mut buf, + address_space::GuestAddress(mem_addr), + 21, + AddressAttr::Ram, + )?; let event_type = GpuEvent::from(buf[0]); let x = LittleEndian::read_u32(&buf[1..5]); let y = LittleEndian::read_u32(&buf[5..9]); diff --git a/devices/src/pci/demo_device/kbd_pointer_device.rs b/devices/src/pci/demo_device/kbd_pointer_device.rs index ce55498718b6ea0d14744175a388890b74107d87..c4c036bade115e32b3a0d2cb80a7c09899f041c0 100644 --- a/devices/src/pci/demo_device/kbd_pointer_device.rs +++ b/devices/src/pci/demo_device/kbd_pointer_device.rs @@ -22,7 +22,7 @@ use byteorder::{ByteOrder, LittleEndian}; use once_cell::sync::Lazy; use super::DeviceTypeOperation; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use ui::input::{register_keyboard, register_pointer, Axis, InputType, KeyboardOpts, PointerOpts}; static MEM_ADDR: Lazy>> = Lazy::new(|| { @@ -51,12 +51,36 @@ impl MemSpace { bail!("No memory allocated!") } }; - sys_mem.write_object(&(msg.event_type as u8), address_space::GuestAddress(addr))?; - sys_mem.write_object(&msg.keycode, address_space::GuestAddress(addr + 1))?; - sys_mem.write_object(&msg.down, address_space::GuestAddress(addr + 3))?; - sys_mem.write_object(&msg.button, address_space::GuestAddress(addr + 4))?; - sys_mem.write_object(&msg.x, address_space::GuestAddress(addr + 8))?; - sys_mem.write_object(&msg.y, address_space::GuestAddress(addr + 12))?; + sys_mem.write_object( + &(msg.event_type as u8), + address_space::GuestAddress(addr), + AddressAttr::Ram, + )?; + sys_mem.write_object( + &msg.keycode, + address_space::GuestAddress(addr + 1), + AddressAttr::Ram, + )?; + sys_mem.write_object( + &msg.down, + address_space::GuestAddress(addr + 3), + AddressAttr::Ram, + )?; + sys_mem.write_object( + &msg.button, + address_space::GuestAddress(addr + 4), + AddressAttr::Ram, + )?; + sys_mem.write_object( + &msg.x, + address_space::GuestAddress(addr + 8), + AddressAttr::Ram, + )?; + sys_mem.write_object( + &msg.y, + address_space::GuestAddress(addr + 12), + AddressAttr::Ram, + )?; Ok(()) } @@ -94,7 +118,7 @@ impl KeyboardOpts for TestPciKbd { let msg = PointerMessage { event_type: InputEvent::KbdEvent, keycode, - down: down as u8, + down: u8::from(down), ..Default::default() }; MEM_ADDR.lock().unwrap().send_kbdmouse_message(&msg) diff --git a/devices/src/pci/demo_device/mod.rs b/devices/src/pci/demo_device/mod.rs index 778953667f7000cafd2367bd9fbf87cb51623e94..8f6fe60224058f93f1c2c91f63d8f2efac51ab03 100644 --- a/devices/src/pci/demo_device/mod.rs +++ b/devices/src/pci/demo_device/mod.rs @@ -33,31 +33,49 @@ pub mod dpy_device; pub mod gpu_device; pub mod kbd_pointer_device; -use std::{ - sync::Mutex, - sync::{ - atomic::{AtomicU16, Ordering}, - Arc, Weak, - }, -}; +use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; +use std::sync::{Arc, Mutex, Weak}; -use anyhow::{bail, Result}; +use anyhow::Result; +use clap::Parser; use log::error; -use crate::pci::demo_device::{ - dpy_device::DemoDisplay, gpu_device::DemoGpu, kbd_pointer_device::DemoKbdMouse, +use crate::pci::config::{ + PciConfig, RegionType, DEVICE_ID, HEADER_TYPE, HEADER_TYPE_ENDPOINT, PCIE_CONFIG_SPACE_SIZE, + SUB_CLASS_CODE, VENDOR_ID, }; -use crate::pci::{ - config::{ - PciConfig, RegionType, DEVICE_ID, HEADER_TYPE, HEADER_TYPE_ENDPOINT, - PCIE_CONFIG_SPACE_SIZE, SUB_CLASS_CODE, VENDOR_ID, - }, - init_msix, le_write_u16, PciBus, PciDevOps, +use crate::pci::demo_device::{ + base_device::BaseDevice, dpy_device::DemoDisplay, gpu_device::DemoGpu, + kbd_pointer_device::DemoKbdMouse, }; -use crate::pci::{demo_device::base_device::BaseDevice, PciDevBase}; -use crate::{Device, DeviceBase}; +use crate::pci::{init_msix, le_write_u16, PciBus, PciDevBase, PciDevOps}; +use crate::{convert_bus_ref, Bus, Device, DeviceBase, PCI_BUS}; use address_space::{AddressSpace, GuestAddress, Region, RegionOps}; -use machine_manager::config::DemoDevConfig; +use machine_manager::config::{get_pci_df, valid_id}; +use util::gen_base_func; + +/// Config struct for `demo_dev`. +/// Contains demo_dev device's attr. +#[derive(Parser, Debug, Clone)] +#[command(no_binary_name(true))] +pub struct DemoDevConfig { + #[arg(long, value_parser = ["pcie-demo-dev"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: String, + #[arg(long, value_parser = get_pci_df)] + pub addr: (u8, u8), + // Different device implementations can be configured based on this parameter + #[arg(long, alias = "device_type")] + pub device_type: Option, + #[arg(long, alias = "bar_num", default_value = "0")] + pub bar_num: u8, + // Every bar has the same size just for simplification. + #[arg(long, alias = "bar_size", default_value = "0")] + pub bar_size: u64, +} pub struct DemoDev { base: PciDevBase, @@ -71,25 +89,26 @@ impl DemoDev { pub fn new( cfg: DemoDevConfig, devfn: u8, - _sys_mem: Arc, - parent_bus: Weak>, + sys_mem: Arc, + parent_bus: Weak>, ) -> Self { // You can choose different device function based on the parameter of device_type. - let device: Arc> = match cfg.device_type.as_str() { - "demo-gpu" => Arc::new(Mutex::new(DemoGpu::new(_sys_mem, cfg.id.clone()))), - "demo-input" => Arc::new(Mutex::new(DemoKbdMouse::new(_sys_mem))), - "demo-display" => Arc::new(Mutex::new(DemoDisplay::new(_sys_mem))), + let device_type = cfg.device_type.clone().unwrap_or_default(); + let device: Arc> = match device_type.as_str() { + "demo-gpu" => Arc::new(Mutex::new(DemoGpu::new(sys_mem, cfg.id.clone()))), + "demo-input" => Arc::new(Mutex::new(DemoKbdMouse::new(sys_mem))), + "demo-display" => Arc::new(Mutex::new(DemoDisplay::new(sys_mem))), _ => Arc::new(Mutex::new(BaseDevice::new())), }; DemoDev { base: PciDevBase { - base: DeviceBase::new(cfg.id.clone(), false), - config: PciConfig::new(PCIE_CONFIG_SPACE_SIZE, cfg.bar_num), + base: DeviceBase::new(cfg.id.clone(), false, Some(parent_bus)), + config: PciConfig::new(devfn, PCIE_CONFIG_SPACE_SIZE, cfg.bar_num), devfn, - parent_bus, + bme: Arc::new(AtomicBool::new(false)), }, cmd_cfg: cfg, - mem_region: Region::init_container_region(u32::MAX as u64, "DemoDev"), + mem_region: Region::init_container_region(u64::from(u32::MAX), "DemoDev"), dev_id: Arc::new(AtomicU16::new(0)), device, } @@ -108,19 +127,6 @@ impl DemoDev { Ok(()) } - fn attach_to_parent_bus(self) -> Result<()> { - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let mut locked_parent_bus = parent_bus.lock().unwrap(); - if locked_parent_bus.devices.get(&self.base.devfn).is_some() { - bail!("device already existed"); - } - let devfn = self.base.devfn; - let demo_pci_dev = Arc::new(Mutex::new(self)); - locked_parent_bus.devices.insert(devfn, demo_pci_dev); - - Ok(()) - } - fn register_data_handling_bar(&mut self) -> Result<()> { let device = self.device.clone(); let write_ops = move |data: &[u8], addr: GuestAddress, offset: u64| -> bool { @@ -155,7 +161,7 @@ impl DemoDev { self.mem_region.clone(), RegionType::Mem64Bit, false, - (self.cmd_cfg.bar_size * self.cmd_cfg.bar_num as u64).next_power_of_two(), + (self.cmd_cfg.bar_size * u64::from(self.cmd_cfg.bar_num)).next_power_of_two(), )?; Ok(()) @@ -170,26 +176,13 @@ const DEVICE_ID_DEMO: u16 = 0xBEEF; const CLASS_CODE_DEMO: u16 = 0xEE; impl Device for DemoDev { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} - -impl PciDevOps for DemoDev { - fn pci_base(&self) -> &PciDevBase { - &self.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + self.base.config.reset_common_regs() } - /// Realize PCI/PCIe device. - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { self.init_pci_config()?; if self.cmd_cfg.bar_num > 0 { init_msix(&mut self.base, 0, 1, self.dev_id.clone(), None, None)?; @@ -198,19 +191,27 @@ impl PciDevOps for DemoDev { self.register_data_handling_bar()?; self.device.lock().unwrap().realize()?; - self.attach_to_parent_bus()?; - Ok(()) + let devfn = u64::from(self.base.devfn); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + let mut locked_bus = parent_bus.lock().unwrap(); + let demo_pci_dev = Arc::new(Mutex::new(self)); + locked_bus.attach_child(devfn, demo_pci_dev.clone())?; + + Ok(demo_pci_dev) } - /// Unrealize PCI/PCIe device. fn unrealize(&mut self) -> Result<()> { self.device.lock().unwrap().unrealize() } +} + +impl PciDevOps for DemoDev { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); /// write the pci configuration space fn write_config(&mut self, offset: usize, data: &[u8]) { - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let parent_bus_locked = parent_bus.lock().unwrap(); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(parent_bus, locked_bus, pci_bus); self.base.config.write( offset, @@ -218,14 +219,9 @@ impl PciDevOps for DemoDev { self.dev_id.load(Ordering::Acquire), #[cfg(target_arch = "x86_64")] None, - Some(&parent_bus_locked.mem_region), + Some(&pci_bus.mem_region), ); } - - /// Reset device - fn reset(&mut self, _reset_child_device: bool) -> Result<()> { - self.base.config.reset_common_regs() - } } pub trait DeviceTypeOperation: Send { @@ -234,3 +230,29 @@ pub trait DeviceTypeOperation: Send { fn realize(&mut self) -> Result<()>; fn unrealize(&mut self) -> Result<()>; } + +#[cfg(test)] +mod tests { + use super::*; + use machine_manager::config::str_slip_to_clap; + #[test] + fn test_parse_demo_dev() { + // Test1: Right. + let demo_cmd1 = "pcie-demo-dev,bus=pcie.0,addr=0x4,id=test_0,device_type=demo-gpu,bar_num=3,bar_size=4096"; + let result = DemoDevConfig::try_parse_from(str_slip_to_clap(demo_cmd1, true, false)); + assert!(result.is_ok()); + let demo_cfg = result.unwrap(); + assert_eq!(demo_cfg.id, "test_0".to_string()); + assert_eq!(demo_cfg.device_type, Some("demo-gpu".to_string())); + assert_eq!(demo_cfg.bar_num, 3); + assert_eq!(demo_cfg.bar_size, 4096); + + // Test2: Default bar_num/bar_size. + let demo_cmd2 = "pcie-demo-dev,bus=pcie.0,addr=4.0,id=test_0,device_type=demo-gpu"; + let result = DemoDevConfig::try_parse_from(str_slip_to_clap(demo_cmd2, true, false)); + assert!(result.is_ok()); + let demo_cfg = result.unwrap(); + assert_eq!(demo_cfg.bar_num, 0); + assert_eq!(demo_cfg.bar_size, 0); + } +} diff --git a/devices/src/pci/host.rs b/devices/src/pci/host.rs index b6a44dabda1aaa7bcdcb02108b27a53fade9deec..7f3b3d78f6be07e974ff727bc5609e40f88d717e 100644 --- a/devices/src/pci/host.rs +++ b/devices/src/pci/host.rs @@ -16,11 +16,11 @@ use anyhow::{Context, Result}; #[cfg(target_arch = "aarch64")] use crate::pci::PCI_INTR_BASE; -use crate::pci::{bus::PciBus, PciDevOps, PCI_PIN_NUM, PCI_SLOT_MAX}; +use crate::pci::{bus::PciBus, to_pcidevops, PCI_PIN_NUM, PCI_SLOT_MAX}; #[cfg(target_arch = "x86_64")] use crate::pci::{le_read_u32, le_write_u32}; use crate::sysbus::{SysBusDevBase, SysBusDevOps}; -use crate::{Device, DeviceBase}; +use crate::{Device, DeviceBase, PCI_BUS_DEVICE}; use acpi::{ AmlActiveLevel, AmlAddressSpaceDecode, AmlAnd, AmlArg, AmlBuilder, AmlCacheable, AmlCreateDWordField, AmlDWord, AmlDWordDesc, AmlDevice, AmlEdgeLevel, AmlEisaId, AmlElse, @@ -34,6 +34,7 @@ use acpi::{AmlIoDecode, AmlIoResource}; #[cfg(target_arch = "aarch64")] use acpi::{AmlOne, AmlQWordDesc}; use address_space::{AddressSpace, GuestAddress, RegionOps}; +use util::gen_base_func; #[cfg(target_arch = "x86_64")] const CONFIG_ADDRESS_ENABLE_MASK: u32 = 0x8000_0000; @@ -53,7 +54,6 @@ const ECAM_OFFSET_MASK: u64 = 0xfff; #[derive(Clone)] pub struct PciHost { base: SysBusDevBase, - pub root_bus: Arc>, #[cfg(target_arch = "x86_64")] config_addr: u32, pcie_ecam_range: (u64, u64), @@ -95,9 +95,10 @@ impl PciHost { io_region, mem_region, ); + let mut base = SysBusDevBase::default(); + base.base.child = Some(Arc::new(Mutex::new(root_bus))); PciHost { - base: SysBusDevBase::default(), - root_bus: Arc::new(Mutex::new(root_bus)), + base, #[cfg(target_arch = "x86_64")] config_addr: 0, pcie_ecam_range, @@ -110,14 +111,21 @@ impl PciHost { } } - pub fn find_device(&self, bus_num: u8, devfn: u8) -> Option>> { - let locked_root_bus = self.root_bus.lock().unwrap(); + pub fn find_device(&self, bus_num: u8, devfn: u8) -> Option>> { + let root_bus = self.child_bus().unwrap(); + let locked_root_bus = root_bus.lock().unwrap(); if bus_num == 0 { - return locked_root_bus.get_device(0, devfn); + let dev = locked_root_bus.child_dev(u64::from(devfn))?; + return Some(dev.clone()); } - for bus in &locked_root_bus.child_buses { - if let Some(b) = PciBus::find_bus_by_num(bus, bus_num) { - return b.lock().unwrap().get_device(bus_num, devfn); + + for dev in locked_root_bus.child_devices().values() { + let child_bus = dev.lock().unwrap().child_bus(); + if let Some(bus) = child_bus { + if let Some(b) = PciBus::find_bus_by_num(&bus, bus_num) { + let dev = b.lock().unwrap().child_dev(u64::from(devfn))?.clone(); + return Some(dev); + } } } None @@ -196,7 +204,8 @@ impl PciHost { match locked_hb.find_device(bus_num, devfn) { Some(dev) => { offset &= PIO_OFFSET_MASK; - dev.lock().unwrap().read_config(offset as usize, data); + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + pci_dev.read_config(offset as usize, data); } None => { for d in data.iter_mut() { @@ -218,7 +227,8 @@ impl PciHost { let devfn = ((offset >> PIO_DEVFN_SHIFT) & CONFIG_DEVFN_MASK) as u8; if let Some(dev) = locked_hb.find_device(bus_num, devfn) { offset &= PIO_OFFSET_MASK; - dev.lock().unwrap().write_config(offset as usize, data); + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + pci_dev.write_config(offset as usize, data); } true }; @@ -231,23 +241,23 @@ impl PciHost { } impl Device for PciHost { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + let root_bus = self.child_bus().unwrap(); + for dev in root_bus.lock().unwrap().child_devices().values() { + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + pci_dev + .reset(true) + .with_context(|| "Fail to reset pci device under pci host")?; + } + + Ok(()) } } impl SysBusDevOps for PciHost { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { let bus_num = ((offset as u32 >> ECAM_BUS_SHIFT) & CONFIG_BUS_MASK) as u8; @@ -255,9 +265,9 @@ impl SysBusDevOps for PciHost { match self.find_device(bus_num, devfn) { Some(dev) => { let addr: usize = (offset & ECAM_OFFSET_MASK) as usize; - let dev_name = &dev.lock().unwrap().pci_base().base.id.clone(); - trace::pci_read_config(dev_name, addr, data); - dev.lock().unwrap().read_config(addr, data); + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + trace::pci_read_config(&pci_dev.name(), addr, data); + pci_dev.read_config(addr, data); } None => { for d in data.iter_mut() { @@ -274,26 +284,14 @@ impl SysBusDevOps for PciHost { match self.find_device(bus_num, devfn) { Some(dev) => { let addr: usize = (offset & ECAM_OFFSET_MASK) as usize; - let dev_name = &dev.lock().unwrap().pci_base().base.id.clone(); - trace::pci_write_config(dev_name, addr, data); - dev.lock().unwrap().write_config(addr, data); + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + trace::pci_write_config(&pci_dev.name(), addr, data); + pci_dev.write_config(addr, data); true } None => true, } } - - fn reset(&mut self) -> Result<()> { - for (_id, pci_dev) in self.root_bus.lock().unwrap().devices.iter_mut() { - pci_dev - .lock() - .unwrap() - .reset(true) - .with_context(|| "Fail to reset pci device under pci host")?; - } - - Ok(()) - } } #[cfg(target_arch = "x86_64")] @@ -398,8 +396,8 @@ fn build_prt_for_aml(pci_bus: &mut AmlDevice, irq: i32) { (0..PCI_PIN_NUM).for_each(|pin| { let gsi = (pin + slot) % PCI_PIN_NUM; let mut pkg = AmlPackage::new(4); - pkg.append_child(AmlDWord((slot as u32) << 16 | 0xFFFF)); - pkg.append_child(AmlDWord(pin as u32)); + pkg.append_child(AmlDWord(u32::from(slot) << 16 | 0xFFFF)); + pkg.append_child(AmlDWord(u32::from(pin))); pkg.append_child(AmlName(format!("GSI{}", gsi))); pkg.append_child(AmlZero); prt_pkg.append_child(pkg); @@ -421,7 +419,7 @@ fn build_prt_for_aml(pci_bus: &mut AmlDevice, irq: i32) { AmlEdgeLevel::Level, AmlActiveLevel::High, AmlIntShare::Exclusive, - vec![irqs as u32], + vec![u32::from(irqs)], )); gsi.append_child(AmlNameDecl::new("_PRS", crs)); let mut crs = AmlResTemplate::new(); @@ -430,7 +428,7 @@ fn build_prt_for_aml(pci_bus: &mut AmlDevice, irq: i32) { AmlEdgeLevel::Level, AmlActiveLevel::High, AmlIntShare::Exclusive, - vec![irqs as u32], + vec![u32::from(irqs)], )); gsi.append_child(AmlNameDecl::new("_CRS", crs)); let method = AmlMethod::new("_SRS", 1, false); @@ -542,87 +540,17 @@ impl AmlBuilder for PciHost { #[cfg(test)] pub mod tests { + #[cfg(target_arch = "x86_64")] use byteorder::{ByteOrder, LittleEndian}; use super::*; use crate::pci::bus::PciBus; - use crate::pci::config::{PciConfig, PCI_CONFIG_SPACE_SIZE, SECONDARY_BUS_NUM}; - use crate::pci::root_port::RootPort; - use crate::pci::{PciDevBase, Result}; - use crate::{Device, DeviceBase}; + use crate::pci::config::SECONDARY_BUS_NUM; + use crate::pci::root_port::{RootPort, RootPortConfig}; + use crate::pci::tests::TestPciDevice; + use crate::pci::{clean_pcidevops_type, register_pcidevops_type, PciDevOps}; use address_space::Region; - struct PciDevice { - base: PciDevBase, - } - - impl Device for PciDevice { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } - } - - impl PciDevOps for PciDevice { - fn pci_base(&self) -> &PciDevBase { - &self.base - } - - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base - } - - fn init_write_mask(&mut self, _is_bridge: bool) -> Result<()> { - let mut offset = 0_usize; - while offset < self.base.config.config.len() { - LittleEndian::write_u32( - &mut self.base.config.write_mask[offset..offset + 4], - 0xffff_ffff, - ); - offset += 4; - } - Ok(()) - } - - fn init_write_clear_mask(&mut self, _is_bridge: bool) -> Result<()> { - Ok(()) - } - - fn write_config(&mut self, offset: usize, data: &[u8]) { - #[allow(unused_variables)] - self.base.config.write( - offset, - data, - 0, - #[cfg(target_arch = "x86_64")] - None, - None, - ); - } - - fn realize(mut self) -> Result<()> { - let devfn = self.base.devfn; - self.init_write_mask(false)?; - self.init_write_clear_mask(false)?; - - let dev = Arc::new(Mutex::new(self)); - dev.lock() - .unwrap() - .base - .parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .devices - .insert(devfn, dev.clone()); - Ok(()) - } - } - pub fn create_pci_host() -> Arc> { #[cfg(target_arch = "x86_64")] let sys_io = AddressSpace::new( @@ -654,11 +582,18 @@ pub mod tests { #[test] #[cfg(target_arch = "x86_64")] fn test_pio_ops() { + register_pcidevops_type::().unwrap(); + let pci_host = create_pci_host(); - let root_bus = Arc::downgrade(&pci_host.lock().unwrap().root_bus); + let root_bus = Arc::downgrade(&pci_host.lock().unwrap().child_bus().unwrap()); let pio_addr_ops = PciHost::build_pio_addr_ops(pci_host.clone()); let pio_data_ops = PciHost::build_pio_data_ops(pci_host.clone()); - let root_port = RootPort::new("pcie.1".to_string(), 8, 0, root_bus, false); + let root_port_config = RootPortConfig { + addr: (1, 0), + id: "pcie.1".to_string(), + ..Default::default() + }; + let root_port = RootPort::new(root_port_config, root_bus.clone()); root_port.realize().unwrap(); let mut data = [0_u8; 4]; @@ -675,7 +610,6 @@ pub mod tests { assert_eq!(buf, data); // Non-DWORD access on CONFIG_ADDR - let mut config = [0_u8; 4]; (pio_addr_ops.read)(&mut config, GuestAddress(0), 0); let data = [0x12, 0x34]; @@ -727,39 +661,48 @@ pub mod tests { let mut buf = [0_u8; 4]; (pio_data_ops.read)(&mut buf, GuestAddress(0), 0); assert_eq!(buf, [0xff_u8; 4]); + + clean_pcidevops_type(); } #[test] fn test_mmio_ops() { + register_pcidevops_type::().unwrap(); + register_pcidevops_type::().unwrap(); + let pci_host = create_pci_host(); - let root_bus = Arc::downgrade(&pci_host.lock().unwrap().root_bus); + let root_bus = pci_host.lock().unwrap().child_bus().unwrap(); + let weak_root_bus = Arc::downgrade(&root_bus); let mmconfig_region_ops = PciHost::build_mmconfig_ops(pci_host.clone()); - let mut root_port = RootPort::new("pcie.1".to_string(), 8, 0, root_bus.clone(), false); + let root_port_config = RootPortConfig { + addr: (1, 0), + id: "pcie.1".to_string(), + ..Default::default() + }; + let mut root_port = RootPort::new(root_port_config, weak_root_bus.clone()); root_port.write_config(SECONDARY_BUS_NUM as usize, &[1]); root_port.realize().unwrap(); - let mut root_port = RootPort::new("pcie.2".to_string(), 16, 0, root_bus, false); + let root_port_config = RootPortConfig { + addr: (2, 0), + id: "pcie.2".to_string(), + ..Default::default() + }; + let mut root_port = RootPort::new(root_port_config, weak_root_bus); root_port.write_config(SECONDARY_BUS_NUM as usize, &[2]); root_port.realize().unwrap(); - let bus = PciBus::find_bus_by_name(&pci_host.lock().unwrap().root_bus, "pcie.2").unwrap(); - let pci_dev = PciDevice { - base: PciDevBase { - base: DeviceBase::new("PCI device".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 0), - devfn: 8, - parent_bus: Arc::downgrade(&bus), - }, - }; + let bus = PciBus::find_bus_by_name(&root_bus, "pcie.2").unwrap(); + let pci_dev = TestPciDevice::new("PCI device", 8, Arc::downgrade(&bus)); pci_dev.realize().unwrap(); - let addr: u64 = 8_u64 << ECAM_DEVFN_SHIFT | SECONDARY_BUS_NUM as u64; + let addr: u64 = 8_u64 << ECAM_DEVFN_SHIFT | u64::from(SECONDARY_BUS_NUM); let data = [1_u8]; (mmconfig_region_ops.write)(&data, GuestAddress(0), addr); let mut buf = [0_u8]; (mmconfig_region_ops.read)(&mut buf, GuestAddress(0), addr); assert_eq!(buf, data); - let addr: u64 = 16_u64 << ECAM_DEVFN_SHIFT | SECONDARY_BUS_NUM as u64; + let addr: u64 = 16_u64 << ECAM_DEVFN_SHIFT | u64::from(SECONDARY_BUS_NUM); let data = [2_u8]; (mmconfig_region_ops.write)(&data, GuestAddress(0), addr); let mut buf = [0_u8]; @@ -778,5 +721,7 @@ pub mod tests { let mut buf = [0_u8; 2]; (mmconfig_region_ops.read)(&mut buf, GuestAddress(0), addr); assert_eq!(buf, data); + + clean_pcidevops_type(); } } diff --git a/devices/src/pci/hotplug.rs b/devices/src/pci/hotplug.rs index df5bfca390ad1c5abbe295b271fd2a50ccc62579..7f13de932233210c61b7b8b063d4a0368b66b677 100644 --- a/devices/src/pci/hotplug.rs +++ b/devices/src/pci/hotplug.rs @@ -14,18 +14,19 @@ use std::sync::{Arc, Mutex}; use anyhow::{bail, Context, Result}; -use crate::pci::{PciBus, PciDevOps}; +use crate::pci::PciBus; +use crate::{convert_bus_ref, Bus, Device, PCI_BUS}; pub trait HotplugOps: Send { /// Plug device, usually called when hot plug device in device_add. - fn plug(&mut self, dev: &Arc>) -> Result<()>; + fn plug(&mut self, dev: &Arc>) -> Result<()>; /// Unplug device request, usually called when hot unplug device in device_del. /// Only send unplug request to the guest OS, without actually removing the device. - fn unplug_request(&mut self, dev: &Arc>) -> Result<()>; + fn unplug_request(&mut self, dev: &Arc>) -> Result<()>; /// Remove the device. - fn unplug(&mut self, dev: &Arc>) -> Result<()>; + fn unplug(&mut self, dev: &Arc>) -> Result<()>; } /// Plug the device into the bus. @@ -40,14 +41,14 @@ pub trait HotplugOps: Send { /// Return Error if /// * No hot plug controller found. /// * Device plug failed. -pub fn handle_plug(bus: &Arc>, dev: &Arc>) -> Result<()> { - let locked_bus = bus.lock().unwrap(); - if let Some(hpc) = locked_bus.hotplug_controller.as_ref() { +pub fn handle_plug(bus: &Arc>, dev: &Arc>) -> Result<()> { + PCI_BUS!(bus, locked_bus, pci_bus); + if let Some(hpc) = pci_bus.hotplug_controller.as_ref() { hpc.upgrade().unwrap().lock().unwrap().plug(dev) } else { bail!( "No hot plug controller found for bus {} when plug", - locked_bus.name + pci_bus.name() ); } } @@ -65,18 +66,18 @@ pub fn handle_plug(bus: &Arc>, dev: &Arc>) -> /// * No hot plug controller found. /// * Device unplug request failed. pub fn handle_unplug_pci_request( - bus: &Arc>, - dev: &Arc>, + bus: &Arc>, + dev: &Arc>, ) -> Result<()> { - let locked_bus = bus.lock().unwrap(); - let hpc = locked_bus + PCI_BUS!(bus, locked_bus, pci_bus); + let hpc = pci_bus .hotplug_controller .as_ref() .cloned() .with_context(|| { format!( "No hot plug controller found for bus {} when unplug request", - locked_bus.name + pci_bus.name() ) })?; // No need to hold the lock. diff --git a/devices/src/pci/intx.rs b/devices/src/pci/intx.rs index b0624fe02c64fe39c755674dcfa4af6aa25fdaf8..9b62e0bda81d03fa2c61ff9738b20688fb6b6ef0 100644 --- a/devices/src/pci/intx.rs +++ b/devices/src/pci/intx.rs @@ -15,8 +15,10 @@ use std::sync::{Arc, Mutex, Weak}; use anyhow::Result; use log::error; +use super::{PciDevOps, RootPort}; use crate::interrupt_controller::LineIrqManager; use crate::pci::{swizzle_map_irq, PciBus, PciConfig, INTERRUPT_PIN, PCI_PIN_NUM}; +use crate::{convert_bus_ref, convert_device_ref, Bus, PCI_BUS, ROOT_PORT}; pub type InterruptHandler = Box Result<()> + Send + Sync>; @@ -119,7 +121,7 @@ impl Intx { pub fn init_intx( name: String, config: &mut PciConfig, - parent_bus: Weak>, + parent_bus: Weak>, devfn: u8, ) -> Result<()> { if config.config[INTERRUPT_PIN as usize] == 0 { @@ -129,24 +131,24 @@ pub fn init_intx( return Ok(()); } - let (irq, intx_state) = if let Some(pci_bus) = parent_bus.upgrade() { - let locked_pci_bus = pci_bus.lock().unwrap(); + let (irq, intx_state) = if let Some(bus) = parent_bus.upgrade() { + PCI_BUS!(bus, locked_bus, pci_bus); let pin = config.config[INTERRUPT_PIN as usize] - 1; - let (irq, intx_state) = match &locked_pci_bus.parent_bridge { + let (irq, intx_state) = match &pci_bus.parent_device() { Some(parent_bridge) => { let parent_bridge = parent_bridge.upgrade().unwrap(); - let locked_parent_bridge = parent_bridge.lock().unwrap(); + ROOT_PORT!(parent_bridge, locked_bridge, bridge); ( - swizzle_map_irq(locked_parent_bridge.pci_base().devfn, pin), - locked_parent_bridge.get_intx_state(), + swizzle_map_irq(bridge.pci_base().devfn, pin), + bridge.get_intx_state(), ) } None => { - if locked_pci_bus.intx_state.is_some() { + if pci_bus.intx_state.is_some() { ( swizzle_map_irq(devfn, pin), - Some(locked_pci_bus.intx_state.as_ref().unwrap().clone()), + Some(pci_bus.intx_state.as_ref().unwrap().clone()), ) } else { (std::u32::MAX, None) diff --git a/devices/src/pci/mod.rs b/devices/src/pci/mod.rs index 0c8c0af3ee314be2a4b0bcec5211333ff67e10af..2866b3c440a43740fb7d70d367f8fa8b994dc4a5 100644 --- a/devices/src/pci/mod.rs +++ b/devices/src/pci/mod.rs @@ -28,22 +28,28 @@ pub use error::PciError; pub use host::PciHost; pub use intx::{init_intx, InterruptHandler, PciIntxState}; pub use msix::{init_msix, MsiVector}; -pub use root_port::RootPort; +pub use root_port::{RootPort, RootPortConfig}; -use std::{ - mem::size_of, - sync::{Arc, Mutex, Weak}, -}; +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::mem::size_of; +use std::sync::atomic::AtomicBool; +use std::sync::{Arc, Mutex, Weak}; use anyhow::{bail, Result}; use byteorder::{ByteOrder, LittleEndian}; +#[cfg(feature = "scream")] +use crate::misc::ivshmem::Ivshmem; +#[cfg(feature = "pvpanic")] +use crate::misc::pvpanic::PvPanicPci; +use crate::pci::config::{HEADER_TYPE, HEADER_TYPE_MULTIFUNC, MAX_FUNC}; +use crate::usb::xhci::xhci_pci::XhciPciDevice; use crate::{ - pci::config::{HEADER_TYPE, HEADER_TYPE_MULTIFUNC, MAX_FUNC}, - MsiIrqManager, + convert_bus_ref, convert_device_ref, Bus, Device, DeviceBase, MsiIrqManager, PCI_BUS, ROOT_PORT, }; -use crate::{Device, DeviceBase}; -use util::AsAny; +#[cfg(feature = "demo_device")] +use demo_device::DemoDev; const BDF_FUNC_SHIFT: u8 = 3; pub const PCI_SLOT_MAX: u8 = 32; @@ -138,11 +144,11 @@ pub struct PciDevBase { pub config: PciConfig, /// Devfn. pub devfn: u8, - /// Primary Bus. - pub parent_bus: Weak>, + /// Bus master enable. + pub bme: Arc, } -pub trait PciDevOps: Device + Send + AsAny { +pub trait PciDevOps: Device + Send { /// Get base property of pci device. fn pci_base(&self) -> &PciDevBase; @@ -169,14 +175,6 @@ pub trait PciDevOps: Device + Send + AsAny { Ok(()) } - /// Realize PCI/PCIe device. - fn realize(self) -> Result<()>; - - /// Unrealize PCI/PCIe device. - fn unrealize(&mut self) -> Result<()> { - bail!("Unrealize of the pci device is not implemented"); - } - /// Configuration space read. /// /// # Arguments @@ -207,35 +205,23 @@ pub trait PciDevOps: Device + Send + AsAny { /// Device id to send MSI/MSI-X. fn set_dev_id(&self, bus_num: u8, devfn: u8) -> u16 { let bus_shift: u16 = 8; - ((bus_num as u16) << bus_shift) | (devfn as u16) - } - - /// Reset device - fn reset(&mut self, _reset_child_device: bool) -> Result<()> { - Ok(()) + (u16::from(bus_num) << bus_shift) | u16::from(devfn) } /// Get the path of the PCI bus where the device resides. - fn get_parent_dev_path(&self, parent_bus: Arc>) -> String { - let locked_parent_bus = parent_bus.lock().unwrap(); - let parent_dev_path = if locked_parent_bus.name.eq("pcie.0") { + fn get_parent_dev_path(&self, parent_bus: Arc>) -> String { + PCI_BUS!(parent_bus, locked_bus, pci_bus); + + if pci_bus.name().eq("pcie.0") { String::from("/pci@ffffffffffffffff") } else { // This else branch will not be executed currently, // which is mainly to be compatible with new PCI bridge devices. // unwrap is safe because pci bus under root port will not return null. - locked_parent_bus - .parent_bridge - .as_ref() - .unwrap() - .upgrade() - .unwrap() - .lock() - .unwrap() - .get_dev_path() - .unwrap() - }; - parent_dev_path + let parent_bridge = pci_bus.parent_device().unwrap().upgrade().unwrap(); + ROOT_PORT!(parent_bridge, locked_bridge, rootport); + rootport.get_dev_path().unwrap() + } } /// Fill the device path according to parent device path and device function. @@ -270,6 +256,72 @@ pub trait PciDevOps: Device + Send + AsAny { } } +pub type ToPciDevOpsFunc = fn(&mut dyn Any) -> &mut dyn PciDevOps; + +static mut PCIDEVOPS_HASHMAP: Option> = None; + +pub fn convert_to_pcidevops(item: &mut dyn Any) -> &mut dyn PciDevOps { + // SAFETY: The typeid of `T` is the typeid recorded in the hashmap. The target structure type of + // the conversion is its own structure type, so the conversion result will definitely not be `None`. + let t = item.downcast_mut::().unwrap(); + t as &mut dyn PciDevOps +} + +pub fn register_pcidevops_type() -> Result<()> { + let type_id = TypeId::of::(); + // SAFETY: PCIDEVOPS_HASHMAP will be built in `type_init` function sequentially in the main thread. + // And will not be changed after `type_init`. + unsafe { + if PCIDEVOPS_HASHMAP.is_none() { + PCIDEVOPS_HASHMAP = Some(HashMap::new()); + } + let types = PCIDEVOPS_HASHMAP.as_mut().unwrap(); + if types.get(&type_id).is_some() { + bail!("Type Id {:?} has been registered.", type_id); + } + types.insert(type_id, convert_to_pcidevops::); + } + + Ok(()) +} + +pub fn devices_register_pcidevops_type() -> Result<()> { + #[cfg(feature = "scream")] + register_pcidevops_type::()?; + #[cfg(feature = "pvpanic")] + register_pcidevops_type::()?; + register_pcidevops_type::()?; + #[cfg(feature = "demo_device")] + register_pcidevops_type::()?; + register_pcidevops_type::() +} + +#[cfg(test)] +pub fn clean_pcidevops_type() { + unsafe { + PCIDEVOPS_HASHMAP = None; + } +} + +pub fn to_pcidevops(dev: &mut dyn Device) -> Option<&mut dyn PciDevOps> { + // SAFETY: PCIDEVOPS_HASHMAP has been built. And this function is called without changing hashmap. + unsafe { + let types = PCIDEVOPS_HASHMAP.as_mut().unwrap(); + let func = types.get(&dev.device_type_id())?; + let pcidev = func(dev.as_any_mut()); + Some(pcidev) + } +} + +/// Convert from Arc> to &mut dyn PciDevOps. +#[macro_export] +macro_rules! PCI_BUS_DEVICE { + ($trait_device:expr, $lock_device: ident, $trait_pcidevops: ident) => { + let mut $lock_device = $trait_device.lock().unwrap(); + let $trait_pcidevops = to_pcidevops(&mut *$lock_device).unwrap(); + }; +} + /// Init multifunction for pci devices. /// /// # Arguments @@ -282,12 +334,12 @@ pub fn init_multifunction( multifunction: bool, config: &mut [u8], devfn: u8, - parent_bus: Weak>, + parent_bus: Weak>, ) -> Result<()> { let mut header_type = - le_read_u16(config, HEADER_TYPE as usize)? & (!HEADER_TYPE_MULTIFUNC as u16); + le_read_u16(config, HEADER_TYPE as usize)? & u16::from(!HEADER_TYPE_MULTIFUNC); if multifunction { - header_type |= HEADER_TYPE_MULTIFUNC as u16; + header_type |= u16::from(HEADER_TYPE_MULTIFUNC); } le_write_u16(config, HEADER_TYPE as usize, header_type)?; @@ -297,24 +349,21 @@ pub fn init_multifunction( // leave the bit to 0. let slot = pci_slot(devfn); let bus = parent_bus.upgrade().unwrap(); - let locked_bus = bus.lock().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); if pci_func(devfn) != 0 { - let pci_dev = locked_bus.devices.get(&pci_devfn(slot, 0)); - if pci_dev.is_none() { + let dev = pci_bus.child_dev(u64::from(pci_devfn(slot, 0))); + if dev.is_none() { return Ok(()); } let mut data = vec![0_u8; 2]; - pci_dev - .unwrap() - .lock() - .unwrap() - .read_config(HEADER_TYPE as usize, data.as_mut_slice()); - if LittleEndian::read_u16(&data) & HEADER_TYPE_MULTIFUNC as u16 == 0 { + PCI_BUS_DEVICE!(dev.unwrap(), locked_dev, pci_dev); + pci_dev.read_config(HEADER_TYPE as usize, data.as_mut_slice()); + if LittleEndian::read_u16(&data) & u16::from(HEADER_TYPE_MULTIFUNC) == 0 { // Function 0 should set multifunction bit. bail!( "PCI: single function device can't be populated in bus {} function {}.{}", - &locked_bus.name, + &pci_bus.name(), slot, devfn & 0x07 ); @@ -328,7 +377,10 @@ pub fn init_multifunction( // If function 0 is set to single function, the rest function should be None. for func in 1..MAX_FUNC { - if locked_bus.devices.get(&pci_devfn(slot, func)).is_some() { + if pci_bus + .child_dev(u64::from(pci_devfn(slot, func))) + .is_some() + { bail!( "PCI: {}.0 indicates single function, but {}.{} is already populated", slot, @@ -344,15 +396,88 @@ pub fn init_multifunction( /// PCI-to-PCI bridge specification 9.1: Interrupt routing. pub fn swizzle_map_irq(devfn: u8, pin: u8) -> u32 { let pci_slot = devfn >> 3 & 0x1f; - ((pci_slot + pin) % PCI_PIN_NUM) as u32 + u32::from((pci_slot + pin) % PCI_PIN_NUM) } #[cfg(test)] mod tests { + use super::*; + use crate::pci::config::{PciConfig, PCI_CONFIG_SPACE_SIZE}; use crate::DeviceBase; use address_space::{AddressSpace, Region}; + use util::gen_base_func; - use super::*; + #[derive(Clone)] + pub struct TestPciDevice { + base: PciDevBase, + } + + impl TestPciDevice { + pub fn new(name: &str, devfn: u8, parent_bus: Weak>) -> Self { + Self { + base: PciDevBase { + base: DeviceBase::new(name.to_string(), false, Some(parent_bus)), + config: PciConfig::new(devfn, PCI_CONFIG_SPACE_SIZE, 0), + devfn, + bme: Arc::new(AtomicBool::new(false)), + }, + } + } + } + + impl Device for TestPciDevice { + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn realize(mut self) -> Result>> { + let devfn = u64::from(self.base.devfn); + self.init_write_mask(false)?; + self.init_write_clear_mask(false)?; + + let dev = Arc::new(Mutex::new(self)); + let parent_bus = dev.lock().unwrap().parent_bus().unwrap().upgrade().unwrap(); + parent_bus + .lock() + .unwrap() + .attach_child(devfn, dev.clone())?; + + Ok(dev) + } + + fn unrealize(&mut self) -> Result<()> { + Ok(()) + } + } + + impl PciDevOps for TestPciDevice { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); + + fn write_config(&mut self, offset: usize, data: &[u8]) { + self.base.config.write( + offset, + data, + 0, + #[cfg(target_arch = "x86_64")] + None, + None, + ); + } + + fn init_write_mask(&mut self, _is_bridge: bool) -> Result<()> { + let mut offset = 0_usize; + while offset < self.base.config.config.len() { + LittleEndian::write_u32( + &mut self.base.config.write_mask[offset..offset + 4], + 0xffff_ffff, + ); + offset += 4; + } + Ok(()) + } + + fn init_write_clear_mask(&mut self, _is_bridge: bool) -> Result<()> { + Ok(()) + } + } #[test] fn test_le_write_u16_01() { @@ -395,57 +520,20 @@ mod tests { #[test] fn set_dev_id() { - struct PciDev { - base: PciDevBase, - } - - impl Device for PciDev { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } - } - - impl PciDevOps for PciDev { - fn pci_base(&self) -> &PciDevBase { - &self.base - } - - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base - } - - fn write_config(&mut self, _offset: usize, _data: &[u8]) {} - - fn realize(self) -> Result<()> { - Ok(()) - } - } - let sys_mem = AddressSpace::new( Region::init_container_region(u64::max_value(), "sysmem"), "sysmem", None, ) .unwrap(); - let parent_bus: Arc> = Arc::new(Mutex::new(PciBus::new( + let parent_bus = Arc::new(Mutex::new(PciBus::new( String::from("test bus"), #[cfg(target_arch = "x86_64")] Region::init_container_region(1 << 16, "parent_bus"), sys_mem.root().clone(), - ))); - - let dev = PciDev { - base: PciDevBase { - base: DeviceBase::new("PCI device".to_string(), false), - config: PciConfig::new(1, 1), - devfn: 0, - parent_bus: Arc::downgrade(&parent_bus), - }, - }; + ))) as Arc>; + + let dev = TestPciDevice::new("PCI device", 0, Arc::downgrade(&parent_bus)); assert_eq!(dev.set_dev_id(1, 2), 258); } } diff --git a/devices/src/pci/msix.rs b/devices/src/pci/msix.rs index 451a66384d2a3de5baf6adde934db134ca082c75..23562999f813825c7f0274fb4b43ff0aa140db65 100644 --- a/devices/src/pci/msix.rs +++ b/devices/src/pci/msix.rs @@ -21,9 +21,10 @@ use vmm_sys_util::eventfd::EventFd; use crate::pci::config::{CapId, RegionType, MINIMUM_BAR_SIZE_FOR_MMIO}; use crate::pci::{ - le_read_u16, le_read_u32, le_read_u64, le_write_u16, le_write_u32, le_write_u64, PciDevBase, + le_read_u16, le_read_u32, le_read_u64, le_write_u16, le_write_u32, le_write_u64, PciBus, + PciDevBase, }; -use crate::MsiIrqManager; +use crate::{convert_bus_ref, MsiIrqManager, PCI_BUS}; use address_space::{GuestAddress, Region, RegionOps}; use migration::{ DeviceStateDesc, FieldDesc, MigrationError, MigrationHook, MigrationManager, StateTransfer, @@ -187,7 +188,7 @@ impl Msix { fn is_vector_pending(&self, vector: u16) -> bool { let offset: usize = vector as usize / 64; - let pending_bit: u64 = 1 << (vector as u64 % 64); + let pending_bit: u64 = 1 << (u64::from(vector) % 64); let value = le_read_u64(&self.pba, offset).unwrap(); if value & pending_bit > 0 { return true; @@ -197,14 +198,14 @@ impl Msix { fn set_pending_vector(&mut self, vector: u16) { let offset: usize = vector as usize / 64; - let pending_bit: u64 = 1 << (vector as u64 % 64); + let pending_bit: u64 = 1 << (u64::from(vector) % 64); let old_val = le_read_u64(&self.pba, offset).unwrap(); le_write_u64(&mut self.pba, offset, old_val | pending_bit).unwrap(); } fn clear_pending_vector(&mut self, vector: u16) { let offset: usize = vector as usize / 64; - let pending_bit: u64 = !(1 << (vector as u64 % 64)); + let pending_bit: u64 = !(1 << (u64::from(vector) % 64)); let old_val = le_read_u64(&self.pba, offset).unwrap(); le_write_u64(&mut self.pba, offset, old_val & pending_bit).unwrap(); } @@ -230,7 +231,7 @@ impl Msix { msg_data: entry.data, masked: false, #[cfg(target_arch = "aarch64")] - dev_id: self.dev_id.load(Ordering::Acquire) as u32, + dev_id: u32::from(self.dev_id.load(Ordering::Acquire)), }; let irq_manager = self.msi_irq_manager.as_ref().unwrap(); @@ -260,7 +261,7 @@ impl Msix { msg_data: entry.data, masked: false, #[cfg(target_arch = "aarch64")] - dev_id: self.dev_id.load(Ordering::Acquire) as u32, + dev_id: u32::from(self.dev_id.load(Ordering::Acquire)), }; let irq_manager = self.msi_irq_manager.as_ref().unwrap(); @@ -300,7 +301,12 @@ impl Msix { let cloned_msix = msix.clone(); let table_read = move |data: &mut [u8], _addr: GuestAddress, offset: u64| -> bool { - if offset as usize + data.len() > cloned_msix.lock().unwrap().table.len() { + let offset = offset as usize; + if offset + .checked_add(data.len()) + .filter(|&sum| sum <= cloned_msix.lock().unwrap().table.len()) + .is_none() + { error!( "It's forbidden to read out of the msix table(size: {}), with offset of {} and size of {}", cloned_msix.lock().unwrap().table.len(), @@ -309,13 +315,17 @@ impl Msix { ); return false; } - let offset = offset as usize; data.copy_from_slice(&cloned_msix.lock().unwrap().table[offset..(offset + data.len())]); true }; let cloned_msix = msix.clone(); let table_write = move |data: &[u8], _addr: GuestAddress, offset: u64| -> bool { - if offset as usize + data.len() > cloned_msix.lock().unwrap().table.len() { + let offset = offset as usize; + if offset + .checked_add(data.len()) + .filter(|&sum| sum <= cloned_msix.lock().unwrap().table.len()) + .is_none() + { error!( "It's forbidden to write out of the msix table(size: {}), with offset of {} and size of {}", cloned_msix.lock().unwrap().table.len(), @@ -327,13 +337,14 @@ impl Msix { let mut locked_msix = cloned_msix.lock().unwrap(); let vector: u16 = offset as u16 / MSIX_TABLE_ENTRY_SIZE; let was_masked: bool = locked_msix.is_vector_masked(vector); - let offset = offset as usize; locked_msix.table[offset..(offset + 4)].copy_from_slice(data); let is_masked: bool = locked_msix.is_vector_masked(vector); - if was_masked != is_masked && locked_msix.update_irq_routing(vector, is_masked).is_err() - { - return false; + if was_masked != is_masked { + if let Err(e) = locked_msix.update_irq_routing(vector, is_masked) { + error!("Failed to update irq routing: {:?}", e); + return false; + } } // Clear the pending vector just when it is pending. Otherwise, it @@ -356,7 +367,12 @@ impl Msix { let cloned_msix = msix.clone(); let pba_read = move |data: &mut [u8], _addr: GuestAddress, offset: u64| -> bool { - if offset as usize + data.len() > cloned_msix.lock().unwrap().pba.len() { + let offset = offset as usize; + if offset + .checked_add(data.len()) + .filter(|&sum| sum <= cloned_msix.lock().unwrap().pba.len()) + .is_none() + { error!( "Fail to read msi pba, illegal data length {}, offset {}", data.len(), @@ -364,7 +380,6 @@ impl Msix { ); return false; } - let offset = offset as usize; data.copy_from_slice(&cloned_msix.lock().unwrap().pba[offset..(offset + data.len())]); true }; @@ -406,11 +421,11 @@ impl Msix { msg_data: msg.data, masked: false, #[cfg(target_arch = "aarch64")] - dev_id: dev_id as u32, + dev_id: u32::from(dev_id), }; let irq_manager = self.msi_irq_manager.as_ref().unwrap(); - if let Err(e) = irq_manager.trigger(None, msix_vector, dev_id as u32) { + if let Err(e) = irq_manager.trigger(None, msix_vector, u32::from(dev_id)) { error!("Send msix error: {:?}", e); }; } @@ -516,7 +531,7 @@ impl MigrationHook for Msix { msg_data: msg.data, masked: false, #[cfg(target_arch = "aarch64")] - dev_id: self.dev_id.load(Ordering::Acquire) as u32, + dev_id: u32::from(self.dev_id.load(Ordering::Acquire)), }; let irq_manager = self.msi_irq_manager.as_ref().unwrap(); irq_manager.allocate_irq(msi_vector)?; @@ -552,8 +567,8 @@ pub fn init_msix( offset_opt: Option<(u32, u32)>, ) -> Result<()> { let config = &mut pcidev_base.config; - let parent_bus = &pcidev_base.parent_bus; - if vector_nr == 0 || vector_nr > MSIX_TABLE_SIZE_MAX as u32 + 1 { + let parent_bus = pcidev_base.base.parent.as_ref().unwrap(); + if vector_nr == 0 || vector_nr > u32::from(MSIX_TABLE_SIZE_MAX) + 1 { bail!( "invalid msix vectors, which should be in [1, {}]", MSIX_TABLE_SIZE_MAX + 1 @@ -569,8 +584,8 @@ pub fn init_msix( MSIX_CAP_FUNC_MASK | MSIX_CAP_ENABLE, )?; offset = msix_cap_offset + MSIX_CAP_TABLE as usize; - let table_size = vector_nr * MSIX_TABLE_ENTRY_SIZE as u32; - let pba_size = ((round_up(vector_nr as u64, 64).unwrap() / 64) * 8) as u32; + let table_size = vector_nr * u32::from(MSIX_TABLE_ENTRY_SIZE); + let pba_size = ((round_up(u64::from(vector_nr), 64).unwrap() / 64) * 8) as u32; let (table_offset, pba_offset) = offset_opt.unwrap_or((0, table_size)); if ranges_overlap( table_offset as usize, @@ -586,9 +601,9 @@ pub fn init_msix( offset = msix_cap_offset + MSIX_CAP_PBA as usize; le_write_u32(&mut config.config, offset, pba_offset | bar_id as u32)?; - let msi_irq_manager = if let Some(pci_bus) = parent_bus.upgrade() { - let locked_pci_bus = pci_bus.lock().unwrap(); - locked_pci_bus.get_msi_irq_manager() + let msi_irq_manager = if let Some(bus) = parent_bus.upgrade() { + PCI_BUS!(bus, locked_bus, pci_bus); + pci_bus.get_msi_irq_manager() } else { error!("Msi irq controller is none"); None @@ -606,19 +621,19 @@ pub fn init_msix( msix.clone(), region, dev_id, - table_offset as u64, - pba_offset as u64, + u64::from(table_offset), + u64::from(pba_offset), )?; } else { - let mut bar_size = ((table_size + pba_size) as u64).next_power_of_two(); + let mut bar_size = u64::from(table_size + pba_size).next_power_of_two(); bar_size = max(bar_size, MINIMUM_BAR_SIZE_FOR_MMIO as u64); let region = Region::init_container_region(bar_size, "Msix_region"); Msix::register_memory_region( msix.clone(), ®ion, dev_id, - table_offset as u64, - pba_offset as u64, + u64::from(table_offset), + u64::from(pba_offset), )?; config.register_bar(bar_id, region, RegionType::Mem32Bit, false, bar_size)?; } @@ -633,27 +648,29 @@ pub fn init_msix( #[cfg(test)] mod tests { - use std::sync::Weak; + use std::sync::atomic::AtomicBool; use super::*; - use crate::{ - pci::config::{PciConfig, PCI_CONFIG_SPACE_SIZE}, - DeviceBase, - }; + use crate::pci::config::{PciConfig, PCI_CONFIG_SPACE_SIZE}; + use crate::pci::host::tests::create_pci_host; + use crate::{Device, DeviceBase}; #[test] fn test_init_msix() { + let pci_host = create_pci_host(); + let locked_pci_host = pci_host.lock().unwrap(); + let root_bus = Arc::downgrade(&locked_pci_host.child_bus().unwrap()); let mut base = PciDevBase { - base: DeviceBase::new("msix".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 2), + base: DeviceBase::new("msix".to_string(), false, Some(root_bus)), + config: PciConfig::new(1, PCI_CONFIG_SPACE_SIZE, 2), devfn: 1, - parent_bus: Weak::new(), + bme: Arc::new(AtomicBool::new(false)), }; // Too many vectors. assert!(init_msix( &mut base, 0, - MSIX_TABLE_SIZE_MAX as u32 + 2, + u32::from(MSIX_TABLE_SIZE_MAX) + 2, Arc::new(AtomicU16::new(0)), None, None, @@ -666,7 +683,7 @@ mod tests { init_msix(&mut base, 1, 2, Arc::new(AtomicU16::new(0)), None, None).unwrap(); let pci_config = base.config; let msix_cap_start = 64_u8; - assert_eq!(pci_config.last_cap_end, 64 + MSIX_CAP_SIZE as u16); + assert_eq!(pci_config.last_cap_end, 64 + u16::from(MSIX_CAP_SIZE)); // Capabilities pointer assert_eq!(pci_config.config[0x34], msix_cap_start); assert_eq!( @@ -690,7 +707,7 @@ mod tests { fn test_mask_vectors() { let nr_vector = 2_u32; let mut msix = Msix::new( - nr_vector * MSIX_TABLE_ENTRY_SIZE as u32, + nr_vector * u32::from(MSIX_TABLE_ENTRY_SIZE), 64, 64, Arc::new(AtomicU16::new(0)), @@ -712,7 +729,7 @@ mod tests { #[test] fn test_pending_vectors() { let mut msix = Msix::new( - MSIX_TABLE_ENTRY_SIZE as u32, + u32::from(MSIX_TABLE_ENTRY_SIZE), 64, 64, Arc::new(AtomicU16::new(0)), @@ -728,7 +745,7 @@ mod tests { #[test] fn test_get_message() { let mut msix = Msix::new( - MSIX_TABLE_ENTRY_SIZE as u32, + u32::from(MSIX_TABLE_ENTRY_SIZE), 64, 64, Arc::new(AtomicU16::new(0)), @@ -746,11 +763,14 @@ mod tests { #[test] fn test_write_config() { + let pci_host = create_pci_host(); + let locked_pci_host = pci_host.lock().unwrap(); + let root_bus = Arc::downgrade(&locked_pci_host.child_bus().unwrap()); let mut base = PciDevBase { - base: DeviceBase::new("msix".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 2), + base: DeviceBase::new("msix".to_string(), false, Some(root_bus)), + config: PciConfig::new(1, PCI_CONFIG_SPACE_SIZE, 2), devfn: 1, - parent_bus: Weak::new(), + bme: Arc::new(AtomicBool::new(false)), }; init_msix(&mut base, 0, 2, Arc::new(AtomicU16::new(0)), None, None).unwrap(); let msix = base.config.msix.as_ref().unwrap(); diff --git a/devices/src/pci/root_port.rs b/devices/src/pci/root_port.rs index 40130277353f04d82c0d323b09e84310dc97530f..fdf4ef47f80b7b49a03f0cc5afc30361a6d54bfe 100644 --- a/devices/src/pci/root_port.rs +++ b/devices/src/pci/root_port.rs @@ -10,10 +10,11 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; use std::sync::{Arc, Mutex, Weak}; use anyhow::{anyhow, bail, Context, Result}; +use clap::{ArgAction, Parser}; use log::{error, info}; use once_cell::sync::OnceCell; @@ -33,23 +34,49 @@ use crate::pci::config::{BRIDGE_CONTROL, BRIDGE_CTL_SEC_BUS_RESET}; use crate::pci::hotplug::HotplugOps; use crate::pci::intx::init_intx; use crate::pci::msix::init_msix; -use crate::pci::{init_multifunction, PciDevBase, PciError, PciIntxState, INTERRUPT_PIN}; use crate::pci::{ - le_read_u16, le_write_clear_value_u16, le_write_set_value_u16, le_write_u16, PciDevOps, + init_multifunction, le_read_u16, le_write_clear_value_u16, le_write_set_value_u16, + le_write_u16, to_pcidevops, PciDevBase, PciDevOps, PciError, PciIntxState, INTERRUPT_PIN, +}; +use crate::{ + convert_bus_mut, convert_bus_ref, Bus, Device, DeviceBase, MsiIrqManager, MUT_PCI_BUS, PCI_BUS, + PCI_BUS_DEVICE, }; -use crate::{Device, DeviceBase, MsiIrqManager}; use address_space::Region; +use machine_manager::config::{get_pci_df, parse_bool, valid_id}; use machine_manager::qmp::qmp_channel::send_device_deleted_msg; use migration::{ DeviceStateDesc, FieldDesc, MigrationError, MigrationHook, MigrationManager, StateTransfer, }; use migration_derive::{ByteCode, Desc}; -use util::{byte_code::ByteCode, num_ops::ranges_overlap}; +use util::byte_code::ByteCode; +use util::gen_base_func; +use util::num_ops::{ranges_overlap, str_to_num}; const DEVICE_ID_RP: u16 = 0x000c; static FAST_UNPLUG_FEATURE: OnceCell = OnceCell::new(); +/// Basic information of RootPort like port number. +#[derive(Parser, Debug, Clone, Default)] +#[command(no_binary_name(true))] +pub struct RootPortConfig { + #[arg(long, value_parser = ["pcie-root-port"])] + pub classtype: String, + #[arg(long, value_parser = str_to_num::)] + pub port: u8, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: String, + #[arg(long, value_parser = get_pci_df)] + pub addr: (u8, u8), + #[arg(long, default_value = "off", value_parser = parse_bool, action = ArgAction::Append)] + pub multifunction: bool, + #[arg(long, default_value = "0")] + pub chassis: u8, +} + /// Device state root port. #[repr(C)] #[derive(Copy, Clone, Desc, ByteCode)] @@ -67,7 +94,6 @@ struct RootPortState { pub struct RootPort { base: PciDevBase, port_num: u8, - sec_bus: Arc>, #[cfg(target_arch = "x86_64")] io_region: Region, mem_region: Region, @@ -81,41 +107,35 @@ impl RootPort { /// /// # Arguments /// - /// * `name` - Root port name. - /// * `devfn` - Device number << 3 | Function number. - /// * `port_num` - Root port number. + /// * `cfg` - Root port config. /// * `parent_bus` - Weak reference to the parent bus. - pub fn new( - name: String, - devfn: u8, - port_num: u8, - parent_bus: Weak>, - multifunction: bool, - ) -> Self { + pub fn new(cfg: RootPortConfig, parent_bus: Weak>) -> Self { + let devfn = cfg.addr.0 << 3 | cfg.addr.1; #[cfg(target_arch = "x86_64")] let io_region = Region::init_container_region(1 << 16, "RootPortIo"); let mem_region = Region::init_container_region(u64::max_value(), "RootPortMem"); - let sec_bus = Arc::new(Mutex::new(PciBus::new( - name.clone(), + let child_bus = Arc::new(Mutex::new(PciBus::new( + cfg.id.clone(), #[cfg(target_arch = "x86_64")] io_region.clone(), mem_region.clone(), ))); + let mut dev_base = DeviceBase::new(cfg.id, true, Some(parent_bus)); + dev_base.child = Some(child_bus); Self { base: PciDevBase { - base: DeviceBase::new(name, true), - config: PciConfig::new(PCIE_CONFIG_SPACE_SIZE, 2), + base: dev_base, + config: PciConfig::new(devfn, PCIE_CONFIG_SPACE_SIZE, 2), devfn, - parent_bus, + bme: Arc::new(AtomicBool::new(false)), }, - port_num, - sec_bus, + port_num: cfg.port, #[cfg(target_arch = "x86_64")] io_region, mem_region, dev_id: Arc::new(AtomicU16::new(0)), - multifunction, + multifunction: cfg.multifunction, hpev_notified: false, } } @@ -161,7 +181,7 @@ impl RootPort { if locked_msix.enabled { locked_msix.notify(0, self.dev_id.load(Ordering::Acquire)); } else if self.base.config.config[INTERRUPT_PIN as usize] != 0 { - intx.lock().unwrap().notify(self.hpev_notified as u8); + intx.lock().unwrap().notify(u8::from(self.hpev_notified)); } } @@ -203,10 +223,11 @@ impl RootPort { // Store device in a temp vector and unlock the bus. // If the device unrealize called when the bus is locked, a deadlock occurs. // This is because the device unrealize also requires the bus lock. - let devices = self.sec_bus.lock().unwrap().devices.clone(); + let bus = self.child_bus().unwrap(); + let devices = bus.lock().unwrap().child_devices(); for dev in devices.values() { - let mut locked_dev = dev.lock().unwrap(); - if let Err(e) = locked_dev.unrealize() { + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + if let Err(e) = pci_dev.unrealize() { error!("{}", format!("{:?}", e)); error!("Failed to unrealize device {}.", locked_dev.name()); } @@ -215,20 +236,17 @@ impl RootPort { // Send QMP event for successful hot unplugging. send_device_deleted_msg(&locked_dev.name()); } - self.sec_bus.lock().unwrap().devices.clear(); + bus.lock().unwrap().bus_base_mut().children.clear(); } fn register_region(&mut self) { + let bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); + let command: u16 = le_read_u16(&self.base.config.config, COMMAND as usize).unwrap(); if command & COMMAND_IO_SPACE != 0 { #[cfg(target_arch = "x86_64")] - if let Err(e) = self - .base - .parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() + if let Err(e) = pci_bus .io_region .add_subregion(self.io_region.clone(), 0) .with_context(|| "Failed to add IO container region.") @@ -237,13 +255,7 @@ impl RootPort { } } if command & COMMAND_MEMORY_SPACE != 0 { - if let Err(e) = self - .base - .parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() + if let Err(e) = pci_bus .mem_region .add_subregion(self.mem_region.clone(), 0) .with_context(|| "Failed to add memory container region.") @@ -266,7 +278,7 @@ impl RootPort { (cap_offset + PCI_EXP_SLTSTA) as usize, ) .unwrap(); - let val: u16 = data[0] as u16 + ((data[1] as u16) << 8); + let val: u16 = u16::from(data[0]) + (u16::from(data[1]) << 8); if (val & !old_status & PCI_EXP_SLOTSTA_EVENTS) != 0 { let tmpstat = (status & !PCI_EXP_SLOTSTA_EVENTS) | (old_status & PCI_EXP_SLOTSTA_EVENTS); @@ -327,25 +339,41 @@ impl RootPort { } impl Device for RootPort { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} - -impl PciDevOps for RootPort { - fn pci_base(&self) -> &PciDevBase { - &self.base - } + /// Only set slot status to on, and no other device reset actions are implemented. + fn reset(&mut self, reset_child_device: bool) -> Result<()> { + if reset_child_device { + let child_bus = self.child_bus().unwrap(); + MUT_PCI_BUS!(child_bus, locked_child_bus, child_pci_bus); + child_pci_bus + .reset() + .with_context(|| "Fail to reset child_bus in root port")?; + } else { + let cap_offset = self.base.config.pci_express_cap_offset; + le_write_u16( + &mut self.base.config.config, + (cap_offset + PCI_EXP_SLTSTA) as usize, + PCI_EXP_SLTSTA_PDS, + )?; + le_write_u16( + &mut self.base.config.config, + (cap_offset + PCI_EXP_SLTCTL) as usize, + !PCI_EXP_SLTCTL_PCC | PCI_EXP_SLTCTL_PWR_IND_ON, + )?; + le_write_u16( + &mut self.base.config.config, + (cap_offset + PCI_EXP_LNKSTA) as usize, + PCI_EXP_LNKSTA_DLLLA, + )?; + } - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base + self.base.config.reset_bridge_regs()?; + self.base.config.reset() } - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { + let parent_bus = self.parent_bus().unwrap(); self.init_write_mask(true)?; self.init_write_clear_mask(true)?; @@ -361,7 +389,7 @@ impl PciDevOps for RootPort { self.multifunction, config_space, self.base.devfn, - self.base.parent_bus.clone(), + parent_bus.clone(), )?; #[cfg(target_arch = "aarch64")] @@ -373,57 +401,71 @@ impl PciDevOps for RootPort { PcieDevType::RootPort as u8, )?; - self.dev_id.store(self.base.devfn as u16, Ordering::SeqCst); + self.dev_id + .store(u16::from(self.base.devfn), Ordering::SeqCst); init_msix(&mut self.base, 0, 1, self.dev_id.clone(), None, None)?; init_intx( self.name(), &mut self.base.config, - self.base.parent_bus.clone(), + parent_bus.clone(), self.base.devfn, )?; - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let mut locked_parent_bus = parent_bus.lock().unwrap(); + let arc_parent_bus = parent_bus.upgrade().unwrap(); + MUT_PCI_BUS!(arc_parent_bus, locked_parent_bus, parent_pci_bus); + let child_bus = self.child_bus().unwrap(); + MUT_PCI_BUS!(child_bus, locked_child_bus, child_pci_bus); #[cfg(target_arch = "x86_64")] - locked_parent_bus + parent_pci_bus .io_region - .add_subregion(self.sec_bus.lock().unwrap().io_region.clone(), 0) + .add_subregion(child_pci_bus.io_region.clone(), 0) .with_context(|| "Failed to register subregion in I/O space.")?; - locked_parent_bus + parent_pci_bus .mem_region - .add_subregion(self.sec_bus.lock().unwrap().mem_region.clone(), 0) + .add_subregion(child_pci_bus.mem_region.clone(), 0) .with_context(|| "Failed to register subregion in memory space.")?; + drop(locked_child_bus); let name = self.name(); let root_port = Arc::new(Mutex::new(self)); - #[allow(unused_mut)] - let mut locked_root_port = root_port.lock().unwrap(); - locked_root_port.sec_bus.lock().unwrap().parent_bridge = - Some(Arc::downgrade(&root_port) as Weak>); - locked_root_port.sec_bus.lock().unwrap().hotplug_controller = + let locked_root_port = root_port.lock().unwrap(); + let child_bus = locked_root_port.child_bus().unwrap(); + MUT_PCI_BUS!(child_bus, locked_child_bus, child_pci_bus); + child_pci_bus.base.parent = Some(Arc::downgrade(&root_port) as Weak>); + child_pci_bus.hotplug_controller = Some(Arc::downgrade(&root_port) as Weak>); - let pci_device = locked_parent_bus.devices.get(&locked_root_port.base.devfn); - if pci_device.is_none() { - locked_parent_bus - .child_buses - .push(locked_root_port.sec_bus.clone()); - locked_parent_bus - .devices - .insert(locked_root_port.base.devfn, root_port.clone()); - } else { - bail!( - "Devfn {:?} has been used by {:?}", - locked_root_port.base.devfn, - pci_device.unwrap().lock().unwrap().name() - ); - } + parent_pci_bus.attach_child(u64::from(locked_root_port.base.devfn), root_port.clone())?; // Need to drop locked_root_port in order to register root_port instance. drop(locked_root_port); - MigrationManager::register_device_instance(RootPortState::descriptor(), root_port, &name); + MigrationManager::register_device_instance( + RootPortState::descriptor(), + root_port.clone(), + &name, + ); - Ok(()) + Ok(root_port) } +} + +/// Convert from Arc> to &mut RootPort. +#[macro_export] +macro_rules! MUT_ROOT_PORT { + ($trait_device:expr, $lock_device: ident, $struct_device: ident) => { + convert_device_mut!($trait_device, $lock_device, $struct_device, RootPort); + }; +} + +/// Convert from Arc> to &RootPort. +#[macro_export] +macro_rules! ROOT_PORT { + ($trait_device:expr, $lock_device: ident, $struct_device: ident) => { + convert_device_ref!($trait_device, $lock_device, $struct_device, RootPort); + }; +} + +impl PciDevOps for RootPort { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); fn write_config(&mut self, offset: usize, data: &[u8]) { let size = data.len(); @@ -501,39 +543,8 @@ impl PciDevOps for RootPort { self.do_unplug(offset, data, old_ctl, old_status); } - /// Only set slot status to on, and no other device reset actions are implemented. - fn reset(&mut self, reset_child_device: bool) -> Result<()> { - if reset_child_device { - self.sec_bus - .lock() - .unwrap() - .reset() - .with_context(|| "Fail to reset sec_bus in root port")?; - } else { - let cap_offset = self.base.config.pci_express_cap_offset; - le_write_u16( - &mut self.base.config.config, - (cap_offset + PCI_EXP_SLTSTA) as usize, - PCI_EXP_SLTSTA_PDS, - )?; - le_write_u16( - &mut self.base.config.config, - (cap_offset + PCI_EXP_SLTCTL) as usize, - !PCI_EXP_SLTCTL_PCC | PCI_EXP_SLTCTL_PWR_IND_ON, - )?; - le_write_u16( - &mut self.base.config.config, - (cap_offset + PCI_EXP_LNKSTA) as usize, - PCI_EXP_LNKSTA_DLLLA, - )?; - } - - self.base.config.reset_bridge_regs()?; - self.base.config.reset() - } - fn get_dev_path(&self) -> Option { - let parent_bus = self.base.parent_bus.upgrade().unwrap(); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); let parent_dev_path = self.get_parent_dev_path(parent_bus); let dev_path = self.populate_dev_path(parent_dev_path, self.base.devfn, "/pci-bridge@"); Some(dev_path) @@ -556,11 +567,13 @@ impl PciDevOps for RootPort { } impl HotplugOps for RootPort { - fn plug(&mut self, dev: &Arc>) -> Result<()> { + fn plug(&mut self, dev: &Arc>) -> Result<()> { if !dev.lock().unwrap().hotpluggable() { bail!("Don't support hot-plug!"); } - let devfn = dev.lock().unwrap().pci_base().devfn; + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + let devfn = pci_dev.pci_base().devfn; + drop(locked_dev); // Only if devfn is equal to 0, hot plugging is supported. if devfn != 0 { return Err(anyhow!(PciError::HotplugUnsupported(devfn))); @@ -582,7 +595,7 @@ impl HotplugOps for RootPort { Ok(()) } - fn unplug_request(&mut self, dev: &Arc>) -> Result<()> { + fn unplug_request(&mut self, dev: &Arc>) -> Result<()> { let pcie_cap_offset = self.base.config.pci_express_cap_offset; let sltctl = le_read_u16( &self.base.config.config, @@ -597,7 +610,9 @@ impl HotplugOps for RootPort { if !dev.lock().unwrap().hotpluggable() { bail!("Don't support hot-unplug request!"); } - let devfn = dev.lock().unwrap().pci_base().devfn; + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + let devfn = pci_dev.pci_base().devfn; + drop(locked_dev); if devfn != 0 { return self.unplug(dev); } @@ -635,14 +650,15 @@ impl HotplugOps for RootPort { Ok(()) } - fn unplug(&mut self, dev: &Arc>) -> Result<()> { + fn unplug(&mut self, dev: &Arc>) -> Result<()> { if !dev.lock().unwrap().hotpluggable() { bail!("Don't support hot-unplug!"); } - let devfn = dev.lock().unwrap().pci_base().devfn; - let mut locked_dev = dev.lock().unwrap(); - locked_dev.unrealize()?; - self.sec_bus.lock().unwrap().devices.remove(&devfn); + PCI_BUS_DEVICE!(dev, locked_dev, pci_dev); + let devfn = u64::from(pci_dev.pci_base().devfn); + pci_dev.unrealize()?; + let child_bus = self.child_bus().unwrap(); + child_bus.lock().unwrap().detach_child(devfn)?; Ok(()) } } @@ -688,43 +704,46 @@ impl MigrationHook for RootPort {} #[cfg(test)] mod tests { use super::*; - use crate::pci::host::tests::create_pci_host; + use crate::{convert_device_mut, pci::host::tests::create_pci_host, MUT_ROOT_PORT}; #[test] fn test_read_config() { let pci_host = create_pci_host(); - let root_bus = Arc::downgrade(&pci_host.lock().unwrap().root_bus); - let root_port = RootPort::new("pcie.1".to_string(), 8, 0, root_bus, false); + let root_bus = Arc::downgrade(&pci_host.lock().unwrap().child_bus().unwrap()); + let root_port_config = RootPortConfig { + addr: (1, 0), + id: "pcie.1".to_string(), + ..Default::default() + }; + let root_port = RootPort::new(root_port_config, root_bus.clone()); root_port.realize().unwrap(); - let root_port = pci_host.lock().unwrap().find_device(0, 8).unwrap(); + let dev = pci_host.lock().unwrap().find_device(0, 8).unwrap(); let mut buf = [1_u8; 4]; - root_port - .lock() - .unwrap() - .read_config(PCIE_CONFIG_SPACE_SIZE - 1, &mut buf); + MUT_ROOT_PORT!(dev, locked_dev, root_port); + root_port.read_config(PCIE_CONFIG_SPACE_SIZE - 1, &mut buf); assert_eq!(buf, [1_u8; 4]); } #[test] fn test_write_config() { let pci_host = create_pci_host(); - let root_bus = Arc::downgrade(&pci_host.lock().unwrap().root_bus); - let root_port = RootPort::new("pcie.1".to_string(), 8, 0, root_bus, false); + let root_bus = Arc::downgrade(&pci_host.lock().unwrap().child_bus().unwrap()); + let root_port_config = RootPortConfig { + addr: (1, 0), + id: "pcie.1".to_string(), + ..Default::default() + }; + let root_port = RootPort::new(root_port_config, root_bus.clone()); root_port.realize().unwrap(); - let root_port = pci_host.lock().unwrap().find_device(0, 8).unwrap(); + let dev = pci_host.lock().unwrap().find_device(0, 8).unwrap(); + MUT_ROOT_PORT!(dev, locked_dev, root_port); // Invalid write. let data = [1_u8; 4]; - root_port - .lock() - .unwrap() - .write_config(PCIE_CONFIG_SPACE_SIZE - 1, &data); + root_port.write_config(PCIE_CONFIG_SPACE_SIZE - 1, &data); let mut buf = [0_u8]; - root_port - .lock() - .unwrap() - .read_config(PCIE_CONFIG_SPACE_SIZE - 1, &mut buf); + root_port.read_config(PCIE_CONFIG_SPACE_SIZE - 1, &mut buf); assert_eq!(buf, [0_u8]); } } diff --git a/devices/src/scsi/bus.rs b/devices/src/scsi/bus.rs index 9546743508485ef6cbe56c1cadbdc79a7c2c4f9d..ceaa1e5bb50bbba608a82065c8ad2bb7f03fe088 100644 --- a/devices/src/scsi/bus.rs +++ b/devices/src/scsi/bus.rs @@ -11,7 +11,6 @@ // See the Mulan PSL v2 for more details. use std::cmp; -use std::collections::HashMap; use std::io::Write; use std::sync::{Arc, Mutex}; @@ -24,8 +23,9 @@ use crate::ScsiDisk::{ SCSI_DISK_DEFAULT_BLOCK_SIZE_SHIFT, SCSI_DISK_F_DPOFUA, SCSI_DISK_F_REMOVABLE, SCSI_TYPE_DISK, SCSI_TYPE_ROM, SECTOR_SHIFT, }; +use crate::{convert_bus_ref, convert_device_ref, Bus, BusBase, Device, SCSI_DEVICE}; use util::aio::{AioCb, AioReqResult, Iovec}; -use util::AsAny; +use util::{gen_base_func, AsAny}; /// Scsi Operation code. pub const TEST_UNIT_READY: u8 = 0x00; @@ -260,13 +260,13 @@ pub const MODE_PAGE_TO_PROTECT: u8 = 0x1d; pub const MODE_PAGE_CAPABILITIES: u8 = 0x2a; pub const MODE_PAGE_ALLS: u8 = 0x3f; -pub const SCSI_MAX_INQUIRY_LEN: u32 = 256; +pub const SCSI_MAX_INQUIRY_LEN: u64 = 256; pub const SCSI_INQUIRY_PRODUCT_MAX_LEN: usize = 16; pub const SCSI_INQUIRY_VENDOR_MAX_LEN: usize = 8; pub const SCSI_INQUIRY_VERSION_MAX_LEN: usize = 4; pub const SCSI_INQUIRY_VPD_SERIAL_NUMBER_MAX_LEN: usize = 32; -const SCSI_TARGET_INQUIRY_LEN: u32 = 36; +const SCSI_TARGET_INQUIRY_LEN: u64 = 36; /// | bit7 - bit 5 | bit 4 - bit 0 | /// | Peripheral Qualifier | Peripheral Device Type | @@ -356,6 +356,11 @@ const GC_FC_CORE: u16 = 0x0001; /// The medium may be removed from the device. const GC_FC_REMOVABLE_MEDIUM: u16 = 0x0003; +// BusBase.Children uses `u64` for device's unique address. We use bits [32-39] in `u64` +// to represent the target number and bits[0-15] in `u64` to the lun number. +const TARGET_ID_SHIFT: u64 = 32; +const LUN_ID_MASK: u64 = 0xFFFF; + #[derive(Clone, PartialEq, Eq)] pub enum ScsiXferMode { /// TEST_UNIT_READY, ... @@ -366,18 +371,36 @@ pub enum ScsiXferMode { ScsiXferToDev, } +// Convert from (target, lun) to unique address in BusBase. +pub fn get_scsi_key(target: u8, lun: u16) -> u64 { + u64::from(target) << TARGET_ID_SHIFT | u64::from(lun) +} + +// Convert from unique address in BusBase to (target, lun). +fn parse_scsi_key(key: u64) -> (u8, u16) { + ((key >> TARGET_ID_SHIFT) as u8, (key & LUN_ID_MASK) as u16) +} + pub struct ScsiBus { - /// Bus name. - pub name: String, - /// Scsi Devices attached to the bus. - pub devices: HashMap<(u8, u16), Arc>>, + pub base: BusBase, +} + +impl Bus for ScsiBus { + gen_base_func!(bus_base, bus_base_mut, BusBase, base); +} + +/// Convert from Arc> to &ScsiBus. +#[macro_export] +macro_rules! SCSI_BUS { + ($trait_bus:expr, $lock_bus: ident, $struct_bus: ident) => { + convert_bus_ref!($trait_bus, $lock_bus, $struct_bus, ScsiBus); + }; } impl ScsiBus { pub fn new(bus_name: String) -> ScsiBus { ScsiBus { - name: bus_name, - devices: HashMap::new(), + base: BusBase::new(bus_name), } } @@ -385,9 +408,9 @@ impl ScsiBus { /// If the device requested by the target number and the lun number is non-existen, /// return the first device in ScsiBus's devices list. It's OK because we will not /// use this "random" device, we will just use it to prove that the target is existen. - pub fn get_device(&self, target: u8, lun: u16) -> Option>> { - if let Some(dev) = self.devices.get(&(target, lun)) { - return Some((*dev).clone()); + pub fn get_device(&self, target: u8, lun: u16) -> Option>> { + if let Some(device) = self.child_dev(get_scsi_key(target, lun)) { + return Some(device.clone()); } // If lun device requested in CDB's LUNS bytes is not found, it may be a target request. @@ -396,11 +419,11 @@ impl ScsiBus { // is non-existent. So, we should find if there exists a lun which has the same id with // target id in CBD's LUNS bytes. And, if there exist two or more luns which have the same // target id, just return the first one is OK enough. - for (id, device) in self.devices.iter() { - let (target_id, lun_id) = id; - if *target_id == target { - trace::scsi_bus_get_device(*target_id, lun, *lun_id); - return Some((*device).clone()); + for (key, device) in self.child_devices() { + let (target_id, lun_id) = parse_scsi_key(key); + if target_id == target { + trace::scsi_bus_get_device(target_id, lun, lun_id); + return Some(device.clone()); } } @@ -413,7 +436,7 @@ impl ScsiBus { fn scsi_bus_parse_req_cdb( cdb: [u8; SCSI_CMD_BUF_SIZE], - dev: Arc>, + dev: Arc>, ) -> Option { let op = cdb[0]; let len = scsi_cdb_length(&cdb); @@ -423,15 +446,15 @@ fn scsi_bus_parse_req_cdb( // When CDB's Group Code is vendor specific or reserved, len/xfer/lba will be negative. // So, don't need to check again after checking in cdb length. - let xfer = scsi_cdb_xfer(&cdb, dev); - let lba = scsi_cdb_lba(&cdb); + let xfer = scsi_cdb_xfer(&cdb, dev) as u64; + let lba = scsi_cdb_lba(&cdb) as u64; Some(ScsiCommand { buf: cdb, op, len: len as u32, - xfer: xfer as u32, - lba: lba as u64, + xfer, + lba, mode: scsi_cdb_xfer_mode(&cdb), }) } @@ -445,7 +468,7 @@ pub struct ScsiCommand { /// Length of CDB. pub len: u32, /// Transfer length. - pub xfer: u32, + pub xfer: u64, /// Logical Block Address. pub lba: u64, /// Transfer direction. @@ -492,7 +515,7 @@ pub struct ScsiRequest { pub iovec: Vec, // Provided buffer's length. pub datalen: u32, - pub dev: Arc>, + pub dev: Arc>, // Upper level request which contains this ScsiRequest. pub upper_req: Box, } @@ -503,18 +526,17 @@ impl ScsiRequest { req_lun: u16, iovec: Vec, datalen: u32, - scsidevice: Arc>, + device: Arc>, upper_req: Box, ) -> Result { - let cmd = scsi_bus_parse_req_cdb(cdb, scsidevice.clone()).with_context(|| "Error cdb!")?; + let cmd = scsi_bus_parse_req_cdb(cdb, device.clone()).with_context(|| "Error cdb!")?; let op = cmd.op; let opstype = scsi_operation_type(op); if op == WRITE_10 || op == READ_10 { - let dev_lock = scsidevice.lock().unwrap(); - let disk_size = dev_lock.disk_sectors << SECTOR_SHIFT; - let disk_type = dev_lock.scsi_type; - drop(dev_lock); + SCSI_DEVICE!(device, locked_dev, scsi_dev); + let disk_size = scsi_dev.disk_sectors << SECTOR_SHIFT; + let disk_type = scsi_dev.scsi_type; let offset_shift = match disk_type { SCSI_TYPE_DISK => SCSI_DISK_DEFAULT_BLOCK_SIZE_SHIFT, _ => SCSI_CDROM_DEFAULT_BLOCK_SIZE_SHIFT, @@ -525,7 +547,7 @@ impl ScsiRequest { .with_context(|| "Too large offset IO!")?; offset - .checked_add(datalen as u64) + .checked_add(u64::from(datalen)) .filter(|&off| off <= disk_size) .with_context(|| { format!( @@ -541,7 +563,7 @@ impl ScsiRequest { opstype, iovec, datalen, - dev: scsidevice, + dev: device, upper_req, }) } @@ -550,14 +572,14 @@ impl ScsiRequest { let mode = self.cmd.mode.clone(); let op = self.cmd.op; let dev = self.dev.clone(); - let locked_dev = dev.lock().unwrap(); + SCSI_DEVICE!(dev, locked_dev, scsi_dev); // SAFETY: the block_backend is assigned after device realized. - let block_backend = locked_dev.block_backend.as_ref().unwrap(); + let block_backend = scsi_dev.block_backend.as_ref().unwrap(); let mut locked_backend = block_backend.lock().unwrap(); let s_req = Arc::new(Mutex::new(self)); let scsicompletecb = ScsiCompleteCb { req: s_req.clone() }; - let offset_bits = match locked_dev.scsi_type { + let offset_bits = match scsi_dev.scsi_type { SCSI_TYPE_DISK => SCSI_DISK_DEFAULT_BLOCK_SIZE_SHIFT, _ => SCSI_CDROM_DEFAULT_BLOCK_SIZE_SHIFT, }; @@ -630,8 +652,8 @@ impl ScsiRequest { Ok(Vec::new()) } TEST_UNIT_READY => { - let dev_lock = self.dev.lock().unwrap(); - if dev_lock.block_backend.is_none() { + SCSI_DEVICE!(self.dev, locked_dev, scsi_dev); + if scsi_dev.block_backend.is_none() { Err(anyhow!("No scsi backend!")) } else { Ok(Vec::new()) @@ -666,7 +688,9 @@ impl ScsiRequest { let mut not_supported_flag = false; let mut sense = None; let mut status = GOOD; - let found_lun = self.dev.lock().unwrap().config.lun; + SCSI_DEVICE!(self.dev, locked_dev, scsi_dev); + let found_lun = scsi_dev.dev_cfg.lun; + drop(locked_dev); // Requested lun id is not equal to found device id means it may be a target request. // REPORT LUNS is also a target request command. @@ -766,22 +790,22 @@ fn scsi_cdb_length(cdb: &[u8; SCSI_CMD_BUF_SIZE]) -> i32 { } } -fn scsi_cdb_xfer(cdb: &[u8; SCSI_CMD_BUF_SIZE], dev: Arc>) -> i32 { - let dev_lock = dev.lock().unwrap(); - let block_size = dev_lock.block_size as i32; - drop(dev_lock); +pub fn scsi_cdb_xfer(cdb: &[u8; SCSI_CMD_BUF_SIZE], dev: Arc>) -> i64 { + SCSI_DEVICE!(dev, locked_dev, scsi_dev); + let block_size = scsi_dev.block_size as i64; + drop(locked_dev); - let mut xfer = match cdb[0] >> 5 { + let mut xfer: i64 = match cdb[0] >> 5 { // Group Code | Transfer length. | // 000b | Byte[4]. | // 001b | Bytes[7-8]. | // 010b | Bytes[7-8]. | // 100b | Bytes[10-13]. | // 101b | Bytes[6-9]. | - 0 => cdb[4] as i32, - 1 | 2 => BigEndian::read_u16(&cdb[7..]) as i32, - 4 => BigEndian::read_u32(&cdb[10..]) as i32, - 5 => BigEndian::read_u32(&cdb[6..]) as i32, + 0 => i64::from(cdb[4]), + 1 | 2 => i64::from(BigEndian::read_u16(&cdb[7..])), + 4 => i64::from(BigEndian::read_u32(&cdb[10..])), + 5 => i64::from(BigEndian::read_u32(&cdb[6..])), _ => -1, }; @@ -795,14 +819,16 @@ fn scsi_cdb_xfer(cdb: &[u8; SCSI_CMD_BUF_SIZE], dev: Arc>) -> WRITE_6 | READ_6 => { // length 0 means 256 blocks. if xfer == 0 { + // Safety: block_size is 2048 or 512. xfer = 256 * block_size; } } WRITE_10 | WRITE_12 | WRITE_16 | READ_10 | READ_12 | READ_16 => { + // Safety: xfer is less than u32::max now. xfer *= block_size; } INQUIRY => { - xfer = i32::from(cdb[4]) | i32::from(cdb[3]) << 8; + xfer = i64::from(cdb[4]) | i64::from(cdb[3]) << 8; } _ => {} } @@ -817,8 +843,8 @@ fn scsi_cdb_lba(cdb: &[u8; SCSI_CMD_BUF_SIZE]) -> i64 { // 010b | Bytes[2-5]. | // 100b | Bytes[2-9]. | // 101b | Bytes[2-5]. | - 0 => (BigEndian::read_u32(&cdb[0..]) & 0x1fffff) as i64, - 1 | 2 | 5 => BigEndian::read_u32(&cdb[2..]) as i64, + 0 => i64::from(BigEndian::read_u32(&cdb[0..]) & 0x1fffff), + 1 | 2 | 5 => i64::from(BigEndian::read_u32(&cdb[2..])), 4 => BigEndian::read_u64(&cdb[2..]) as i64, _ => -1, } @@ -879,28 +905,27 @@ fn scsi_cdb_xfer_mode(cdb: &[u8; SCSI_CMD_BUF_SIZE]) -> ScsiXferMode { /// VPD: Vital Product Data. fn scsi_command_emulate_vpd_page( cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { let buflen: usize; let mut outbuf: Vec = vec![0; 4]; - - let dev_lock = dev.lock().unwrap(); + SCSI_DEVICE!(dev, locked_dev, scsi_dev); let page_code = cmd.buf[2]; - outbuf[0] = dev_lock.scsi_type as u8 & 0x1f; + outbuf[0] = scsi_dev.scsi_type as u8 & 0x1f; outbuf[1] = page_code; match page_code { 0x00 => { // Supported VPD Pages. outbuf.push(0_u8); - if !dev_lock.state.serial.is_empty() { + if !scsi_dev.state.serial.is_empty() { // 0x80: Unit Serial Number. outbuf.push(0x80); } // 0x83: Device Identification. outbuf.push(0x83); - if dev_lock.scsi_type == SCSI_TYPE_DISK { + if scsi_dev.scsi_type == SCSI_TYPE_DISK { // 0xb0: Block Limits. outbuf.push(0xb0); // 0xb1: Block Device Characteristics. @@ -912,20 +937,20 @@ fn scsi_command_emulate_vpd_page( } 0x80 => { // Unit Serial Number. - let len = dev_lock.state.serial.len(); + let len = scsi_dev.state.serial.len(); if len == 0 { bail!("Missed serial number!"); } let l = cmp::min(SCSI_INQUIRY_VPD_SERIAL_NUMBER_MAX_LEN, len); - let mut serial_vec = dev_lock.state.serial.as_bytes().to_vec(); + let mut serial_vec = scsi_dev.state.serial.as_bytes().to_vec(); serial_vec.truncate(l); outbuf.append(&mut serial_vec); buflen = outbuf.len(); } 0x83 => { // Device Identification. - let mut len: u8 = dev_lock.state.device_id.len() as u8; + let mut len: u8 = scsi_dev.state.device_id.len() as u8; if len > (255 - 8) { len = 255 - 8; } @@ -937,7 +962,7 @@ fn scsi_command_emulate_vpd_page( // len: identifier length. outbuf.append(&mut [0x2_u8, 0_u8, 0_u8, len].to_vec()); - let mut device_id_vec = dev_lock.state.device_id.as_bytes().to_vec(); + let mut device_id_vec = scsi_dev.state.device_id.as_bytes().to_vec(); device_id_vec.truncate(len as usize); outbuf.append(&mut device_id_vec); } @@ -945,7 +970,7 @@ fn scsi_command_emulate_vpd_page( } 0xb0 => { // Block Limits. - if dev_lock.scsi_type == SCSI_TYPE_ROM { + if scsi_dev.scsi_type == SCSI_TYPE_ROM { bail!("Invalid scsi type: SCSI_TYPE_ROM !"); } outbuf.resize(64, 0); @@ -969,7 +994,7 @@ fn scsi_command_emulate_vpd_page( outbuf[4] = 1; let max_xfer_length: u32 = u32::MAX / 512; BigEndian::write_u32(&mut outbuf[8..12], max_xfer_length); - BigEndian::write_u64(&mut outbuf[36..44], max_xfer_length as u64); + BigEndian::write_u64(&mut outbuf[36..44], u64::from(max_xfer_length)); buflen = outbuf.len(); } 0xb1 => { @@ -1065,7 +1090,7 @@ fn scsi_command_emulate_target_inquiry(lun: u16, cmd: &ScsiCommand) -> Result>, + dev: &Arc>, ) -> Result> { // Byte1 bit0: EVPD(enable vital product data). if cmd.buf[1] == 0x1 { @@ -1079,26 +1104,26 @@ fn scsi_command_emulate_inquiry( let buflen = cmp::min(cmd.xfer, SCSI_MAX_INQUIRY_LEN); let mut outbuf: Vec = vec![0; SCSI_MAX_INQUIRY_LEN as usize]; - let dev_lock = dev.lock().unwrap(); + SCSI_DEVICE!(dev, locked_dev, scsi_dev); - outbuf[0] = (dev_lock.scsi_type & 0x1f) as u8; - outbuf[1] = match dev_lock.state.features & SCSI_DISK_F_REMOVABLE { + outbuf[0] = (scsi_dev.scsi_type & 0x1f) as u8; + outbuf[1] = match scsi_dev.state.features & SCSI_DISK_F_REMOVABLE { 1 => 0x80, _ => 0, }; - let product_bytes = dev_lock.state.product.as_bytes(); + let product_bytes = scsi_dev.state.product.as_bytes(); let product_len = cmp::min(product_bytes.len(), SCSI_INQUIRY_PRODUCT_MAX_LEN); - let vendor_bytes = dev_lock.state.vendor.as_bytes(); + let vendor_bytes = scsi_dev.state.vendor.as_bytes(); let vendor_len = cmp::min(vendor_bytes.len(), SCSI_INQUIRY_VENDOR_MAX_LEN); - let version_bytes = dev_lock.state.version.as_bytes(); + let version_bytes = scsi_dev.state.version.as_bytes(); let vension_len = cmp::min(version_bytes.len(), SCSI_INQUIRY_VERSION_MAX_LEN); outbuf[16..16 + product_len].copy_from_slice(product_bytes); outbuf[8..8 + vendor_len].copy_from_slice(vendor_bytes); outbuf[32..32 + vension_len].copy_from_slice(version_bytes); - drop(dev_lock); + drop(locked_dev); // outbuf: // Byte2: Version. @@ -1121,17 +1146,17 @@ fn scsi_command_emulate_inquiry( fn scsi_command_emulate_read_capacity_10( cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { if cmd.buf[8] & 1 == 0 && cmd.lba != 0 { // PMI(Partial Medium Indicator) bail!("Invalid scsi cmd READ_CAPACITY_10!"); } - let dev_lock = dev.lock().unwrap(); - let block_size = dev_lock.block_size; + SCSI_DEVICE!(dev, locked_dev, scsi_dev); + let block_size = scsi_dev.block_size; let mut outbuf: Vec = vec![0; 8]; - let mut nb_sectors = cmp::min(dev_lock.disk_sectors as u32, u32::MAX); + let mut nb_sectors = cmp::min(scsi_dev.disk_sectors as u32, u32::MAX); nb_sectors /= block_size / DEFAULT_SECTOR_SIZE; nb_sectors -= 1; @@ -1146,18 +1171,18 @@ fn scsi_command_emulate_read_capacity_10( fn scsi_command_emulate_mode_sense( cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { // disable block descriptors(DBD) bit. let mut dbd: bool = cmd.buf[1] & 0x8 != 0; let page_code = cmd.buf[2] & 0x3f; let page_control = (cmd.buf[2] & 0xc0) >> 6; let mut outbuf: Vec = vec![0]; - let dev_lock = dev.lock().unwrap(); + SCSI_DEVICE!(dev, locked_dev, scsi_dev); let mut dev_specific_parameter: u8 = 0; - let mut nb_sectors = dev_lock.disk_sectors as u32; - let scsi_type = dev_lock.scsi_type; - let block_size = dev_lock.block_size; + let mut nb_sectors = scsi_dev.disk_sectors as u32; + let scsi_type = scsi_dev.scsi_type; + let block_size = scsi_dev.block_size; nb_sectors /= block_size / DEFAULT_SECTOR_SIZE; trace::scsi_emulate_mode_sense( @@ -1171,17 +1196,17 @@ fn scsi_command_emulate_mode_sense( // Device specific paramteter field for direct access block devices: // Bit 7: WP(Write Protect); bit 4: DPOFUA; if scsi_type == SCSI_TYPE_DISK { - if dev_lock.state.features & (1 << SCSI_DISK_F_DPOFUA) != 0 { + if scsi_dev.state.features & (1 << SCSI_DISK_F_DPOFUA) != 0 { dev_specific_parameter = 0x10; } - if dev_lock.config.read_only { + if scsi_dev.drive_cfg.readonly { // Readonly. dev_specific_parameter |= 0x80; } } else { dbd = true; } - drop(dev_lock); + drop(locked_dev); if cmd.op == MODE_SENSE { outbuf.resize(4, 0); @@ -1357,12 +1382,12 @@ fn scsi_command_emulate_mode_sense_page( fn scsi_command_emulate_report_luns( cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { - let dev_lock = dev.lock().unwrap(); + SCSI_DEVICE!(dev, locked_dev, scsi_dev); // Byte 0-3: Lun List Length. Byte 4-7: Reserved. let mut outbuf: Vec = vec![0; 8]; - let target = dev_lock.config.target; + let target = scsi_dev.dev_cfg.target; if cmd.xfer < 16 { bail!("scsi REPORT LUNS xfer {} too short!", cmd.xfer); @@ -1376,27 +1401,24 @@ fn scsi_command_emulate_report_luns( ); } - let scsi_bus = dev_lock.parent_bus.upgrade().unwrap(); - let scsi_bus_clone = scsi_bus.lock().unwrap(); - - drop(dev_lock); + let bus = scsi_dev.parent_bus().unwrap().upgrade().unwrap(); + SCSI_BUS!(bus, locked_bus, scsi_bus); + drop(locked_dev); - for (_pos, device) in scsi_bus_clone.devices.iter() { - let device_lock = device.lock().unwrap(); - if device_lock.config.target != target { - drop(device_lock); + for device in scsi_bus.child_devices().values() { + SCSI_DEVICE!(device, locked_dev, scsi_dev); + if scsi_dev.dev_cfg.target != target { continue; } let len = outbuf.len(); - if device_lock.config.lun < 256 { + if scsi_dev.dev_cfg.lun < 256 { outbuf.push(0); - outbuf.push(device_lock.config.lun as u8); + outbuf.push(scsi_dev.dev_cfg.lun as u8); } else { - outbuf.push(0x40 | ((device_lock.config.lun >> 8) & 0xff) as u8); - outbuf.push((device_lock.config.lun & 0xff) as u8); + outbuf.push(0x40 | ((scsi_dev.dev_cfg.lun >> 8) & 0xff) as u8); + outbuf.push((scsi_dev.dev_cfg.lun & 0xff) as u8); } outbuf.resize(len + 8, 0); - drop(device_lock); } let len: u32 = outbuf.len() as u32 - 8; @@ -1406,20 +1428,19 @@ fn scsi_command_emulate_report_luns( fn scsi_command_emulate_service_action_in_16( cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { // Read Capacity(16) Command. // Byte 0: Operation Code(0x9e) // Byte 1: bit0 - bit4: Service Action(0x10), bit 5 - bit 7: Reserved. if cmd.buf[1] & 0x1f == SUBCODE_READ_CAPACITY_16 { - let dev_lock = dev.lock().unwrap(); - let block_size = dev_lock.block_size; + SCSI_DEVICE!(dev, locked_dev, scsi_dev); + let block_size = scsi_dev.block_size; let mut outbuf: Vec = vec![0; 32]; - let mut nb_sectors = dev_lock.disk_sectors; - nb_sectors /= (block_size / DEFAULT_SECTOR_SIZE) as u64; + let mut nb_sectors = scsi_dev.disk_sectors; + nb_sectors /= u64::from(block_size / DEFAULT_SECTOR_SIZE); nb_sectors -= 1; - - drop(dev_lock); + drop(locked_dev); // Byte[0-7]: Returned Logical BLock Address(the logical block address of the last logical // block). @@ -1439,7 +1460,7 @@ fn scsi_command_emulate_service_action_in_16( fn scsi_command_emulate_read_disc_information( cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { // Byte1: Bits[0-2]: Data type. // Data Type | Returned Data. | @@ -1453,9 +1474,11 @@ fn scsi_command_emulate_read_disc_information( if data_type != 0 { bail!("Unsupported read disc information data type {}!", data_type); } - if dev.lock().unwrap().scsi_type != SCSI_TYPE_ROM { + SCSI_DEVICE!(dev, locked_dev, scsi_dev); + if scsi_dev.scsi_type != SCSI_TYPE_ROM { bail!("Read disc information command is only for scsi multi-media device!"); } + drop(locked_dev); // Outbuf: // Bytes[0-1]: Disc Information Length(32). @@ -1503,7 +1526,7 @@ const RT_RAW_TOC: u8 = 0x0010; fn scsi_command_emulate_read_toc( cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { // Byte1: Bit1: MSF.(MSF: Minute, Second, Frame) // MSF = 1: the address fields in some returned data formats shall be in MSF form. @@ -1517,7 +1540,8 @@ fn scsi_command_emulate_read_toc( match format { RT_FORMATTED_TOC => { - let nb_sectors = dev.lock().unwrap().disk_sectors as u32; + SCSI_DEVICE!(dev, locked_dev, scsi_dev); + let nb_sectors = scsi_dev.disk_sectors as u32; let mut buf = cdrom_read_formatted_toc(nb_sectors, msf, track_number)?; outbuf.append(&mut buf); } @@ -1538,11 +1562,11 @@ fn scsi_command_emulate_read_toc( fn scsi_command_emulate_get_configuration( _cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { - let dev_lock = dev.lock().unwrap(); - if dev_lock.scsi_type != SCSI_TYPE_ROM { - bail!("Invalid scsi type {}", dev_lock.scsi_type); + SCSI_DEVICE!(dev, locked_dev, scsi_dev); + if scsi_dev.scsi_type != SCSI_TYPE_ROM { + bail!("Invalid scsi type {}", scsi_dev.scsi_type); } // 8 bytes(Feature Header) + 12 bytes(Profile List Feature) + @@ -1555,7 +1579,7 @@ fn scsi_command_emulate_get_configuration( // Bytes[4-5]: Reserved. // Bytes[6-7]: Current Profile. BigEndian::write_u32(&mut outbuf[0..4], 36); - let current = if dev_lock.disk_sectors > CD_MAX_SECTORS as u64 { + let current = if scsi_dev.disk_sectors > u64::from(CD_MAX_SECTORS) { GC_PROFILE_DVD_ROM } else { GC_PROFILE_CD_ROM @@ -1578,9 +1602,9 @@ fn scsi_command_emulate_get_configuration( outbuf[10] = 0x03; outbuf[11] = 8; BigEndian::write_u16(&mut outbuf[12..14], GC_PROFILE_CD_ROM); - outbuf[14] |= (current == GC_PROFILE_CD_ROM) as u8; + outbuf[14] |= u8::from(current == GC_PROFILE_CD_ROM); BigEndian::write_u16(&mut outbuf[16..18], GC_PROFILE_DVD_ROM); - outbuf[18] |= (current == GC_PROFILE_DVD_ROM) as u8; + outbuf[18] |= u8::from(current == GC_PROFILE_DVD_ROM); // Bytes[8-n]: Feature Descriptor(s): // Bytes[20-31]: Feature 1: Core Feature: @@ -1616,14 +1640,14 @@ fn scsi_command_emulate_get_configuration( fn scsi_command_emulate_get_event_status_notification( cmd: &ScsiCommand, - dev: &Arc>, + dev: &Arc>, ) -> Result> { // Byte4: Notification Class Request. let notification_class_request = cmd.buf[4]; - let dev_lock = dev.lock().unwrap(); + SCSI_DEVICE!(dev, locked_dev, scsi_dev); - if dev_lock.scsi_type != SCSI_TYPE_ROM { - bail!("Invalid scsi type {}", dev_lock.scsi_type); + if scsi_dev.scsi_type != SCSI_TYPE_ROM { + bail!("Invalid scsi type {}", scsi_dev.scsi_type); } // Byte1: Bit0: Polled. diff --git a/devices/src/scsi/disk.rs b/devices/src/scsi/disk.rs index 6c4652ed8527a841d9261435975b3b311e878f4c..a3106890b56ebad25593fc82a4dba5dcd31e9a95 100644 --- a/devices/src/scsi/disk.rs +++ b/devices/src/scsi/disk.rs @@ -11,16 +11,20 @@ // See the Mulan PSL v2 for more details. use std::collections::HashMap; -use std::sync::{Arc, Mutex, Weak}; +use std::sync::{Arc, Mutex}; use anyhow::{bail, Result}; +use clap::Parser; -use crate::ScsiBus::{aio_complete_cb, ScsiBus, ScsiCompleteCb}; +use crate::ScsiBus::{aio_complete_cb, ScsiCompleteCb}; use crate::{Device, DeviceBase}; use block_backend::{create_block_backend, BlockDriverOps, BlockProperty}; -use machine_manager::config::{DriveFile, ScsiDevConfig, VmConfig}; +use machine_manager::config::{valid_id, DriveConfig, DriveFile, VmConfig}; use machine_manager::event_loop::EventLoop; use util::aio::{Aio, AioEngine, WriteZeroesState}; +use util::gen_base_func; + +use super::bus::ScsiBus; /// SCSI DEVICE TYPES. pub const SCSI_TYPE_DISK: u32 = 0x00; @@ -57,6 +61,49 @@ pub const SCSI_DISK_DEFAULT_BLOCK_SIZE: u32 = 1 << SCSI_DISK_DEFAULT_BLOCK_SIZE_ pub const SCSI_CDROM_DEFAULT_BLOCK_SIZE_SHIFT: u32 = 11; pub const SCSI_CDROM_DEFAULT_BLOCK_SIZE: u32 = 1 << SCSI_CDROM_DEFAULT_BLOCK_SIZE_SHIFT; +// Stratovirt uses scsi mod in only virtio-scsi and usb-storage. Scsi's channel/target/lun +// of usb-storage are both 0. Scsi's channel/target/lun of virtio-scsi is no more than 0/255/16383. +// Set valid range of channel/target according to the range of virtio-scsi as 0/255. +// +// For stratovirt doesn't support `Flat space addressing format`(14 bits for lun) and only supports +// `peripheral device addressing format`(8 bits for lun) now, lun should be less than 255(2^8 - 1) temporarily. +const SCSI_MAX_CHANNEL: i64 = 0; +const SCSI_MAX_TARGET: i64 = 255; +const SUPPORT_SCSI_MAX_LUN: i64 = 255; + +#[derive(Parser, Clone, Debug, Default)] +#[command(no_binary_name(true))] +pub struct ScsiDevConfig { + #[arg(long, value_parser = ["scsi-cd", "scsi-hd"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long, value_parser = valid_scsi_bus)] + pub bus: String, + /// Scsi four level hierarchical address(host, channel, target, lun). + #[arg(long, default_value = "0", value_parser = clap::value_parser!(u8).range(..=SCSI_MAX_CHANNEL))] + pub channel: u8, + #[arg(long, alias = "scsi-id", value_parser = clap::value_parser!(u8).range(..=SCSI_MAX_TARGET))] + pub target: u8, + #[arg(long, value_parser = clap::value_parser!(u16).range(..=SUPPORT_SCSI_MAX_LUN))] + pub lun: u16, + #[arg(long)] + pub drive: String, + #[arg(long)] + pub serial: Option, + #[arg(long)] + pub bootindex: Option, +} + +// Scsi device should has bus named as "$parent_cntlr_name.0". +fn valid_scsi_bus(bus: &str) -> Result { + let strs = bus.split('.').collect::>(); + if strs.len() != 2 || strs[1] != "0" { + bail!("Invalid scsi bus {}", bus); + } + Ok(bus.to_string()) +} + #[derive(Clone, Default)] pub struct ScsiDevState { /// Features which the scsi device supports. @@ -87,19 +134,79 @@ impl ScsiDevState { } impl Device for ScsiDevice { - fn device_base(&self) -> &DeviceBase { - &self.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base); + + fn realize(mut self) -> Result>> { + match self.scsi_type { + SCSI_TYPE_DISK => { + self.block_size = SCSI_DISK_DEFAULT_BLOCK_SIZE; + self.state.product = "STRA HARDDISK".to_string(); + } + SCSI_TYPE_ROM => { + self.block_size = SCSI_CDROM_DEFAULT_BLOCK_SIZE; + self.state.product = "STRA CDROM".to_string(); + } + _ => { + bail!("Scsi type {} does not support now", self.scsi_type); + } + } + + if let Some(serial) = &self.dev_cfg.serial { + self.state.serial = serial.clone(); + } + + let drive_files = self.drive_files.lock().unwrap(); + // File path can not be empty string. And it has also been checked in command parsing by using `Clap`. + let file = VmConfig::fetch_drive_file(&drive_files, &self.drive_cfg.path_on_host)?; + + let alignments = VmConfig::fetch_drive_align(&drive_files, &self.drive_cfg.path_on_host)?; + self.req_align = alignments.0; + self.buf_align = alignments.1; + let drive_id = VmConfig::get_drive_id(&drive_files, &self.drive_cfg.path_on_host)?; + drop(drive_files); + + let mut thread_pool = None; + if self.drive_cfg.aio != AioEngine::Off { + thread_pool = Some(EventLoop::get_ctx(None).unwrap().thread_pool.clone()); + } + let aio = Aio::new(Arc::new(aio_complete_cb), self.drive_cfg.aio, thread_pool)?; + let conf = BlockProperty { + id: drive_id, + format: self.drive_cfg.format, + iothread: self.iothread.clone(), + direct: self.drive_cfg.direct, + req_align: self.req_align, + buf_align: self.buf_align, + discard: false, + write_zeroes: WriteZeroesState::Off, + l2_cache_size: self.drive_cfg.l2_cache_size, + refcount_cache_size: self.drive_cfg.refcount_cache_size, + }; + let backend = create_block_backend(file, aio, conf)?; + let disk_size = backend.lock().unwrap().disk_size()?; + self.block_backend = Some(backend); + self.disk_sectors = disk_size >> SECTOR_SHIFT; - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base + let dev = Arc::new(Mutex::new(self)); + Ok(dev) } } +/// Convert from Arc> to &ScsiDevice. +#[macro_export] +macro_rules! SCSI_DEVICE { + ($trait_device:expr, $lock_device: ident, $struct_device: ident) => { + convert_device_ref!($trait_device, $lock_device, $struct_device, ScsiDevice); + }; +} + +#[derive(Default)] pub struct ScsiDevice { pub base: DeviceBase, /// Configuration of the scsi device. - pub config: ScsiDevConfig, + pub dev_cfg: ScsiDevConfig, + /// Configuration of the scsi device's drive. + pub drive_cfg: DriveConfig, /// State of the scsi device. pub state: ScsiDevState, /// Block backend opened by scsi device. @@ -114,12 +221,11 @@ pub struct ScsiDevice { pub block_size: u32, /// Scsi device type. pub scsi_type: u32, - /// Scsi Bus attached to. - pub parent_bus: Weak>, /// Drive backend files. drive_files: Arc>>, /// Aio context. pub aio: Option>>>, + pub iothread: Option, } // SAFETY: the devices attached in one scsi controller will process IO in the same thread. @@ -129,76 +235,87 @@ unsafe impl Sync for ScsiDevice {} impl ScsiDevice { pub fn new( - config: ScsiDevConfig, - scsi_type: u32, + dev_cfg: ScsiDevConfig, + drive_cfg: DriveConfig, drive_files: Arc>>, + iothread: Option, + scsi_bus: Arc>, ) -> ScsiDevice { - ScsiDevice { - base: DeviceBase::new(config.id.clone(), false), - config, + let scsi_type = match dev_cfg.classtype.as_str() { + "scsi-hd" => SCSI_TYPE_DISK, + _ => SCSI_TYPE_ROM, + }; + + let mut scsi_dev = ScsiDevice { + base: DeviceBase::new(dev_cfg.id.clone(), false, None), + dev_cfg, + drive_cfg, state: ScsiDevState::new(), - block_backend: None, req_align: 1, buf_align: 1, - disk_sectors: 0, - block_size: 0, scsi_type, - parent_bus: Weak::new(), drive_files, - aio: None, - } + iothread, + ..Default::default() + }; + scsi_dev.set_parent_bus(scsi_bus); + scsi_dev } +} - pub fn realize(&mut self, iothread: Option) -> Result<()> { - match self.scsi_type { - SCSI_TYPE_DISK => { - self.block_size = SCSI_DISK_DEFAULT_BLOCK_SIZE; - self.state.product = "STRA HARDDISK".to_string(); - } - SCSI_TYPE_ROM => { - self.block_size = SCSI_CDROM_DEFAULT_BLOCK_SIZE; - self.state.product = "STRA CDROM".to_string(); - } - _ => { - bail!("Scsi type {} does not support now", self.scsi_type); - } - } - - if let Some(serial) = &self.config.serial { - self.state.serial = serial.clone(); - } +#[cfg(test)] +mod tests { + use super::*; + use machine_manager::config::str_slip_to_clap; - let drive_files = self.drive_files.lock().unwrap(); - // File path can not be empty string. And it has also been checked in CmdParser::parse. - let file = VmConfig::fetch_drive_file(&drive_files, &self.config.path_on_host)?; + #[test] + fn test_scsi_device_cmdline_parser() { + // Test1: Right. + let cmdline1 = "scsi-hd,bus=scsi0.0,scsi-id=0,lun=0,drive=drive-0-0-0-0,id=scsi0-0-0-0,serial=123456,bootindex=1"; + let config = + ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline1, true, false)).unwrap(); + assert_eq!(config.id, "scsi0-0-0-0"); + assert_eq!(config.bus, "scsi0.0"); + assert_eq!(config.target, 0); + assert_eq!(config.lun, 0); + assert_eq!(config.drive, "drive-0-0-0-0"); + assert_eq!(config.serial.unwrap(), "123456"); + assert_eq!(config.bootindex.unwrap(), 1); - let alignments = VmConfig::fetch_drive_align(&drive_files, &self.config.path_on_host)?; - self.req_align = alignments.0; - self.buf_align = alignments.1; - let drive_id = VmConfig::get_drive_id(&drive_files, &self.config.path_on_host)?; + // Test2: Default value. + let cmdline2 = "scsi-cd,bus=scsi0.0,scsi-id=0,lun=0,drive=drive-0-0-0-0,id=scsi0-0-0-0"; + let config = + ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline2, true, false)).unwrap(); + assert_eq!(config.channel, 0); + assert_eq!(config.serial, None); + assert_eq!(config.bootindex, None); - let mut thread_pool = None; - if self.config.aio_type != AioEngine::Off { - thread_pool = Some(EventLoop::get_ctx(None).unwrap().thread_pool.clone()); - } - let aio = Aio::new(Arc::new(aio_complete_cb), self.config.aio_type, thread_pool)?; - let conf = BlockProperty { - id: drive_id, - format: self.config.format, - iothread, - direct: self.config.direct, - req_align: self.req_align, - buf_align: self.buf_align, - discard: false, - write_zeroes: WriteZeroesState::Off, - l2_cache_size: self.config.l2_cache_size, - refcount_cache_size: self.config.refcount_cache_size, - }; - let backend = create_block_backend(file, aio, conf)?; - let disk_size = backend.lock().unwrap().disk_size()?; - self.block_backend = Some(backend); - self.disk_sectors = disk_size >> SECTOR_SHIFT; + // Test3: Illegal value. + let cmdline3 = "scsi-hd,bus=scsi0.0,scsi-id=256,lun=0,drive=drive-0-0-0-0,id=scsi0-0-0-0"; + let result = ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline3, true, false)); + assert!(result.is_err()); + let cmdline3 = "scsi-hd,bus=scsi0.0,scsi-id=0,lun=256,drive=drive-0-0-0-0,id=scsi0-0-0-0"; + let result = ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline3, true, false)); + assert!(result.is_err()); + let cmdline3 = "illegal,bus=scsi0.0,scsi-id=0,lun=0,drive=drive-0-0-0-0,id=scsi0-0-0-0"; + let result = ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline3, true, false)); + assert!(result.is_err()); - Ok(()) + // Test4: Missing necessary parameters. + let cmdline4 = "scsi-hd,scsi-id=0,lun=0,drive=drive-0-0-0-0,id=scsi0-0-0-0"; + let result = ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline4, true, false)); + assert!(result.is_err()); + let cmdline4 = "scsi-hd,bus=scsi0.0,lun=0,drive=drive-0-0-0-0,id=scsi0-0-0-0"; + let result = ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline4, true, false)); + assert!(result.is_err()); + let cmdline4 = "scsi-hd,bus=scsi0.0,scsi-id=0,drive=drive-0-0-0-0,id=scsi0-0-0-0"; + let result = ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline4, true, false)); + assert!(result.is_err()); + let cmdline4 = "scsi-hd,bus=scsi0.0,scsi-id=0,lun=0,id=scsi0-0-0-0"; + let result = ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline4, true, false)); + assert!(result.is_err()); + let cmdline4 = "scsi-hd,bus=scsi0.0,scsi-id=0,lun=0,drive=drive-0-0-0-0"; + let result = ScsiDevConfig::try_parse_from(str_slip_to_clap(cmdline4, true, false)); + assert!(result.is_err()); } } diff --git a/devices/src/smbios/smbios_table.rs b/devices/src/smbios/smbios_table.rs index 9211bd8817153b7ebd17db9eb32d25549950f87d..2751e6c9312785775b411a7335a50702d542216e 100644 --- a/devices/src/smbios/smbios_table.rs +++ b/devices/src/smbios/smbios_table.rs @@ -656,9 +656,9 @@ impl SmbiosTable { fn build_type0(&mut self, type0: SmbiosType0Config) { let mut table0: SmbiosType0Table = SmbiosType0Table::new(); - if let Some(vender) = type0.vender { + if let Some(vendor) = type0.vendor { table0.header.vendor_idx = table0.str_index + 1; - table0.set_str(vender); + table0.set_str(vendor); } if let Some(version) = type0.version { @@ -866,11 +866,11 @@ impl SmbiosTable { table4.header.core_count = mach_cfg.nr_cores; table4.header.core_enabled = mach_cfg.nr_cores; - table4.header.core_count2 = (mach_cfg.nr_cores as u16).to_le_bytes(); - table4.header.core_enabled2 = (mach_cfg.nr_cores as u16).to_le_bytes(); + table4.header.core_count2 = u16::from(mach_cfg.nr_cores).to_le_bytes(); + table4.header.core_enabled2 = u16::from(mach_cfg.nr_cores).to_le_bytes(); table4.header.thread_count = mach_cfg.nr_threads; - table4.header.thread_count2 = (mach_cfg.nr_threads as u16).to_le_bytes(); + table4.header.thread_count2 = u16::from(mach_cfg.nr_threads).to_le_bytes(); table4.finish(); self.entries.append(&mut table4.header.as_bytes().to_vec()); @@ -946,7 +946,7 @@ impl SmbiosTable { let start_kb = start / 1024; let end_kb = (start + size - 1) / 1024; - if start_kb < u32::MAX as u64 && end_kb < u32::MAX as u64 { + if start_kb < u64::from(u32::MAX) && end_kb < u64::from(u32::MAX) { table19.header.starting_address = (start_kb as u32).to_le_bytes(); table19.header.ending_address = (end_kb as u32).to_le_bytes(); } else { @@ -994,7 +994,7 @@ impl SmbiosTable { let smbios_sockets = mach_cfg.nr_cpus / (mach_cfg.nr_cores * mach_cfg.nr_threads); for i in 0..smbios_sockets { - self.build_type4(smbios.type4.clone(), i as u16, mach_cfg); + self.build_type4(smbios.type4.clone(), u16::from(i), mach_cfg); } let mem_num = ((mach_cfg.mem_config.mem_size + 16 * GB_SIZE - 1) / (16 * GB_SIZE)) as u16; self.build_type16(mach_cfg.mem_config.mem_size, mem_num); diff --git a/devices/src/sysbus/error.rs b/devices/src/sysbus/error.rs index da1c890988096cc50dbaf19d22b86aa91a71d1ae..1d233b6be21582140b3bb94d5197eaa562ba185b 100644 --- a/devices/src/sysbus/error.rs +++ b/devices/src/sysbus/error.rs @@ -19,4 +19,6 @@ pub enum SysBusError { #[from] source: address_space::error::AddressSpaceError, }, + #[error("Failed to register region in {0} space: offset={1:#x},size={2:#x}")] + AddRegionErr(&'static str, u64, u64), } diff --git a/devices/src/sysbus/mod.rs b/devices/src/sysbus/mod.rs index 36003cd1bbe829ce3eb4a584151c72ccb015e3a9..01de827dd7337eefe2ebc83b8c3041c4c0f02c8f 100644 --- a/devices/src/sysbus/mod.rs +++ b/devices/src/sysbus/mod.rs @@ -14,16 +14,31 @@ pub mod error; pub use error::SysBusError; +use std::any::{Any, TypeId}; +use std::collections::HashMap; use std::fmt; use std::sync::{Arc, Mutex}; use anyhow::{bail, Context, Result}; use vmm_sys_util::eventfd::EventFd; -use crate::{Device, DeviceBase, IrqState, LineIrqManager, TriggerMode}; +#[cfg(target_arch = "x86_64")] +use crate::acpi::cpu_controller::CpuController; +use crate::acpi::ged::Ged; +#[cfg(target_arch = "aarch64")] +use crate::acpi::power::PowerDev; +#[cfg(all(feature = "ramfb", target_arch = "aarch64"))] +use crate::legacy::Ramfb; +#[cfg(target_arch = "x86_64")] +use crate::legacy::{FwCfgIO, RTC}; +#[cfg(target_arch = "aarch64")] +use crate::legacy::{FwCfgMem, PL011, PL031}; +use crate::legacy::{PFlash, Serial}; +use crate::pci::PciHost; +use crate::{Bus, BusBase, Device, DeviceBase, IrqState, LineIrqManager, TriggerMode}; use acpi::{AmlBuilder, AmlScope}; use address_space::{AddressSpace, GuestAddress, Region, RegionIoEventFd, RegionOps}; -use util::AsAny; +use util::gen_base_func; // Now that the serial device use a hardcoded IRQ number (4), and the starting // free IRQ number can be 5. @@ -39,10 +54,12 @@ pub const IRQ_BASE: i32 = 32; pub const IRQ_MAX: i32 = 191; pub struct SysBus { + pub base: BusBase, + // Record the largest key used in the BTreemap of the busbase(children field). + max_key: u64, #[cfg(target_arch = "x86_64")] pub sys_io: Arc, pub sys_mem: Arc, - pub devices: Vec>>, pub free_irqs: (i32, i32), pub min_free_irq: i32, pub mmio_region: (u64, u64), @@ -75,10 +92,11 @@ impl SysBus { mmio_region: (u64, u64), ) -> Self { Self { + base: BusBase::new("sysbus".to_string()), + max_key: 0, #[cfg(target_arch = "x86_64")] sys_io: sys_io.clone(), sys_mem: sys_mem.clone(), - devices: Vec::new(), free_irqs, min_free_irq: free_irqs.0, mmio_region, @@ -104,84 +122,74 @@ impl SysBus { } } - pub fn attach_device( - &mut self, - dev: &Arc>, - region_base: u64, - region_size: u64, - region_name: &str, - ) -> Result<()> { - let region_ops = self.build_region_ops(dev); - let region = Region::init_io_region(region_size, region_ops, region_name); - let locked_dev = dev.lock().unwrap(); - - region.set_ioeventfds(&locked_dev.ioeventfds()); - match locked_dev.sysbusdev_base().dev_type { - SysBusDevType::Serial if cfg!(target_arch = "x86_64") => { - #[cfg(target_arch = "x86_64")] - self.sys_io - .root() - .add_subregion(region, region_base) - .with_context(|| { - format!( - "Failed to register region in I/O space: offset={},size={}", - region_base, region_size - ) - })?; - } - SysBusDevType::FwCfg if cfg!(target_arch = "x86_64") => { + pub fn attach_device(&mut self, dev: &Arc>) -> Result<()> { + let res = dev.lock().unwrap().get_sys_resource().clone(); + let region_base = res.region_base; + let region_size = res.region_size; + let region_name = res.region_name; + + // region_base/region_size are both 0 means this device doesn't have its own memory layout. + // The normally allocated device region_base is above the `MEM_LAYOUT[LayoutEntryType::Mmio as usize].0`. + if region_base != 0 && region_size != 0 { + let region_ops = self.build_region_ops(dev); + let region = Region::init_io_region(region_size, region_ops, ®ion_name); + let locked_dev = dev.lock().unwrap(); + + region.set_ioeventfds(&locked_dev.ioeventfds()); + match locked_dev.sysbusdev_base().dev_type { #[cfg(target_arch = "x86_64")] - self.sys_io + SysBusDevType::Serial | SysBusDevType::FwCfg | SysBusDevType::Rtc => { + self.sys_io + .root() + .add_subregion(region, region_base) + .with_context(|| { + SysBusError::AddRegionErr("I/O", region_base, region_size) + })?; + } + _ => self + .sys_mem .root() .add_subregion(region, region_base) .with_context(|| { - format!( - "Failed to register region in I/O space: offset 0x{:x}, size {}", - region_base, region_size - ) - })?; + SysBusError::AddRegionErr("memory", region_base, region_size) + })?, } - SysBusDevType::Rtc if cfg!(target_arch = "x86_64") => { - #[cfg(target_arch = "x86_64")] - self.sys_io - .root() - .add_subregion(region, region_base) - .with_context(|| { - format!( - "Failed to register region in I/O space: offset 0x{:x}, size {}", - region_base, region_size - ) - })?; - } - _ => self - .sys_mem - .root() - .add_subregion(region, region_base) - .with_context(|| { - format!( - "Failed to register region in memory space: offset={},size={}", - region_base, region_size - ) - })?, } - self.devices.push(dev.clone()); + self.sysbus_attach_child(dev.clone())?; Ok(()) } - pub fn attach_dynamic_device( - &mut self, - dev: &Arc>, - ) -> Result<()> { - self.devices.push(dev.clone()); + pub fn sysbus_attach_child(&mut self, dev: Arc>) -> Result<()> { + self.attach_child(self.max_key, dev.clone())?; + // Note: Incrementally generate a number that has no substantive effect, and is only used for the + // key of Btreemap in the busbase(children field). + // The number of system-bus devices is limited, and it is also difficult to reach the `u64` range for + // hot-plug times. So, `u64` is currently sufficient for using and don't consider overflow issues for now. + self.max_key += 1; Ok(()) } } -#[derive(Copy, Clone)] +impl Bus for SysBus { + gen_base_func!(bus_base, bus_base_mut, BusBase, base); +} + +/// Convert from Arc> to &mut SysBus. +#[macro_export] +macro_rules! MUT_SYS_BUS { + ($trait_bus:expr, $lock_bus: ident, $struct_bus: ident) => { + convert_bus_mut!($trait_bus, $lock_bus, $struct_bus, SysBus); + }; +} + +#[derive(Clone)] pub struct SysRes { + // Note: region_base/region_size are both 0 means that this device doesn't have its own memory layout. + // The normally allocated device memory region is above the `MEM_LAYOUT[LayoutEntryType::Mmio as usize].0`. pub region_base: u64, pub region_size: u64, + pub region_name: String, pub irq: i32, } @@ -190,6 +198,7 @@ impl Default for SysRes { Self { region_base: 0, region_size: 0, + region_name: "".to_string(), irq: -1, } } @@ -243,15 +252,16 @@ impl SysBusDevBase { } } - pub fn set_sys(&mut self, irq: i32, region_base: u64, region_size: u64) { + pub fn set_sys(&mut self, irq: i32, region_base: u64, region_size: u64, region_name: &str) { self.res.irq = irq; self.res.region_base = region_base; self.res.region_size = region_size; + self.res.region_name = region_name.to_string(); } } /// Operations for sysbus devices. -pub trait SysBusDevOps: Device + Send + AmlBuilder + AsAny { +pub trait SysBusDevOps: Device + Send + AmlBuilder { fn sysbusdev_base(&self) -> &SysBusDevBase; fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase; @@ -292,19 +302,22 @@ pub trait SysBusDevOps: Device + Send + AmlBuilder + AsAny { Ok(irq) } - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - None + fn get_sys_resource(&mut self) -> &mut SysRes { + &mut self.sysbusdev_base_mut().res } fn set_sys_resource( &mut self, - sysbus: &mut SysBus, + sysbus: &Arc>, region_base: u64, region_size: u64, + region_name: &str, ) -> Result<()> { - let irq = self.get_irq(sysbus)?; + let mut locked_sysbus = sysbus.lock().unwrap(); + let irq = self.get_irq(&mut locked_sysbus)?; let interrupt_evt = self.sysbusdev_base().interrupt_evt.clone(); - let irq_manager = sysbus.irq_manager.clone(); + let irq_manager = locked_sysbus.irq_manager.clone(); + drop(locked_sysbus); self.sysbusdev_base_mut().irq_state = IrqState::new(irq as u32, interrupt_evt, irq_manager, TriggerMode::Edge); @@ -312,7 +325,7 @@ pub trait SysBusDevOps: Device + Send + AmlBuilder + AsAny { irq_state.register_irq()?; self.sysbusdev_base_mut() - .set_sys(irq, region_base, region_size); + .set_sys(irq, region_base, region_size, region_name); Ok(()) } @@ -326,19 +339,113 @@ pub trait SysBusDevOps: Device + Send + AmlBuilder + AsAny { ) }); } +} - fn reset(&mut self) -> Result<()> { - Ok(()) - } +/// Convert from Arc> to &mut dyn SysBusDevOps. +#[macro_export] +macro_rules! SYS_BUS_DEVICE { + ($trait_device:expr, $lock_device: ident, $trait_sysbusdevops: ident) => { + let mut $lock_device = $trait_device.lock().unwrap(); + let $trait_sysbusdevops = to_sysbusdevops(&mut *$lock_device).unwrap(); + }; } impl AmlBuilder for SysBus { fn aml_bytes(&self) -> Vec { let mut scope = AmlScope::new("_SB"); - self.devices.iter().for_each(|dev| { - scope.append(&dev.lock().unwrap().aml_bytes()); - }); + let child_devices = self.base.children.clone(); + for dev in child_devices.values() { + SYS_BUS_DEVICE!(dev, locked_dev, sysbusdev); + scope.append(&sysbusdev.aml_bytes()); + } scope.aml_bytes() } } + +pub type ToSysBusDevOpsFunc = fn(&mut dyn Any) -> &mut dyn SysBusDevOps; + +static mut SYSBUSDEVTYPE_HASHMAP: Option> = None; + +pub fn convert_to_sysbusdevops(item: &mut dyn Any) -> &mut dyn SysBusDevOps { + // SAFETY: The typeid of `T` is the typeid recorded in the hashmap. The target structure type of + // the conversion is its own structure type, so the conversion result will definitely not be `None`. + let t = item.downcast_mut::().unwrap(); + t as &mut dyn SysBusDevOps +} + +pub fn register_sysbusdevops_type() -> Result<()> { + let type_id = TypeId::of::(); + // SAFETY: SYSBUSDEVTYPE_HASHMAP will be built in `type_init` function sequentially in the main thread. + // And will not be changed after `type_init`. + unsafe { + if SYSBUSDEVTYPE_HASHMAP.is_none() { + SYSBUSDEVTYPE_HASHMAP = Some(HashMap::new()); + } + let types = SYSBUSDEVTYPE_HASHMAP.as_mut().unwrap(); + if types.get(&type_id).is_some() { + bail!("Type Id {:?} has been registered.", type_id); + } + types.insert(type_id, convert_to_sysbusdevops::); + } + + Ok(()) +} + +pub fn devices_register_sysbusdevops_type() -> Result<()> { + #[cfg(target_arch = "x86_64")] + { + register_sysbusdevops_type::()?; + register_sysbusdevops_type::()?; + register_sysbusdevops_type::()?; + } + #[cfg(target_arch = "aarch64")] + { + register_sysbusdevops_type::()?; + #[cfg(all(feature = "ramfb"))] + register_sysbusdevops_type::()?; + register_sysbusdevops_type::()?; + register_sysbusdevops_type::()?; + register_sysbusdevops_type::()?; + } + register_sysbusdevops_type::()?; + register_sysbusdevops_type::()?; + register_sysbusdevops_type::()?; + register_sysbusdevops_type::() +} + +pub fn to_sysbusdevops(dev: &mut dyn Device) -> Option<&mut dyn SysBusDevOps> { + // SAFETY: SYSBUSDEVTYPE_HASHMAP has been built. And this function is called without changing hashmap. + unsafe { + let types = SYSBUSDEVTYPE_HASHMAP.as_mut().unwrap(); + let func = types.get(&dev.device_type_id())?; + let sysbusdev = func(dev.as_any_mut()); + Some(sysbusdev) + } +} + +#[cfg(test)] +pub fn sysbus_init() -> Arc> { + let sys_mem = AddressSpace::new( + Region::init_container_region(u64::max_value(), "sys_mem"), + "sys_mem", + None, + ) + .unwrap(); + #[cfg(target_arch = "x86_64")] + let sys_io = AddressSpace::new( + Region::init_container_region(1 << 16, "sys_io"), + "sys_io", + None, + ) + .unwrap(); + let free_irqs: (i32, i32) = (IRQ_BASE, IRQ_MAX); + let mmio_region: (u64, u64) = (0x0A00_0000, 0x1000_0000); + Arc::new(Mutex::new(SysBus::new( + #[cfg(target_arch = "x86_64")] + &sys_io, + &sys_mem, + free_irqs, + mmio_region, + ))) +} diff --git a/devices/src/usb/camera.rs b/devices/src/usb/camera.rs index bd7ae9fb6dedc956e7bd67273302a51e3107d9ef..e05f89b91881afc72179cb829b3ee7dd1c9af26f 100644 --- a/devices/src/usb/camera.rs +++ b/devices/src/usb/camera.rs @@ -32,23 +32,19 @@ use crate::camera_backend::{ }; use crate::usb::config::*; use crate::usb::descriptor::*; -use crate::usb::{ - UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbEndpoint, UsbPacket, UsbPacketStatus, -}; +use crate::usb::{UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbPacket, UsbPacketStatus}; use machine_manager::config::camera::CameraDevConfig; use machine_manager::config::valid_id; use machine_manager::event_loop::{register_event_helper, unregister_event_helper}; +use machine_manager::notifier::{register_vm_pause_notifier, unregister_vm_pause_notifier}; use util::aio::{iov_discard_front_direct, Iovec}; use util::byte_code::ByteCode; +use util::gen_base_func; use util::loop_context::{ - read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, + create_new_eventfd, read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, + NotifierOperation, }; -// CRC16 of "STRATOVIRT" -const UVC_VENDOR_ID: u16 = 0xB74C; -// The first 4 chars of "VIDEO", 5 substitutes V. -const UVC_PRODUCT_ID: u16 = 0x51DE; - const INTERFACE_ID_CONTROL: u8 = 0; const INTERFACE_ID_STREAMING: u8 = 1; @@ -95,8 +91,10 @@ const FRAME_SIZE_1280_720: u32 = 1280 * 720 * 2; const USB_CAMERA_BUFFER_LEN: usize = 12 * 1024; #[derive(Parser, Debug, Clone)] -#[command(name = "usb_camera")] +#[command(no_binary_name(true))] pub struct UsbCameraConfig { + #[arg(long)] + pub classtype: String, #[arg(long, value_parser = valid_id)] pub id: String, #[arg(long)] @@ -116,6 +114,7 @@ pub struct UsbCamera { broken: Arc, // if the device broken or not iothread: Option, delete_evts: Vec, + notifier_id: u64, } #[derive(Debug)] @@ -433,8 +432,8 @@ fn gen_desc_device_camera(fmt_list: Vec) -> Result Result { - let camera = create_cam_backend(config.clone(), cameradev)?; + pub fn new(config: UsbCameraConfig, cameradev: CameraDevConfig, tokenid: u64) -> Result { + let camera = create_cam_backend(config.clone(), cameradev, tokenid)?; Ok(Self { base: UsbDeviceBase::new(config.id, USB_CAMERA_BUFFER_LEN), vs_control: VideoStreamingControl::default(), - camera_fd: Arc::new(EventFd::new(libc::EFD_NONBLOCK)?), + camera_fd: Arc::new(create_new_eventfd()?), camera_backend: camera, packet_list: Arc::new(Mutex::new(LinkedList::new())), payload: Arc::new(Mutex::new(UvcPayload::new())), @@ -516,6 +515,7 @@ impl UsbCamera { broken: Arc::new(AtomicBool::new(false)), iothread: config.iothread, delete_evts: Vec::new(), + notifier_id: 0, }) } @@ -752,13 +752,7 @@ impl UsbCamera { } impl UsbDevice for UsbCamera { - fn usb_device_base(&self) -> &UsbDeviceBase { - &self.base - } - - fn usb_device_base_mut(&mut self) -> &mut UsbDeviceBase { - &mut self.base - } + gen_base_func!(usb_device_base, usb_device_base_mut, UsbDeviceBase, base); fn realize(mut self) -> Result>> { let fmt_list = self.camera_backend.lock().unwrap().list_format()?; @@ -774,15 +768,27 @@ impl UsbDevice for UsbCamera { self.register_cb(); let camera = Arc::new(Mutex::new(self)); + let cloned_camera = camera.clone(); + let pause_notify = Arc::new(move |paused: bool| { + let locked_cam = cloned_camera.lock().unwrap(); + locked_cam.camera_backend.lock().unwrap().pause(paused); + }); + camera.lock().unwrap().notifier_id = register_vm_pause_notifier(pause_notify); + Ok(camera) } fn unrealize(&mut self) -> Result<()> { info!("Camera {} unrealize", self.device_id()); + self.unregister_camera_fd()?; self.camera_backend.lock().unwrap().reset(); + unregister_vm_pause_notifier(self.notifier_id); + self.notifier_id = 0; Ok(()) } + fn cancel_packet(&mut self, _packet: &Arc>) {} + fn reset(&mut self) { info!("Camera {} device reset", self.device_id()); self.base.addr = 0; @@ -809,7 +815,7 @@ impl UsbDevice for UsbCamera { } } Err(e) => { - warn!("Camera descriptor error {:?}", e); + warn!("Received incorrect USB Camera descriptor message: {:?}", e); locked_packet.status = UsbPacketStatus::Stall; return; } @@ -851,10 +857,6 @@ impl UsbDevice for UsbCamera { fn get_controller(&self) -> Option>> { None } - - fn get_wakeup_endpoint(&self) -> &UsbEndpoint { - self.base.get_endpoint(true, 1) - } } /// UVC payload @@ -884,7 +886,12 @@ impl UvcPayload { let mut frame_data_size = iov_size; let header_len = self.header.len(); // Within the scope of the frame. - if self.frame_offset + frame_data_size as usize >= current_frame_size { + if self + .frame_offset + .checked_add(frame_data_size as usize) + .with_context(|| "get_frame_data_size: invalid frame data")? + >= current_frame_size + { if self.frame_offset > current_frame_size { bail!( "Invalid frame offset {} {}", @@ -895,7 +902,12 @@ impl UvcPayload { frame_data_size = (current_frame_size - self.frame_offset) as u64; } // Within the scope of the payload. - if self.payload_offset + frame_data_size as usize >= MAX_PAYLOAD as usize { + if self + .payload_offset + .checked_add(frame_data_size as usize) + .with_context(|| "get_frame_data_size: invalid payload data")? + >= MAX_PAYLOAD as usize + { if self.payload_offset > MAX_PAYLOAD as usize { bail!( "Invalid payload offset {} {}", @@ -903,10 +915,15 @@ impl UvcPayload { MAX_PAYLOAD ); } - frame_data_size = MAX_PAYLOAD as u64 - self.payload_offset as u64; + frame_data_size = u64::from(MAX_PAYLOAD) - self.payload_offset as u64; } // payload start, reserve the header. - if self.payload_offset == 0 && frame_data_size + header_len as u64 > iov_size { + if self.payload_offset == 0 + && frame_data_size + .checked_add(header_len as u64) + .with_context(|| "get_frame_data_size: invalid header_len")? + > iov_size + { if iov_size <= header_len as u64 { bail!("Invalid iov size {}", iov_size); } @@ -996,7 +1013,7 @@ impl CameraIoHandler { // Payload start, add header. pkt.transfer_packet(&mut locked_payload.header, header_len); locked_payload.payload_offset += header_len; - iovecs = iov_discard_front_direct(&mut pkt.iovecs, pkt.actual_length as u64) + iovecs = iov_discard_front_direct(&mut pkt.iovecs, u64::from(pkt.actual_length)) .with_context(|| format!("Invalid iov size {}", pkt_size))?; } let copied = locked_camera.get_frame( @@ -1076,7 +1093,7 @@ fn gen_fmt_desc(fmt_list: Vec) -> Result body.push(Arc::new(UsbDescOther { data })); } - header_struct.wTotalLength = header_struct.bLength as u16 + header_struct.wTotalLength = u16::from(header_struct.bLength) + body.clone().iter().fold(0, |len, x| len + x.data.len()) as u16; let mut vec = header_struct.as_bytes().to_vec(); @@ -1090,7 +1107,10 @@ fn gen_fmt_desc(fmt_list: Vec) -> Result fn gen_intface_header_desc(fmt_num: u8) -> VsDescInputHeader { VsDescInputHeader { - bLength: 0xd + fmt_num, + bLength: 0xd_u8.checked_add(fmt_num).unwrap_or_else(|| { + error!("gen_intface_header_desc: too large fmt num"); + u8::MAX + }), bDescriptorType: CS_INTERFACE, bDescriptorSubtype: VS_INPUT_HEADER, bNumFormats: fmt_num, @@ -1107,8 +1127,8 @@ fn gen_intface_header_desc(fmt_num: u8) -> VsDescInputHeader { fn gen_fmt_header(fmt: &CameraFormatList) -> Result> { let bits_per_pixel = match fmt.format { - FmtType::Yuy2 | FmtType::Rgb565 => 0x10, - FmtType::Nv12 => 0xc, + FmtType::Yuy2 | FmtType::Rgb565 => 0x10_u8, + FmtType::Nv12 => 0xc_u8, _ => 0, }; let header = match fmt.format { diff --git a/devices/src/usb/config.rs b/devices/src/usb/config.rs index 3dda38b45fb7a09755bccca4b62522804a252344..aef89932a7a14ad4d202c152795692c2cf96bb8f 100644 --- a/devices/src/usb/config.rs +++ b/devices/src/usb/config.rs @@ -155,6 +155,8 @@ pub const USB_RECIPIENT_DEVICE: u8 = 0; pub const USB_RECIPIENT_INTERFACE: u8 = 1; pub const USB_RECIPIENT_ENDPOINT: u8 = 2; pub const USB_RECIPIENT_OTHER: u8 = 3; +pub const USB_TYPE_MASK: u8 = 3 << 5; +pub const USB_RECIPIENT_MASK: u8 = 0x1F; /// USB device request combination pub const USB_DEVICE_IN_REQUEST: u8 = @@ -206,11 +208,14 @@ pub const USB_DT_DEBUG: u8 = 10; pub const USB_DT_INTERFACE_ASSOCIATION: u8 = 11; pub const USB_DT_BOS: u8 = 15; pub const USB_DT_DEVICE_CAPABILITY: u8 = 16; +pub const USB_DT_PIPE_USAGE: u8 = 36; pub const USB_DT_ENDPOINT_COMPANION: u8 = 48; /// USB SuperSpeed Device Capability. pub const USB_SS_DEVICE_CAP: u8 = 0x3; +pub const USB_SS_DEVICE_SPEED_SUPPORTED_HIGH: u16 = 1 << 2; pub const USB_SS_DEVICE_SPEED_SUPPORTED_SUPER: u16 = 1 << 3; +pub const USB_SS_DEVICE_FUNCTIONALITY_SUPPORT_HIGH: u8 = 2; pub const USB_SS_DEVICE_FUNCTIONALITY_SUPPORT_SUPER: u8 = 3; /// USB Descriptor size @@ -218,8 +223,10 @@ pub const USB_DT_DEVICE_SIZE: u8 = 18; pub const USB_DT_CONFIG_SIZE: u8 = 9; pub const USB_DT_INTERFACE_SIZE: u8 = 9; pub const USB_DT_ENDPOINT_SIZE: u8 = 7; +pub const USB_DT_DEVICE_QUALIFIER_SIZE: u8 = 10; pub const USB_DT_BOS_SIZE: u8 = 5; pub const USB_DT_SS_CAP_SIZE: u8 = 10; +pub const USB_DT_PIPE_USAGE_SIZE: u8 = 4; pub const USB_DT_SS_EP_COMP_SIZE: u8 = 6; /// USB Endpoint Descriptor @@ -236,8 +243,27 @@ pub const USB_CONFIGURATION_ATTR_ONE: u8 = 1 << 7; pub const USB_CONFIGURATION_ATTR_SELF_POWER: u8 = 1 << 6; pub const USB_CONFIGURATION_ATTR_REMOTE_WAKEUP: u8 = 1 << 5; -// USB Class +/// USB Class pub const USB_CLASS_HID: u8 = 3; pub const USB_CLASS_MASS_STORAGE: u8 = 8; pub const USB_CLASS_VIDEO: u8 = 0xe; pub const USB_CLASS_MISCELLANEOUS: u8 = 0xef; + +/// USB Subclass +pub const USB_SUBCLASS_BOOT: u8 = 0x01; +pub const USB_SUBCLASS_SCSI: u8 = 0x06; + +/// USB Interface Protocol +pub const USB_IFACE_PROTOCOL_KEYBOARD: u8 = 0x01; +pub const USB_IFACE_PROTOCOL_BOT: u8 = 0x50; +pub const USB_IFACE_PROTOCOL_UAS: u8 = 0x62; + +/// CRC16 of "STRATOVIRT" +pub const USB_VENDOR_ID_STRATOVIRT: u16 = 0xB74C; + +/// USB Product IDs +pub const USB_PRODUCT_ID_UVC: u16 = 0x0001; +pub const USB_PRODUCT_ID_KEYBOARD: u16 = 0x0002; +pub const USB_PRODUCT_ID_STORAGE: u16 = 0x0003; +pub const USB_PRODUCT_ID_TABLET: u16 = 0x0004; +pub const USB_PRODUCT_ID_UAS: u16 = 0x0005; diff --git a/devices/src/usb/descriptor.rs b/devices/src/usb/descriptor.rs index bdb30f771daa150f3137733f984504d9129c8a1a..55858cd4dc7d2a7d9043a9849a9ad882b151fe2a 100644 --- a/devices/src/usb/descriptor.rs +++ b/devices/src/usb/descriptor.rs @@ -132,7 +132,7 @@ impl ByteCode for UsbBOSDescriptor {} #[allow(non_snake_case)] #[repr(C, packed)] #[derive(Copy, Clone, Debug, Default)] -struct UsbSuperSpeedCapDescriptor { +pub struct UsbSuperSpeedCapDescriptor { pub bLength: u8, pub bDescriptorType: u8, pub bDevCapabilityType: u8, @@ -165,11 +165,6 @@ pub struct UsbDescDevice { pub configs: Vec>, } -/// USB device qualifier descriptor. -pub struct UsbDescDeviceQualifier { - pub qualifier_desc: UsbDeviceQualifierDescriptor, -} - /// USB config descriptor. pub struct UsbDescConfig { pub config_desc: UsbConfigDescriptor, @@ -221,24 +216,24 @@ pub struct UsbDescEndpoint { /// USB Descriptor. pub struct UsbDescriptor { pub device_desc: Option>, - pub device_qualifier_desc: Option>, pub configuration_selected: Option>, pub interfaces: Vec>>, pub altsetting: Vec, pub interface_number: u32, pub strings: Vec, + pub capabilities: Vec, } impl UsbDescriptor { pub fn new() -> Self { Self { device_desc: None, - device_qualifier_desc: None, configuration_selected: None, interfaces: vec![None; USB_MAX_INTERFACES as usize], altsetting: vec![0; USB_MAX_INTERFACES as usize], interface_number: 0, strings: Vec::new(), + capabilities: Vec::new(), } } @@ -264,7 +259,7 @@ impl UsbDescriptor { let mut ifs = self.get_interfaces_descriptor(conf.interfaces.as_ref())?; config_desc.wTotalLength = - config_desc.bLength as u16 + iads.len() as u16 + ifs.len() as u16; + u16::from(config_desc.bLength) + iads.len() as u16 + ifs.len() as u16; let mut buf = config_desc.as_bytes().to_vec(); buf.append(&mut iads); @@ -349,11 +344,25 @@ impl UsbDescriptor { } fn get_device_qualifier_descriptor(&self) -> Result> { - if let Some(desc) = self.device_qualifier_desc.as_ref() { - Ok(desc.qualifier_desc.as_bytes().to_vec()) - } else { - bail!("Device qualifier descriptor not found"); + if self.device_desc.is_none() { + bail!("device qualifier descriptor not found"); } + + // SAFETY: device_desc has just been checked + let device_desc = &self.device_desc.as_ref().unwrap().device_desc; + let device_qualifier_desc = UsbDeviceQualifierDescriptor { + bLength: USB_DT_DEVICE_QUALIFIER_SIZE, + bDescriptorType: USB_DT_DEVICE_QUALIFIER, + bcdUSB: device_desc.bcdUSB, + bDeviceClass: device_desc.bDeviceClass, + bDeviceSubClass: device_desc.bDeviceSubClass, + bDeviceProtocol: device_desc.bDeviceProtocol, + bMaxPacketSize0: device_desc.bMaxPacketSize0, + bNumConfigurations: device_desc.bNumConfigurations, + bReserved: 0, + }; + + Ok(device_qualifier_desc.as_bytes().to_vec()) } fn get_debug_descriptor(&self) -> Result> { @@ -362,25 +371,32 @@ impl UsbDescriptor { } fn get_bos_descriptor(&self, speed: u32) -> Result> { - let mut total = USB_DT_BOS_SIZE as u16; + let mut total = u16::from(USB_DT_BOS_SIZE); let mut cap = Vec::new(); let mut cap_num = 0; if speed == USB_SPEED_SUPER { - let super_cap = UsbSuperSpeedCapDescriptor { - bLength: USB_DT_SS_CAP_SIZE, - bDescriptorType: USB_DT_DEVICE_CAPABILITY, - bDevCapabilityType: USB_SS_DEVICE_CAP, - bmAttributes: 0, - wSpeedsSupported: USB_SS_DEVICE_SPEED_SUPPORTED_SUPER, - bFunctionalitySupport: USB_SS_DEVICE_FUNCTIONALITY_SUPPORT_SUPER, - bU1DevExitLat: 0xa, - wU2DevExitLat: 0x20, + let default_cap = if self.capabilities.is_empty() { + vec![UsbSuperSpeedCapDescriptor { + bLength: USB_DT_SS_CAP_SIZE, + bDescriptorType: USB_DT_DEVICE_CAPABILITY, + bDevCapabilityType: USB_SS_DEVICE_CAP, + bmAttributes: 0, + wSpeedsSupported: USB_SS_DEVICE_SPEED_SUPPORTED_SUPER, + bFunctionalitySupport: USB_SS_DEVICE_FUNCTIONALITY_SUPPORT_SUPER, + bU1DevExitLat: 0xa, + wU2DevExitLat: 0x20, + }] + } else { + Vec::new() }; - let mut super_buf = super_cap.as_bytes().to_vec(); - cap_num += 1; - total += super_buf.len() as u16; - cap.append(&mut super_buf); + + for desc in default_cap.iter().chain(self.capabilities.iter()) { + let mut super_buf = (*desc).as_bytes().to_vec(); + cap_num += 1; + total += super_buf.len() as u16; + cap.append(&mut super_buf); + } } let bos = UsbBOSDescriptor { @@ -400,8 +416,8 @@ impl UsbDescriptor { for i in 0..conf.iad_desc.len() { let ifaces = &conf.iad_desc[i].as_ref().itfs; for iface in ifaces { - if iface.interface_desc.bInterfaceNumber == nif as u8 - && iface.interface_desc.bAlternateSetting == alt as u8 + if u32::from(iface.interface_desc.bInterfaceNumber) == nif + && u32::from(iface.interface_desc.bAlternateSetting) == alt { return Some(iface.clone()); } @@ -409,8 +425,8 @@ impl UsbDescriptor { } for i in 0..conf.interfaces.len() { let iface = conf.interfaces[i].as_ref(); - if iface.interface_desc.bInterfaceNumber == nif as u8 - && iface.interface_desc.bAlternateSetting == alt as u8 + if u32::from(iface.interface_desc.bInterfaceNumber) == nif + && u32::from(iface.interface_desc.bAlternateSetting) == alt { return Some(conf.interfaces[i].clone()); } @@ -436,6 +452,9 @@ pub trait UsbDescriptorOps { /// Set interface descriptor with the Interface and Alternate Setting. fn set_interface_descriptor(&mut self, index: u32, v: u32) -> Result<()>; + /// Set super speed capability descriptors. + fn set_capability_descriptors(&mut self, caps: Vec); + /// Init all endpoint descriptors and reset the USB endpoint. fn init_endpoint(&mut self) -> Result<()>; @@ -476,7 +495,7 @@ impl UsbDescriptorOps for UsbDeviceBase { for i in 0..num as usize { if desc.configs[i].config_desc.bConfigurationValue == v { self.descriptor.interface_number = - desc.configs[i].config_desc.bNumInterfaces as u32; + u32::from(desc.configs[i].config_desc.bNumInterfaces); self.descriptor.configuration_selected = Some(desc.configs[i].clone()); found = true; } @@ -510,6 +529,10 @@ impl UsbDescriptorOps for UsbDeviceBase { Ok(()) } + fn set_capability_descriptors(&mut self, caps: Vec) { + self.descriptor.capabilities = caps; + } + fn init_endpoint(&mut self) -> Result<()> { self.reset_usb_endpoint(); for i in 0..self.descriptor.interface_number { diff --git a/devices/src/usb/hid.rs b/devices/src/usb/hid.rs index 5706cdda715c91690f5fb3eb50b9dc15c2ef37df..513005ef8d71ae8e99016551d79f26b3ee7ff824 100644 --- a/devices/src/usb/hid.rs +++ b/devices/src/usb/hid.rs @@ -264,7 +264,7 @@ impl Hid { self.num -= 1; let keycode = self.keyboard.keycodes[slot as usize]; let key = keycode & 0x7f; - let index = key | ((self.keyboard.modifiers as u32 & (1 << 8)) >> 1); + let index = key | ((u32::from(self.keyboard.modifiers) & (1 << 8)) >> 1); let hid_code = HID_CODE[index as usize]; self.keyboard.modifiers &= !(1 << 8); trace::usb_convert_to_hid_code(&hid_code, &index, &key); diff --git a/devices/src/usb/keyboard.rs b/devices/src/usb/keyboard.rs index 73dc8ed9e50ddf5ccd9791fa2a75f6a6306a7221..418d6043464a898b92105f1cd3c16c899053fadd 100644 --- a/devices/src/usb/keyboard.rs +++ b/devices/src/usb/keyboard.rs @@ -22,14 +22,14 @@ use super::descriptor::{ UsbDescriptorOps, UsbDeviceDescriptor, UsbEndpointDescriptor, UsbInterfaceDescriptor, }; use super::hid::{Hid, HidType, QUEUE_LENGTH, QUEUE_MASK}; -use super::xhci::xhci_controller::XhciDevice; +use super::xhci::xhci_controller::{endpoint_number_to_id, XhciDevice}; use super::{config::*, USB_DEVICE_BUFFER_DEFAULT_LEN}; use super::{ - notify_controller, UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbEndpoint, UsbPacket, - UsbPacketStatus, + notify_controller, UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbPacket, UsbPacketStatus, }; use machine_manager::config::valid_id; use ui::input::{register_keyboard, unregister_keyboard, KeyboardOpts}; +use util::gen_base_func; /// Keyboard device descriptor static DESC_DEVICE_KEYBOARD: Lazy> = Lazy::new(|| { @@ -38,7 +38,7 @@ static DESC_DEVICE_KEYBOARD: Lazy> = Lazy::new(|| { bLength: USB_DT_DEVICE_SIZE, bDescriptorType: USB_DT_DEVICE, idVendor: 0x0627, - idProduct: 0x0001, + idProduct: USB_PRODUCT_ID_KEYBOARD, bcdDevice: 0, iManufacturer: STR_MANUFACTURER_INDEX, iProduct: STR_PRODUCT_KEYBOARD_INDEX, @@ -76,8 +76,8 @@ static DESC_IFACE_KEYBOARD: Lazy> = Lazy::new(|| { bAlternateSetting: 0, bNumEndpoints: 1, bInterfaceClass: USB_CLASS_HID, - bInterfaceSubClass: 1, - bInterfaceProtocol: 1, + bInterfaceSubClass: USB_SUBCLASS_BOOT, + bInterfaceProtocol: USB_IFACE_PROTOCOL_KEYBOARD, iInterface: 0, }, other_desc: vec![Arc::new(UsbDescOther { @@ -121,8 +121,10 @@ const DESC_STRINGS: [&str; 5] = [ ]; #[derive(Parser, Clone, Debug, Default)] -#[command(name = "usb_keyboard")] +#[command(no_binary_name(true))] pub struct UsbKeyboardConfig { + #[arg(long)] + pub classtype: String, #[arg(long, value_parser = valid_id)] id: String, #[arg(long)] @@ -150,14 +152,14 @@ impl KeyboardOpts for UsbKeyboardAdapter { let mut scan_codes = Vec::new(); let mut keycode = keycode; if keycode & SCANCODE_GREY != 0 { - scan_codes.push(SCANCODE_EMUL0 as u32); + scan_codes.push(u32::from(SCANCODE_EMUL0)); keycode &= !SCANCODE_GREY; } if !down { keycode |= SCANCODE_UP; } - scan_codes.push(keycode as u32); + scan_codes.push(u32::from(keycode)); let mut locked_kbd = self.usb_kbd.lock().unwrap(); if scan_codes.len() as u32 + locked_kbd.hid.num > QUEUE_LENGTH { @@ -172,7 +174,9 @@ impl KeyboardOpts for UsbKeyboardAdapter { } drop(locked_kbd); let clone_kbd = self.usb_kbd.clone(); - notify_controller(&(clone_kbd as Arc>)) + // Wakeup endpoint. + let ep_id = endpoint_number_to_id(true, 1); + notify_controller(&(clone_kbd as Arc>), ep_id) } } @@ -187,13 +191,7 @@ impl UsbKeyboard { } impl UsbDevice for UsbKeyboard { - fn usb_device_base(&self) -> &UsbDeviceBase { - &self.base - } - - fn usb_device_base_mut(&mut self) -> &mut UsbDeviceBase { - &mut self.base - } + gen_base_func!(usb_device_base, usb_device_base_mut, UsbDeviceBase, base); fn realize(mut self) -> Result>> { self.base.reset_usb_endpoint(); @@ -217,6 +215,8 @@ impl UsbDevice for UsbKeyboard { Ok(()) } + fn cancel_packet(&mut self, _packet: &Arc>) {} + fn reset(&mut self) { info!("Keyboard device reset"); self.base.remote_wakeup = 0; @@ -237,7 +237,10 @@ impl UsbDevice for UsbKeyboard { } } Err(e) => { - warn!("Keyboard descriptor error {:?}", e); + warn!( + "Received incorrect USB Keyboard descriptor message: {:?}", + e + ); locked_packet.status = UsbPacketStatus::Stall; return; } @@ -258,8 +261,4 @@ impl UsbDevice for UsbKeyboard { fn get_controller(&self) -> Option>> { self.cntlr.clone() } - - fn get_wakeup_endpoint(&self) -> &UsbEndpoint { - self.base.get_endpoint(true, 1) - } } diff --git a/devices/src/usb/mod.rs b/devices/src/usb/mod.rs index f44d8b5077c08de7c9ce43d12c4ddc6ff6c1ad43..110152b84ff12b189f9dae00aa5036e8b4f455cf 100644 --- a/devices/src/usb/mod.rs +++ b/devices/src/usb/mod.rs @@ -20,6 +20,8 @@ pub mod hid; pub mod keyboard; pub mod storage; pub mod tablet; +#[cfg(feature = "usb_uas")] +pub mod uas; #[cfg(feature = "usb_host")] pub mod usbhost; pub mod xhci; @@ -50,9 +52,10 @@ const USB_MAX_ADDRESS: u8 = 127; pub const USB_DEVICE_BUFFER_DEFAULT_LEN: usize = 4096; /// USB packet return status. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)] pub enum UsbPacketStatus { Success, + #[default] NoDev, Nak, Stall, @@ -60,15 +63,9 @@ pub enum UsbPacketStatus { IoError, } -impl Default for UsbPacketStatus { - fn default() -> Self { - Self::NoDev - } -} - /// USB request used to transfer to USB device. #[repr(C)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[derive(Copy, Clone, PartialEq, Eq, Default)] pub struct UsbDeviceRequest { pub request_type: u8, pub request: u8, @@ -79,8 +76,68 @@ pub struct UsbDeviceRequest { impl ByteCode for UsbDeviceRequest {} +impl std::fmt::Debug for UsbDeviceRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UsbDeviceRequest") + .field("request_type", &parse_request_type(self.request_type)) + .field("request", &parse_request(self.request)) + .field("value", &self.value) + .field("index", &self.index) + .field("length", &self.length) + .finish() + } +} + +fn parse_request_type(request_type: u8) -> String { + let mut ret = "".to_string(); + + match request_type & USB_DIRECTION_DEVICE_TO_HOST { + USB_DIRECTION_DEVICE_TO_HOST => ret.push_str("IN"), + _ => ret.push_str("OUT"), + } + + ret.push(' '); + + match request_type & USB_TYPE_MASK { + USB_TYPE_STANDARD => ret.push_str("STD"), + USB_TYPE_CLASS => ret.push_str("CLASS"), + USB_TYPE_VENDOR => ret.push_str("VEND"), + _ => ret.push_str("RSVD"), + } + + ret.push(' '); + + match request_type & USB_RECIPIENT_MASK { + USB_RECIPIENT_DEVICE => ret.push_str("DEV"), + USB_RECIPIENT_INTERFACE => ret.push_str("IFACE"), + USB_RECIPIENT_ENDPOINT => ret.push_str("EP"), + _ => ret.push_str("OTHER"), + } + + ret +} + +fn parse_request(request: u8) -> String { + match request { + USB_REQUEST_GET_STATUS => "GET STAT".to_string(), + USB_REQUEST_CLEAR_FEATURE => "CLR FEAT".to_string(), + USB_REQUEST_SET_FEATURE => "SET FEAT".to_string(), + USB_REQUEST_SET_ADDRESS => "SET ADDR".to_string(), + USB_REQUEST_GET_DESCRIPTOR => "GET DESC".to_string(), + USB_REQUEST_SET_DESCRIPTOR => "SET DESC".to_string(), + USB_REQUEST_GET_CONFIGURATION => "GET CONF".to_string(), + USB_REQUEST_SET_CONFIGURATION => "SET CONF".to_string(), + USB_REQUEST_GET_INTERFACE => "GET IFACE".to_string(), + USB_REQUEST_SET_INTERFACE => "SET IFACE".to_string(), + USB_REQUEST_SYNCH_FRAME => "SYN FRAME".to_string(), + USB_REQUEST_SET_SEL => "SET SEL".to_string(), + USB_REQUEST_SET_ISOCH_DELAY => "SET ISO DEL".to_string(), + _ => format!("UNKNOWN {}", request), + } +} + /// The data transmission channel. -#[derive(Default, Clone)] +#[derive(Default, Clone, Copy)] pub struct UsbEndpoint { pub ep_number: u8, pub in_direction: bool, @@ -108,7 +165,7 @@ impl UsbEndpoint { _ => 1, }; - self.max_packet_size = size as u32 * micro_frames; + self.max_packet_size = u32::from(size) * micro_frames; } } @@ -134,7 +191,7 @@ pub struct UsbDeviceBase { impl UsbDeviceBase { pub fn new(id: String, data_buf_len: usize) -> Self { let mut dev = UsbDeviceBase { - base: DeviceBase::new(id, false), + base: DeviceBase::new(id, false, None), port: None, speed: 0, addr: 0, @@ -212,9 +269,9 @@ impl UsbDeviceBase { packet: &mut UsbPacket, device_req: &UsbDeviceRequest, ) -> Result { - let value = device_req.value as u32; - let index = device_req.index as u32; - let length = device_req.length as u32; + let value = u32::from(device_req.value); + let index = u32::from(device_req.index); + let length = u32::from(device_req.length); match device_req.request_type { USB_DEVICE_IN_REQUEST => match device_req.request { USB_REQUEST_GET_DESCRIPTOR => { @@ -241,7 +298,7 @@ impl UsbDeviceBase { .as_ref() .with_context(|| "Device descriptor not found")?; desc.configs - .get(0) + .first() .with_context(|| "Config descriptor not found")? .clone() }; @@ -349,6 +406,10 @@ pub trait UsbDevice: Send + Sync { fn unrealize(&mut self) -> Result<()> { Ok(()) } + + /// Cancel specified USB packet. + fn cancel_packet(&mut self, packet: &Arc>); + /// Handle the attach ops when attach device to controller. fn handle_attach(&mut self) -> Result<()> { let usb_dev = self.usb_device_base_mut(); @@ -366,9 +427,6 @@ pub trait UsbDevice: Send + Sync { /// Get the controller which the USB device attached. fn get_controller(&self) -> Option>>; - /// Get the endpoint to wakeup. - fn get_wakeup_endpoint(&self) -> &UsbEndpoint; - /// Set the attached USB port. fn set_usb_port(&mut self, port: Option>>) { let usb_dev = self.usb_device_base_mut(); @@ -381,7 +439,7 @@ pub trait UsbDevice: Send + Sync { locked_packet.status = UsbPacketStatus::Success; let ep_nr = locked_packet.ep_number; drop(locked_packet); - debug!("handle packet endpointer number {}", ep_nr); + debug!("handle packet endpoint number {}", ep_nr); if ep_nr == 0 { if let Err(e) = self.do_parameter(packet) { error!("Failed to handle control packet {:?}", e); @@ -445,7 +503,7 @@ pub trait UsbDevice: Send + Sync { } /// Notify controller to process data request. -pub fn notify_controller(dev: &Arc>) -> Result<()> { +pub fn notify_controller(dev: &Arc>, ep_id: u8) -> Result<()> { let locked_dev = dev.lock().unwrap(); let xhci = if let Some(cntlr) = &locked_dev.get_controller() { cntlr.upgrade().unwrap() @@ -460,7 +518,6 @@ pub fn notify_controller(dev: &Arc>) -> Result<()> { }; let slot_id = usb_dev.addr; let wakeup = usb_dev.remote_wakeup & USB_DEVICE_REMOTE_WAKEUP == USB_DEVICE_REMOTE_WAKEUP; - let ep = locked_dev.get_wakeup_endpoint().clone(); // Drop the small lock. drop(locked_dev); let mut locked_xhci = xhci.lock().unwrap(); @@ -477,7 +534,7 @@ pub fn notify_controller(dev: &Arc>) -> Result<()> { locked_xhci.port_notify(&usb_port, PORTSC_PLC)?; } } - if let Err(e) = locked_xhci.wakeup_endpoint(slot_id as u32, &ep) { + if let Err(e) = locked_xhci.wakeup_endpoint(u32::from(slot_id), u32::from(ep_id), 0) { error!("Failed to wakeup endpoint {:?}", e); } Ok(()) @@ -491,9 +548,9 @@ pub trait TransferOps: Send + Sync { /// Usb packet used for device transfer data. #[derive(Default)] pub struct UsbPacket { - /// USB packet unique identifier. + /// Unique number for packet tracking. pub packet_id: u32, - /// USB packet id. + /// USB packet id (direction of the transfer). pub pid: u32, pub is_async: bool, pub iovecs: Vec, @@ -505,10 +562,12 @@ pub struct UsbPacket { pub actual_length: u32, /// Endpoint number. pub ep_number: u8, - /// Transfer for complete packet. - pub xfer_ops: Option>>, /// Stream id. pub stream: u32, + /// Transfer for complete packet. + pub xfer_ops: Option>>, + /// Target USB device for this packet. + pub target_dev: Option>>, } impl std::fmt::Display for UsbPacket { @@ -526,8 +585,10 @@ impl UsbPacket { packet_id: u32, pid: u32, ep_number: u8, + stream: u32, iovecs: Vec, xfer_ops: Option>>, + target_dev: Option>>, ) -> Self { Self { packet_id, @@ -538,8 +599,9 @@ impl UsbPacket { status: UsbPacketStatus::Success, actual_length: 0, ep_number, + stream, xfer_ops, - stream: 0, + target_dev, } } @@ -563,7 +625,8 @@ impl UsbPacket { } let cnt = min(iov.iov_len as usize, len - copied); let tmp = &vec[copied..(copied + cnt)]; - if let Err(e) = mem_from_buf(tmp, iov.iov_base) { + // SAFETY: iovecs is generated by address_space and len is not less than tmp's. + if let Err(e) = unsafe { mem_from_buf(tmp, iov.iov_base) } { error!("Failed to write mem: {:?}", e); } copied += cnt; @@ -578,7 +641,8 @@ impl UsbPacket { } let cnt = min(iov.iov_len as usize, len - copied); let tmp = &mut vec[copied..(copied + cnt)]; - if let Err(e) = mem_to_buf(tmp, iov.iov_base) { + // SAFETY: iovecs is generation by address_space and len is not less than tmp's. + if let Err(e) = unsafe { mem_to_buf(tmp, iov.iov_base) } { error!("Failed to read mem {:?}", e); } copied += cnt; @@ -606,7 +670,7 @@ mod tests { let buf = [0_u8; 10]; let hva = buf.as_ptr() as u64; let mut packet = UsbPacket::default(); - packet.pid = USB_TOKEN_IN as u32; + packet.pid = u32::from(USB_TOKEN_IN); packet.iovecs.push(Iovec::new(hva, 4)); packet.iovecs.push(Iovec::new(hva + 4, 2)); let mut data: Vec = vec![1, 2, 3, 4, 5, 6]; @@ -620,7 +684,7 @@ mod tests { let buf = [0_u8; 10]; let hva = buf.as_ptr() as u64; let mut packet = UsbPacket::default(); - packet.pid = USB_TOKEN_IN as u32; + packet.pid = u32::from(USB_TOKEN_IN); packet.iovecs.push(Iovec::new(hva, 4)); let mut data: Vec = vec![1, 2, 3, 4, 5, 6]; @@ -634,7 +698,7 @@ mod tests { let buf = [0_u8; 10]; let hva = buf.as_ptr() as u64; let mut packet = UsbPacket::default(); - packet.pid = USB_TOKEN_IN as u32; + packet.pid = u32::from(USB_TOKEN_IN); packet.iovecs.push(Iovec::new(hva, 4)); let mut data: Vec = vec![1, 2, 3, 4, 5, 6]; @@ -648,7 +712,7 @@ mod tests { let buf = [0_u8; 10]; let hva = buf.as_ptr() as u64; let mut packet = UsbPacket::default(); - packet.pid = USB_TOKEN_IN as u32; + packet.pid = u32::from(USB_TOKEN_IN); packet.iovecs.push(Iovec::new(hva, 10)); let mut data: Vec = vec![1, 2, 3, 4, 5, 6]; @@ -662,7 +726,7 @@ mod tests { let buf: [u8; 10] = [1, 2, 3, 4, 5, 6, 0, 0, 0, 0]; let hva = buf.as_ptr() as u64; let mut packet = UsbPacket::default(); - packet.pid = USB_TOKEN_OUT as u32; + packet.pid = u32::from(USB_TOKEN_OUT); packet.iovecs.push(Iovec::new(hva, 4)); packet.iovecs.push(Iovec::new(hva + 4, 2)); @@ -677,7 +741,7 @@ mod tests { let buf: [u8; 10] = [1, 2, 3, 4, 5, 6, 0, 0, 0, 0]; let hva = buf.as_ptr() as u64; let mut packet = UsbPacket::default(); - packet.pid = USB_TOKEN_OUT as u32; + packet.pid = u32::from(USB_TOKEN_OUT); packet.iovecs.push(Iovec::new(hva, 4)); packet.iovecs.push(Iovec::new(hva + 4, 2)); @@ -692,7 +756,7 @@ mod tests { let buf: [u8; 10] = [1, 2, 3, 4, 5, 6, 0, 0, 0, 0]; let hva = buf.as_ptr() as u64; let mut packet = UsbPacket::default(); - packet.pid = USB_TOKEN_OUT as u32; + packet.pid = u32::from(USB_TOKEN_OUT); packet.iovecs.push(Iovec::new(hva, 4)); let mut data = [0_u8; 10]; @@ -706,7 +770,7 @@ mod tests { let buf: [u8; 10] = [1, 2, 3, 4, 5, 6, 0, 0, 0, 0]; let hva = buf.as_ptr() as u64; let mut packet = UsbPacket::default(); - packet.pid = USB_TOKEN_OUT as u32; + packet.pid = u32::from(USB_TOKEN_OUT); packet.iovecs.push(Iovec::new(hva, 6)); let mut data = [0_u8; 2]; diff --git a/devices/src/usb/storage.rs b/devices/src/usb/storage.rs index b227ecc088a75d7054aee88fee8bf26fa09a3d9d..aaa04aa08d79e54fc77964fbdc0ed137eea20fec 100644 --- a/devices/src/usb/storage.rs +++ b/devices/src/usb/storage.rs @@ -17,6 +17,7 @@ use std::{ use anyhow::{anyhow, bail, Context, Result}; use byteorder::{ByteOrder, LittleEndian}; +use clap::Parser; use log::{error, info, warn}; use once_cell::sync::Lazy; @@ -26,15 +27,16 @@ use super::descriptor::{ }; use super::xhci::xhci_controller::XhciDevice; use super::{config::*, USB_DEVICE_BUFFER_DEFAULT_LEN}; -use super::{UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbEndpoint, UsbPacket, UsbPacketStatus}; -use crate::{ - ScsiBus::{ - ScsiBus, ScsiRequest, ScsiRequestOps, ScsiSense, ScsiXferMode, EMULATE_SCSI_OPS, GOOD, - SCSI_CMD_BUF_SIZE, - }, - ScsiDisk::{ScsiDevice, SCSI_TYPE_DISK, SCSI_TYPE_ROM}, +use super::{UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbPacket, UsbPacketStatus}; +use crate::ScsiBus::{ + get_scsi_key, ScsiBus, ScsiRequest, ScsiRequestOps, ScsiSense, ScsiXferMode, EMULATE_SCSI_OPS, + GOOD, SCSI_CMD_BUF_SIZE, }; -use machine_manager::config::{DriveFile, UsbStorageConfig}; +use crate::ScsiDisk::{ScsiDevConfig, ScsiDevice}; +use crate::{Bus, Device}; +use machine_manager::config::{DriveConfig, DriveFile}; +use util::aio::AioEngine; +use util::gen_base_func; // Storage device descriptor static DESC_DEVICE_STORAGE: Lazy> = Lazy::new(|| { @@ -43,7 +45,7 @@ static DESC_DEVICE_STORAGE: Lazy> = Lazy::new(|| { bLength: USB_DT_DEVICE_SIZE, bDescriptorType: USB_DT_DEVICE, idVendor: USB_STORAGE_VENDOR_ID, - idProduct: 0x0001, + idProduct: USB_PRODUCT_ID_STORAGE, bcdDevice: 0, iManufacturer: STR_MANUFACTURER_INDEX, iProduct: STR_PRODUCT_STORAGE_INDEX, @@ -82,8 +84,8 @@ static DESC_IFACE_STORAGE: Lazy> = Lazy::new(|| { bAlternateSetting: 0, bNumEndpoints: 2, bInterfaceClass: USB_CLASS_MASS_STORAGE, - bInterfaceSubClass: 0x06, // SCSI - bInterfaceProtocol: 0x50, // Bulk-only + bInterfaceSubClass: USB_SUBCLASS_SCSI, + bInterfaceProtocol: USB_IFACE_PROTOCOL_BOT, iInterface: 0, }, other_desc: vec![], @@ -221,6 +223,21 @@ impl UsbStorageState { } } +#[derive(Parser, Clone, Debug)] +#[command(no_binary_name(true))] +pub struct UsbStorageConfig { + #[arg(long, value_parser = ["usb-storage"])] + pub classtype: String, + #[arg(long)] + pub id: String, + #[arg(long)] + pub drive: String, + #[arg(long)] + pub(super) bus: Option, + #[arg(long)] + pub(super) port: Option, +} + /// USB storage device. pub struct UsbStorage { base: UsbDeviceBase, @@ -228,7 +245,9 @@ pub struct UsbStorage { /// USB controller used to notify controller to transfer data. cntlr: Option>>, /// Configuration of the USB storage device. - pub config: UsbStorageConfig, + pub dev_cfg: UsbStorageConfig, + /// Configuration of the USB storage device's drive. + pub drive_cfg: DriveConfig, /// Scsi bus attached to this usb-storage device. scsi_bus: Arc>, /// Effective scsi backend. @@ -237,7 +256,9 @@ pub struct UsbStorage { // (usb-storage/scsi bus/scsi device) correspond one-to-one, add scsi device member here // for the execution efficiency (No need to find a unique device from the hash table of the // unique bus). - scsi_dev: Arc>, + scsi_dev: Option>>, + /// Drive backend files. + drive_files: Arc>>, } #[derive(Debug)] @@ -305,29 +326,62 @@ impl UsbMsdCsw { impl UsbStorage { pub fn new( - config: UsbStorageConfig, + dev_cfg: UsbStorageConfig, + drive_cfg: DriveConfig, drive_files: Arc>>, - ) -> Self { - let scsi_type = match &config.media as &str { - "disk" => SCSI_TYPE_DISK, - _ => SCSI_TYPE_ROM, - }; + ) -> Result { + if drive_cfg.aio != AioEngine::Off || drive_cfg.direct { + bail!("USB-storage: \"aio=off,direct=false\" must be configured."); + } - Self { - base: UsbDeviceBase::new(config.id.clone().unwrap(), USB_DEVICE_BUFFER_DEFAULT_LEN), + Ok(Self { + base: UsbDeviceBase::new(dev_cfg.id.clone(), USB_DEVICE_BUFFER_DEFAULT_LEN), state: UsbStorageState::new(), cntlr: None, - config: config.clone(), + dev_cfg, + drive_cfg, scsi_bus: Arc::new(Mutex::new(ScsiBus::new("".to_string()))), - scsi_dev: Arc::new(Mutex::new(ScsiDevice::new( - config.scsi_cfg, - scsi_type, - drive_files, - ))), - } + scsi_dev: None, + drive_files, + }) + } + + pub fn do_realize(&mut self) -> Result<()> { + self.base.reset_usb_endpoint(); + self.base.speed = USB_SPEED_HIGH; + let mut s: Vec = DESC_STRINGS.iter().map(|&s| s.to_string()).collect(); + let prefix = &s[STR_SERIAL_STORAGE_INDEX as usize]; + s[STR_SERIAL_STORAGE_INDEX as usize] = self.base.generate_serial_number(prefix); + self.base.init_descriptor(DESC_DEVICE_STORAGE.clone(), s)?; + + // NOTE: "aio=off,direct=false" must be configured and other aio/direct values are not + // supported. + let scsidev_classtype = match self.drive_cfg.media.as_str() { + "disk" => "scsi-hd".to_string(), + _ => "scsi-cd".to_string(), + }; + let scsi_dev_cfg = ScsiDevConfig { + classtype: scsidev_classtype, + drive: self.dev_cfg.drive.clone(), + ..Default::default() + }; + let scsi_device = ScsiDevice::new( + scsi_dev_cfg, + self.drive_cfg.clone(), + self.drive_files.clone(), + None, + self.scsi_bus.clone(), + ); + let realized_scsi = scsi_device.realize()?; + self.scsi_dev = Some(realized_scsi.clone()); + + self.scsi_bus + .lock() + .unwrap() + .attach_child(get_scsi_key(0, 0), realized_scsi) } - fn handle_control_packet(&mut self, packet: &mut UsbPacket, device_req: &UsbDeviceRequest) { + pub fn handle_control_packet(&mut self, packet: &mut UsbPacket, device_req: &UsbDeviceRequest) { match device_req.request_type { USB_ENDPOINT_OUT_REQUEST => { if device_req.request == USB_REQUEST_CLEAR_FEATURE { @@ -363,7 +417,7 @@ impl UsbStorage { match self.state.mode { UsbMsdMode::Cbw => { - if packet.get_iovecs_size() < CBW_SIZE as u64 { + if packet.get_iovecs_size() < u64::from(CBW_SIZE) { bail!("Bad CBW size {}", packet.get_iovecs_size()); } self.state.check_cdb_exist(false)?; @@ -417,7 +471,7 @@ impl UsbStorage { bail!("Not supported usb packet(Token_in and data_out)."); } UsbMsdMode::Csw => { - if packet.get_iovecs_size() < CSW_SIZE as u64 { + if packet.get_iovecs_size() < u64::from(CSW_SIZE) { bail!("Bad CSW size {}", packet.get_iovecs_size()); } self.state.check_cdb_exist(true)?; @@ -449,6 +503,7 @@ impl UsbStorage { self.state.check_cdb_exist(true)?; self.state.check_iovec_empty(true)?; + // Safety: iovecs are set in `setup_usb_packet` and iovec_len is no more than TRB_TR_LEN_MASK. let iovec_len = packet.get_iovecs_size() as u32; if iovec_len < self.state.cbw.data_len { bail!( @@ -480,12 +535,12 @@ impl UsbStorage { 0, packet.iovecs.clone(), self.state.iovec_len, - self.scsi_dev.clone(), + self.scsi_dev.as_ref().unwrap().clone(), csw, ) .with_context(|| "Error in creating scsirequest.")?; - if sreq.cmd.xfer > sreq.datalen && sreq.cmd.mode != ScsiXferMode::ScsiXferNone { + if sreq.cmd.xfer > u64::from(sreq.datalen) && sreq.cmd.mode != ScsiXferMode::ScsiXferNone { // Wrong USB packet which doesn't provide enough datain/dataout buffer. bail!( "command {:x} requested data's length({}), provided buffer length({})", @@ -511,37 +566,16 @@ impl UsbStorage { } impl UsbDevice for UsbStorage { - fn usb_device_base(&self) -> &UsbDeviceBase { - &self.base - } - - fn usb_device_base_mut(&mut self) -> &mut UsbDeviceBase { - &mut self.base - } + gen_base_func!(usb_device_base, usb_device_base_mut, UsbDeviceBase, base); fn realize(mut self) -> Result>> { - self.base.reset_usb_endpoint(); - self.base.speed = USB_SPEED_HIGH; - let mut s: Vec = DESC_STRINGS.iter().map(|&s| s.to_string()).collect(); - let prefix = &s[STR_SERIAL_STORAGE_INDEX as usize]; - s[STR_SERIAL_STORAGE_INDEX as usize] = self.base.generate_serial_number(prefix); - self.base.init_descriptor(DESC_DEVICE_STORAGE.clone(), s)?; - - // NOTE: "aio=off,direct=false" must be configured and other aio/direct values are not - // supported. - let mut locked_scsi_dev = self.scsi_dev.lock().unwrap(); - locked_scsi_dev.realize(None)?; - drop(locked_scsi_dev); - self.scsi_bus - .lock() - .unwrap() - .devices - .insert((0, 0), self.scsi_dev.clone()); - + self.do_realize()?; let storage: Arc> = Arc::new(Mutex::new(self)); Ok(storage) } + fn cancel_packet(&mut self, _packet: &Arc>) {} + fn reset(&mut self) { info!("Storage device reset"); self.base.remote_wakeup = 0; @@ -563,7 +597,7 @@ impl UsbDevice for UsbStorage { self.handle_control_packet(&mut locked_packet, device_req) } Err(e) => { - warn!("Storage descriptor error {:?}", e); + warn!("Received incorrect USB Storage descriptor message: {:?}", e); locked_packet.status = UsbPacketStatus::Stall; } } @@ -600,8 +634,4 @@ impl UsbDevice for UsbStorage { fn get_controller(&self) -> Option>> { self.cntlr.clone() } - - fn get_wakeup_endpoint(&self) -> &UsbEndpoint { - self.base.get_endpoint(true, 1) - } } diff --git a/devices/src/usb/tablet.rs b/devices/src/usb/tablet.rs index 67a15139a8ee30e7fa680c3a4f8d4b89474de07d..b9c54fc9735768223392c28ec83d2098a38c5128 100644 --- a/devices/src/usb/tablet.rs +++ b/devices/src/usb/tablet.rs @@ -23,10 +23,10 @@ use super::descriptor::{ UsbDescriptorOps, UsbDeviceDescriptor, UsbEndpointDescriptor, UsbInterfaceDescriptor, }; use super::hid::{Hid, HidType, QUEUE_LENGTH, QUEUE_MASK}; -use super::xhci::xhci_controller::XhciDevice; +use super::xhci::xhci_controller::{endpoint_number_to_id, XhciDevice}; use super::{ - config::*, notify_controller, UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbEndpoint, - UsbPacket, UsbPacketStatus, USB_DEVICE_BUFFER_DEFAULT_LEN, + config::*, notify_controller, UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbPacket, + UsbPacketStatus, USB_DEVICE_BUFFER_DEFAULT_LEN, }; use machine_manager::config::valid_id; use ui::input::{ @@ -34,6 +34,7 @@ use ui::input::{ INPUT_BUTTON_MASK, INPUT_BUTTON_WHEEL_DOWN, INPUT_BUTTON_WHEEL_LEFT, INPUT_BUTTON_WHEEL_RIGHT, INPUT_BUTTON_WHEEL_UP, }; +use util::gen_base_func; const INPUT_COORDINATES_MAX: u32 = 0x7fff; @@ -44,7 +45,7 @@ static DESC_DEVICE_TABLET: Lazy> = Lazy::new(|| { bLength: USB_DT_DEVICE_SIZE, bDescriptorType: USB_DT_DEVICE, idVendor: 0x0627, - idProduct: 0x0001, + idProduct: USB_PRODUCT_ID_TABLET, bcdDevice: 0, iManufacturer: STR_MANUFACTURER_INDEX, iProduct: STR_PRODUCT_TABLET_INDEX, @@ -114,8 +115,10 @@ const STR_SERIAL_TABLET_INDEX: u8 = 4; const DESC_STRINGS: [&str; 5] = ["", "StratoVirt", "StratoVirt USB Tablet", "HID Tablet", "2"]; #[derive(Parser, Clone, Debug, Default)] -#[command(name = "usb_tablet")] +#[command(no_binary_name(true))] pub struct UsbTabletConfig { + #[arg(long)] + pub classtype: String, #[arg(long, value_parser = valid_id)] id: String, #[arg(long)] @@ -226,18 +229,14 @@ impl PointerOpts for UsbTabletAdapter { locked_tablet.hid.num += 1; drop(locked_tablet); let clone_tablet = self.tablet.clone(); - notify_controller(&(clone_tablet as Arc>)) + // Wakeup endpoint. + let ep_id = endpoint_number_to_id(true, 1); + notify_controller(&(clone_tablet as Arc>), ep_id) } } impl UsbDevice for UsbTablet { - fn usb_device_base(&self) -> &UsbDeviceBase { - &self.base - } - - fn usb_device_base_mut(&mut self) -> &mut UsbDeviceBase { - &mut self.base - } + gen_base_func!(usb_device_base, usb_device_base_mut, UsbDeviceBase, base); fn realize(mut self) -> Result>> { self.base.reset_usb_endpoint(); @@ -260,6 +259,8 @@ impl UsbDevice for UsbTablet { Ok(()) } + fn cancel_packet(&mut self, _packet: &Arc>) {} + fn reset(&mut self) { info!("Tablet device reset"); self.base.remote_wakeup = 0; @@ -280,7 +281,7 @@ impl UsbDevice for UsbTablet { } } Err(e) => { - warn!("Tablet descriptor error {:?}", e); + warn!("Received incorrect USB Tablet descriptor message: {:?}", e); locked_packet.status = UsbPacketStatus::Stall; return; } @@ -301,8 +302,4 @@ impl UsbDevice for UsbTablet { fn get_controller(&self) -> Option>> { self.cntlr.clone() } - - fn get_wakeup_endpoint(&self) -> &UsbEndpoint { - self.base.get_endpoint(true, 1) - } } diff --git a/devices/src/usb/uas.rs b/devices/src/usb/uas.rs new file mode 100644 index 0000000000000000000000000000000000000000..301dff080518d086d588f79c0607f7852be1c32d --- /dev/null +++ b/devices/src/usb/uas.rs @@ -0,0 +1,1111 @@ +// Copyright (c) 2023 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::array; +use std::cmp::min; +use std::collections::HashMap; +use std::mem::size_of; +use std::sync::{Arc, Mutex, Weak}; + +use anyhow::{anyhow, bail, Context, Result}; +use clap::Parser; +use log::{debug, error, info, warn}; +use once_cell::sync::Lazy; +use strum::EnumCount; +use strum_macros::EnumCount; + +use super::config::*; +use super::descriptor::{ + UsbConfigDescriptor, UsbDescConfig, UsbDescDevice, UsbDescEndpoint, UsbDescIface, + UsbDescriptorOps, UsbDeviceDescriptor, UsbEndpointDescriptor, UsbInterfaceDescriptor, + UsbSuperSpeedCapDescriptor, UsbSuperSpeedEndpointCompDescriptor, +}; +use super::storage::{UsbStorage, UsbStorageConfig, GET_MAX_LUN, MASS_STORAGE_RESET}; +use super::xhci::xhci_controller::XhciDevice; +use super::{ + UsbDevice, UsbDeviceBase, UsbDeviceRequest, UsbPacket, UsbPacketStatus, + USB_DEVICE_BUFFER_DEFAULT_LEN, +}; +use crate::ScsiBus::{ + get_scsi_key, scsi_cdb_xfer, ScsiBus, ScsiRequest, ScsiRequestOps, ScsiSense, ScsiXferMode, + CHECK_CONDITION, EMULATE_SCSI_OPS, GOOD, SCSI_SENSE_INVALID_PARAM_VALUE, + SCSI_SENSE_INVALID_TAG, SCSI_SENSE_NO_SENSE, +}; +use crate::ScsiDisk::{ScsiDevConfig, ScsiDevice}; +use crate::{Bus, Device}; +use machine_manager::config::{DriveConfig, DriveFile}; +use util::gen_base_func; +use util::{aio::AioEngine, byte_code::ByteCode}; + +// Size of UasIUBody +const UAS_IU_BODY_SIZE: usize = 30; + +// Size of cdb in UAS Command IU +const UAS_COMMAND_CDB_SIZE: usize = 16; + +// UAS Pipe IDs +const UAS_PIPE_ID_COMMAND: u8 = 0x01; +const UAS_PIPE_ID_STATUS: u8 = 0x02; +const UAS_PIPE_ID_DATA_IN: u8 = 0x03; +const UAS_PIPE_ID_DATA_OUT: u8 = 0x04; + +// UAS Streams Attributes +const UAS_MAX_STREAMS_BM_ATTR: u8 = 4; +const UAS_MAX_STREAMS: usize = 1 << UAS_MAX_STREAMS_BM_ATTR; + +// UAS IU IDs +const UAS_IU_ID_COMMAND: u8 = 0x01; +const UAS_IU_ID_SENSE: u8 = 0x03; +const UAS_IU_ID_RESPONSE: u8 = 0x04; +const UAS_IU_ID_TASK_MGMT: u8 = 0x05; + +// UAS Response Codes +const UAS_RC_TMF_COMPLETE: u8 = 0x00; +const _UAS_RC_INVALID_IU: u8 = 0x02; +const UAS_RC_TMF_NOT_SUPPORTED: u8 = 0x04; +const _UAS_RC_TMF_FAILED: u8 = 0x05; +const _UAS_RC_TMF_SUCCEEDED: u8 = 0x08; +const _UAS_RC_INCORRECT_LUN: u8 = 0x09; +const _UAS_RC_OVERLAPPED_TAG: u8 = 0x0A; + +// UAS Task Management Functions +const UAS_TMF_ABORT_TASK: u8 = 0x01; +const _UAS_TMF_ABORT_TASK_SET: u8 = 0x02; +const _UAS_TMF_CLEAR_TASK_SET: u8 = 0x04; +const _UAS_TMF_LOGICAL_UNIT_RESET: u8 = 0x08; +const _UAS_TMF_I_T_NEXUS_RESET: u8 = 0x10; +const _UAS_TMF_CLEAR_ACA: u8 = 0x40; +const _UAS_TMF_QUERY_TASK: u8 = 0x80; +const _UAS_TMF_QUERY_TASK_SET: u8 = 0x81; +const _UAS_TMF_QUERY_ASYNC_EVENT: u8 = 0x82; + +// Interface alt settings +const UAS_ALT_SETTING_BOT: u8 = 0; +const UAS_ALT_SETTING_UAS: u8 = 1; + +#[derive(Parser, Clone, Debug)] +#[command(no_binary_name(true))] +pub struct UsbUasConfig { + #[arg(long, value_parser = ["usb-uas"])] + pub classtype: String, + #[arg(long)] + pub drive: String, + #[arg(long)] + pub id: String, + #[arg(long)] + bus: Option, + #[arg(long)] + port: Option, +} + +impl From for UsbStorageConfig { + fn from(uas_config: UsbUasConfig) -> Self { + Self { + classtype: uas_config.classtype, + id: String::new(), + drive: uas_config.drive, + bus: uas_config.bus, + port: uas_config.port, + } + } +} + +pub struct UsbUas { + base: UsbDeviceBase, + uas_config: UsbUasConfig, + scsi_bus: Arc>, + scsi_device: Option>>, + drive_cfg: DriveConfig, + drive_files: Arc>>, + commands: [Option; UAS_MAX_STREAMS + 1], + statuses: [Option>>; UAS_MAX_STREAMS + 1], + data: [Option>>; UAS_MAX_STREAMS + 1], + bot: UsbStorage, + is_bot: bool, +} + +#[derive(Debug, Default, EnumCount)] +enum UsbUasStringId { + #[default] + Invalid = 0, + Manufacturer = 1, + Product = 2, + SerialNumber = 3, + Configuration = 4, +} + +const UAS_DESC_STRINGS: [&str; UsbUasStringId::COUNT] = [ + "", + "StratoVirt", + "StratoVirt USB Uas", + "5", + "Super speed config (usb 3.0)", +]; + +struct UasRequest { + data: Option>>, + status: Arc>, + iu: UasIU, + completed: bool, +} + +impl ScsiRequestOps for UasRequest { + fn scsi_request_complete_cb( + &mut self, + scsi_status: u8, + scsi_sense: Option, + ) -> Result<()> { + let tag = u16::from_be(self.iu.header.tag); + let sense = scsi_sense.unwrap_or(SCSI_SENSE_NO_SENSE); + UsbUas::fill_sense(&mut self.status.lock().unwrap(), tag, sense, scsi_status); + self.complete(); + Ok(()) + } +} + +#[derive(Debug, PartialEq, Eq)] +enum UasPacketStatus { + Completed = 0, + Pending = 1, +} + +impl From for UasPacketStatus { + fn from(status: bool) -> Self { + match status { + true => Self::Completed, + false => Self::Pending, + } + } +} + +#[allow(non_snake_case)] +#[repr(C, packed)] +#[derive(Copy, Clone, Debug, Default)] +struct UsbPipeUsageDescriptor { + bLength: u8, + bDescriptorType: u8, + bPipeId: u8, + bReserved: u8, +} + +impl ByteCode for UsbPipeUsageDescriptor {} + +#[repr(C, packed)] +#[derive(Default, Clone, Copy)] +struct UasIUHeader { + id: u8, + reserved: u8, + tag: u16, +} + +#[repr(C, packed)] +#[derive(Default, Clone, Copy)] +struct UasIUCommand { + prio_task_attr: u8, // 6:3 priority, 2:0 task attribute + reserved_1: u8, + add_cdb_len: u8, + reserved_2: u8, + lun: u64, + cdb: [u8; UAS_COMMAND_CDB_SIZE], + add_cdb: [u8; 1], // not supported by stratovirt +} + +#[repr(C, packed)] +#[derive(Default, Clone, Copy)] +struct UasIUSense { + status_qualifier: u16, + status: u8, + reserved: [u8; 7], + sense_length: u16, + sense_data: [u8; 18], +} + +#[repr(C, packed)] +#[derive(Default, Clone, Copy)] +struct UasIUResponse { + add_response_info: [u8; 3], + response_code: u8, +} + +#[repr(C, packed)] +#[derive(Default, Clone, Copy)] +struct UasIUTaskManagement { + function: u8, + reserved: u8, + task_tag: u16, + lun: u64, +} + +#[repr(C, packed)] +#[derive(Clone, Copy)] +union UasIUBody { + command: UasIUCommand, + sense: UasIUSense, + response: UasIUResponse, + task_management: UasIUTaskManagement, + raw_data: [u8; UAS_IU_BODY_SIZE], +} + +impl Default for UasIUBody { + fn default() -> Self { + Self { + raw_data: [0; UAS_IU_BODY_SIZE], + } + } +} + +#[repr(C, packed)] +#[derive(Default, Clone, Copy)] +struct UasIU { + header: UasIUHeader, + body: UasIUBody, +} + +impl ByteCode for UasIU {} + +static DESC_DEVICE_UAS: Lazy> = Lazy::new(|| { + Arc::new(UsbDescDevice { + device_desc: UsbDeviceDescriptor { + bLength: USB_DT_DEVICE_SIZE, + bDescriptorType: USB_DT_DEVICE, + bcdUSB: 0x0300, + bDeviceClass: 0, + bDeviceSubClass: 0, + bDeviceProtocol: 0, + bMaxPacketSize0: 9, + idVendor: USB_VENDOR_ID_STRATOVIRT, + idProduct: USB_PRODUCT_ID_UAS, + bcdDevice: 0, + iManufacturer: UsbUasStringId::Manufacturer as u8, + iProduct: UsbUasStringId::Product as u8, + iSerialNumber: UsbUasStringId::SerialNumber as u8, + bNumConfigurations: 1, + }, + configs: vec![Arc::new(UsbDescConfig { + config_desc: UsbConfigDescriptor { + bLength: USB_DT_CONFIG_SIZE, + bDescriptorType: USB_DT_CONFIGURATION, + wTotalLength: 0, + bNumInterfaces: 1, + bConfigurationValue: 1, + iConfiguration: UsbUasStringId::Configuration as u8, + bmAttributes: USB_CONFIGURATION_ATTR_ONE | USB_CONFIGURATION_ATTR_SELF_POWER, + bMaxPower: 50, + }, + iad_desc: vec![], + interfaces: vec![DESC_IFACE_BOT.clone(), DESC_IFACE_UAS.clone()], + })], + }) +}); + +static DESC_IFACE_UAS: Lazy> = Lazy::new(|| { + Arc::new(UsbDescIface { + interface_desc: UsbInterfaceDescriptor { + bLength: USB_DT_INTERFACE_SIZE, + bDescriptorType: USB_DT_INTERFACE, + bInterfaceNumber: 0, + bAlternateSetting: UAS_ALT_SETTING_UAS, + bNumEndpoints: 4, + bInterfaceClass: USB_CLASS_MASS_STORAGE, + bInterfaceSubClass: USB_SUBCLASS_SCSI, + bInterfaceProtocol: USB_IFACE_PROTOCOL_UAS, + iInterface: 0, + }, + other_desc: vec![], + endpoints: vec![ + Arc::new(UsbDescEndpoint { + endpoint_desc: UsbEndpointDescriptor { + bLength: USB_DT_ENDPOINT_SIZE, + bDescriptorType: USB_DT_ENDPOINT, + bEndpointAddress: USB_DIRECTION_HOST_TO_DEVICE | UAS_PIPE_ID_COMMAND, + bmAttributes: USB_ENDPOINT_ATTR_BULK, + wMaxPacketSize: 1024, + bInterval: 0, + }, + extra: [ + UsbSuperSpeedEndpointCompDescriptor { + bLength: USB_DT_SS_EP_COMP_SIZE, + bDescriptorType: USB_DT_ENDPOINT_COMPANION, + bMaxBurst: 15, + bmAttributes: 0, + wBytesPerInterval: 0, + } + .as_bytes(), + UsbPipeUsageDescriptor { + bLength: USB_DT_PIPE_USAGE_SIZE, + bDescriptorType: USB_DT_PIPE_USAGE, + bPipeId: UAS_PIPE_ID_COMMAND, + bReserved: 0, + } + .as_bytes(), + ] + .concat() + .to_vec(), + }), + Arc::new(UsbDescEndpoint { + endpoint_desc: UsbEndpointDescriptor { + bLength: USB_DT_ENDPOINT_SIZE, + bDescriptorType: USB_DT_ENDPOINT, + bEndpointAddress: USB_DIRECTION_DEVICE_TO_HOST | UAS_PIPE_ID_STATUS, + bmAttributes: USB_ENDPOINT_ATTR_BULK, + wMaxPacketSize: 1024, + bInterval: 0, + }, + extra: [ + UsbSuperSpeedEndpointCompDescriptor { + bLength: USB_DT_SS_EP_COMP_SIZE, + bDescriptorType: USB_DT_ENDPOINT_COMPANION, + bMaxBurst: 15, + bmAttributes: UAS_MAX_STREAMS_BM_ATTR, + wBytesPerInterval: 0, + } + .as_bytes(), + UsbPipeUsageDescriptor { + bLength: USB_DT_PIPE_USAGE_SIZE, + bDescriptorType: USB_DT_PIPE_USAGE, + bPipeId: UAS_PIPE_ID_STATUS, + bReserved: 0, + } + .as_bytes(), + ] + .concat() + .to_vec(), + }), + Arc::new(UsbDescEndpoint { + endpoint_desc: UsbEndpointDescriptor { + bLength: USB_DT_ENDPOINT_SIZE, + bDescriptorType: USB_DT_ENDPOINT, + bEndpointAddress: USB_DIRECTION_DEVICE_TO_HOST | UAS_PIPE_ID_DATA_IN, + bmAttributes: USB_ENDPOINT_ATTR_BULK, + wMaxPacketSize: 1024, + bInterval: 0, + }, + extra: [ + UsbSuperSpeedEndpointCompDescriptor { + bLength: USB_DT_SS_EP_COMP_SIZE, + bDescriptorType: USB_DT_ENDPOINT_COMPANION, + bMaxBurst: 15, + bmAttributes: UAS_MAX_STREAMS_BM_ATTR, + wBytesPerInterval: 0, + } + .as_bytes(), + UsbPipeUsageDescriptor { + bLength: USB_DT_PIPE_USAGE_SIZE, + bDescriptorType: USB_DT_PIPE_USAGE, + bPipeId: UAS_PIPE_ID_DATA_IN, + bReserved: 0, + } + .as_bytes(), + ] + .concat() + .to_vec(), + }), + Arc::new(UsbDescEndpoint { + endpoint_desc: UsbEndpointDescriptor { + bLength: USB_DT_ENDPOINT_SIZE, + bDescriptorType: USB_DT_ENDPOINT, + bEndpointAddress: USB_DIRECTION_HOST_TO_DEVICE | UAS_PIPE_ID_DATA_OUT, + bmAttributes: USB_ENDPOINT_ATTR_BULK, + wMaxPacketSize: 1024, + bInterval: 0, + }, + extra: [ + UsbSuperSpeedEndpointCompDescriptor { + bLength: USB_DT_SS_EP_COMP_SIZE, + bDescriptorType: USB_DT_ENDPOINT_COMPANION, + bMaxBurst: 15, + bmAttributes: UAS_MAX_STREAMS_BM_ATTR, + wBytesPerInterval: 0, + } + .as_bytes(), + UsbPipeUsageDescriptor { + bLength: USB_DT_PIPE_USAGE_SIZE, + bDescriptorType: USB_DT_PIPE_USAGE, + bPipeId: UAS_PIPE_ID_DATA_OUT, + bReserved: 0, + } + .as_bytes(), + ] + .concat() + .to_vec(), + }), + ], + }) +}); + +static DESC_IFACE_BOT: Lazy> = Lazy::new(|| { + Arc::new(UsbDescIface { + interface_desc: UsbInterfaceDescriptor { + bLength: USB_DT_INTERFACE_SIZE, + bDescriptorType: USB_DT_INTERFACE, + bInterfaceNumber: 0, + bAlternateSetting: UAS_ALT_SETTING_BOT, + bNumEndpoints: 2, + bInterfaceClass: USB_CLASS_MASS_STORAGE, + bInterfaceSubClass: USB_SUBCLASS_SCSI, + bInterfaceProtocol: USB_IFACE_PROTOCOL_BOT, + iInterface: 0, + }, + other_desc: vec![], + endpoints: vec![ + Arc::new(UsbDescEndpoint { + endpoint_desc: UsbEndpointDescriptor { + bLength: USB_DT_ENDPOINT_SIZE, + bDescriptorType: USB_DT_ENDPOINT, + bEndpointAddress: USB_DIRECTION_DEVICE_TO_HOST | 0x01, + bmAttributes: USB_ENDPOINT_ATTR_BULK, + wMaxPacketSize: 1024, + bInterval: 0, + }, + extra: UsbSuperSpeedEndpointCompDescriptor { + bLength: USB_DT_SS_EP_COMP_SIZE, + bDescriptorType: USB_DT_ENDPOINT_COMPANION, + bMaxBurst: 15, + bmAttributes: 0, + wBytesPerInterval: 0, + } + .as_bytes() + .to_vec(), + }), + Arc::new(UsbDescEndpoint { + endpoint_desc: UsbEndpointDescriptor { + bLength: USB_DT_ENDPOINT_SIZE, + bDescriptorType: USB_DT_ENDPOINT, + bEndpointAddress: USB_DIRECTION_HOST_TO_DEVICE | 0x02, + bmAttributes: USB_ENDPOINT_ATTR_BULK, + wMaxPacketSize: 1024, + bInterval: 0, + }, + extra: UsbSuperSpeedEndpointCompDescriptor { + bLength: USB_DT_SS_EP_COMP_SIZE, + bDescriptorType: USB_DT_ENDPOINT_COMPANION, + bMaxBurst: 15, + bmAttributes: 0, + wBytesPerInterval: 0, + } + .as_bytes() + .to_vec(), + }), + ], + }) +}); + +static DESC_CAP_UAS: UsbSuperSpeedCapDescriptor = UsbSuperSpeedCapDescriptor { + bLength: USB_DT_SS_CAP_SIZE, + bDescriptorType: USB_DT_DEVICE_CAPABILITY, + bDevCapabilityType: USB_SS_DEVICE_CAP, + bmAttributes: 0, + wSpeedsSupported: USB_SS_DEVICE_SPEED_SUPPORTED_SUPER | USB_SS_DEVICE_SPEED_SUPPORTED_HIGH, + bFunctionalitySupport: USB_SS_DEVICE_FUNCTIONALITY_SUPPORT_HIGH, + bU1DevExitLat: 0xA, + wU2DevExitLat: 0x20, +}; + +fn complete_async_packet(packet: &Arc>) { + let locked_packet = packet.lock().unwrap(); + + if let Some(xfer_ops) = locked_packet.xfer_ops.as_ref() { + if let Some(xfer_ops) = xfer_ops.clone().upgrade() { + drop(locked_packet); + xfer_ops.lock().unwrap().submit_transfer(); + } + } +} + +impl UsbUas { + pub fn new( + uas_config: UsbUasConfig, + drive_cfg: DriveConfig, + drive_files: Arc>>, + ) -> Result { + if drive_cfg.aio != AioEngine::Off || drive_cfg.direct { + bail!("USB UAS: \"aio=off,direct=false\" must be configured."); + } + + Ok(Self { + base: UsbDeviceBase::new(uas_config.id.clone(), USB_DEVICE_BUFFER_DEFAULT_LEN), + uas_config: uas_config.clone(), + scsi_bus: Arc::new(Mutex::new(ScsiBus::new("".to_string()))), + scsi_device: None, + drive_cfg: drive_cfg.clone(), + drive_files: drive_files.clone(), + commands: array::from_fn(|_| None), + statuses: array::from_fn(|_| None), + data: array::from_fn(|_| None), + bot: UsbStorage::new(uas_config.into(), drive_cfg, drive_files)?, + is_bot: true, + }) + } + + fn cancel_io(&mut self) { + self.commands = array::from_fn(|_| None); + self.statuses = array::from_fn(|_| None); + self.data = array::from_fn(|_| None); + } + + /// Class (Mass Storage) specific requests. + fn handle_control_for_device(&mut self, packet: &mut UsbPacket, device_req: &UsbDeviceRequest) { + match device_req.request_type { + USB_ENDPOINT_OUT_REQUEST => { + if device_req.request == USB_REQUEST_CLEAR_FEATURE { + return; + } + } + USB_INTERFACE_CLASS_OUT_REQUEST => { + // NOTE: See USB Mass Storage Class specification: 3.1 Bulk-Only Mass Storage Reset + if device_req.request == MASS_STORAGE_RESET { + // Set storage state mode. + self.bot.handle_control_packet(packet, device_req); + self.cancel_io(); + return; + } + } + USB_INTERFACE_CLASS_IN_REQUEST => { + // NOTE: See USB Mass Storage Class specification: 3.2 Get Max LUN + if device_req.request == GET_MAX_LUN { + // Now only supports 1 LUN. + self.base.data_buf[0] = 0; + packet.actual_length = 1; + return; + } + } + _ => (), + } + + error!( + "UAS {} device unhandled control request {:?}.", + self.device_id(), + device_req + ); + packet.status = UsbPacketStatus::Stall; + } + + fn handle_iu_command( + &mut self, + iu: &UasIU, + mut uas_request: UasRequest, + ) -> Result { + // SAFETY: IU is guaranteed to be of type command. + let add_cdb_len = unsafe { iu.body.command.add_cdb_len }; + let tag = u16::from_be(iu.header.tag); + + if add_cdb_len > 0 { + Self::fill_fake_sense( + &mut uas_request.status.lock().unwrap(), + tag, + SCSI_SENSE_INVALID_PARAM_VALUE, + ); + uas_request.complete(); + bail!("additional cdb length is not supported"); + } + + if tag > UAS_MAX_STREAMS as u16 { + Self::fill_fake_sense( + &mut uas_request.status.lock().unwrap(), + tag, + SCSI_SENSE_INVALID_TAG, + ); + uas_request.complete(); + bail!("invalid tag {}", tag); + } + + let (scsi_iovec, scsi_iovec_size) = match &uas_request.data { + Some(data) => { + let mut locked_data = data.lock().unwrap(); + let iov_size = locked_data.get_iovecs_size() as u32; + locked_data.actual_length = iov_size; + (locked_data.iovecs.clone(), iov_size) + } + None => (Vec::new(), 0), + }; + + // SAFETY: IU is guaranteed to of type command. + let cdb = unsafe { iu.body.command.cdb }; + // SAFETY: IU is guaranteed to of type command. + let lun = unsafe { iu.body.command.lun } as u16; + trace::usb_uas_handle_iu_command(self.device_id(), cdb[0]); + let uas_request = Box::new(uas_request); + let scsi_request = ScsiRequest::new( + cdb, + lun, + scsi_iovec, + scsi_iovec_size, + self.scsi_device.as_ref().unwrap().clone(), + uas_request, + ) + .with_context(|| "failed to create SCSI request")?; + + if scsi_request.cmd.xfer > u64::from(scsi_request.datalen) + && scsi_request.cmd.mode != ScsiXferMode::ScsiXferNone + { + bail!( + "insufficient buffer provided (requested length {}, provided length {})", + scsi_request.cmd.xfer, + scsi_request.datalen + ); + } + + let scsi_request = match scsi_request.opstype { + EMULATE_SCSI_OPS => scsi_request.emulate_execute(), + _ => scsi_request.execute(), + } + .with_context(|| "failed to execute SCSI request")?; + + let upper_request = &mut scsi_request.lock().unwrap().upper_req; + let uas_request = upper_request + .as_mut() + .as_any_mut() + .downcast_mut::() + .unwrap(); + + Ok(uas_request.completed.into()) + } + + fn handle_iu_task_management( + &mut self, + iu: &UasIU, + mut uas_request: UasRequest, + ) -> Result { + let tag = u16::from_be(iu.header.tag); + + if tag > UAS_MAX_STREAMS as u16 { + Self::fill_fake_sense( + &mut uas_request.status.lock().unwrap(), + tag, + SCSI_SENSE_INVALID_TAG, + ); + uas_request.complete(); + bail!("invalid tag {}", tag); + } + + // SAFETY: IU is guaranteed to be of type task management. + let tmf = unsafe { iu.body.task_management.function }; + trace::usb_uas_handle_iu_task_management(self.device_id(), tmf, tag); + + match tmf { + UAS_TMF_ABORT_TASK => { + // SAFETY: IU is guaranteed to be of type task management. + let task_tag = unsafe { iu.body.task_management.task_tag } as usize; + self.commands[task_tag] = None; + self.statuses[task_tag] = None; + self.data[task_tag] = None; + trace::usb_uas_tmf_abort_task(self.device_id(), task_tag); + Self::fill_response( + &mut uas_request.status.lock().unwrap(), + tag, + UAS_RC_TMF_COMPLETE, + ); + } + _ => { + warn!("UAS {} device unsupported TMF {}.", self.device_id(), tmf); + Self::fill_response( + &mut uas_request.status.lock().unwrap(), + tag, + UAS_RC_TMF_NOT_SUPPORTED, + ); + } + }; + + uas_request.complete(); + Ok(UasPacketStatus::Completed) + } + + fn fill_response(packet: &mut UsbPacket, tag: u16, code: u8) { + let mut iu = UasIU::new(UAS_IU_ID_RESPONSE, tag); + iu.body.response.response_code = code; + let iu_len = size_of::() + size_of::(); + Self::fill_packet(packet, &mut iu, iu_len); + } + + fn fill_fake_sense(packet: &mut UsbPacket, tag: u16, sense: ScsiSense) { + let mut iu = UasIU::new(UAS_IU_ID_SENSE, tag); + // SAFETY: IU is guaranteed to be of type status. + let iu_sense = unsafe { &mut iu.body.sense }; + + iu_sense.status = CHECK_CONDITION; + iu_sense.status_qualifier = 0_u16.to_be(); + iu_sense.sense_length = 18_u16.to_be(); + iu_sense.sense_data[0] = 0x70; // Error code: current errors + iu_sense.sense_data[2] = sense.key; + iu_sense.sense_data[7] = 10; // Additional sense length: total length - 8 + iu_sense.sense_data[12] = sense.asc; + iu_sense.sense_data[13] = sense.ascq; + + let iu_len = size_of::() + size_of::(); + trace::usb_uas_fill_fake_sense(CHECK_CONDITION, iu_len, iu_sense.sense_length as usize); + Self::fill_packet(packet, &mut iu, iu_len); + } + + fn fill_sense(packet: &mut UsbPacket, tag: u16, sense: ScsiSense, status: u8) { + let mut iu = UasIU::new(UAS_IU_ID_SENSE, tag); + // SAFETY: IU is guaranteed to be of type status. + let iu_sense = unsafe { &mut iu.body.sense }; + + iu_sense.status = status; + iu_sense.status_qualifier = 0_u16.to_be(); + iu_sense.sense_length = 0_u16.to_be(); + + if status != GOOD { + iu_sense.sense_length = 18_u16.to_be(); + iu_sense.sense_data[0] = 0x71; // Error code: deferred errors + iu_sense.sense_data[2] = sense.key; + iu_sense.sense_data[7] = 10; // Additional sense length: total length - 8 + iu_sense.sense_data[12] = sense.asc; + iu_sense.sense_data[13] = sense.ascq; + } + + let sense_len = + size_of::() - iu_sense.sense_data.len() + iu_sense.sense_length as usize; + let iu_len = size_of::() + sense_len; + trace::usb_uas_fill_sense(status, iu_len, iu_sense.sense_length as usize); + Self::fill_packet(packet, &mut iu, iu_len); + } + + fn fill_packet(packet: &mut UsbPacket, iu: &mut UasIU, iu_len: usize) { + let iov_size = packet.get_iovecs_size() as usize; + let iu_len = min(iov_size, iu_len); + trace::usb_uas_fill_packet(iov_size); + packet.transfer_packet(iu.as_mut_bytes(), iu_len); + } + + fn try_start_next_transfer(&mut self, stream: usize) -> UasPacketStatus { + if self.commands[stream].is_none() { + debug!( + "UAS {} device no inflight command on stream {}.", + self.device_id(), + stream + ); + return UasPacketStatus::Pending; + } + + if self.statuses[stream].is_none() { + debug!( + "UAS {} device no inflight status on stream {}.", + self.device_id(), + stream + ); + return UasPacketStatus::Pending; + } + + // SAFETY: Command was checked to be Some. + let command = self.commands[stream].as_ref().unwrap(); + // SAFETY: IU is guaranteed to be of type command. + let cdb = unsafe { &command.body.command.cdb }; + let xfer_len = scsi_cdb_xfer(cdb, self.scsi_device.as_ref().unwrap().clone()); + trace::usb_uas_try_start_next_transfer(self.device_id(), xfer_len); + + if xfer_len == 0 { + return self.start_next_transfer(stream); + } + + if self.data[stream].is_some() { + self.start_next_transfer(stream) + } else { + debug!( + "UAS {} device no inflight data on stream {}.", + self.device_id(), + stream + ); + UasPacketStatus::Pending + } + } + + fn start_next_transfer(&mut self, stream: usize) -> UasPacketStatus { + trace::usb_uas_start_next_transfer(self.device_id(), stream); + // SAFETY: Status and command must have been checked in try_start_next_transfer. + let status = self.statuses[stream].take().unwrap(); + let command = self.commands[stream].take().unwrap(); + let mut uas_request = UasRequest::new(&status, &command); + uas_request.data = self.data[stream].take(); + + let result = match command.header.id { + UAS_IU_ID_COMMAND => self.handle_iu_command(&command, uas_request), + UAS_IU_ID_TASK_MGMT => self.handle_iu_task_management(&command, uas_request), + _ => Err(anyhow!("impossible command IU {}", command.header.id)), + }; + + match result { + Ok(result) => result, + Err(err) => { + error!("UAS {} device error: {:?}.", self.device_id(), err); + UasPacketStatus::Completed + } + } + } +} + +impl UsbDevice for UsbUas { + gen_base_func!(usb_device_base, usb_device_base_mut, UsbDeviceBase, base); + + fn realize(mut self) -> Result>> { + info!("UAS {} device realize.", self.device_id()); + self.base.reset_usb_endpoint(); + self.base.speed = USB_SPEED_SUPER; + let mut s: Vec = UAS_DESC_STRINGS.iter().map(|&s| s.to_string()).collect(); + let prefix = &s[UsbUasStringId::SerialNumber as usize]; + s[UsbUasStringId::SerialNumber as usize] = self.base.generate_serial_number(prefix); + self.base.init_descriptor(DESC_DEVICE_UAS.clone(), s)?; + self.base.set_capability_descriptors(vec![DESC_CAP_UAS]); + + // NOTE: "aio=off,direct=false" must be configured and other aio/direct values are not + // supported. + let scsidev_classtype = match self.drive_cfg.media.as_str() { + "disk" => "scsi-hd".to_string(), + _ => "scsi-cd".to_string(), + }; + let scsi_dev_cfg = ScsiDevConfig { + classtype: scsidev_classtype, + drive: self.uas_config.drive.clone(), + ..Default::default() + }; + let scsi_device = ScsiDevice::new( + scsi_dev_cfg, + self.drive_cfg.clone(), + self.drive_files.clone(), + None, + self.scsi_bus.clone(), + ); + let realized_scsi = scsi_device.realize()?; + self.scsi_device = Some(realized_scsi.clone()); + self.scsi_bus + .lock() + .unwrap() + .attach_child(get_scsi_key(0, 0), realized_scsi)?; + + self.bot.do_realize()?; + let uas = Arc::new(Mutex::new(self)); + Ok(uas) + } + + fn cancel_packet(&mut self, _packet: &Arc>) { + self.cancel_io(); + } + + fn reset(&mut self) { + info!("UAS {} device reset.", self.device_id()); + self.base.remote_wakeup = 0; + self.base.addr = 0; + self.cancel_io(); + // Reset storage state. + self.bot.reset(); + } + + fn handle_control(&mut self, packet: &Arc>, device_req: &UsbDeviceRequest) { + let mut locked_packet = packet.lock().unwrap(); + trace::usb_uas_handle_control( + locked_packet.packet_id, + self.device_id(), + device_req.as_bytes(), + ); + + if device_req.request_type == USB_INTERFACE_OUT_REQUEST + && device_req.request == USB_REQUEST_SET_INTERFACE + { + self.is_bot = device_req.value != UAS_ALT_SETTING_UAS as u16; + } + + match self + .base + .handle_control_for_descriptor(&mut locked_packet, device_req) + { + Ok(handled) => { + if handled { + debug!( + "UAS {} device control handled by descriptor, return directly.", + self.device_id() + ); + return; + } + + self.handle_control_for_device(&mut locked_packet, device_req); + } + Err(err) => { + warn!( + "{} received incorrect UAS descriptor message: {:?}", + self.device_id(), + err + ); + locked_packet.status = UsbPacketStatus::Stall; + } + } + } + + fn handle_data(&mut self, packet: &Arc>) { + if self.is_bot { + return self.bot.handle_data(packet); + } + + let locked_packet = packet.lock().unwrap(); + let stream = locked_packet.stream as usize; + let ep_number = locked_packet.ep_number; + let packet_id = locked_packet.packet_id; + trace::usb_uas_handle_data(self.device_id(), ep_number, stream); + drop(locked_packet); + + if stream > UAS_MAX_STREAMS || ep_number != UAS_PIPE_ID_COMMAND && stream == 0 { + warn!("UAS {} device invalid stream {}.", self.device_id(), stream); + packet.lock().unwrap().status = UsbPacketStatus::Stall; + return; + } + + // NOTE: The architecture of this device is rather simple: it first waits for all of the + // required USB packets to arrive, and only then creates and sends an actual UAS request. + // The number of USB packets differs from 2 to 3 and depends on whether the command involves + // data transfers or not. Since the packets arrive in arbitrary order, some of them may be + // queued asynchronously. Note that the command packet is always completed right away. For + // all the other types of packets, their asynchronous status is determined by the return + // value of try_start_next_transfer(). All the asynchronously queued packets will be + // completed in scsi_request_complete_cb() callback. + match ep_number { + UAS_PIPE_ID_COMMAND => { + let mut locked_packet = packet.lock().unwrap(); + let mut iu = UasIU::default(); + let iov_size = locked_packet.get_iovecs_size() as usize; + let iu_len = min(iov_size, size_of::()); + locked_packet.transfer_packet(iu.as_mut_bytes(), iu_len); + let stream = u16::from_be(iu.header.tag) as usize; + + if self.commands[stream].is_some() { + warn!( + "UAS {} device multiple command packets on stream {}.", + self.device_id(), + stream + ); + packet.lock().unwrap().status = UsbPacketStatus::Stall; + return; + } + + trace::usb_uas_command_received(packet_id, self.device_id()); + self.commands[stream] = Some(iu); + self.try_start_next_transfer(stream); + trace::usb_uas_command_completed(packet_id, self.device_id()); + } + UAS_PIPE_ID_STATUS => { + if self.statuses[stream].is_some() { + warn!( + "UAS {} device multiple status packets on stream {}.", + self.device_id(), + stream + ); + packet.lock().unwrap().status = UsbPacketStatus::Stall; + return; + } + + trace::usb_uas_status_received(packet_id, self.device_id()); + self.statuses[stream] = Some(Arc::clone(packet)); + let result = self.try_start_next_transfer(stream); + + match result { + UasPacketStatus::Completed => { + trace::usb_uas_status_completed(packet_id, self.device_id()) + } + UasPacketStatus::Pending => { + packet.lock().unwrap().is_async = true; + trace::usb_uas_status_queued_async(packet_id, self.device_id()); + } + } + } + UAS_PIPE_ID_DATA_OUT | UAS_PIPE_ID_DATA_IN => { + if self.data[stream].is_some() { + warn!( + "UAS {} device multiple data packets on stream {}.", + self.device_id(), + stream + ); + packet.lock().unwrap().status = UsbPacketStatus::Stall; + return; + } + + trace::usb_uas_data_received(packet_id, self.device_id()); + self.data[stream] = Some(Arc::clone(packet)); + let result = self.try_start_next_transfer(stream); + + match result { + UasPacketStatus::Completed => { + trace::usb_uas_data_completed(packet_id, self.device_id()) + } + UasPacketStatus::Pending => { + packet.lock().unwrap().is_async = true; + trace::usb_uas_data_queued_async(packet_id, self.device_id()); + } + } + } + _ => { + error!( + "UAS {} device bad endpoint number {}.", + self.device_id(), + ep_number + ); + } + } + } + + fn set_controller(&mut self, _cntlr: std::sync::Weak>) {} + + fn get_controller(&self) -> Option>> { + None + } +} + +impl UasRequest { + fn new(status: &Arc>, iu: &UasIU) -> Self { + Self { + data: None, + status: Arc::clone(status), + iu: *iu, + completed: false, + } + } + + fn complete(&mut self) { + let status = &self.status; + let status_async = status.lock().unwrap().is_async; + + // NOTE: Due to the specifics of this device, it waits for all of the required USB packets + // to arrive before starting an actual transfer. Therefore, some packets may arrive earlier + // than others, and they won't be completed right away (except for the command packets), but + // rather queued asynchronously. A certain packet may also be async if it was the last to + // arrive, but UasRequest didn't complete right away. + if status_async { + complete_async_packet(status); + } + + if let Some(data) = &self.data { + let data_async = data.lock().unwrap().is_async; + + if data_async { + complete_async_packet(data); + } + } + + self.completed = true; + } +} + +impl UasIUHeader { + fn new(id: u8, tag: u16) -> Self { + UasIUHeader { + id, + reserved: 0, + tag: tag.to_be(), + } + } +} + +impl UasIU { + fn new(id: u8, tag: u16) -> Self { + Self { + header: UasIUHeader::new(id, tag), + body: UasIUBody::default(), + } + } +} diff --git a/devices/src/usb/usbhost/host_usblib.rs b/devices/src/usb/usbhost/host_usblib.rs index 2ee7a9deac50b7fda7154bda0db5143772d0cfdc..8c5542afb26a708b42af12913d8af874b1e9353a 100644 --- a/devices/src/usb/usbhost/host_usblib.rs +++ b/devices/src/usb/usbhost/host_usblib.rs @@ -11,11 +11,16 @@ // See the Mulan PSL v2 for more details. use std::{ + iter::Iterator, + os::unix::io::{AsRawFd, RawFd}, rc::Rc, + slice, sync::{Arc, Mutex}, }; -use libc::{c_int, c_uint, c_void, EPOLLIN, EPOLLOUT}; +use libc::{c_int, c_short, c_uint, c_void, EPOLLIN, EPOLLOUT}; +#[cfg(all(target_arch = "aarch64", target_env = "ohos"))] +use libusb1_sys::{constants::LIBUSB_SUCCESS, libusb_context, libusb_set_option}; use libusb1_sys::{ constants::{ LIBUSB_ERROR_ACCESS, LIBUSB_ERROR_BUSY, LIBUSB_ERROR_INTERRUPTED, @@ -25,7 +30,8 @@ use libusb1_sys::{ LIBUSB_TRANSFER_COMPLETED, LIBUSB_TRANSFER_ERROR, LIBUSB_TRANSFER_NO_DEVICE, LIBUSB_TRANSFER_STALL, LIBUSB_TRANSFER_TIMED_OUT, LIBUSB_TRANSFER_TYPE_ISOCHRONOUS, }, - libusb_get_pollfds, libusb_iso_packet_descriptor, libusb_pollfd, libusb_transfer, + libusb_free_pollfds, libusb_get_pollfds, libusb_iso_packet_descriptor, libusb_pollfd, + libusb_transfer, }; use log::error; use rusb::{Context, DeviceHandle, Error, Result, TransferType, UsbContext}; @@ -114,41 +120,28 @@ pub fn map_packet_status(status: i32) -> UsbPacketStatus { } } -pub fn get_libusb_pollfds(usbhost: Arc>) -> *const *mut libusb_pollfd { - // SAFETY: call C library of libusb to get pointer of poll fd. - unsafe { libusb_get_pollfds(usbhost.lock().unwrap().context.as_raw()) } -} - pub fn set_pollfd_notifiers( - poll: *const *mut libusb_pollfd, + pollfds: PollFds, notifiers: &mut Vec, handler: Rc, ) { - let mut i = 0; - // SAFETY: have checked whether the pointer is null before dereference it. - unsafe { - loop { - if (*poll.offset(i)).is_null() { - break; - }; - if (*(*poll.offset(i))).events as c_int == EPOLLIN { - notifiers.push(EventNotifier::new( - NotifierOperation::AddShared, - (*(*poll.offset(i))).fd, - None, - EventSet::IN, - vec![handler.clone()], - )); - } else if (*(*poll.offset(i))).events as c_int == EPOLLOUT { - notifiers.push(EventNotifier::new( - NotifierOperation::AddShared, - (*(*poll.offset(i))).fd, - None, - EventSet::OUT, - vec![handler.clone()], - )); - } - i += 1; + for pollfd in pollfds.iter() { + if i32::from(pollfd.events()) == EPOLLIN { + notifiers.push(EventNotifier::new( + NotifierOperation::AddShared, + pollfd.as_raw_fd(), + None, + EventSet::IN, + vec![handler.clone()], + )); + } else if i32::from(pollfd.events()) == EPOLLOUT { + notifiers.push(EventNotifier::new( + NotifierOperation::AddShared, + pollfd.as_raw_fd(), + None, + EventSet::OUT, + vec![handler.clone()], + )); } } } @@ -380,3 +373,103 @@ pub fn free_host_transfer(transfer: *mut libusb_transfer) { // SAFETY: have checked the validity of transfer before call libusb_free_transfer. unsafe { libusb1_sys::libusb_free_transfer(transfer) }; } + +#[cfg(all(target_arch = "aarch64", target_env = "ohos"))] +pub fn set_option(opt: u32) -> Result<()> { + // SAFETY: This function will only configure a specific option within libusb, null for ctx is valid. + let err = unsafe { + libusb_set_option( + std::ptr::null_mut() as *mut libusb_context, + opt, + std::ptr::null_mut() as *mut c_void, + ) + }; + if err != LIBUSB_SUCCESS { + return Err(from_libusb(err)); + } + + Ok(()) +} + +#[derive(Debug)] +pub struct PollFd { + fd: c_int, + events: c_short, +} + +impl PollFd { + unsafe fn from_raw(raw: *mut libusb_pollfd) -> Self { + Self { + fd: (*raw).fd, + events: (*raw).events, + } + } + + pub fn events(&self) -> c_short { + self.events + } +} + +impl AsRawFd for PollFd { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +pub struct PollFds { + poll_fds: *const *mut libusb_pollfd, +} + +impl PollFds { + pub unsafe fn new(usbhost: Arc>) -> Result { + let poll_fds = libusb_get_pollfds(usbhost.lock().unwrap().context.as_raw()); + if poll_fds.is_null() { + Err(Error::NotFound) + } else { + Ok(Self { poll_fds }) + } + } + + pub fn iter(&self) -> PollFdIter { + let mut len: usize = 0; + // SAFETY: self.poll_fds is acquired from libusb_get_pollfds which is guaranteed to be valid. + unsafe { + while !(*self.poll_fds.add(len)).is_null() { + len += 1; + } + PollFdIter { + fds: slice::from_raw_parts(self.poll_fds, len), + index: 0, + } + } + } +} + +impl Drop for PollFds { + fn drop(&mut self) { + // SAFETY: self.poll_fds is acquired from libusb_get_pollfds which is guaranteed to be valid. + unsafe { + libusb_free_pollfds(self.poll_fds); + } + } +} + +pub struct PollFdIter<'a> { + fds: &'a [*mut libusb_pollfd], + index: usize, +} + +impl<'a> Iterator for PollFdIter<'a> { + type Item = PollFd; + + fn next(&mut self) -> Option { + if self.index < self.fds.len() { + // SAFETY: self.fds is guaranteed to be valid. + let poll_fd = unsafe { PollFd::from_raw(self.fds[self.index]) }; + self.index += 1; + Some(poll_fd) + } else { + None + } + } +} diff --git a/devices/src/usb/usbhost/mod.rs b/devices/src/usb/usbhost/mod.rs index e34d1176c3d8272c523d13a9a5eac29c3119735a..3c94ceb14db5b166af98f32dc9ca88f625b943bb 100644 --- a/devices/src/usb/usbhost/mod.rs +++ b/devices/src/usb/usbhost/mod.rs @@ -11,6 +11,8 @@ // See the Mulan PSL v2 for more details. mod host_usblib; +#[cfg(all(target_arch = "aarch64", target_env = "ohos"))] +mod ohusb; use std::{ collections::LinkedList, @@ -20,7 +22,9 @@ use std::{ time::Duration, }; -use anyhow::{anyhow, bail, Result}; +#[cfg(not(all(target_arch = "aarch64", target_env = "ohos")))] +use anyhow::Context as anyhowContext; +use anyhow::{anyhow, Result}; use clap::Parser; use libc::c_int; use libusb1_sys::{ @@ -50,12 +54,13 @@ use machine_manager::{ event_loop::{register_event_helper, unregister_event_helper}, temp_cleaner::{ExitNotifier, TempCleaner}, }; -use util::{ - byte_code::ByteCode, - link_list::{List, Node}, - loop_context::{EventNotifier, EventNotifierHelper, NotifierCallback}, - num_ops::str_to_num, -}; +#[cfg(all(target_arch = "aarch64", target_env = "ohos"))] +use ohusb::OhUsbDev; +use util::byte_code::ByteCode; +use util::gen_base_func; +use util::link_list::{List, Node}; +use util::loop_context::{EventNotifier, EventNotifierHelper, NotifierCallback}; +use util::num_ops::str_to_num; const NON_ISO_PACKETS_NUMS: c_int = 0; const HANDLE_TIMEOUT_MS: u64 = 2; @@ -238,7 +243,7 @@ impl IsoTransfer { pub fn copy_data(&mut self, packet: Arc>, ep_max_packet_size: u32) -> bool { let mut lockecd_packet = packet.lock().unwrap(); let mut size: usize; - if lockecd_packet.pid == USB_TOKEN_OUT as u32 { + if lockecd_packet.pid == u32::from(USB_TOKEN_OUT) { size = lockecd_packet.get_iovecs_size() as usize; if size > ep_max_packet_size as usize { size = ep_max_packet_size as usize; @@ -357,10 +362,12 @@ impl IsoQueue { } #[derive(Parser, Clone, Debug, Default)] -#[command(name = "usb_host")] +#[command(no_binary_name(true))] pub struct UsbHostConfig { + #[arg(long)] + pub classtype: String, #[arg(long, value_parser = valid_id)] - id: String, + pub id: String, #[arg(long, default_value = "0")] hostbus: u8, #[arg(long, default_value = "0", value_parser = clap::value_parser!(u8).range(..=USBHOST_ADDR_MAX))] @@ -394,14 +401,14 @@ pub struct UsbHost { /// Configuration interface number. ifs_num: u8, ifs: [InterfaceStatus; USB_MAX_INTERFACES as usize], - /// Callback for release dev to Host after the vm exited. - exit: Option>, /// All pending asynchronous usb request. requests: Arc>>, /// ISO queues corresponding to all endpoints. iso_queues: Arc>>>>, iso_urb_frames: u32, iso_urb_count: u32, + #[cfg(all(target_arch = "aarch64", target_env = "ohos"))] + oh_dev: OhUsbDev, } // SAFETY: Send and Sync is not auto-implemented for util::link_list::List. @@ -412,6 +419,9 @@ unsafe impl Send for UsbHost {} impl UsbHost { pub fn new(config: UsbHostConfig) -> Result { + #[cfg(all(target_arch = "aarch64", target_env = "ohos"))] + let oh_dev = OhUsbDev::new(config.hostbus, config.hostaddr)?; + let mut context = Context::new()?; context.set_log_level(rusb::LogLevel::None); let iso_urb_frames = config.iso_urb_frames; @@ -427,14 +437,16 @@ impl UsbHost { ifs_num: 0, ifs: [InterfaceStatus::default(); USB_MAX_INTERFACES as usize], base: UsbDeviceBase::new(id, USB_HOST_BUFFER_LEN), - exit: None, requests: Arc::new(Mutex::new(List::new())), iso_queues: Arc::new(Mutex::new(LinkedList::new())), iso_urb_frames, iso_urb_count, + #[cfg(all(target_arch = "aarch64", target_env = "ohos"))] + oh_dev, }) } + #[cfg(not(all(target_arch = "aarch64", target_env = "ohos")))] fn find_libdev(&self) -> Option> { if self.config.vendorid != 0 && self.config.productid != 0 { self.find_dev_by_vendor_product() @@ -447,6 +459,7 @@ impl UsbHost { } } + #[cfg(not(all(target_arch = "aarch64", target_env = "ohos")))] fn find_dev_by_bus_addr(&self) -> Option> { self.context .devices() @@ -463,6 +476,7 @@ impl UsbHost { .unwrap_or_else(|| None) } + #[cfg(not(all(target_arch = "aarch64", target_env = "ohos")))] fn find_dev_by_vendor_product(&self) -> Option> { self.context .devices() @@ -480,6 +494,7 @@ impl UsbHost { .unwrap_or_else(|| None) } + #[cfg(not(all(target_arch = "aarch64", target_env = "ohos")))] fn find_dev_by_bus_port(&self) -> Option> { let hostport: Vec<&str> = self.config.hostport.as_ref().unwrap().split('.').collect(); let mut port: Vec = Vec::new(); @@ -543,13 +558,8 @@ impl UsbHost { } fn attach_kernel(&mut self) { - if self - .libdev - .as_ref() - .unwrap() - .active_config_descriptor() - .is_err() - { + if let Err(e) = self.libdev.as_ref().unwrap().active_config_descriptor() { + warn!("Failed to active config descriptor: {:?}.", e); return; } for i in 0..self.ifs_num { @@ -642,8 +652,7 @@ impl UsbHost { } } - fn open_and_init(&mut self) -> Result<()> { - self.handle = Some(self.libdev.as_ref().unwrap().open()?); + fn init_usbdev(&mut self) -> Result<()> { self.config.hostbus = self.libdev.as_ref().unwrap().bus_number(); self.config.hostaddr = self.libdev.as_ref().unwrap().address(); trace::usb_host_open_started(self.config.hostbus, self.config.hostaddr); @@ -654,25 +663,20 @@ impl UsbHost { self.ep_update(); - self.base.speed = self.libdev.as_ref().unwrap().speed() as u32 - 1; + match self.libdev.as_ref().unwrap().speed() as u32 { + 0 => { + return Err(anyhow!( + "Failed to realize usb host device due to unknown device speed." + )) + } + speed => self.base.speed = speed - 1, + }; + trace::usb_host_open_success(self.config.hostbus, self.config.hostaddr); Ok(()) } - fn register_exit(&mut self) { - let exit = self as *const Self as u64; - let exit_notifier = Arc::new(move || { - let usb_host = - // SAFETY: This callback is deleted after the device hot-unplug, so it is called only - // when the vm exits abnormally. - &mut unsafe { std::slice::from_raw_parts_mut(exit as *mut UsbHost, 1) }[0]; - usb_host.release_dev_to_host(); - }) as Arc; - self.exit = Some(exit_notifier.clone()); - TempCleaner::add_exit_notifier(self.device_id().to_string(), exit_notifier); - } - fn release_interfaces(&mut self) { for i in 0..self.ifs_num { if !self.ifs[i as usize].claimed { @@ -690,7 +694,8 @@ impl UsbHost { fn claim_interfaces(&mut self) -> UsbPacketStatus { self.base.altsetting = [0; USB_MAX_INTERFACES as usize]; - if self.detach_kernel().is_err() { + if let Err(e) = self.detach_kernel() { + error!("Failed to detach kernel for usbhost: {:?}.", e); return UsbPacketStatus::Stall; } @@ -766,7 +771,7 @@ impl UsbHost { .set_alternate_setting(iface as u8, alt as u8) { Ok(_) => { - self.base.altsetting[iface as usize] = alt as u32; + self.base.altsetting[iface as usize] = u32::from(alt); self.ep_update(); } Err(e) => { @@ -818,18 +823,14 @@ impl UsbHost { } pub fn abort_host_transfers(&mut self) -> Result<()> { - let mut locked_requests = self.requests.lock().unwrap(); - for _i in 0..locked_requests.len { - let mut node = locked_requests.pop_head().unwrap(); - node.value.abort_req(); - locked_requests.add_tail(node); + for req in self.requests.lock().unwrap().iter_mut() { + req.abort_req(); } - drop(locked_requests); // Max counts of uncompleted request to be handled. - let mut limit = 100; + let mut limit: i32 = 100; loop { - if self.requests.lock().unwrap().len == 0 { + if self.requests.lock().unwrap().is_empty() { return Ok(()); } let timeout = Some(Duration::from_millis(HANDLE_TIMEOUT_MS)); @@ -854,7 +855,7 @@ impl UsbHost { pub fn handle_iso_data_in(&mut self, packet: Arc>) { let cloned_packet = packet.clone(); let locked_packet = packet.lock().unwrap(); - let in_direction = locked_packet.pid == USB_TOKEN_IN as u32; + let in_direction = locked_packet.pid == u32::from(USB_TOKEN_IN); let iso_queue = if self.find_iso_queue(locked_packet.ep_number).is_some() { self.find_iso_queue(locked_packet.ep_number).unwrap() } else { @@ -888,7 +889,7 @@ impl UsbHost { let mut locked_iso_queue = iso_queue.lock().unwrap(); - let in_direction = locked_packet.pid == USB_TOKEN_IN as u32; + let in_direction = locked_packet.pid == u32::from(USB_TOKEN_IN); let ep = self .base .get_endpoint(in_direction, locked_packet.ep_number); @@ -980,6 +981,26 @@ impl UsbHost { locked_packet.is_async = true; } + + #[cfg(not(all(target_arch = "aarch64", target_env = "ohos")))] + fn open_usbdev(&mut self) -> Result<()> { + self.libdev = Some( + self.find_libdev() + .with_context(|| format!("Invalid USB host config: {:?}", self.config))?, + ); + self.handle = Some(self.libdev.as_ref().unwrap().open()?); + Ok(()) + } + + #[cfg(all(target_arch = "aarch64", target_env = "ohos"))] + fn open_usbdev(&mut self) -> Result<()> { + self.handle = Some( + self.oh_dev + .open(self.config.clone(), self.context.clone())?, + ); + self.libdev = Some(self.handle.as_ref().unwrap().device()); + Ok(()) + } } impl Drop for UsbHost { @@ -993,47 +1014,46 @@ impl EventNotifierHelper for UsbHost { let cloned_usbhost = usbhost.clone(); let mut notifiers = Vec::new(); - let poll = get_libusb_pollfds(usbhost); let timeout = Some(Duration::new(0, 0)); let handler: Rc = Rc::new(move |_, _fd: RawFd| { - cloned_usbhost - .lock() - .unwrap() - .context - .handle_events(timeout) + let ctx = cloned_usbhost.lock().unwrap().context.clone(); + ctx.handle_events(timeout) .unwrap_or_else(|e| error!("Failed to handle event: {:?}", e)); None }); - - set_pollfd_notifiers(poll, &mut notifiers, handler); + // SAFETY: The usbhost is guaranteed to be valid. + if let Ok(pollfds) = unsafe { PollFds::new(usbhost) } { + set_pollfd_notifiers(pollfds, &mut notifiers, handler); + } notifiers } } -impl UsbDevice for UsbHost { - fn usb_device_base(&self) -> &UsbDeviceBase { - &self.base - } +fn register_exit(usbhost: Arc>) { + let usbhost_cloned = usbhost.clone(); + let exit_notifier = Arc::new(move || { + usbhost_cloned.lock().unwrap().release_dev_to_host(); + }) as Arc; + TempCleaner::add_exit_notifier( + usbhost.lock().unwrap().device_id().to_string(), + exit_notifier, + ); +} - fn usb_device_base_mut(&mut self) -> &mut UsbDeviceBase { - &mut self.base - } +impl UsbDevice for UsbHost { + gen_base_func!(usb_device_base, usb_device_base_mut, UsbDeviceBase, base); fn realize(mut self) -> Result>> { - self.libdev = self.find_libdev(); - if self.libdev.is_none() { - bail!("Invalid USB host config: {:?}", self.config); - } - info!("Open and init usbhost device: {:?}", self.config); - self.open_and_init()?; + self.open_usbdev()?; + self.init_usbdev()?; let usbhost = Arc::new(Mutex::new(self)); let notifiers = EventNotifierHelper::internal_notifiers(usbhost.clone()); register_event_helper(notifiers, None, &mut usbhost.lock().unwrap().libevt)?; // UsbHost addr is changed after Arc::new, so so the registration must be here. - usbhost.lock().unwrap().register_exit(); + register_exit(usbhost.clone()); Ok(usbhost) } @@ -1045,6 +1065,8 @@ impl UsbDevice for UsbHost { Ok(()) } + fn cancel_packet(&mut self, _packet: &Arc>) {} + fn reset(&mut self) { info!("Usb Host device {} reset", self.device_id()); if self.handle.is_none() { @@ -1068,10 +1090,6 @@ impl UsbDevice for UsbHost { None } - fn get_wakeup_endpoint(&self) -> &UsbEndpoint { - self.base.get_endpoint(true, 1) - } - fn handle_control(&mut self, packet: &Arc>, device_req: &UsbDeviceRequest) { trace::usb_host_req_control(self.config.hostbus, self.config.hostaddr, device_req); let mut locked_packet = packet.lock().unwrap(); @@ -1292,7 +1310,7 @@ impl UsbDevice for UsbHost { } } -fn check_device_valid(device: &Device) -> bool { +pub fn check_device_valid(device: &Device) -> bool { let ddesc = match device.device_descriptor() { Ok(ddesc) => ddesc, Err(_) => return false, diff --git a/devices/src/usb/usbhost/ohusb.rs b/devices/src/usb/usbhost/ohusb.rs new file mode 100644 index 0000000000000000000000000000000000000000..0e56acef5827138883699185b93251c39f5fb9ab --- /dev/null +++ b/devices/src/usb/usbhost/ohusb.rs @@ -0,0 +1,78 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::fs::File; +use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::ptr; + +use anyhow::{bail, Context as anyhowContext, Result}; +use libusb1_sys::constants::LIBUSB_OPTION_NO_DEVICE_DISCOVERY; +use log::info; +use rusb::{Context, DeviceHandle, UsbContext}; + +use super::host_usblib::set_option; +use super::{check_device_valid, UsbHostConfig}; +use util::ohos_binding::usb::*; + +pub struct OhUsbDev { + #[allow(dead_code)] + lib: OhUsb, + dev_file: File, +} + +impl OhUsbDev { + pub fn new(bus_num: u8, dev_addr: u8) -> Result { + // In combination with libusb_wrap_sys_device(), in order to access a device directly without prior device scanning on ohos. + set_option(LIBUSB_OPTION_NO_DEVICE_DISCOVERY)?; + + let mut ohusb_dev = OhusbDevice { + busNum: bus_num, + devAddr: dev_addr, + fd: -1, + }; + + let lib = OhUsb::new()?; + match lib.open_device(ptr::addr_of_mut!(ohusb_dev))? { + 0 => { + if ohusb_dev.fd < 0 { + bail!( + "Failed to open usb device due to invalid fd {}", + ohusb_dev.fd + ); + } + } + _ => bail!("Failed to open usb device"), + } + info!("OH USB: open_device: returned fd is {}", ohusb_dev.fd); + + Ok(Self { + lib, + // SAFETY: fd is passed from OH USB framework and we have checked the function return value. + // Now let's save it to rust File struct. + dev_file: unsafe { File::from_raw_fd(ohusb_dev.fd) }, + }) + } + + pub fn open(&mut self, cfg: UsbHostConfig, ctx: Context) -> Result> { + // SAFETY: The validation of fd is guaranteed by new function. + let handle = unsafe { + ctx.open_device_with_fd(self.dev_file.as_raw_fd()) + .with_context(|| format!("os last error: {:?}", std::io::Error::last_os_error()))? + }; + + if !check_device_valid(&handle.device()) { + bail!("Invalid USB host config: {:?}", cfg); + } + + Ok(handle) + } +} diff --git a/devices/src/usb/xhci/xhci_controller.rs b/devices/src/usb/xhci/xhci_controller.rs index d15fef77479b7875076806c95b7246809d11fca6..e437283b18c28d0e7eec19896b86327879997f6a 100644 --- a/devices/src/usb/xhci/xhci_controller.rs +++ b/devices/src/usb/xhci/xhci_controller.rs @@ -13,11 +13,11 @@ use std::collections::LinkedList; use std::mem::size_of; use std::slice::{from_raw_parts, from_raw_parts_mut}; -use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::{Arc, Mutex, Weak}; use std::time::Duration; -use anyhow::{bail, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use byteorder::{ByteOrder, LittleEndian}; use log::{error, info, warn}; @@ -25,13 +25,13 @@ use super::xhci_pci::XhciConfig; use super::xhci_regs::{XhciInterrupter, XhciOperReg}; use super::xhci_ring::{XhciCommandRing, XhciEventRingSeg, XhciTRB, XhciTransferRing}; use super::xhci_trb::{ - TRBCCode, TRBType, SETUP_TRB_TR_LEN, TRB_EV_ED, TRB_TR_DIR, TRB_TR_FRAMEID_MASK, + TRBCCode, TRBType, SETUP_TRB_TR_LEN, TRB_EV_ED, TRB_SIZE, TRB_TR_DIR, TRB_TR_FRAMEID_MASK, TRB_TR_FRAMEID_SHIFT, TRB_TR_IDT, TRB_TR_IOC, TRB_TR_ISP, TRB_TR_LEN_MASK, TRB_TR_SIA, TRB_TYPE_SHIFT, }; use crate::usb::{config::*, TransferOps}; -use crate::usb::{UsbDevice, UsbDeviceRequest, UsbEndpoint, UsbError, UsbPacket, UsbPacketStatus}; -use address_space::{AddressSpace, GuestAddress}; +use crate::usb::{UsbDevice, UsbDeviceRequest, UsbError, UsbPacket, UsbPacketStatus}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use machine_manager::event_loop::EventLoop; const INVALID_SLOT_ID: u32 = 0; @@ -60,10 +60,16 @@ pub const SLOT_CONFIGURED: u32 = 3; const TRB_CR_BSR: u32 = 1 << 9; const TRB_CR_EPID_SHIFT: u32 = 16; const TRB_CR_EPID_MASK: u32 = 0x1f; +const TRB_CR_STREAMID_SHIFT: u32 = 16; +const TRB_CR_STREAMID_MASK: u32 = 0xffff; const TRB_CR_DC: u32 = 1 << 9; const TRB_CR_SLOTID_SHIFT: u32 = 24; const TRB_CR_SLOTID_MASK: u32 = 0xff; const COMMAND_LIMIT: u32 = 256; +const EP_CTX_MAX_PSTREAMS_SHIFT: u32 = 10; +const EP_CTX_MAX_PSTREAMS_MASK: u32 = 0xf; +const EP_CTX_LSA_SHIFT: u32 = 15; +const EP_CTX_LSA_MASK: u32 = 0x01; const EP_CTX_INTERVAL_SHIFT: u32 = 16; const EP_CTX_INTERVAL_MASK: u32 = 0xff; const EVENT_TRB_CCODE_SHIFT: u32 = 24; @@ -110,6 +116,17 @@ const EP_CONTEXT_EP_TYPE_MASK: u32 = 0x7; const EP_CONTEXT_EP_TYPE_SHIFT: u32 = 3; const ISO_BASE_TIME_INTERVAL: u64 = 125000; const MFINDEX_WRAP_NUM: u64 = 0x4000; +/// Stream Context. +const _STREAM_CTX_SCT_SHIFT: u32 = 1; +const _STREAM_CTX_SCT_MASK: u32 = 0x7; +const _STREAM_CTX_SCT_SECONDARY_TR: u32 = 0; +const _STREAM_CTX_SCT_PRIMARY_TR: u32 = 1; +const _STREAM_CTX_SCT_PRIMARY_SSA_8: u32 = 2; +const _STREAM_CTX_SCT_PRIMARY_SSA_16: u32 = 3; +const _STREAM_CTX_SCT_PRIMARY_SSA_32: u32 = 4; +const _STREAM_CTX_SCT_PRIMARY_SSA_64: u32 = 5; +const _STREAM_CTX_SCT_PRIMARY_SSA_128: u32 = 6; +const _STREAM_CTX_SCT_PRIMARY_SSA_256: u32 = 7; type DmaAddr = u64; @@ -121,26 +138,24 @@ pub struct XhciTransfer { complete: bool, slotid: u32, epid: u32, + streamid: u32, + ep_context: XhciEpContext, in_xfer: bool, iso_xfer: bool, timed_xfer: bool, running_retry: bool, running_async: bool, interrupter: Arc>, - ep_ring: Arc, - ep_type: EpType, - ep_state: Arc, mfindex_kick: u64, } impl XhciTransfer { fn new( - ep_info: (u32, u32, EpType), + ep_info: (u32, u32, u32), + ep_context: XhciEpContext, in_xfer: bool, td: Vec, intr: &Arc>, - ring: &Arc, - ep_state: &Arc, ) -> Self { XhciTransfer { packet: Arc::new(Mutex::new(UsbPacket::default())), @@ -149,15 +164,14 @@ impl XhciTransfer { complete: false, slotid: ep_info.0, epid: ep_info.1, + streamid: ep_info.2, + ep_context, in_xfer, iso_xfer: false, timed_xfer: false, running_retry: false, running_async: false, interrupter: intr.clone(), - ep_ring: ring.clone(), - ep_type: ep_info.2, - ep_state: ep_state.clone(), mfindex_kick: 0, } } @@ -171,18 +185,25 @@ impl XhciTransfer { if self.status == TRBCCode::Success { trace::usb_xhci_xfer_success(&self.packet.lock().unwrap().actual_length); self.submit_transfer()?; - self.ep_ring.refresh_dequeue_ptr()?; + let ring = self.ep_context.get_ring(self.streamid).with_context(|| { + format!( + "Failed to find Transfer Ring with Endpoint ID {}, Slot ID {}, Stream ID {}.", + self.epid, self.slotid, self.streamid + ) + })?; + ring.refresh_dequeue_ptr(*self.ep_context.output_ctx_addr.lock().unwrap())?; return Ok(()); } trace::usb_xhci_xfer_error(&self.packet.lock().unwrap().status); self.report_transfer_error()?; + let ep_type = self.ep_context.ep_type; - if self.ep_type == EpType::IsoIn || self.ep_type == EpType::IsoOut { + if ep_type == EpType::IsoIn || ep_type == EpType::IsoOut { return Ok(()); } // Set the endpoint state to halted if an error occurs in the packet. - set_ep_state_helper(&self.ep_ring, &self.ep_state, EP_HALTED)?; + self.ep_context.set_state(EP_HALTED, Some(self.streamid))?; Ok(()) } @@ -191,6 +212,7 @@ impl XhciTransfer { pub fn submit_transfer(&mut self) -> Result<()> { // Event Data Transfer Length Accumulator. let mut edtla: u32 = 0; + let mut shortpkt = false; let mut left = self.packet.lock().unwrap().actual_length; for i in 0..self.td.len() { let trb = &self.td[i]; @@ -201,7 +223,9 @@ impl XhciTransfer { TRBType::TrData | TRBType::TrNormal | TRBType::TrIsoch => { if chunk > left { chunk = left; - self.status = TRBCCode::ShortPacket; + if self.status == TRBCCode::Success { + shortpkt = true; + } } left -= chunk; edtla = edtla.checked_add(chunk).with_context(|| @@ -217,16 +241,25 @@ impl XhciTransfer { } } if (trb.control & TRB_TR_IOC == TRB_TR_IOC) - || (self.status == TRBCCode::ShortPacket - && (trb.control & TRB_TR_ISP == TRB_TR_ISP)) + || (shortpkt && (trb.control & TRB_TR_ISP == TRB_TR_ISP)) + || (self.status != TRBCCode::Success && left == 0) { - self.send_transfer_event(trb, chunk, &mut edtla)?; + self.send_transfer_event(trb, chunk, &mut edtla, shortpkt)?; + if self.status != TRBCCode::Success { + return Ok(()); + } } } Ok(()) } - fn send_transfer_event(&self, trb: &XhciTRB, transferred: u32, edtla: &mut u32) -> Result<()> { + fn send_transfer_event( + &self, + trb: &XhciTRB, + transferred: u32, + edtla: &mut u32, + shortpkt: bool, + ) -> Result<()> { let trb_type = trb.get_type(); let mut evt = XhciEvent::new(TRBType::ErTransfer, TRBCCode::Success); evt.slot_id = self.slotid as u8; @@ -234,7 +267,15 @@ impl XhciTransfer { evt.length = (trb.status & TRB_TR_LEN_MASK) - transferred; evt.flags = 0; evt.ptr = trb.addr; - evt.ccode = self.status; + evt.ccode = if self.status == TRBCCode::Success { + if shortpkt { + TRBCCode::ShortPacket + } else { + TRBCCode::Success + } + } else { + self.status + }; if trb_type == TRBType::TrEvdata { evt.ptr = trb.parameter; evt.flags |= TRB_EV_ED; @@ -273,44 +314,67 @@ impl TransferOps for XhciTransfer { } /// Endpoint context which use the ring to transfer data. +#[derive(Clone)] pub struct XhciEpContext { epid: u32, enabled: bool, - ring: Arc, + ring: Option>, ep_type: EpType, - output_ctx_addr: Arc, + output_ctx_addr: Arc>, state: Arc, interval: u32, mfindex_last: u64, transfers: LinkedList>>, retry: Option>>, + mem: Arc, + max_pstreams: u32, + lsa: bool, + stream_array: Option, } impl XhciEpContext { pub fn new(mem: &Arc) -> Self { - let addr = Arc::new(AtomicU64::new(0)); Self { epid: 0, enabled: false, - ring: Arc::new(XhciTransferRing::new(mem, &addr)), + ring: None, ep_type: EpType::Invalid, - output_ctx_addr: addr, + output_ctx_addr: Arc::new(Mutex::new(GuestAddress(0))), state: Arc::new(AtomicU32::new(0)), interval: 0, mfindex_last: 0, transfers: LinkedList::new(), retry: None, + mem: Arc::clone(mem), + max_pstreams: 0, + lsa: false, + stream_array: None, } } /// Init the endpoint context used the context read from memory. - fn init_ctx(&mut self, output_ctx: DmaAddr, ctx: &XhciEpCtx) { + fn init_ctx(&mut self, output_ctx: DmaAddr, ctx: &XhciEpCtx) -> Result<()> { let dequeue: DmaAddr = addr64_from_u32(ctx.deq_lo & !0xf, ctx.deq_hi); self.ep_type = ((ctx.ep_info2 >> EP_TYPE_SHIFT) & EP_TYPE_MASK).into(); - self.output_ctx_addr.store(output_ctx, Ordering::SeqCst); - self.ring.init(dequeue); - self.ring.set_cycle_bit((ctx.deq_lo & 1) == 1); + *self.output_ctx_addr.lock().unwrap() = GuestAddress(output_ctx); + self.max_pstreams = (ctx.ep_info >> EP_CTX_MAX_PSTREAMS_SHIFT) & EP_CTX_MAX_PSTREAMS_MASK; + self.lsa = ((ctx.ep_info >> EP_CTX_LSA_SHIFT) & EP_CTX_LSA_MASK) != 0; self.interval = 1 << ((ctx.ep_info >> EP_CTX_INTERVAL_SHIFT) & EP_CTX_INTERVAL_MASK); + + if self.max_pstreams == 0 { + let ring = XhciTransferRing::new(&self.mem); + ring.init(dequeue); + ring.set_cycle_bit((ctx.deq_lo & 1) == 1); + self.ring = Some(Arc::new(ring)); + } else { + let stream_array = XhciStreamArray::new(&self.mem, self.max_pstreams); + stream_array + .init(dequeue) + .with_context(|| "Failed to initialize Stream Array.")?; + self.stream_array = Some(stream_array); + } + + Ok(()) } fn get_ep_state(&self) -> u32 { @@ -318,27 +382,66 @@ impl XhciEpContext { } fn set_ep_state(&self, state: u32) { - self.state.store(state, Ordering::SeqCst); + self.state.store(state, Ordering::Release); } /// Update the endpoint state and write the state to memory. - fn set_state(&mut self, state: u32) -> Result<()> { - set_ep_state_helper(&self.ring, &self.state, state) + fn set_state(&self, state: u32, stream_id: Option) -> Result<()> { + let mut ep_ctx = XhciEpCtx::default(); + let output_addr = self.output_ctx_addr.lock().unwrap(); + dma_read_u32(&self.mem, *output_addr, ep_ctx.as_mut_dwords())?; + ep_ctx.ep_info &= !EP_STATE_MASK; + ep_ctx.ep_info |= state; + dma_write_u32(&self.mem, *output_addr, ep_ctx.as_dwords())?; + drop(output_addr); + self.flush_dequeue_to_memory(stream_id)?; + self.set_ep_state(state); + trace::usb_xhci_set_state(self.epid, state); + Ok(()) } - /// Update the dequeue pointer in endpoint context. + /// Update the dequeue pointer in endpoint or stream context. /// If dequeue is None, only flush the dequeue pointer to memory. - fn update_dequeue(&mut self, mem: &Arc, dequeue: Option) -> Result<()> { - let mut ep_ctx = XhciEpCtx::default(); - let output_addr = self.output_ctx_addr.load(Ordering::Acquire); - dma_read_u32(mem, GuestAddress(output_addr), ep_ctx.as_mut_dwords())?; + fn update_dequeue(&self, dequeue: Option, stream_id: u32) -> Result<()> { if let Some(dequeue) = dequeue { - self.ring.init(dequeue & EP_CTX_TR_DEQUEUE_POINTER_MASK); - self.ring - .set_cycle_bit((dequeue & EP_CTX_DCS) == EP_CTX_DCS); - } - self.ring.update_dequeue_to_ctx(&mut ep_ctx); - dma_write_u32(mem, GuestAddress(output_addr), ep_ctx.as_dwords())?; + let ring = self.get_ring(stream_id).with_context(|| { + format!( + "Failed to find Transfer Ring for Endpoint {}, Stream ID {}.", + self.epid, stream_id + ) + })?; + ring.init(dequeue & EP_CTX_TR_DEQUEUE_POINTER_MASK); + ring.set_cycle_bit((dequeue & EP_CTX_DCS) == EP_CTX_DCS); + trace::usb_xhci_update_dequeue(self.epid, dequeue, stream_id); + } + + self.flush_dequeue_to_memory(Some(stream_id))?; + Ok(()) + } + + /// Flush the dequeue pointer to the memory. + /// Stream Endpoints flush ring dequeue to both Endpoint and Stream context. + fn flush_dequeue_to_memory(&self, stream_id: Option) -> Result<()> { + let mut ep_ctx = XhciEpCtx::default(); + let output_addr = self.output_ctx_addr.lock().unwrap(); + dma_read_u32(&self.mem, *output_addr, ep_ctx.as_mut_dwords())?; + + if self.max_pstreams == 0 { + let ring = self.get_ring(0)?; + ring.update_dequeue_to_ctx(&mut ep_ctx.as_mut_dwords()[2..]); + } else if let Some(stream_id) = stream_id { + let mut stream_ctx = XhciStreamCtx::default(); + let stream = self.get_stream(stream_id)?; + let locked_stream = stream.lock().unwrap(); + let output_addr = locked_stream.dequeue; + let ring = locked_stream.ring.as_ref(); + dma_read_u32(&self.mem, output_addr, stream_ctx.as_mut_dwords())?; + ring.update_dequeue_to_ctx(stream_ctx.as_mut_dwords()); + ring.update_dequeue_to_ctx(&mut ep_ctx.as_mut_dwords()[2..]); + dma_write_u32(&self.mem, output_addr, stream_ctx.as_dwords())?; + } + + dma_write_u32(&self.mem, *output_addr, ep_ctx.as_dwords())?; Ok(()) } @@ -352,23 +455,72 @@ impl XhciEpContext { } self.transfers = undo; } -} -fn set_ep_state_helper( - ring: &Arc, - ep_state: &Arc, - state: u32, -) -> Result<()> { - let mem = &ring.mem; - let mut ep_ctx = XhciEpCtx::default(); - let output_addr = ring.output_ctx_addr.load(Ordering::Acquire); - dma_read_u32(mem, GuestAddress(output_addr), ep_ctx.as_mut_dwords())?; - ep_ctx.ep_info &= !EP_STATE_MASK; - ep_ctx.ep_info |= state; - ring.update_dequeue_to_ctx(&mut ep_ctx); - dma_write_u32(mem, GuestAddress(output_addr), ep_ctx.as_dwords())?; - ep_state.store(state, Ordering::SeqCst); - Ok(()) + /// Find and return a stream corresponding to the specified Stream ID. + /// Returns error if there is no stream support or LSA is not enabled. + fn get_stream(&self, stream_id: u32) -> Result>> { + let stream_arr = self + .stream_array + .as_ref() + .ok_or_else(|| anyhow!("Endpoint {} does not support streams.", self.epid))?; + + if !self.lsa { + bail!("Only Linear Streams Array (LSA) is supported."); + } + + let XhciStreamArray(pstreams) = &stream_arr; + let pstreams_num = pstreams.len() as u32; + + if stream_id >= pstreams_num || stream_id == 0 { + bail!( + "Stream ID {} is either invalid or reserved, max number of streams is {}.", + stream_id, + pstreams_num + ); + } + + let stream_context = &pstreams[stream_id as usize]; + let mut locked_context = stream_context.lock().unwrap(); + locked_context.try_refresh()?; + trace::usb_xhci_get_stream(stream_id, self.epid); + Ok(Arc::clone(stream_context)) + } + + /// Get a ring corresponding to the specified Stream ID if stream support is enabled, + /// return the standard Transfer Ring otherwise. + fn get_ring(&self, stream_id: u32) -> Result> { + if self.max_pstreams == 0 { + Ok(Arc::clone(self.ring.as_ref().ok_or_else(|| { + anyhow!( + "Failed to get the Transfer Ring for Endpoint {} without streams.", + self.epid + ) + })?)) + } else { + let stream = self.get_stream(stream_id).with_context(|| { + format!( + "Failed to find Stream Context with Stream ID {}.", + stream_id + ) + })?; + let locked_stream = stream.lock().unwrap(); + trace::usb_xhci_get_ring(self.epid, stream_id); + Ok(Arc::clone(&locked_stream.ring)) + } + } + + /// Reset all streams on this Endpoint. + fn reset_streams(&self) -> Result<()> { + let stream_arr = self.stream_array.as_ref().ok_or_else(|| { + anyhow!( + "Endpoint {} does not support streams, reset aborted.", + self.epid + ) + })?; + stream_arr.reset(); + trace::usb_xhci_reset_streams(self.epid); + Ok(()) + } } /// Endpoint type, including control, bulk, interrupt and isochronous. @@ -404,7 +556,7 @@ impl From for EpType { pub struct XhciSlot { pub enabled: bool, pub addressed: bool, - pub slot_ctx_addr: u64, + pub slot_ctx_addr: GuestAddress, pub usb_port: Option>>, pub endpoints: Vec, } @@ -419,7 +571,7 @@ impl XhciSlot { XhciSlot { enabled: false, addressed: false, - slot_ctx_addr: 0, + slot_ctx_addr: GuestAddress(0), usb_port: None, endpoints: eps, } @@ -428,18 +580,14 @@ impl XhciSlot { /// Get the slot context from the memory. fn get_slot_ctx(&self, mem: &Arc) -> Result { let mut slot_ctx = XhciSlotCtx::default(); - dma_read_u32( - mem, - GuestAddress(self.slot_ctx_addr), - slot_ctx.as_mut_dwords(), - )?; + dma_read_u32(mem, self.slot_ctx_addr, slot_ctx.as_mut_dwords())?; Ok(slot_ctx) } /// Get the slot state in slot context. fn get_slot_state_in_context(&self, mem: &Arc) -> Result { // Table 4-1: Device Slot State Code Definitions. - if self.slot_ctx_addr == 0 { + if self.slot_ctx_addr == GuestAddress(0) { return Ok(SLOT_DISABLED_ENABLED); } let slot_ctx = self.get_slot_ctx(mem)?; @@ -530,8 +678,8 @@ impl XhciEvent { XhciTRB { parameter: self.ptr, status: self.length | (self.ccode as u32) << EVENT_TRB_CCODE_SHIFT, - control: (self.slot_id as u32) << EVENT_TRB_SLOT_ID_SHIFT - | (self.ep_id as u32) << EVENT_TRB_EP_ID_SHIFT + control: u32::from(self.slot_id) << EVENT_TRB_SLOT_ID_SHIFT + | u32::from(self.ep_id) << EVENT_TRB_EP_ID_SHIFT | self.flags | (self.trb_type as u32) << TRB_TYPE_SHIFT, addr: 0, @@ -677,6 +825,96 @@ pub trait DwordOrder: Default + Copy + Send + Sync { } } +#[repr(transparent)] +#[derive(Clone)] +pub struct XhciStreamArray(Vec>>); + +impl XhciStreamArray { + fn new(mem: &Arc, max_pstreams: u32) -> Self { + let pstreams_num = 1 << (max_pstreams + 1); + let pstreams = (0..pstreams_num) + .map(|_| Arc::new(Mutex::new(XhciStreamContext::new(mem)))) + .collect(); + XhciStreamArray(pstreams) + } + + fn init(&self, mut dequeue: u64) -> Result<()> { + for stream_context in self.0.iter() { + stream_context.lock().unwrap().init(dequeue)?; + dequeue += std::mem::size_of::() as u64; + } + + Ok(()) + } + + fn reset(&self) { + for stream_context in self.0.iter() { + stream_context.lock().unwrap().reset(); + } + } +} + +#[derive(Clone)] +pub struct XhciStreamContext { + /// Memory address space. + mem: Arc, + /// Dequeue pointer. + dequeue: GuestAddress, + /// Transfer Ring (no Secondary Streams for now). + ring: Arc, + /// Whether the context is up to date after reset. + needs_refresh: bool, +} + +impl XhciStreamContext { + fn new(mem: &Arc) -> Self { + Self { + mem: Arc::clone(mem), + dequeue: GuestAddress(0), + ring: Arc::new(XhciTransferRing::new(mem)), + needs_refresh: true, + } + } + + fn init(&mut self, addr: u64) -> Result<()> { + self.dequeue = GuestAddress(addr); + self.refresh()?; + Ok(()) + } + + fn try_refresh(&mut self) -> Result<()> { + if self.needs_refresh { + self.refresh()?; + } + + Ok(()) + } + + fn refresh(&mut self) -> Result<()> { + let mut stream_ctx = XhciStreamCtx::default(); + dma_read_u32(&self.mem, self.dequeue, stream_ctx.as_mut_dwords())?; + let dequeue = addr64_from_u32(stream_ctx.deq_lo & !0xf, stream_ctx.deq_hi); + self.ring.init(dequeue); + self.needs_refresh = false; + Ok(()) + } + + fn reset(&mut self) { + self.needs_refresh = true; + } +} + +#[repr(C, packed)] +#[derive(Debug, Default, Clone, Copy)] +pub struct XhciStreamCtx { + pub deq_lo: u32, + pub deq_hi: u32, + pub stopped_edtla: u32, + pub reserved: u32, +} + +impl DwordOrder for XhciStreamCtx {} + /// Xhci controller device. pub struct XhciDevice { pub numports_2: u8, @@ -691,10 +929,15 @@ pub struct XhciDevice { mfindex_start: Duration, timer_id: Option, packet_count: u32, + bme: Arc, } impl XhciDevice { - pub fn new(mem_space: &Arc, config: &XhciConfig) -> Arc> { + pub fn new( + mem_space: &Arc, + config: &XhciConfig, + bme: &Arc, + ) -> Arc> { let mut p2 = XHCI_DEFAULT_PORT; let mut p3 = XHCI_DEFAULT_PORT; if config.p2.is_some() { @@ -738,6 +981,7 @@ impl XhciDevice { mem_space: mem_space.clone(), mfindex_start: EventLoop::get_ctx(None).unwrap().get_virtual_clock(), timer_id: None, + bme: bme.clone(), }; let xhci = Arc::new(Mutex::new(xhci)); let clone_xhci = xhci.clone(); @@ -807,14 +1051,14 @@ impl XhciDevice { pub fn stop(&mut self) { trace::usb_xhci_stop(); self.oper.set_usb_status_flag(USB_STS_HCH); - self.oper.cmd_ring_ctrl &= !(CMD_RING_CTRL_CRR as u64); + self.oper.cmd_ring_ctrl &= !u64::from(CMD_RING_CTRL_CRR); } pub fn running(&self) -> bool { self.oper.get_usb_status() & USB_STS_HCH != USB_STS_HCH } - pub fn host_controller_error(&mut self) { + pub fn host_controller_error(&self) { error!("Xhci host controller error!"); self.oper.set_usb_status_flag(USB_STS_HCE) } @@ -883,11 +1127,17 @@ impl XhciDevice { } trace::usb_xhci_port_notify(&locked_port.port_id, &flag); locked_port.portsc |= flag; + if !self.running() { return Ok(()); } + self.check_bme_valid().map_err(|e| { + self.host_controller_error(); + e + })?; + let mut evt = XhciEvent::new(TRBType::ErPortStatusChange, TRBCCode::Success); - evt.ptr = ((locked_port.port_id as u32) << PORT_EVENT_ID_SHIFT) as u64; + evt.ptr = u64::from(u32::from(locked_port.port_id) << PORT_EVENT_ID_SHIFT); self.intrs[0].lock().unwrap().send_event(&evt)?; Ok(()) } @@ -961,6 +1211,9 @@ impl XhciDevice { /// Control plane pub fn handle_command(&mut self) -> Result<()> { + // The caller will set HCE if handle_command() returns error. + self.check_bme_valid()?; + self.oper.start_cmd_ring(); let mut slot_id: u32 = 0; let mut event = XhciEvent::new(TRBType::ErCommandComplete, TRBCCode::Success); @@ -1028,7 +1281,10 @@ impl XhciDevice { slot_id = self.get_slot_id(&mut event, &trb); if slot_id != 0 { let ep_id = trb.control >> TRB_CR_EPID_SHIFT & TRB_CR_EPID_MASK; - event.ccode = self.set_tr_dequeue_pointer(slot_id, ep_id, &trb)?; + let stream_id = + trb.status >> TRB_CR_STREAMID_SHIFT & TRB_CR_STREAMID_MASK; + event.ccode = + self.set_tr_dequeue_pointer(slot_id, ep_id, stream_id, &trb)?; } } TRBType::CrResetDevice => { @@ -1071,7 +1327,7 @@ impl XhciDevice { self.slots[(slot_id - 1) as usize].enabled = false; self.slots[(slot_id - 1) as usize].addressed = false; self.slots[(slot_id - 1) as usize].usb_port = None; - self.slots[(slot_id - 1) as usize].slot_ctx_addr = 0; + self.slots[(slot_id - 1) as usize].slot_ctx_addr = GuestAddress(0); Ok(TRBCCode::Success) } @@ -1138,7 +1394,7 @@ impl XhciDevice { let mut locked_port = usb_port.lock().unwrap(); locked_port.slot_id = slot_id; self.slots[(slot_id - 1) as usize].usb_port = Some(usb_port.clone()); - self.slots[(slot_id - 1) as usize].slot_ctx_addr = octx; + self.slots[(slot_id - 1) as usize].slot_ctx_addr = GuestAddress(octx); let dev = locked_port.dev.as_ref().unwrap(); dev.lock().unwrap().reset(); if bsr { @@ -1189,13 +1445,16 @@ impl XhciDevice { index: 0, length: 0, }; + let target_dev = Arc::downgrade(dev) as Weak>; let packet_id = self.generate_packet_id(); let p = Arc::new(Mutex::new(UsbPacket::new( packet_id, - USB_TOKEN_OUT as u32, + u32::from(USB_TOKEN_OUT), + 0, 0, Vec::new(), None, + Some(target_dev), ))); trace::usb_handle_control(&locked_dev.usb_device_base().base.id, &device_req); locked_dev.handle_control(&p, &device_req); @@ -1204,8 +1463,11 @@ impl XhciDevice { fn get_device_context_addr(&self, slot_id: u32) -> Result { self.oper .dcbaap - .checked_add((8 * slot_id) as u64) - .with_context(|| UsbError::MemoryAccessOverflow(self.oper.dcbaap, (8 * slot_id) as u64)) + .raw_value() + .checked_add(u64::from(8 * slot_id)) + .with_context(|| { + UsbError::MemoryAccessOverflow(self.oper.dcbaap.raw_value(), u64::from(8 * slot_id)) + }) } fn configure_endpoint(&mut self, slot_id: u32, trb: &XhciTRB) -> Result { @@ -1234,7 +1496,7 @@ impl XhciDevice { slot_ctx.set_slot_state(SLOT_ADDRESSED); dma_write_u32( &self.mem_space, - GuestAddress(self.slots[(slot_id - 1) as usize].slot_ctx_addr), + self.slots[(slot_id - 1) as usize].slot_ctx_addr, slot_ctx.as_dwords(), )?; Ok(TRBCCode::Success) @@ -1264,7 +1526,7 @@ impl XhciDevice { } if ictl_ctx.add_flags & (1 << i) == 1 << i { self.disable_endpoint(slot_id, i)?; - self.enable_endpoint(slot_id, i, ictx, octx)?; + self.enable_endpoint(slot_id, i, ictx, octx.raw_value())?; } } // From section 4.6.6 Configure Endpoint of the spec: @@ -1289,7 +1551,7 @@ impl XhciDevice { slot_ctx.set_slot_state(SLOT_CONFIGURED); slot_ctx.set_context_entry(enabled_ep_idx); } - dma_write_u32(&self.mem_space, GuestAddress(octx), slot_ctx.as_dwords())?; + dma_write_u32(&self.mem_space, octx, slot_ctx.as_dwords())?; Ok(TRBCCode::Success) } @@ -1328,14 +1590,10 @@ impl XhciDevice { islot_ctx.as_mut_dwords(), )?; let mut slot_ctx = XhciSlotCtx::default(); - dma_read_u32( - &self.mem_space, - GuestAddress(octx), - slot_ctx.as_mut_dwords(), - )?; + dma_read_u32(&self.mem_space, octx, slot_ctx.as_mut_dwords())?; slot_ctx.set_max_exit_latency(islot_ctx.get_max_exit_latency()); slot_ctx.set_interrupter_target(islot_ctx.get_interrupter_target()); - dma_write_u32(&self.mem_space, GuestAddress(octx), slot_ctx.as_dwords())?; + dma_write_u32(&self.mem_space, octx, slot_ctx.as_dwords())?; } if ictl_ctx.add_flags & 0x2 == 0x2 { // Default control endpoint context. @@ -1349,21 +1607,16 @@ impl XhciDevice { iep_ctx.as_mut_dwords(), )?; let mut ep_ctx = XhciEpCtx::default(); - dma_read_u32( - &self.mem_space, - GuestAddress( - // It is safe to use plus here because we previously verify the address. - octx + EP_CTX_OFFSET, - ), - ep_ctx.as_mut_dwords(), - )?; + let ep_ctx_addr = octx.checked_add(EP_CTX_OFFSET).with_context(|| { + format!( + "Endpoint Context access overflow, addr {:x} size {:x}", + octx.raw_value(), + EP_CTX_OFFSET + ) + })?; + dma_read_u32(&self.mem_space, ep_ctx_addr, ep_ctx.as_mut_dwords())?; ep_ctx.set_max_packet_size(iep_ctx.get_max_packet_size()); - dma_write_u32( - &self.mem_space, - // It is safe to use plus here because we previously verify the address. - GuestAddress(octx + EP_CTX_OFFSET), - ep_ctx.as_dwords(), - )?; + dma_write_u32(&self.mem_space, ep_ctx_addr, ep_ctx.as_dwords())?; } Ok(TRBCCode::Success) } @@ -1372,11 +1625,7 @@ impl XhciDevice { trace::usb_xhci_reset_device(&slot_id); let mut slot_ctx = XhciSlotCtx::default(); let octx = self.slots[(slot_id - 1) as usize].slot_ctx_addr; - dma_read_u32( - &self.mem_space, - GuestAddress(octx), - slot_ctx.as_mut_dwords(), - )?; + dma_read_u32(&self.mem_space, octx, slot_ctx.as_mut_dwords())?; let slot_state = (slot_ctx.dev_state >> SLOT_STATE_SHIFT) & SLOT_STATE_MASK; if slot_state != SLOT_ADDRESSED && slot_state != SLOT_CONFIGURED @@ -1391,7 +1640,7 @@ impl XhciDevice { slot_ctx.set_slot_state(SLOT_DEFAULT); slot_ctx.set_context_entry(1); slot_ctx.set_usb_device_address(0); - dma_write_u32(&self.mem_space, GuestAddress(octx), slot_ctx.as_dwords())?; + dma_write_u32(&self.mem_space, octx, slot_ctx.as_dwords())?; Ok(TRBCCode::Success) } @@ -1403,7 +1652,7 @@ impl XhciDevice { output_ctx: DmaAddr, ) -> Result { trace::usb_xhci_enable_endpoint(&slot_id, &ep_id); - let entry_offset = (ep_id - 1) as u64 * EP_INPUT_CTX_ENTRY_SIZE; + let entry_offset = u64::from(ep_id - 1) * EP_INPUT_CTX_ENTRY_SIZE; let mut ep_ctx = XhciEpCtx::default(); dma_read_u32( &self.mem_space, @@ -1417,7 +1666,7 @@ impl XhciDevice { epctx.epid = ep_id; epctx.enabled = true; // It is safe to use plus here because we previously verify the address on the outer layer. - epctx.init_ctx(output_ctx + EP_CTX_OFFSET + entry_offset, &ep_ctx); + epctx.init_ctx(output_ctx + EP_CTX_OFFSET + entry_offset, &ep_ctx)?; epctx.set_ep_state(EP_RUNNING); ep_ctx.ep_info &= !EP_STATE_MASK; ep_ctx.ep_info |= EP_RUNNING; @@ -1443,8 +1692,8 @@ impl XhciDevice { } self.cancel_all_ep_transfers(slot_id, ep_id, TRBCCode::Invalid)?; let epctx = &mut self.slots[(slot_id - 1) as usize].endpoints[(ep_id - 1) as usize]; - if self.oper.dcbaap != 0 { - epctx.set_state(EP_DISABLED)?; + if self.oper.dcbaap.raw_value() != 0 { + epctx.set_state(EP_DISABLED, None)?; } epctx.enabled = false; Ok(TRBCCode::Success) @@ -1480,7 +1729,11 @@ impl XhciDevice { slot_id, ep_id )); } - self.slots[(slot_id - 1) as usize].endpoints[(ep_id - 1) as usize].set_state(EP_STOPPED)?; + let epctx = &mut self.slots[(slot_id - 1) as usize].endpoints[(ep_id - 1) as usize]; + epctx.set_state(EP_STOPPED, None)?; + if epctx.max_pstreams != 0 { + epctx.reset_streams()?; + } Ok(TRBCCode::Success) } @@ -1511,7 +1764,7 @@ impl XhciDevice { let epctx = &mut slot.endpoints[(ep_id - 1) as usize]; if let Some(port) = &slot.usb_port { if port.lock().unwrap().dev.is_some() { - epctx.set_state(EP_STOPPED)?; + epctx.set_state(EP_STOPPED, None)?; } else { error!("Failed to found usb device"); return Ok(TRBCCode::UsbTransactionError); @@ -1520,6 +1773,9 @@ impl XhciDevice { error!("Failed to found port"); return Ok(TRBCCode::UsbTransactionError); } + if epctx.max_pstreams != 0 { + epctx.reset_streams()?; + } Ok(TRBCCode::Success) } @@ -1527,6 +1783,7 @@ impl XhciDevice { &mut self, slotid: u32, epid: u32, + streamid: u32, trb: &XhciTRB, ) -> Result { trace::usb_xhci_set_tr_dequeue(&slotid, &epid, &trb.parameter); @@ -1551,12 +1808,12 @@ impl XhciDevice { ); return Ok(TRBCCode::ContextStateError); } - epctx.update_dequeue(&self.mem_space, Some(trb.parameter))?; + epctx.update_dequeue(Some(trb.parameter), streamid)?; Ok(TRBCCode::Success) } /// Data plane - pub(crate) fn kick_endpoint(&mut self, slot_id: u32, ep_id: u32) -> Result<()> { + pub(crate) fn kick_endpoint(&mut self, slot_id: u32, ep_id: u32, stream_id: u32) -> Result<()> { let epctx = match self.get_endpoint_ctx(slot_id, ep_id) { Ok(epctx) => epctx, Err(e) => { @@ -1566,6 +1823,13 @@ impl XhciDevice { } }; + let ring = epctx.get_ring(stream_id).with_context(|| { + format!( + "Failed to kick Endpoint {}, no Transfer ring found on Stream ID {}", + ep_id, stream_id + ) + })?; + // If the device has been detached, but the guest has not been notified. // In this case, the Transaction Error is reported when the TRB processed. // Therefore, don't continue here. @@ -1573,11 +1837,16 @@ impl XhciDevice { return Ok(()); } - trace::usb_xhci_ep_kick(&slot_id, &ep_id, &epctx.ring.get_dequeue_ptr()); + self.check_bme_valid().map_err(|e| { + self.host_controller_error(); + e + })?; + + trace::usb_xhci_ep_kick(&slot_id, &ep_id, &ring.get_dequeue_ptr()); if self.slots[(slot_id - 1) as usize].endpoints[(ep_id - 1) as usize] .retry .is_some() - && !self.endpoint_retry_transfer(slot_id, ep_id)? + && !self.endpoint_retry_transfer(slot_id, ep_id, stream_id)? { // Return directly to retry again at the next kick. return Ok(()); @@ -1588,19 +1857,17 @@ impl XhciDevice { info!("xhci: endpoint halted"); return Ok(()); } - epctx.set_state(EP_RUNNING)?; - let ep_state = epctx.state.clone(); + epctx.set_state(EP_RUNNING, Some(stream_id))?; const KICK_LIMIT: u32 = 256; let mut count = 0; - let ring = epctx.ring.clone(); loop { let epctx = &mut self.slots[(slot_id - 1) as usize].endpoints[(ep_id - 1) as usize]; - let td = match epctx.ring.fetch_td()? { + let td = match ring.fetch_td()? { Some(td) => { trace::usb_xhci_unimplemented(&format!( "fetch transfer trb {:?} ring dequeue {:?}", td, - epctx.ring.get_dequeue_ptr(), + ring.get_dequeue_ptr(), )); td } @@ -1613,7 +1880,7 @@ impl XhciDevice { let mut evt = XhciEvent::new(TRBType::ErTransfer, ccode); evt.slot_id = slot_id as u8; evt.ep_id = ep_id as u8; - evt.ptr = epctx.ring.dequeue.load(Ordering::Acquire); + evt.ptr = ring.get_dequeue_ptr().raw_value(); if let Err(e) = self.intrs[0].lock().unwrap().send_event(&evt) { error!("Failed to send event: {:?}", e); } @@ -1623,14 +1890,17 @@ impl XhciDevice { } }; let in_xfer = transfer_in_direction(ep_id as u8, &td, epctx.ep_type); + let mut epctx = epctx.clone(); + // NOTE: It is necessary to clear the transfer list here because otherwise it would + // result in an infinite cycle of destructor calls, leading to a stack overflow. + epctx.transfers.clear(); // NOTE: Only support primary interrupter now. let xfer = Arc::new(Mutex::new(XhciTransfer::new( - (slot_id, ep_id, epctx.ep_type), + (slot_id, ep_id, stream_id), + epctx, in_xfer, td, &self.intrs[0], - &ring, - &ep_state, ))); let packet = match self.setup_usb_packet(&xfer) { Ok(pkt) => pkt, @@ -1646,7 +1916,7 @@ impl XhciDevice { self.endpoint_do_transfer(&mut locked_xfer)?; let epctx = &mut self.slots[(slot_id - 1) as usize].endpoints[(ep_id - 1) as usize]; if locked_xfer.complete { - epctx.update_dequeue(&self.mem_space, None)?; + epctx.update_dequeue(None, stream_id)?; } else { epctx.transfers.push_back(xfer.clone()); } @@ -1695,7 +1965,12 @@ impl XhciDevice { /// Return Ok(true) if retry is done. /// Return Ok(false) if packet is need to retry again. /// Return Err() if retry failed. - fn endpoint_retry_transfer(&mut self, slot_id: u32, ep_id: u32) -> Result { + fn endpoint_retry_transfer( + &mut self, + slot_id: u32, + ep_id: u32, + stream_id: u32, + ) -> Result { let slot = &mut self.slots[(slot_id - 1) as usize]; // Safe because the retry is checked in the outer function call. let xfer = slot.endpoints[(ep_id - 1) as usize] @@ -1727,7 +2002,7 @@ impl XhciDevice { let epctx = &mut self.slots[(slot_id - 1) as usize].endpoints[(ep_id - 1) as usize]; if locked_xfer.complete { drop(locked_xfer); - epctx.update_dequeue(&self.mem_space, None)?; + epctx.update_dequeue(None, stream_id)?; epctx.flush_transfer(); } epctx.retry = None; @@ -1800,16 +2075,17 @@ impl XhciDevice { let epctx = &self.slots[(xfer.slotid - 1) as usize].endpoints[(xfer.epid - 1) as usize]; if xfer.td[0].control & TRB_TR_SIA != 0 { - let asap = ((mfindex as u32 + epctx.interval - 1) & !(epctx.interval - 1)) as u64; - if asap >= epctx.mfindex_last && asap <= epctx.mfindex_last + epctx.interval as u64 * 4 + let asap = u64::from((mfindex as u32 + epctx.interval - 1) & !(epctx.interval - 1)); + if asap >= epctx.mfindex_last + && asap <= epctx.mfindex_last + u64::from(epctx.interval) * 4 { - xfer.mfindex_kick = epctx.mfindex_last + epctx.interval as u64; + xfer.mfindex_kick = epctx.mfindex_last + u64::from(epctx.interval); } else { xfer.mfindex_kick = asap; } } else { xfer.mfindex_kick = - (((xfer.td[0].control >> TRB_TR_FRAMEID_SHIFT) & TRB_TR_FRAMEID_MASK) as u64) << 3; + u64::from((xfer.td[0].control >> TRB_TR_FRAMEID_SHIFT) & TRB_TR_FRAMEID_MASK) << 3; xfer.mfindex_kick |= mfindex & !(MFINDEX_WRAP_NUM - 1); if xfer.mfindex_kick + 0x100 < mfindex { xfer.mfindex_kick += MFINDEX_WRAP_NUM; @@ -1834,10 +2110,10 @@ impl XhciDevice { } }; let ep_state = epctx.get_ep_state(); - if ep_state == EP_STOPPED && ep_state == EP_ERROR { + if ep_state == EP_STOPPED || ep_state == EP_ERROR { return; } - if let Err(e) = locked_xhci.kick_endpoint(slotid, epid) { + if let Err(e) = locked_xhci.kick_endpoint(slotid, epid, 0) { error!("Failed to kick endpoint: {:?}", e); } }); @@ -1924,9 +2200,21 @@ impl XhciDevice { || trb_type == TRBType::TrNormal || trb_type == TRBType::TrIsoch { - let chunk = trb.status & TRB_TR_LEN_MASK; + let trb_len = trb.status & TRB_TR_LEN_MASK; + + // According to xHCI Spec 3.2.7/6.4.1/4.9.1, zero-length packet is required + // when exact multiple of max packet size is transferred, it is essential + // for proper stream termination in bulk/interrupt transfers, skip TRB + // submission for zero-byte transfers may end up in Device hanging, waiting + // for status phase completion. zero-length packet usually comes with the + // default address 0. In this case, the correct way to handle it is to skip + // the address translation and leave the iovec empty. + if trb_len == 0 { + continue; + } + let dma_addr = if trb.control & TRB_TR_IDT == TRB_TR_IDT { - if chunk > 8 && locked_xfer.in_xfer { + if trb_len > 8 && locked_xfer.in_xfer { bail!("Invalid immediate data TRB"); } trb.addr @@ -1934,15 +2222,35 @@ impl XhciDevice { trb.parameter }; - self.mem_space - .get_address_map(GuestAddress(dma_addr), chunk as u64, &mut vec)?; + self.mem_space.get_address_map( + &None, + GuestAddress(dma_addr), + u64::from(trb_len), + &mut vec, + )?; } } + let target_dev = + if let Ok(target_dev) = self.get_usb_dev(locked_xfer.slotid, locked_xfer.epid) { + Some(Arc::downgrade(&target_dev) as Weak>) + } else { + None + }; + let packet_id = self.generate_packet_id(); let (_, ep_number) = endpoint_id_to_number(locked_xfer.epid as u8); + let stream = locked_xfer.streamid; let xfer_ops = Arc::downgrade(xfer) as Weak>; - let packet = UsbPacket::new(packet_id, dir as u32, ep_number, vec, Some(xfer_ops)); + let packet = UsbPacket::new( + packet_id, + u32::from(dir), + ep_number, + stream, + vec, + Some(xfer_ops), + target_dev, + ); Ok(Arc::new(Mutex::new(packet))) } @@ -2019,6 +2327,15 @@ impl XhciDevice { if report != TRBCCode::Invalid { xfer.status = report; xfer.submit_transfer()?; + let locked_packet = xfer.packet.lock().unwrap(); + + if let Some(usb_dev) = locked_packet.target_dev.as_ref() { + if let Some(usb_dev) = usb_dev.clone().upgrade() { + drop(locked_packet); + let mut locked_usb_dev = usb_dev.lock().unwrap(); + locked_usb_dev.cancel_packet(&xfer.packet); + } + } } xfer.running_async = false; killed = 1; @@ -2027,7 +2344,7 @@ impl XhciDevice { if xfer.running_retry { if report != TRBCCode::Invalid { xfer.status = report; - xfer.report_transfer_error()?; + xfer.submit_transfer()?; } let epctx = &mut self.slots[(slotid - 1) as usize].endpoints[(ep_id - 1) as usize]; epctx.retry = None; @@ -2039,32 +2356,49 @@ impl XhciDevice { } /// Used for device to wakeup endpoint - pub fn wakeup_endpoint(&mut self, slot_id: u32, ep: &UsbEndpoint) -> Result<()> { - let ep_id = endpoint_number_to_id(ep.in_direction, ep.ep_number); - if let Err(e) = self.get_endpoint_ctx(slot_id, ep_id as u32) { + pub fn wakeup_endpoint(&mut self, slot_id: u32, ep_id: u32, stream_id: u32) -> Result<()> { + if let Err(e) = self.get_endpoint_ctx(slot_id, ep_id) { trace::usb_xhci_unimplemented(&format!( "Invalid slot id or ep id, maybe device not activated, {:?}", e )); return Ok(()); } - self.kick_endpoint(slot_id, ep_id as u32)?; + self.kick_endpoint(slot_id, ep_id, stream_id)?; Ok(()) } pub(crate) fn reset_event_ring(&mut self, idx: u32) -> Result<()> { let mut locked_intr = self.intrs[idx as usize].lock().unwrap(); - if locked_intr.erstsz == 0 || locked_intr.erstba == 0 { - locked_intr.er_start = 0; + if locked_intr.erstsz == 0 || locked_intr.erstba.raw_value() == 0 { + locked_intr.er_start = GuestAddress(0); locked_intr.er_size = 0; return Ok(()); } + + self.check_bme_valid().map_err(|e| { + self.host_controller_error(); + e + })?; + let mut seg = XhciEventRingSeg::new(&self.mem_space); seg.fetch_event_ring_seg(locked_intr.erstba)?; if seg.size < 16 || seg.size > 4096 { bail!("Invalid segment size {}", seg.size); } - locked_intr.er_start = addr64_from_u32(seg.addr_lo, seg.addr_hi); + + // GPAChecked: the event ring must locate in guest ram. + let base_addr = GuestAddress(addr64_from_u32(seg.addr_lo, seg.addr_hi)); + // SAFETY: seg size is a 16 bit register, will not overflow. + let er_len = seg.size * TRB_SIZE; + if !self + .mem_space + .address_in_memory(base_addr, u64::from(er_len)) + { + bail!("The event ring does not locate in guest ram"); + } + + locked_intr.er_start = base_addr; locked_intr.er_size = seg.size; locked_intr.er_ep_idx = 0; locked_intr.er_pcs = true; @@ -2111,6 +2445,13 @@ impl XhciDevice { } None } + + fn check_bme_valid(&self) -> Result<()> { + if !self.bme.load(Ordering::SeqCst) { + bail!("BME is cleared.") + } + Ok(()) + } } fn usb_packet_status_to_trb_code(status: UsbPacketStatus) -> Result { @@ -2133,12 +2474,14 @@ pub fn dma_read_bytes( mut buf: &mut [u8], ) -> Result<()> { let len = buf.len() as u64; - addr_space.read(&mut buf, addr, len).with_context(|| { - format!( - "Failed to read dma memory at gpa=0x{:x} len=0x{:x}", - addr.0, len - ) - })?; + addr_space + .read(&mut buf, addr, len, AddressAttr::Ram) + .with_context(|| { + format!( + "Failed to read dma memory at gpa=0x{:x} len=0x{:x}", + addr.0, len + ) + })?; Ok(()) } @@ -2148,12 +2491,14 @@ pub fn dma_write_bytes( mut buf: &[u8], ) -> Result<()> { let len = buf.len() as u64; - addr_space.write(&mut buf, addr, len).with_context(|| { - format!( - "Failed to write dma memory at gpa=0x{:x} len=0x{:x}", - addr.0, len - ) - })?; + addr_space + .write(&mut buf, addr, len, AddressAttr::Ram) + .with_context(|| { + format!( + "Failed to write dma memory at gpa=0x{:x} len=0x{:x}", + addr.0, len + ) + })?; Ok(()) } @@ -2195,7 +2540,7 @@ pub fn dma_write_u32( } fn addr64_from_u32(low: u32, high: u32) -> u64 { - (((high << 16) as u64) << 16) | low as u64 + (u64::from(high) << 32) | u64::from(low) } // | ep id | < = > | ep direction | ep number | @@ -2206,7 +2551,7 @@ fn endpoint_id_to_number(ep_id: u8) -> (bool, u8) { (ep_id & 1 == 1, ep_id >> 1) } -fn endpoint_number_to_id(in_direction: bool, ep_number: u8) -> u8 { +pub fn endpoint_number_to_id(in_direction: bool, ep_number: u8) -> u8 { if ep_number == 0 { // Control endpoint. 1 diff --git a/devices/src/usb/xhci/xhci_pci.rs b/devices/src/usb/xhci/xhci_pci.rs index 0b3373cd2bffe5ed8f894f1a706894a7f3e22e67..493a501b03b54d35ee83f8d54fe8202b904b235f 100644 --- a/devices/src/usb/xhci/xhci_pci.rs +++ b/devices/src/usb/xhci/xhci_pci.rs @@ -14,7 +14,7 @@ use std::cmp::max; use std::os::unix::io::AsRawFd; use std::os::unix::prelude::RawFd; use std::rc::Rc; -use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; use std::sync::{Arc, Mutex, Weak}; use anyhow::{bail, Context, Result}; @@ -34,12 +34,14 @@ use crate::pci::config::{ }; use crate::pci::{init_intx, init_msix, le_write_u16, PciBus, PciDevBase, PciDevOps}; use crate::usb::UsbDevice; -use crate::{Device, DeviceBase}; +use crate::{convert_bus_ref, Bus, Device, DeviceBase, PCI_BUS}; use address_space::{AddressRange, AddressSpace, Region, RegionIoEventFd}; use machine_manager::config::{get_pci_df, valid_id}; use machine_manager::event_loop::register_event_helper; +use util::gen_base_func; use util::loop_context::{ - read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, + create_new_eventfd, read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, + NotifierOperation, }; /// 5.2 PCI Configuration Registers(USB) @@ -67,8 +69,10 @@ const XHCI_MSIX_PBA_OFFSET: u32 = 0x3800; /// XHCI controller configuration. #[derive(Parser, Clone, Debug, Default)] -#[command(name = "nec-usb-xhci")] +#[command(no_binary_name(true))] pub struct XhciConfig { + #[arg(long)] + pub classtype: String, #[arg(long, value_parser = valid_id)] id: Option, #[arg(long)] @@ -104,23 +108,24 @@ impl XhciPciDevice { pub fn new( config: &XhciConfig, devfn: u8, - parent_bus: Weak>, + parent_bus: Weak>, mem_space: &Arc, ) -> Self { + let bme = Arc::new(AtomicBool::new(false)); Self { base: PciDevBase { - base: DeviceBase::new(config.id.clone().unwrap(), true), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 1), + base: DeviceBase::new(config.id.clone().unwrap(), true, Some(parent_bus)), + config: PciConfig::new(devfn, PCI_CONFIG_SPACE_SIZE, 1), devfn, - parent_bus, + bme: bme.clone(), }, - xhci: XhciDevice::new(mem_space, config), + xhci: XhciDevice::new(mem_space, config, &bme), dev_id: Arc::new(AtomicU16::new(0)), mem_region: Region::init_container_region( - XHCI_PCI_CONFIG_LENGTH as u64, + u64::from(XHCI_PCI_CONFIG_LENGTH), "XhciPciContainer", ), - doorbell_fd: Arc::new(EventFd::new(libc::EFD_NONBLOCK).unwrap()), + doorbell_fd: Arc::new(create_new_eventfd().unwrap()), delete_evts: Vec::new(), iothread: config.iothread.clone(), } @@ -128,57 +133,57 @@ impl XhciPciDevice { fn mem_region_init(&mut self) -> Result<()> { let cap_region = Region::init_io_region( - XHCI_PCI_CAP_LENGTH as u64, + u64::from(XHCI_PCI_CAP_LENGTH), build_cap_ops(&self.xhci), "XhciPciCapRegion", ); self.mem_region - .add_subregion(cap_region, XHCI_PCI_CAP_OFFSET as u64) + .add_subregion(cap_region, u64::from(XHCI_PCI_CAP_OFFSET)) .with_context(|| "Failed to register cap region.")?; let mut oper_region = Region::init_io_region( - XHCI_PCI_OPER_LENGTH as u64, + u64::from(XHCI_PCI_OPER_LENGTH), build_oper_ops(&self.xhci), "XhciPciOperRegion", ); oper_region.set_access_size(4); self.mem_region - .add_subregion(oper_region, XHCI_PCI_OPER_OFFSET as u64) + .add_subregion(oper_region, u64::from(XHCI_PCI_OPER_OFFSET)) .with_context(|| "Failed to register oper region.")?; let port_num = self.xhci.lock().unwrap().usb_ports.len(); for i in 0..port_num { let port = &self.xhci.lock().unwrap().usb_ports[i]; let port_region = Region::init_io_region( - XHCI_PCI_PORT_LENGTH as u64, + u64::from(XHCI_PCI_PORT_LENGTH), build_port_ops(port), "XhciPciPortRegion", ); - let offset = (XHCI_PCI_PORT_OFFSET + XHCI_PCI_PORT_LENGTH * i as u32) as u64; + let offset = u64::from(XHCI_PCI_PORT_OFFSET + XHCI_PCI_PORT_LENGTH * i as u32); self.mem_region .add_subregion(port_region, offset) .with_context(|| "Failed to register port region.")?; } let mut runtime_region = Region::init_io_region( - XHCI_PCI_RUNTIME_LENGTH as u64, + u64::from(XHCI_PCI_RUNTIME_LENGTH), build_runtime_ops(&self.xhci), "XhciPciRuntimeRegion", ); runtime_region.set_access_size(4); self.mem_region - .add_subregion(runtime_region, XHCI_PCI_RUNTIME_OFFSET as u64) + .add_subregion(runtime_region, u64::from(XHCI_PCI_RUNTIME_OFFSET)) .with_context(|| "Failed to register runtime region.")?; let doorbell_region = Region::init_io_region( - XHCI_PCI_DOORBELL_LENGTH as u64, + u64::from(XHCI_PCI_DOORBELL_LENGTH), build_doorbell_ops(&self.xhci), "XhciPciDoorbellRegion", ); doorbell_region.set_ioeventfds(&self.ioeventfds()); self.mem_region - .add_subregion(doorbell_region, XHCI_PCI_DOORBELL_OFFSET as u64) + .add_subregion(doorbell_region, u64::from(XHCI_PCI_DOORBELL_OFFSET)) .with_context(|| "Failed to register doorbell region.")?; Ok(()) } @@ -234,25 +239,16 @@ impl XhciPciDevice { } impl Device for XhciPciDevice { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} - -impl PciDevOps for XhciPciDevice { - fn pci_base(&self) -> &PciDevBase { - &self.base - } - - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + self.xhci.lock().unwrap().reset(); + self.base.config.reset()?; + self.base.bme.store(false, Ordering::SeqCst); + Ok(()) } - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { self.init_write_mask(false)?; self.init_write_clear_mask(false)?; le_write_u16( @@ -281,7 +277,8 @@ impl PciDevOps for XhciPciDevice { PCI_SERIAL_BUS_RELEASE_VERSION_3_0; self.base.config.config[PCI_FRAME_LENGTH_ADJUSTMENT as usize] = PCI_NO_FRAME_LENGTH_TIMING_CAP; - self.dev_id.store(self.base.devfn as u16, Ordering::SeqCst); + self.dev_id + .store(u16::from(self.base.devfn), Ordering::SeqCst); self.mem_region_init()?; let handler = Arc::new(Mutex::new(DoorbellHandler::new( @@ -305,14 +302,15 @@ impl PciDevOps for XhciPciDevice { Some((XHCI_MSIX_TABLE_OFFSET, XHCI_MSIX_PBA_OFFSET)), )?; + let parent_bus = self.parent_bus().unwrap(); init_intx( self.name(), &mut self.base.config, - self.base.parent_bus.clone(), + parent_bus, self.base.devfn, )?; - let mut mem_region_size = (XHCI_PCI_CONFIG_LENGTH as u64).next_power_of_two(); + let mut mem_region_size = u64::from(XHCI_PCI_CONFIG_LENGTH).next_power_of_two(); mem_region_size = max(mem_region_size, MINIMUM_BAR_SIZE_FOR_MMIO as u64); self.base.config.register_bar( 0_usize, @@ -322,7 +320,7 @@ impl PciDevOps for XhciPciDevice { mem_region_size, )?; - let devfn = self.base.devfn; + let devfn = u64::from(self.base.devfn); // It is safe to unwrap, because it is initialized in init_msix. let cloned_msix = self.base.config.msix.as_ref().unwrap().clone(); let cloned_intx = self.base.config.intx.as_ref().unwrap().clone(); @@ -345,47 +343,39 @@ impl PciDevOps for XhciPciDevice { })); let dev = Arc::new(Mutex::new(self)); // Attach to the PCI bus. - let pci_bus = dev.lock().unwrap().base.parent_bus.upgrade().unwrap(); - let mut locked_pci_bus = pci_bus.lock().unwrap(); - let pci_device = locked_pci_bus.devices.get(&devfn); - if pci_device.is_none() { - locked_pci_bus.devices.insert(devfn, dev); - } else { - bail!( - "Devfn {:?} has been used by {:?}", - &devfn, - pci_device.unwrap().lock().unwrap().name() - ); - } - Ok(()) + let bus = dev.lock().unwrap().parent_bus().unwrap().upgrade().unwrap(); + bus.lock().unwrap().attach_child(devfn, dev.clone())?; + Ok(dev) } fn unrealize(&mut self) -> Result<()> { trace::usb_xhci_exit(); Ok(()) } +} + +impl PciDevOps for XhciPciDevice { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); fn write_config(&mut self, offset: usize, data: &[u8]) { - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let locked_parent_bus = parent_bus.lock().unwrap(); - locked_parent_bus.update_dev_id(self.base.devfn, &self.dev_id); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(parent_bus, locked_bus, pci_bus); + pci_bus.update_dev_id(self.base.devfn, &self.dev_id); self.base.config.write( offset, data, self.dev_id.clone().load(Ordering::Acquire), #[cfg(target_arch = "x86_64")] - Some(&locked_parent_bus.io_region), - Some(&locked_parent_bus.mem_region), + Some(&pci_bus.io_region), + Some(&pci_bus.mem_region), ); - } - fn reset(&mut self, _reset_child_device: bool) -> Result<()> { - self.xhci.lock().unwrap().reset(); - - self.base.config.reset()?; - - Ok(()) + // Make sure synchronize with memory or I/O access. + let _locked_xhci = self.xhci.lock().unwrap(); + self.base + .bme + .store(self.base.config.bus_maser_enable(), Ordering::SeqCst); } } diff --git a/devices/src/usb/xhci/xhci_regs.rs b/devices/src/usb/xhci/xhci_regs.rs index 102b87f1dd8c660109a1856e9e558aa57f3aa9a6..c027e0d33751255cf65c4eb7a6a31fbbbcbfa3e1 100644 --- a/devices/src/usb/xhci/xhci_regs.rs +++ b/devices/src/usb/xhci/xhci_regs.rs @@ -103,6 +103,9 @@ const XHCI_INTR_REG_SHIFT: u64 = 5; /// Doorbell Register Bit Field. /// DB Target. const DB_TARGET_MASK: u32 = 0xff; +/// DB Stream. +const DB_STREAM_ID_SHIFT: u32 = 16; +const DB_STREAM_ID_MASK: u32 = 0xffff; /// Port Registers. const XHCI_PORTSC: u64 = 0x0; const XHCI_PORTPMSC: u64 = 0x4; @@ -121,7 +124,7 @@ pub struct XhciOperReg { /// Command Ring Control pub cmd_ring_ctrl: u64, /// Device Context Base Address Array Pointer - pub dcbaap: u64, + pub dcbaap: GuestAddress, /// Configure pub config: u32, } @@ -132,13 +135,13 @@ impl XhciOperReg { self.set_usb_status(USB_STS_HCH); self.dev_notify_ctrl = 0; self.cmd_ring_ctrl = 0; - self.dcbaap = 0; + self.dcbaap = GuestAddress(0); self.config = 0; } /// Run the command ring. pub fn start_cmd_ring(&mut self) { - self.cmd_ring_ctrl |= CMD_RING_CTRL_CRR as u64; + self.cmd_ring_ctrl |= u64::from(CMD_RING_CTRL_CRR); } pub fn set_usb_cmd(&mut self, value: u32) { @@ -157,7 +160,7 @@ impl XhciOperReg { self.usb_status.load(Ordering::Acquire) } - pub fn set_usb_status_flag(&mut self, value: u32) { + pub fn set_usb_status_flag(&self, value: u32) { self.usb_status.fetch_or(value, Ordering::SeqCst); } @@ -180,12 +183,12 @@ pub struct XhciInterrupter { /// Event Ring Segment Table Size pub erstsz: u32, /// Event Ring Segment Table Base Address - pub erstba: u64, + pub erstba: GuestAddress, /// Event Ring Dequeue Pointer - pub erdp: u64, + pub erdp: GuestAddress, /// Event Ring Producer Cycle State pub er_pcs: bool, - pub er_start: u64, + pub er_start: GuestAddress, pub er_size: u32, pub er_ep_idx: u32, } @@ -206,10 +209,10 @@ impl XhciInterrupter { iman: 0, imod: 0, erstsz: 0, - erstba: 0, - erdp: 0, + erstba: GuestAddress(0), + erdp: GuestAddress(0), er_pcs: true, - er_start: 0, + er_start: GuestAddress(0), er_size: 0, er_ep_idx: 0, } @@ -232,10 +235,10 @@ impl XhciInterrupter { self.iman = 0; self.imod = 0; self.erstsz = 0; - self.erstba = 0; - self.erdp = 0; + self.erstba = GuestAddress(0); + self.erdp = GuestAddress(0); self.er_pcs = true; - self.er_start = 0; + self.er_start = GuestAddress(0); self.er_size = 0; self.er_ep_idx = 0; } @@ -244,25 +247,27 @@ impl XhciInterrupter { pub fn send_event(&mut self, evt: &XhciEvent) -> Result<()> { let er_end = self .er_start - .checked_add((TRB_SIZE * self.er_size) as u64) - .ok_or(UsbError::MemoryAccessOverflow( - self.er_start, - (TRB_SIZE * self.er_size) as u64, - ))?; + .checked_add(u64::from(TRB_SIZE * self.er_size)) + .ok_or_else(|| { + UsbError::MemoryAccessOverflow( + self.er_start.raw_value(), + u64::from(TRB_SIZE * self.er_size), + ) + })?; if self.erdp < self.er_start || self.erdp >= er_end { bail!( - "DMA out of range, erdp {} er_start {:x} er_size {}", - self.erdp, - self.er_start, + "DMA out of range, erdp {:x} er_start {:x} er_size {}", + self.erdp.raw_value(), + self.er_start.raw_value(), self.er_size ); } - let dp_idx = (self.erdp - self.er_start) / TRB_SIZE as u64; - if ((self.er_ep_idx + 2) % self.er_size) as u64 == dp_idx { + let dp_idx = (self.erdp.raw_value() - self.er_start.raw_value()) / u64::from(TRB_SIZE); + if u64::from((self.er_ep_idx + 2) % self.er_size) == dp_idx { debug!("Event ring full error, idx {}", dp_idx); let event = XhciEvent::new(TRBType::ErHostController, TRBCCode::EventRingFullError); self.write_event(&event)?; - } else if ((self.er_ep_idx + 1) % self.er_size) as u64 == dp_idx { + } else if u64::from((self.er_ep_idx + 1) % self.er_size) == dp_idx { debug!("Event Ring full, drop Event."); } else { self.write_event(evt)?; @@ -272,10 +277,11 @@ impl XhciInterrupter { } fn send_intr(&mut self) { - let pending = read_u32(self.erdp, 0) & ERDP_EHB == ERDP_EHB; - let mut erdp_low = read_u32(self.erdp, 0); + let erdp = self.erdp.raw_value(); + let pending = read_u32(erdp, 0) & ERDP_EHB == ERDP_EHB; + let mut erdp_low = read_u32(erdp, 0); erdp_low |= ERDP_EHB; - self.erdp = write_u64_low(self.erdp, erdp_low); + self.erdp = GuestAddress(write_u64_low(erdp, erdp_low)); self.iman |= IMAN_IP; self.enable_intr(); if pending { @@ -331,11 +337,13 @@ impl XhciInterrupter { fn write_trb(&mut self, trb: &XhciTRB) -> Result<()> { let addr = self .er_start - .checked_add((TRB_SIZE * self.er_ep_idx) as u64) - .ok_or(UsbError::MemoryAccessOverflow( - self.er_start, - (TRB_SIZE * self.er_ep_idx) as u64, - ))?; + .checked_add(u64::from(TRB_SIZE * self.er_ep_idx)) + .ok_or_else(|| { + UsbError::MemoryAccessOverflow( + self.er_start.raw_value(), + u64::from(TRB_SIZE * self.er_ep_idx), + ) + })?; let cycle = trb.control as u8; // Toggle the cycle bit to avoid driver read it. let control = if trb.control & TRB_C == TRB_C { @@ -347,10 +355,10 @@ impl XhciInterrupter { LittleEndian::write_u64(&mut buf, trb.parameter); LittleEndian::write_u32(&mut buf[8..], trb.status); LittleEndian::write_u32(&mut buf[12..], control); - dma_write_bytes(&self.mem, GuestAddress(addr), &buf)?; + dma_write_bytes(&self.mem, addr, &buf)?; // Write the cycle bit at last. fence(Ordering::SeqCst); - dma_write_bytes(&self.mem, GuestAddress(addr + 12), &[cycle])?; + dma_write_bytes(&self.mem, addr.unchecked_add(12), &[cycle])?; Ok(()) } } @@ -368,7 +376,7 @@ pub fn build_cap_ops(xhci_dev: &Arc>) -> RegionOps { XHCI_VERSION << hci_version_offset | XHCI_CAP_LENGTH } XHCI_CAP_REG_HCSPARAMS1 => { - (max_ports as u32) << CAP_HCSP_NP_SHIFT + u32::from(max_ports) << CAP_HCSP_NP_SHIFT | max_intrs << CAP_HCSP_NI_SHIFT | (locked_dev.slots.len() as u32) << CAP_HCSP_NDS_SHIFT } @@ -378,7 +386,9 @@ pub fn build_cap_ops(xhci_dev: &Arc>) -> RegionOps { } XHCI_CAP_REG_HCSPARAMS3 => 0x0, XHCI_CAP_REG_HCCPARAMS1 => { - 0x8 << CAP_HCCP_EXCP_SHIFT | (0 << CAP_HCCP_MPSAS_SHIFT) | CAP_HCCP_AC64 + // The offset of the first extended capability is (base) + (0x8 << 2) + // The primary stream array size is 1 << (0x7 + 1) + 0x8 << CAP_HCCP_EXCP_SHIFT | (0x7 << CAP_HCCP_MPSAS_SHIFT) | CAP_HCCP_AC64 } XHCI_CAP_REG_DBOFF => XHCI_OFF_DOORBELL, XHCI_CAP_REG_RTSOFF => XHCI_OFF_RUNTIME, @@ -387,18 +397,18 @@ pub fn build_cap_ops(xhci_dev: &Arc>) -> RegionOps { 0x20 => { CAP_EXT_USB_REVISION_2_0 << CAP_EXT_REVISION_SHIFT | 0x4 << CAP_EXT_NEXT_CAP_POINTER_SHIFT - | CAP_EXT_CAP_ID_SUPPORT_PROTOCOL as u32 + | u32::from(CAP_EXT_CAP_ID_SUPPORT_PROTOCOL) } 0x24 => CAP_EXT_USB_NAME_STRING, - 0x28 => ((locked_dev.numports_2 as u32) << 8) | 1, + 0x28 => (u32::from(locked_dev.numports_2) << 8) | 1, 0x2c => 0x0, // Extended capabilities (USB 3.0) 0x30 => { CAP_EXT_USB_REVISION_3_0 << CAP_EXT_REVISION_SHIFT - | CAP_EXT_CAP_ID_SUPPORT_PROTOCOL as u32 + | u32::from(CAP_EXT_CAP_ID_SUPPORT_PROTOCOL) } 0x34 => CAP_EXT_USB_NAME_STRING, - 0x38 => ((locked_dev.numports_3 as u32) << 8) | (locked_dev.numports_2 + 1) as u32, + 0x38 => (u32::from(locked_dev.numports_3) << 8) | u32::from(locked_dev.numports_2 + 1), 0x3c => 0x0, _ => { error!("Failed to read xhci cap: not implemented"); @@ -443,8 +453,8 @@ pub fn build_oper_ops(xhci_dev: &Arc>) -> RegionOps { // Table 5-24 shows read CRP always returns 0. 0 } - XHCI_OPER_REG_DCBAAP_LO => read_u32(locked_xhci.oper.dcbaap, 0), - XHCI_OPER_REG_DCBAAP_HI => read_u32(locked_xhci.oper.dcbaap, 1), + XHCI_OPER_REG_DCBAAP_LO => read_u32(locked_xhci.oper.dcbaap.raw_value(), 0), + XHCI_OPER_REG_DCBAAP_HI => read_u32(locked_xhci.oper.dcbaap.raw_value(), 1), XHCI_OPER_REG_CONFIG => locked_xhci.oper.config, _ => { error!( @@ -506,7 +516,7 @@ pub fn build_oper_ops(xhci_dev: &Arc>) -> RegionOps { write_u64_low(locked_xhci.oper.cmd_ring_ctrl, crc_lo); } XHCI_OPER_REG_CMD_RING_CTRL_HI => { - let crc_hi = (value as u64) << 32; + let crc_hi = u64::from(value) << 32; let mut crc_lo = read_u32(locked_xhci.oper.cmd_ring_ctrl, 0); if crc_lo & (CMD_RING_CTRL_CA | CMD_RING_CTRL_CS) != 0 && (crc_lo & CMD_RING_CTRL_CRR) == CMD_RING_CTRL_CRR @@ -518,17 +528,19 @@ pub fn build_oper_ops(xhci_dev: &Arc>) -> RegionOps { error!("Failed to send event: {:?}", e); } } else { - let addr = (crc_hi | crc_lo as u64) & XHCI_CRCR_CRP_MASK; + let addr = (crc_hi | u64::from(crc_lo)) & XHCI_CRCR_CRP_MASK; locked_xhci.cmd_ring.init(addr); } crc_lo &= !(CMD_RING_CTRL_CA | CMD_RING_CTRL_CS); locked_xhci.oper.cmd_ring_ctrl = write_u64_low(crc_hi, crc_lo); } XHCI_OPER_REG_DCBAAP_LO => { - locked_xhci.oper.dcbaap = write_u64_low(locked_xhci.oper.dcbaap, value & 0xffffffc0) + let dcbaap = write_u64_low(locked_xhci.oper.dcbaap.raw_value(), value & 0xffffffc0); + locked_xhci.oper.dcbaap = GuestAddress(dcbaap); } XHCI_OPER_REG_DCBAAP_HI => { - locked_xhci.oper.dcbaap = write_u64_high(locked_xhci.oper.dcbaap, value) + let dcbaap = write_u64_high(locked_xhci.oper.dcbaap.raw_value(), value); + locked_xhci.oper.dcbaap = GuestAddress(dcbaap); } XHCI_OPER_REG_CONFIG => locked_xhci.oper.config = value & 0xff, _ => { @@ -572,10 +584,10 @@ pub fn build_runtime_ops(xhci_dev: &Arc>) -> RegionOps { XHCI_INTR_REG_IMAN => locked_intr.iman, XHCI_INTR_REG_IMOD => locked_intr.imod, XHCI_INTR_REG_ERSTSZ => locked_intr.erstsz, - XHCI_INTR_REG_ERSTBA_LO => read_u32(locked_intr.erstba, 0), - XHCI_INTR_REG_ERSTBA_HI => read_u32(locked_intr.erstba, 1), - XHCI_INTR_REG_ERDP_LO => read_u32(locked_intr.erdp, 0), - XHCI_INTR_REG_ERDP_HI => read_u32(locked_intr.erdp, 1), + XHCI_INTR_REG_ERSTBA_LO => read_u32(locked_intr.erstba.raw_value(), 0), + XHCI_INTR_REG_ERSTBA_HI => read_u32(locked_intr.erstba.raw_value(), 1), + XHCI_INTR_REG_ERDP_LO => read_u32(locked_intr.erdp.raw_value(), 0), + XHCI_INTR_REG_ERDP_HI => read_u32(locked_intr.erdp.raw_value(), 1), _ => { error!( "Invalid offset {:x} for reading interrupter registers.", @@ -618,10 +630,12 @@ pub fn build_runtime_ops(xhci_dev: &Arc>) -> RegionOps { XHCI_INTR_REG_IMOD => locked_intr.imod = value, XHCI_INTR_REG_ERSTSZ => locked_intr.erstsz = value & 0xffff, XHCI_INTR_REG_ERSTBA_LO => { - locked_intr.erstba = write_u64_low(locked_intr.erstba, value & 0xffffffc0); + let erstba = write_u64_low(locked_intr.erstba.raw_value(), value & 0xffffffc0); + locked_intr.erstba = GuestAddress(erstba); } XHCI_INTR_REG_ERSTBA_HI => { - locked_intr.erstba = write_u64_high(locked_intr.erstba, value); + let erstba = GuestAddress(write_u64_high(locked_intr.erstba.raw_value(), value)); + locked_intr.erstba = erstba; drop(locked_intr); if let Err(e) = xhci.reset_event_ring(idx) { error!("Failed to reset event ring: {:?}", e); @@ -631,29 +645,31 @@ pub fn build_runtime_ops(xhci_dev: &Arc>) -> RegionOps { // ERDP_EHB is write 1 clear. let mut erdp_lo = value & !ERDP_EHB; if value & ERDP_EHB != ERDP_EHB { - let erdp_old = read_u32(locked_intr.erdp, 0); + let erdp_old = read_u32(locked_intr.erdp.raw_value(), 0); erdp_lo |= erdp_old & ERDP_EHB; } - locked_intr.erdp = write_u64_low(locked_intr.erdp, erdp_lo); + let erdp = write_u64_low(locked_intr.erdp.raw_value(), erdp_lo); + locked_intr.erdp = GuestAddress(erdp); if value & ERDP_EHB == ERDP_EHB { let erdp = locked_intr.erdp; let er_end = if let Some(addr) = locked_intr .er_start - .checked_add((TRB_SIZE * locked_intr.er_size) as u64) + .checked_add(u64::from(TRB_SIZE * locked_intr.er_size)) { addr } else { error!( "Memory access overflow, addr {:x} offset {:x}", - locked_intr.er_start, - (TRB_SIZE * locked_intr.er_size) as u64 + locked_intr.er_start.raw_value(), + u64::from(TRB_SIZE * locked_intr.er_size) ); return false; }; if erdp >= locked_intr.er_start && erdp < er_end - && (erdp - locked_intr.er_start) / TRB_SIZE as u64 - != locked_intr.er_ep_idx as u64 + && (erdp.raw_value() - locked_intr.er_start.raw_value()) + / u64::from(TRB_SIZE) + != u64::from(locked_intr.er_ep_idx) { drop(locked_intr); xhci.intrs[idx as usize].lock().unwrap().send_intr(); @@ -661,7 +677,8 @@ pub fn build_runtime_ops(xhci_dev: &Arc>) -> RegionOps { } } XHCI_INTR_REG_ERDP_HI => { - locked_intr.erdp = write_u64_high(locked_intr.erdp, value); + let erdp = write_u64_high(locked_intr.erdp.raw_value(), value); + locked_intr.erdp = GuestAddress(erdp); } _ => { error!( @@ -692,18 +709,19 @@ pub fn build_doorbell_ops(xhci_dev: &Arc>) -> RegionOps { if !read_data_u32(data, &mut value) { return false; } - if !xhci.lock().unwrap().running() { + let mut xhci = xhci.lock().unwrap(); + if !xhci.running() { error!("Failed to write doorbell, XHCI is not running"); return false; } - let mut xhci = xhci.lock().unwrap(); let slot_id = (offset >> 2) as u32; if slot_id == 0 { error!("Invalid slot id 0 !"); return false; } else { let ep_id = value & DB_TARGET_MASK; - if let Err(e) = xhci.kick_endpoint(slot_id, ep_id) { + let stream_id = (value >> DB_STREAM_ID_SHIFT) & DB_STREAM_ID_MASK; + if let Err(e) = xhci.kick_endpoint(slot_id, ep_id, stream_id) { error!("Failed to kick endpoint: {:?}", e); xhci.host_controller_error(); return false; diff --git a/devices/src/usb/xhci/xhci_ring.rs b/devices/src/usb/xhci/xhci_ring.rs index f37135721740fe68aea05329c4579aafd5ed980e..35cda571b5bf6b29f415ec3f3d4e8d54eed67c91 100644 --- a/devices/src/usb/xhci/xhci_ring.rs +++ b/devices/src/usb/xhci/xhci_ring.rs @@ -10,10 +10,10 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::sync::atomic::{fence, AtomicBool, AtomicU64, Ordering}; -use std::sync::Arc; +use std::sync::atomic::{fence, AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; -use anyhow::{bail, Context, Result}; +use anyhow::{bail, Result}; use byteorder::{ByteOrder, LittleEndian}; use log::debug; @@ -54,7 +54,7 @@ impl XhciTRB { #[derive(Clone)] pub struct XhciCommandRing { mem: Arc, - pub dequeue: u64, + pub dequeue: GuestAddress, /// Consumer Cycle State pub ccs: bool, } @@ -63,13 +63,13 @@ impl XhciCommandRing { pub fn new(mem: &Arc) -> Self { Self { mem: mem.clone(), - dequeue: 0, + dequeue: GuestAddress(0), ccs: true, } } pub fn init(&mut self, addr: u64) { - self.dequeue = addr; + self.dequeue = GuestAddress(addr); self.ccs = true; } @@ -87,7 +87,7 @@ impl XhciCommandRing { } fence(Ordering::Acquire); let mut trb = read_trb(&self.mem, self.dequeue)?; - trb.addr = self.dequeue; + trb.addr = self.dequeue.raw_value(); trb.ccs = self.ccs; let trb_type = trb.get_type(); debug!("Fetch TRB: type {:?} trb {:?}", trb_type, trb); @@ -96,14 +96,20 @@ impl XhciCommandRing { if link_cnt > TRB_LINK_LIMIT { bail!("TRB reach link limit"); } - self.dequeue = trb.parameter; + self.dequeue = GuestAddress(trb.parameter); if trb.control & TRB_LK_TC == TRB_LK_TC { self.ccs = !self.ccs; } } else { - self.dequeue = self.dequeue.checked_add(TRB_SIZE as u64).ok_or( - UsbError::MemoryAccessOverflow(self.dequeue, TRB_SIZE as u64), - )?; + self.dequeue = self + .dequeue + .checked_add(u64::from(TRB_SIZE)) + .ok_or_else(|| { + UsbError::MemoryAccessOverflow( + self.dequeue.raw_value(), + u64::from(TRB_SIZE), + ) + })?; return Ok(Some(trb)); } } @@ -113,33 +119,31 @@ impl XhciCommandRing { /// XHCI Transfer Ring pub struct XhciTransferRing { pub mem: Arc, - pub dequeue: AtomicU64, + pub dequeue: Mutex, /// Consumer Cycle State pub ccs: AtomicBool, - pub output_ctx_addr: Arc, } impl XhciTransferRing { - pub fn new(mem: &Arc, addr: &Arc) -> Self { + pub fn new(mem: &Arc) -> Self { Self { mem: mem.clone(), - dequeue: AtomicU64::new(0), + dequeue: Mutex::new(GuestAddress(0)), ccs: AtomicBool::new(true), - output_ctx_addr: addr.clone(), } } pub fn init(&self, addr: u64) { - self.set_dequeue_ptr(addr); + self.set_dequeue_ptr(GuestAddress(addr)); self.set_cycle_bit(true); } - pub fn get_dequeue_ptr(&self) -> u64 { - self.dequeue.load(Ordering::Acquire) + pub fn get_dequeue_ptr(&self) -> GuestAddress { + *self.dequeue.lock().unwrap() } - pub fn set_dequeue_ptr(&self, addr: u64) { - self.dequeue.store(addr, Ordering::SeqCst); + pub fn set_dequeue_ptr(&self, addr: GuestAddress) { + *self.dequeue.lock().unwrap() = addr } pub fn get_cycle_bit(&self) -> bool { @@ -167,7 +171,7 @@ impl XhciTransferRing { } fence(Ordering::Acquire); let mut trb = read_trb(&self.mem, dequeue)?; - trb.addr = dequeue; + trb.addr = dequeue.raw_value(); trb.ccs = ccs; trace::usb_xhci_fetch_trb(&dequeue, &trb.parameter, &trb.status, &trb.control); let trb_type = trb.get_type(); @@ -176,15 +180,15 @@ impl XhciTransferRing { if link_cnt > TRB_LINK_LIMIT { bail!("TRB link over limit"); } - dequeue = trb.parameter; + dequeue = GuestAddress(trb.parameter); if trb.control & TRB_LK_TC == TRB_LK_TC { ccs = !ccs; } } else { td.push(trb); - dequeue = dequeue - .checked_add(TRB_SIZE as u64) - .ok_or(UsbError::MemoryAccessOverflow(dequeue, TRB_SIZE as u64))?; + dequeue = dequeue.checked_add(u64::from(TRB_SIZE)).ok_or_else(|| { + UsbError::MemoryAccessOverflow(dequeue.raw_value(), u64::from(TRB_SIZE)) + })?; if trb_type == TRBType::TrSetup { ctrl_td = true; } else if trb_type == TRBType::TrStatus { @@ -202,25 +206,24 @@ impl XhciTransferRing { } /// Refresh dequeue pointer to output context. - pub fn refresh_dequeue_ptr(&self) -> Result<()> { + pub fn refresh_dequeue_ptr(&self, output_ctx_addr: GuestAddress) -> Result<()> { let mut ep_ctx = XhciEpCtx::default(); - let output_addr = self.output_ctx_addr.load(Ordering::Acquire); - dma_read_u32(&self.mem, GuestAddress(output_addr), ep_ctx.as_mut_dwords())?; - self.update_dequeue_to_ctx(&mut ep_ctx); - dma_write_u32(&self.mem, GuestAddress(output_addr), ep_ctx.as_dwords())?; + dma_read_u32(&self.mem, output_ctx_addr, ep_ctx.as_mut_dwords())?; + self.update_dequeue_to_ctx(&mut ep_ctx.as_mut_dwords()[2..]); + dma_write_u32(&self.mem, output_ctx_addr, ep_ctx.as_dwords())?; Ok(()) } - pub fn update_dequeue_to_ctx(&self, ep_ctx: &mut XhciEpCtx) { - let dequeue = self.get_dequeue_ptr(); - ep_ctx.deq_lo = dequeue as u32 | self.get_cycle_bit() as u32; - ep_ctx.deq_hi = (dequeue >> 32) as u32; + pub fn update_dequeue_to_ctx(&self, ctx: &mut [u32]) { + let dequeue = self.get_dequeue_ptr().raw_value(); + ctx[0] = dequeue as u32 | u32::from(self.get_cycle_bit()); + ctx[1] = (dequeue >> 32) as u32; } } -fn read_trb(mem: &Arc, addr: u64) -> Result { +fn read_trb(mem: &Arc, addr: GuestAddress) -> Result { let mut buf = [0; TRB_SIZE as usize]; - dma_read_bytes(mem, GuestAddress(addr), &mut buf)?; + dma_read_bytes(mem, addr, &mut buf)?; let trb = XhciTRB { parameter: LittleEndian::read_u64(&buf), status: LittleEndian::read_u32(&buf[8..]), @@ -231,12 +234,12 @@ fn read_trb(mem: &Arc, addr: u64) -> Result { Ok(trb) } -fn read_cycle_bit(mem: &Arc, addr: u64) -> Result { +fn read_cycle_bit(mem: &Arc, addr: GuestAddress) -> Result { let addr = addr .checked_add(12) - .with_context(|| format!("Ring address overflow, {:x}", addr))?; + .ok_or_else(|| UsbError::MemoryAccessOverflow(addr.raw_value(), 12))?; let mut buf = [0]; - dma_read_u32(mem, GuestAddress(addr), &mut buf)?; + dma_read_u32(mem, addr, &mut buf)?; Ok(buf[0] & TRB_C == TRB_C) } @@ -261,9 +264,9 @@ impl XhciEventRingSeg { } /// Fetch the event ring segment. - pub fn fetch_event_ring_seg(&mut self, addr: u64) -> Result<()> { + pub fn fetch_event_ring_seg(&mut self, addr: GuestAddress) -> Result<()> { let mut buf = [0_u8; TRB_SIZE as usize]; - dma_read_bytes(&self.mem, GuestAddress(addr), &mut buf)?; + dma_read_bytes(&self.mem, addr, &mut buf)?; self.addr_lo = LittleEndian::read_u32(&buf); self.addr_hi = LittleEndian::read_u32(&buf[4..]); self.size = LittleEndian::read_u32(&buf[8..]); diff --git a/docs/config_guidebook.md b/docs/config_guidebook.md index bf205fc184ccfc87c0a4c90db23de7cedd563898..483cf9c4b20c2afa42766fd2300190035a773bd4 100644 --- a/docs/config_guidebook.md +++ b/docs/config_guidebook.md @@ -285,6 +285,14 @@ The SMBIOS specification defines the data structures and information that will e ``` +### 1.12 Hardware Signature +This option is used for configuring ACPI Hardware Signature, which is used for VM S4 state. It's an 32 bit integer. For more information, please refer to https://uefi.org/htmlspecs/ACPI_Spec_6_4_html/05_ACPI_Software_Programming_Model/ACPI_Software_Programming_Model.html#firmware-acpi-control-structure-facs-table. + +```shell +# cmdline +-hardware-signature 1 +``` + ## 2. Device Configuration For machine type "microvm", only virtio-mmio and legacy devices are supported. @@ -444,6 +452,8 @@ Eight properties are supported for virtio-net-device or virtio-net-pci. * id: unique net device id. * iothread: indicate which iothread will be used, if not specified the main thread will be used. It has no effect when vhost is set. +* rx-iothread: set the receiving task in this iothread, if not specified the former parameter iothread will be used. +* tx-iothread: set the sending task in this iothread, if not specified the former parameter iothread will be used. * netdev: netdev of net device. * vhost: whether to run as a vhost-net device. * vhostfd: the file descriptor of opened tap device. @@ -462,10 +472,10 @@ is a single function device, the function number should be set to zero. ```shell # virtio mmio net device -netdev tap,id=,ifname= --device virtio-net-device,id=,netdev=[,iothread=][,mac=] +-device virtio-net-device,id=,netdev=[,iothread=][,rx-iothread=][,tx-iothread=][,mac=] # virtio pci net device -netdev tap,id=,ifname=[,queues=] --device virtio-net-pci,id=,netdev=,bus=,addr=<0x2>[,multifunction={on|off}][,iothread=][,mac=][,mq={on|off}][,queue-size=] +-device virtio-net-pci,id=,netdev=,bus=,addr=<0x2>[,multifunction={on|off}][,iothread=][,rx-iothread=][,tx-iothread=][,mac=][,mq={on|off}][,queue-size=] ``` StratoVirt also supports vhost-net to get a higher performance in network. It can be set by @@ -929,6 +939,22 @@ Note: Please see the [4. Build with features](docs/build_guide.md) if you want to enable usb-host. +#### 2.13.7 USB Uas +USB Mass Storage Device that is based on the USB Attached Scsi (UAS) protocol. It should be attached to USB controller. + +Three properties can be set for USB Uas. + +* id: unique device id. +* file: the path of backend image file. +* media: the media type of storage. Possible values are `disk` or `cdrom`. If not set, default is `disk`. + +```shell +-device usb-uas,drive=,id= +-drive id=,file=[,media={disk|cdrom}],aio=off,direct=false +``` + +Note: "aio=off,direct=false" must be configured and other aio/direct values are not supported. + ### 2.14 Virtio Scsi Controller Virtio Scsi controller is a pci device which can be attached scsi device. @@ -938,7 +964,7 @@ Six properties can be set for Virtio-Scsi controller. * bus: bus number of the device. * addr: including slot number and function number. * iothread: indicate which iothread will be used, if not specified the main thread will be used. (optional) -* num-queues: the optional num-queues attribute controls the number of request queues to be used for the scsi controller. If not set, the default block queue number is 1. The max queues number supported is no more than 32. (optional) +* num-queues: the optional num-queues attribute controls the number of request queues to be used for the scsi controller. If not set, the default queue number is the smaller one of vCPU count and the max queues number (e.g, min(vcpu_count, 32)). The max queues number supported is no more than 32. (optional) * queue-size: the optional virtqueue size for all the queues. Configuration range is (2, 1024] and queue size must be power of 2. Default queue size is 256. ```shell -device virtio-scsi-pci,id=,bus=,addr=<0x3>[,multifunction={on|off}][,iothread=][,num-queues=][,queue-size=] @@ -1179,16 +1205,41 @@ Sample Configuration: Please see the [4. Build with features](docs/build_guide.md) if you want to enable pvpanic. +### 2.22 virtio-input +virtio-input is a virtualized input device can be used to create human interface devices such as tablet, mouse. + +Five properties are supported for virtio-input. +* id: unique device id. +* evdev: the path of character evdev device in host. + +For virtio-input-pci, two more properties are required. +* bus: name of bus which to attach. +* addr: including slot number and function number. the first number represents slot number +of device and the second one represents function number of it. As virtio pci input device is a +single function device, the function number should be set to zero. + +Sample Configuration: +```shell +# virtio mmio input device +-device virtio-input-device,id=,evdev= +# virtio pci input device +-device virtio-input-pci,id=,evdev=,bus=,addr=<0x1>[,multifunction=on|off] +``` + +Note: +1. Only host evdev passthrough supported. + ## 3. Trace -Users can specify the configuration file which lists events to trace. +Users can specify a configuration file which lists the traces that needs to be enabled, or specify the trace type that needs to be enabled. Setting both file and type is also allowed, so that traces with the specified type and traces listed in the file will all be enabled. One property can be set: -* events: file lists events to trace. +* file: specify the file containing the traces that needs to be enabled. +* type: specify the traces type that needs to be enabled. ```shell --trace file= +-trace file=|type= ``` ## 4. Seccomp diff --git a/docs/hisysevent.md b/docs/hisysevent.md new file mode 100644 index 0000000000000000000000000000000000000000..b92ce4fe4d2c41328110ad5a5cb9018d50ebc078 --- /dev/null +++ b/docs/hisysevent.md @@ -0,0 +1,40 @@ +# HiSysEvent + +HiSysEvent(https://gitee.com/openharmony/hiviewdfx_hisysevent) is a tool in open- +harmonyOS to recode important information of key processes during system running, +helping locate faults and do some data analytics. + +This document describes the way to how to use hisysevent in StratoVirt. + +## Add Event + +### Modify configuration file + +First, you need to modify or creat toml file in the event/event_info directory +to add a new event in order to generate the event function. For example: + +```toml +[[events]] +name = "example" +event_type = "Behavior" +args = "example_bool: bool, example_str: String, example_integer: u64, example_array: &[u8]" +enable = true +``` + +In the above configuration, "name" is used to represent the only event, and +duplication is not allowed; "event_type" is one of four event type defined +by openharmonyOS: Fault, Statistic, Security and Behavior; "args" will be +formatted as arguments passed to hisysevent service in open-harmonyOS; +"enabled" indicates whether it is enabled during compilation. + +### Call event function + +Just call the event function where needed. +```rust +fn init_machine_ram(&self, sys_mem: &Arc, mem_size: u64) -> Result<()> { + hisysevent::example("true", "init_ram".to_string(), mem_size, &[0,1]); + let vm_ram = self.get_vm_ram(); + let layout_size = MEM_LAYOUT[LayoutEntryType::Mem as usize].1; + ...... +} +``` diff --git a/docs/qmp.md b/docs/qmp.md index e4837f58788aa8668f7e7e6e4ef5038283aa3c58..5e5e25e9f6e1cb09f6933f9b32022acd36b0b36b 100644 --- a/docs/qmp.md +++ b/docs/qmp.md @@ -534,6 +534,17 @@ Query the display image of virtiogpu. Currently only stdvm and gtk supports. <- { "return": { "fileDir": "/tmp/stratovirt-images", "isSuccess": true } } ``` +### query-workloads + +Query the workloads of the vm. + +#### Example + +```json +-> {"execute": "query-workloads", "arguments": {}} +<- {"return":[{"module":"scream-play","state":"Off"},{"module":"tap-0","state":"upload: 0 download: 0"}]} +``` + ### trace-get-state Query whether the trace state is enabled. diff --git a/docs/snapshot.md b/docs/snapshot.md index 6a9c97f73fc41d5db54f45fade6f88bbf20c72de..5e9c118b2577b0276249d699664ef2fba2732fc6 100644 --- a/docs/snapshot.md +++ b/docs/snapshot.md @@ -25,7 +25,7 @@ $ ncat -U path/to/socket {"return":{}} ``` -When VM is in paused state, is's safe to take a snapshot of the VM into the specified directory with QMP. +When VM is in paused state, it's safe to take a snapshot of the VM into the specified directory with QMP. ```shell $ ncat -U path/to/socket {"QMP":{"version":{"StratoVirt":{"micro":1,"minor":0,"major":0},"package":""},"capabilities":[]}} diff --git a/docs/stratovirt-img.md b/docs/stratovirt-img.md index e44be8d67703f5bf6595be58e72660e17b56def2..ec956c1d847db05d3378f9d6457a8a5411a2dc46 100644 --- a/docs/stratovirt-img.md +++ b/docs/stratovirt-img.md @@ -35,6 +35,16 @@ stratovirt-img create -f qcow2 -o cluster-size=65536 img_path img_size Note: 1. The cluster size can be only be set for `qcow2` or default to 65536. 2. Disk format is default to raw. +## Info + +Query the information of virtual disk. + +Sample Configuration: + +```shell +stratovirt-img info img_path +``` + ## Check Check if there are some mistakes on the image and choose to fix. @@ -81,13 +91,14 @@ Operating internal snapshot for disk, it is only supported by qcow2. Command syntax: ```shell -snapshot [-l | -a snapshot_name | -c snapshot_name | -d snapshot_name] img_path +snapshot [-l | -a snapshot_name | -c snapshot_name | -d snapshot_name | -r old_snapshot_name new_snapshot_name] img_path ``` - -a snapshot_name: applies a snapshot (revert disk to saved state). - -c snapshot_name: creates a snapshot. - -d snapshot_name: deletes a snapshot. - -l: lists all snapshots in the given image. +- -r old_snapshot_name new_snapshot_name: change the name from 'old_Snapshot_name' to 'new_Snapshot_name'. Sample Configuration: @@ -96,6 +107,30 @@ stratovirt-img snapshot -c snapshot_name img_path stratovirt-img snapshot -a snapshot_name img_path stratovirt-img snapshot -d snapshot_name img_path stratovirt-img snapshot -l img_path +stratovirt-img snapshot -r old_snapshot_name new_snapshot_name img_path ``` Note: The internal snapshot is not supported by raw. + +## Convert + +Convert the disk image to a new disk image using new format. +Command syntax: + +```shell +convert [ -f input_fmt | -O output_fmt | -S sparse_size ] input_filename output_filename +``` + +- -f fmt: Input image format. +- -O output_fmt: Output image format. +- -S sparse_size: the consecutive number of bytes that must contain only zeroes to create a sparse image during conversion. Unit: sector(512 bytes). Default is 8. +- input_filename: name of the input file using *input_fmt* image format. +- output_filename: name of the output file using *output_fmt* image format. + +Sample Configuration: + +```shell +stratovirt-img convert -f qcow2 -O raw qcow2_img_path raw_img_path +``` + +Note: Only qcow2 image to raw image conversion is supported currently. diff --git a/docs/trace.md b/docs/trace.md index bf1c24aaeb07b3d7c5f9afa25099ba95c76717e7..d4623fa6a97b10904eb501c399d9a159e60d75f0 100644 --- a/docs/trace.md +++ b/docs/trace.md @@ -70,7 +70,7 @@ of settings. ### log StratoVirt supports outputting trace to the log file at trace level. Turn on -the **trace_to_logger** feature to use is. +the **trace_to_logger** feature to use it. ### Ftrace @@ -80,7 +80,7 @@ suited for performance issues. It can be enabled by turning on the **trace_to_ftrace** feature during compilation. StratoVirt use ftrace by writing trace data to ftrace marker, and developers can -read trace records from trace file under mounted ftrace director, +read trace records from trace file under mounted ftrace directory, e.g. /sys/kernel/debug/tracing/trace. ### HiTraceMeter diff --git a/hisysevent/Cargo.toml b/hisysevent/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..11b3b69a0f368a15b2a7281cd385279be2445f36 --- /dev/null +++ b/hisysevent/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "hisysevent" +version = "2.4.0" +authors = ["Huawei StratoVirt Team"] +edition = "2021" +license = "Mulan PSL v2" +description = "Provide hisysevent infrastructure of hmos for StratoVirt" + +[dependencies] +log = "0.4" +lazy_static = "1.4.0" +anyhow = "1.0" +libloading = "0.7.4" +code_generator = { path = "code_generator" } + +[features] +hisysevent = [] diff --git a/machine_manager/src/config/ramfb.rs b/hisysevent/build.rs similarity index 55% rename from machine_manager/src/config/ramfb.rs rename to hisysevent/build.rs index 8473c1df9e57062dd337d985c73eb1414718b375..89563d8528aec20cdfe28ce0643cbce0c243b192 100644 --- a/machine_manager/src/config/ramfb.rs +++ b/hisysevent/build.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Huawei Technologies Co.,Ltd. All rights reserved. +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. // // StratoVirt is licensed under Mulan PSL v2. // You can use this software according to the terms and conditions of the Mulan @@ -10,15 +10,9 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use anyhow::Result; - -use crate::config::CmdParser; - -pub fn parse_ramfb(cfg_args: &str) -> Result { - let mut cmd_parser = CmdParser::new("ramfb"); - cmd_parser.push("").push("install").push("id"); - cmd_parser.parse(cfg_args)?; - - let install = cmd_parser.get_value::("install")?.unwrap_or(false); - Ok(install) +fn main() { + println!( + "cargo:rerun-if-changed={}/hisysevent", + std::env::var("CARGO_MANIFEST_DIR").unwrap() + ); } diff --git a/hisysevent/code_generator/Cargo.toml b/hisysevent/code_generator/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..9c1fc3c178e422264cfbb83f5388f9c8b8138e6d --- /dev/null +++ b/hisysevent/code_generator/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "code_generator" +version = "2.4.0" +authors = ["Huawei StratoVirt Team"] +edition = "2021" +license = "Mulan PSL v2" + +[lib] +name = "code_generator" +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +regex = "1" +serde = { version = "1.0", features = ["derive"] } +syn = "2.0.18" +toml = "0.7" diff --git a/hisysevent/code_generator/src/lib.rs b/hisysevent/code_generator/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..8d7c8d068b1658706895cf2ec0463feb66c71ab5 --- /dev/null +++ b/hisysevent/code_generator/src/lib.rs @@ -0,0 +1,202 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{fs, io::Read}; + +use proc_macro::TokenStream; +use quote::quote; +use regex::Regex; +use serde::Deserialize; +use syn::{parse_str, Expr, Ident, Type}; + +const EVENT_DIR_NAME: &str = "event_info"; + +#[derive(Debug, Deserialize)] +struct EventDesc { + name: String, + event_type: String, + args: String, + enabled: bool, +} + +#[derive(Debug, Deserialize)] +struct HiSysEventConf { + events: Option>, +} + +fn get_event_desc() -> HiSysEventConf { + let event_dir_path = format!( + "{}/{}", + std::env::var("CARGO_MANIFEST_DIR").unwrap(), + EVENT_DIR_NAME + ); + let paths = fs::read_dir(event_dir_path).unwrap(); + let mut desc = String::new(); + + for path in paths { + let file_path = path.unwrap().path(); + let file_name = file_path.to_str().unwrap(); + if file_name.ends_with(".toml") { + let mut file = fs::File::open(file_path).unwrap(); + file.read_to_string(&mut desc).unwrap(); + } + } + match toml::from_str::(&desc) { + Ok(ret) => ret, + Err(e) => panic!("Failed to parse event info : {}", e), + } +} + +fn is_slice(arg_type: &str) -> bool { + let regex = Regex::new(r"\[([^\[\]]*)\]").unwrap(); + let match_texts = regex + .captures_iter(arg_type) + .map(|mat| mat.get(1).map_or("", |m| m.as_str())); + match match_texts.count() { + 0 => false, + 1 => true, + _ => panic!("The format of parameter type: {} is wrong!", arg_type), + } +} + +fn capitalize(s: &str) -> String { + if s.is_empty() { + return String::new(); + } + + let mut chars = s.chars().collect::>(); + if chars[0].is_alphabetic() { + chars[0] = chars[0] + .to_uppercase() + .collect::() + .chars() + .next() + .unwrap(); + } + chars.iter().collect() +} + +fn parse_param_type(arg_type: &str) -> String { + if is_slice(arg_type) { + let regex = Regex::new(r"\[([^\[\]]*)\]").unwrap(); + let match_texts: Vec<&str> = regex + .captures_iter(arg_type) + .map(|mat| mat.get(1).map_or("", |m| m.as_str())) + .collect(); + format!("Array{}", capitalize(match_texts[0])) + } else { + format!("Type{}", capitalize(arg_type)) + } +} + +fn generate_param_value(arg_type: &str, arg_value: &str) -> (Ident, Expr) { + let param_type: Ident; + let param_value: Expr; + if is_slice(arg_type) { + let trans_token = ".as_ptr() as *const std::ffi::c_int as *const ()"; + param_type = parse_str::("void_ptr_value").unwrap(); + param_value = parse_str::(format!("{}{}", arg_value, trans_token).as_str()).unwrap(); + } else if arg_type.contains("String") { + let cstr_arg = format!("std::ffi::CString::new({}).unwrap()", arg_value); + let trans_token = ".into_raw() as *const std::ffi::c_char"; + param_type = parse_str::("char_ptr_value").unwrap(); + param_value = parse_str::(format!("{}{}", cstr_arg, trans_token).as_str()).unwrap(); + } else { + param_type = parse_str::(format!("{}_value", arg_type).as_str()).unwrap(); + param_value = parse_str::(format!("{} as {}", arg_value, arg_type).as_str()).unwrap(); + } + (param_type, param_value) +} + +#[proc_macro] +pub fn gen_hisysevent_func(_input: TokenStream) -> TokenStream { + let events = match get_event_desc().events { + Some(events) => events, + None => return TokenStream::from(quote!()), + }; + let hisysevent_func = events.iter().map(|desc| { + if desc.name.trim().is_empty() { + panic!("Empty event name is unsupported!"); + } + let desc_name = desc.name.trim(); + let func_name = parse_str::(desc_name).unwrap(); + let event_name = desc_name; + let event_type = + parse_str::(format!("HiSysEventType::_{}", desc.event_type.trim()).as_str()) + .unwrap(); + + let func_args = match desc.args.is_empty() { + true => quote!(), + false => { + let split_args: Vec<&str> = desc.args.split(',').collect(); + let _args = split_args.iter().map(|arg| { + let (v, t) = arg.split_once(':').unwrap(); + let arg_name = parse_str::(v.trim()).unwrap(); + let arg_type = parse_str::(t.trim()).unwrap(); + quote!( + #arg_name: #arg_type, + ) + }); + quote! { #( #_args )* } + } + }; + + let param_body = { + let split_args: Vec<&str> = desc.args.split(',').collect(); + let _args = split_args.iter().map(|arg| { + let (v, t) = arg.split_once(':').unwrap(); + let param_name = v.trim(); + let param_type_str: String = parse_param_type(t.trim()); + let param_type_token = format!("EventParamType::_{}", param_type_str); + let param_type = parse_str::(param_type_token.as_str()).unwrap(); + let (elem_type, elem_value) = generate_param_value(t.trim(), v.trim()); + let param_size = if param_type_str.contains("Array") { + parse_str::(format!("{}.len()", v.trim()).as_str()).unwrap() + } else { + parse_str::("0").unwrap() + }; + + quote!( + EventParam { + param_name: #param_name, + param_type: #param_type, + param_value: EventParamValue{#elem_type: #elem_value}, + array_size: #param_size}, + ) + }); + quote! { #( #_args )* } + }; + + let func_body = match desc.enabled { + true => { + quote!( + #[cfg(all(target_env = "ohos", feature = "hisysevent"))] + { + let func = function!(); + let params: &[EventParam] = &[#param_body]; + write_to_hisysevent(func, #event_name, #event_type as std::ffi::c_int, params); + } + ) + } + false => quote!(), + }; + + quote!( + #[inline(always)] + pub fn #func_name(#func_args) { + #func_body + } + ) + }); + + TokenStream::from(quote! { #( #hisysevent_func )* }) +} diff --git a/hisysevent/event_info/example.toml b/hisysevent/event_info/example.toml new file mode 100644 index 0000000000000000000000000000000000000000..6a5e97ad68abfa1fd3649b66da30132f74fbd21f --- /dev/null +++ b/hisysevent/event_info/example.toml @@ -0,0 +1,5 @@ +[[events]] +name = "example" +event_type = "Behavior" +args = "example_bool: bool, example_str: String, example_integer: u32, example_array: &[u8]" +enabled = true diff --git a/hisysevent/event_info/misc.toml b/hisysevent/event_info/misc.toml new file mode 100644 index 0000000000000000000000000000000000000000..eeb4be8d0efab816a03baf719b14cccf23a1c911 --- /dev/null +++ b/hisysevent/event_info/misc.toml @@ -0,0 +1,5 @@ +[[events]] +name = "STRATOVIRT_PVPANIC" +event_type = "Fault" +args = "event: String" +enabled = true diff --git a/hisysevent/src/interface.rs b/hisysevent/src/interface.rs new file mode 100644 index 0000000000000000000000000000000000000000..5f6ecf64bed8d10fea6e6343dbf25ee87ee582e8 --- /dev/null +++ b/hisysevent/src/interface.rs @@ -0,0 +1,189 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::ffi::{c_char, c_int, c_uint, c_ulonglong, CString, OsStr}; + +use anyhow::{Context, Result}; +use lazy_static::lazy_static; +use libloading::os::unix::Symbol; +use libloading::Library; +use log::error; + +const MAX_PARAM_NAME_LENGTH: usize = 49; + +#[derive(Copy, Clone, Debug)] +pub enum HiSysEventType { + _Fault = 1, + _Statistic, + _Security, + _Behavior, +} + +#[derive(Copy, Clone, Debug)] +pub enum EventParamType { + // Invalid type. + _Invalid = 0, + _TypeBool, + _TypeI8, + _TypeU8, + _TypeI16, + _TypeU16, + _TypeI32, + _TypeU32, + _TypeI64, + _TypeU64, + _TypeF32, + _TypeF64, + _TypeString, + _ArrayBool, + _ArrayI8, + _ArrayU8, + _ArrayI16, + _ArrayU16, + _ArrayI32, + _ArrayU32, + _ArrayI64, + _ArrayU64, + _ArrayF32, + _ArrayF64, +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub union EventParamValue { + pub bool_value: bool, + pub i8_value: i8, + pub u8_value: u8, + pub i16_value: i16, + pub u16_value: u16, + pub i32_value: i32, + pub u32_value: u32, + pub i64_value: i64, + pub u64_value: u64, + pub f32_value: f32, + pub f64_value: f64, + // String. + pub char_ptr_value: *const c_char, + // Array. + pub void_ptr_value: *const (), +} + +pub struct EventParam<'a> { + pub param_name: &'a str, + pub param_type: EventParamType, + pub param_value: EventParamValue, + pub array_size: usize, +} + +#[repr(C)] +#[derive(Copy, Clone)] +struct EventParamWrapper { + pub param_name: [u8; MAX_PARAM_NAME_LENGTH], + pub param_type: c_int, + pub param_value: EventParamValue, + pub array_size: c_uint, +} + +lazy_static! { + static ref HISYSEVENT_FUNC_TABLE: HiSysEventFuncTable = + // SAFETY: The dynamic library should be always existing. + unsafe { + HiSysEventFuncTable::new(OsStr::new("libhisysevent.z.so")) + .map_err(|e| { + error!("failed to init HiSysEventFuncTable with error: {:?}", e); + e + }) + .unwrap() + }; +} + +macro_rules! get_libfn { + ( $lib: ident, $tname: ident, $fname: ident ) => { + $lib.get::<$tname>(stringify!($fname).as_bytes()) + .with_context(|| format!("failed to get function {}", stringify!($fname)))? + .into_raw() + }; +} + +type HiSysEventWriteWrapperFn = unsafe extern "C" fn( + func: *const c_char, + line: c_ulonglong, + domain: *const c_char, + name: *const c_char, + event_type: c_int, + params: *const EventParamWrapper, + size: c_uint, +) -> c_int; + +struct HiSysEventFuncTable { + pub hisysevent_write: Symbol, +} + +impl HiSysEventFuncTable { + unsafe fn new(library_name: &OsStr) -> Result { + let library = + Library::new(library_name).with_context(|| "failed to load hisysevent library")?; + + Ok(Self { + hisysevent_write: get_libfn!(library, HiSysEventWriteWrapperFn, HiSysEvent_Write), + }) + } +} + +fn format_param_array(event_params: &[EventParam]) -> Vec { + let mut params_wrapper: Vec = vec![]; + + for param in event_params { + let mut param_name = [0_u8; MAX_PARAM_NAME_LENGTH]; + let name = param.param_name.as_bytes(); + let end = std::cmp::min(name.len(), param_name.len()); + param_name[..end].copy_from_slice(&name[..end]); + params_wrapper.push(EventParamWrapper { + param_name, + param_type: param.param_type as i32, + param_value: param.param_value, + array_size: param.array_size as u32, + }); + } + + params_wrapper +} + +// Write system event. +pub(crate) fn write_to_hisysevent( + func_name: &str, + event_name: &str, + event_type: c_int, + event_params: &[EventParam], +) { + let func = CString::new(func_name).unwrap(); + let domain = CString::new("VM_ENGINE").unwrap(); + let event = CString::new(event_name).unwrap(); + + let params_wrapper = format_param_array(event_params); + + // SAFETY: Call hisysevent function, all parameters are just read. + let ret = unsafe { + (HISYSEVENT_FUNC_TABLE.hisysevent_write)( + func.as_ptr() as *const c_char, + line!() as c_ulonglong, + domain.as_ptr() as *const c_char, + event.as_ptr() as *const c_char, + event_type, + params_wrapper.as_ptr() as *const EventParamWrapper, + params_wrapper.len() as u32, + ) + }; + if ret != 0 { + error!("Failed to write event {} to hisysevent.", event_name); + } +} diff --git a/hisysevent/src/lib.rs b/hisysevent/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..6ab5515a20088d39cb26285e629bc3ed94d32ba8 --- /dev/null +++ b/hisysevent/src/lib.rs @@ -0,0 +1,34 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +#[cfg(all(target_env = "ohos", feature = "hisysevent"))] +mod interface; + +use code_generator::gen_hisysevent_func; + +#[cfg(all(target_env = "ohos", feature = "hisysevent"))] +use crate::interface::*; + +#[macro_export] +macro_rules! function { + () => {{ + fn hook() {} + fn type_name_of(_: T) -> &'static str { + std::any::type_name::() + } + let name = type_name_of(hook); + let off_set: usize = 6; // ::hook + &name[..name.len() - off_set] + }}; +} + +gen_hisysevent_func! {} diff --git a/hypervisor/Cargo.toml b/hypervisor/Cargo.toml index df70a5b234612ee284bfe7bfebe7301383387938..44308ef49c3b1fc0c1ec30d75019599c7aae43a7 100644 --- a/hypervisor/Cargo.toml +++ b/hypervisor/Cargo.toml @@ -8,11 +8,11 @@ license = "Mulan PSL v2" [dependencies] anyhow = "1.0" thiserror = "1.0" -kvm-bindings = { version = "0.6.0", features = ["fam-wrappers"] } -kvm-ioctls = "0.15.0" +kvm-bindings = { version = "0.7.0", features = ["fam-wrappers"] } +kvm-ioctls = "0.16.0" libc = "0.2" log = "0.4" -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" address_space = { path = "../address_space" } cpu = { path = "../cpu" } devices = { path = "../devices" } @@ -21,3 +21,8 @@ migration = { path = "../migration" } migration_derive = { path = "../migration/migration_derive" } util = { path = "../util" } trace = { path = "../trace" } + +[features] +default = [] +vfio_device = [] +boot_time = [] diff --git a/hypervisor/src/kvm/aarch64/gicv2.rs b/hypervisor/src/kvm/aarch64/gicv2.rs index abe70dd4f0f54e07022ac51f2f672f1d65906d56..58758142d337087567978a937e837d277f08cea3 100644 --- a/hypervisor/src/kvm/aarch64/gicv2.rs +++ b/hypervisor/src/kvm/aarch64/gicv2.rs @@ -84,10 +84,10 @@ impl GICv2Access for KvmGICv2 { } fn vcpu_gicr_attr(&self, offset: u64, cpu: usize) -> u64 { - (((cpu as u64) << kvm_bindings::KVM_DEV_ARM_VGIC_CPUID_SHIFT as u64) + (((cpu as u64) << u64::from(kvm_bindings::KVM_DEV_ARM_VGIC_CPUID_SHIFT)) & kvm_bindings::KVM_DEV_ARM_VGIC_CPUID_MASK) - | ((offset << kvm_bindings::KVM_DEV_ARM_VGIC_OFFSET_SHIFT as u64) - & kvm_bindings::KVM_DEV_ARM_VGIC_OFFSET_MASK as u64) + | ((offset << u64::from(kvm_bindings::KVM_DEV_ARM_VGIC_OFFSET_SHIFT)) + & u64::from(kvm_bindings::KVM_DEV_ARM_VGIC_OFFSET_MASK)) } fn access_gic_distributor(&self, offset: u64, gicd_value: &mut u32, write: bool) -> Result<()> { @@ -122,7 +122,7 @@ impl GICv2Access for KvmGICv2 { KvmDevice::kvm_device_access( &self.fd, kvm_bindings::KVM_DEV_ARM_VGIC_GRP_CTRL, - kvm_bindings::KVM_DEV_ARM_VGIC_SAVE_PENDING_TABLES as u64, + u64::from(kvm_bindings::KVM_DEV_ARM_VGIC_SAVE_PENDING_TABLES), 0, true, ) diff --git a/hypervisor/src/kvm/aarch64/gicv3.rs b/hypervisor/src/kvm/aarch64/gicv3.rs index 942c2b6e66f7faed774b7b5bcbf83f5143768642..23d86c0dfe42184017d5e1f7859f5c383a4d1090 100644 --- a/hypervisor/src/kvm/aarch64/gicv3.rs +++ b/hypervisor/src/kvm/aarch64/gicv3.rs @@ -56,7 +56,7 @@ impl GICv3Access for KvmGICv3 { KvmDevice::kvm_device_check( &self.fd, kvm_bindings::KVM_DEV_ARM_VGIC_GRP_ADDR, - kvm_bindings::KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION as u64, + u64::from(kvm_bindings::KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION), ) .with_context(|| { "Multiple redistributors are acquired while KVM does not provide support." @@ -118,7 +118,7 @@ impl GICv3Access for KvmGICv3 { } fn vcpu_gicr_attr(&self, cpu: usize) -> u64 { - let clustersz = 16; + let clustersz = 16usize; let aff1 = (cpu / clustersz) as u64; let aff0 = (cpu % clustersz) as u64; @@ -128,6 +128,7 @@ impl GICv3Access for KvmGICv3 { let last = u64::from((self.vcpu_count - 1) == cpu as u64); + // Allow conversion of variables from i64 to u64. ((cpu_affid << 32) | (1 << 24) | (1 << 8) | (last << 4)) & kvm_bindings::KVM_DEV_ARM_VGIC_V3_MPIDR_MASK as u64 } @@ -196,7 +197,7 @@ impl GICv3Access for KvmGICv3 { KvmDevice::kvm_device_access( &self.fd, kvm_bindings::KVM_DEV_ARM_VGIC_GRP_CTRL, - kvm_bindings::KVM_DEV_ARM_VGIC_SAVE_PENDING_TABLES as u64, + u64::from(kvm_bindings::KVM_DEV_ARM_VGIC_SAVE_PENDING_TABLES), 0, true, ) @@ -259,7 +260,7 @@ impl GICv3ItsAccess for KvmGICv3Its { KvmDevice::kvm_device_access( &self.fd, kvm_bindings::KVM_DEV_ARM_VGIC_GRP_ITS_REGS, - attr as u64, + u64::from(attr), its_value as *const u64 as u64, write, ) @@ -267,9 +268,9 @@ impl GICv3ItsAccess for KvmGICv3Its { fn access_gic_its_tables(&self, save: bool) -> Result<()> { let attr = if save { - kvm_bindings::KVM_DEV_ARM_ITS_SAVE_TABLES as u64 + u64::from(kvm_bindings::KVM_DEV_ARM_ITS_SAVE_TABLES) } else { - kvm_bindings::KVM_DEV_ARM_ITS_RESTORE_TABLES as u64 + u64::from(kvm_bindings::KVM_DEV_ARM_ITS_RESTORE_TABLES) }; KvmDevice::kvm_device_access( &self.fd, diff --git a/hypervisor/src/kvm/aarch64/mod.rs b/hypervisor/src/kvm/aarch64/mod.rs index 072123600f54d8286a83d0ec8425182791ab3641..122cfac9cbe48570c32ef817a33f5109bd2a8413 100644 --- a/hypervisor/src/kvm/aarch64/mod.rs +++ b/hypervisor/src/kvm/aarch64/mod.rs @@ -95,7 +95,7 @@ impl KvmCpu { pub fn arch_init_pmu(&self) -> Result<()> { let pmu_attr = kvm_device_attr { group: KVM_ARM_VCPU_PMU_V3_CTRL, - attr: KVM_ARM_VCPU_PMU_V3_INIT as u64, + attr: u64::from(KVM_ARM_VCPU_PMU_V3_INIT), addr: 0, flags: 0, }; @@ -108,7 +108,7 @@ impl KvmCpu { let irq = PMU_INTR + PPI_BASE; let pmu_irq_attr = kvm_device_attr { group: KVM_ARM_VCPU_PMU_V3_CTRL, - attr: KVM_ARM_VCPU_PMU_V3_IRQ as u64, + attr: u64::from(KVM_ARM_VCPU_PMU_V3_IRQ), addr: &irq as *const u32 as u64, flags: 0, }; @@ -175,12 +175,13 @@ impl KvmCpu { if vcpu_config.sve { self.fd - .vcpu_finalize(&(kvm_bindings::KVM_ARM_VCPU_SVE as i32))?; + .vcpu_finalize(&(i32::try_from(kvm_bindings::KVM_ARM_VCPU_SVE)?))?; } - arch_cpu.lock().unwrap().mpidr = + arch_cpu.lock().unwrap().mpidr = u64::try_from( self.get_one_reg(KVM_REG_ARM_MPIDR_EL1) - .with_context(|| "Failed to get mpidr")? as u64; + .with_context(|| "Failed to get mpidr")?, + )?; arch_cpu.lock().unwrap().features = *vcpu_config; @@ -255,10 +256,11 @@ impl KvmCpu { ); } RegsIndex::VtimerCount => { - locked_arch_cpu.vtimer_cnt = self - .get_one_reg(KVM_REG_ARM_TIMER_CNT) - .with_context(|| "Failed to get virtual timer count")? - as u64; + locked_arch_cpu.vtimer_cnt = u64::try_from( + self.get_one_reg(KVM_REG_ARM_TIMER_CNT) + .with_context(|| "Failed to get virtual timer count")?, + )?; + locked_arch_cpu.vtimer_cnt_valid = true; } } @@ -270,7 +272,7 @@ impl KvmCpu { arch_cpu: Arc>, regs_index: RegsIndex, ) -> Result<()> { - let locked_arch_cpu = arch_cpu.lock().unwrap(); + let mut locked_arch_cpu = arch_cpu.lock().unwrap(); let apic_id = locked_arch_cpu.apic_id; match regs_index { RegsIndex::CoreRegs => { @@ -300,8 +302,14 @@ impl KvmCpu { } } RegsIndex::VtimerCount => { - self.set_one_reg(KVM_REG_ARM_TIMER_CNT, locked_arch_cpu.vtimer_cnt as u128) + if locked_arch_cpu.vtimer_cnt_valid { + self.set_one_reg( + KVM_REG_ARM_TIMER_CNT, + u128::from(locked_arch_cpu.vtimer_cnt), + ) .with_context(|| "Failed to set virtual timer count")?; + locked_arch_cpu.vtimer_cnt_valid = false; + } } } @@ -318,33 +326,34 @@ impl KvmCpu { fn get_core_regs(&self) -> Result { let mut core_regs = kvm_regs::default(); - core_regs.regs.sp = self.get_one_reg(Arm64CoreRegs::UserPTRegSp.into())? as u64; - core_regs.sp_el1 = self.get_one_reg(Arm64CoreRegs::KvmSpEl1.into())? as u64; - core_regs.regs.pstate = self.get_one_reg(Arm64CoreRegs::UserPTRegPState.into())? as u64; - core_regs.regs.pc = self.get_one_reg(Arm64CoreRegs::UserPTRegPc.into())? as u64; - core_regs.elr_el1 = self.get_one_reg(Arm64CoreRegs::KvmElrEl1.into())? as u64; + core_regs.regs.sp = u64::try_from(self.get_one_reg(Arm64CoreRegs::UserPTRegSp.into())?)?; + core_regs.sp_el1 = u64::try_from(self.get_one_reg(Arm64CoreRegs::KvmSpEl1.into())?)?; + core_regs.regs.pstate = + u64::try_from(self.get_one_reg(Arm64CoreRegs::UserPTRegPState.into())?)?; + core_regs.regs.pc = u64::try_from(self.get_one_reg(Arm64CoreRegs::UserPTRegPc.into())?)?; + core_regs.elr_el1 = u64::try_from(self.get_one_reg(Arm64CoreRegs::KvmElrEl1.into())?)?; - for i in 0..KVM_NR_REGS as usize { + for i in 0..usize::try_from(KVM_NR_REGS)? { core_regs.regs.regs[i] = - self.get_one_reg(Arm64CoreRegs::UserPTRegRegs(i).into())? as u64; + u64::try_from(self.get_one_reg(Arm64CoreRegs::UserPTRegRegs(i).into())?)?; } - for i in 0..KVM_NR_SPSR as usize { - core_regs.spsr[i] = self.get_one_reg(Arm64CoreRegs::KvmSpsr(i).into())? as u64; + for i in 0..usize::try_from(KVM_NR_SPSR)? { + core_regs.spsr[i] = u64::try_from(self.get_one_reg(Arm64CoreRegs::KvmSpsr(i).into())?)?; } // State save and restore is not supported for SVE for now, so we just skip it. if self.kvi.lock().unwrap().features[0] & (1 << kvm_bindings::KVM_ARM_VCPU_SVE) == 0 { - for i in 0..KVM_NR_FP_REGS as usize { + for i in 0..usize::try_from(KVM_NR_FP_REGS)? { core_regs.fp_regs.vregs[i] = self.get_one_reg(Arm64CoreRegs::UserFPSIMDStateVregs(i).into())?; } } core_regs.fp_regs.fpsr = - self.get_one_reg(Arm64CoreRegs::UserFPSIMDStateFpsr.into())? as u32; + u32::try_from(self.get_one_reg(Arm64CoreRegs::UserFPSIMDStateFpsr.into())?)?; core_regs.fp_regs.fpcr = - self.get_one_reg(Arm64CoreRegs::UserFPSIMDStateFpcr.into())? as u32; + u32::try_from(self.get_one_reg(Arm64CoreRegs::UserFPSIMDStateFpcr.into())?)?; Ok(core_regs) } @@ -358,29 +367,41 @@ impl KvmCpu { /// * `vcpu_fd` - the VcpuFd in KVM mod. /// * `core_regs` - kvm_regs state to be written. fn set_core_regs(&self, core_regs: kvm_regs) -> Result<()> { - self.set_one_reg(Arm64CoreRegs::UserPTRegSp.into(), core_regs.regs.sp as u128)?; - self.set_one_reg(Arm64CoreRegs::KvmSpEl1.into(), core_regs.sp_el1 as u128)?; + self.set_one_reg( + Arm64CoreRegs::UserPTRegSp.into(), + u128::from(core_regs.regs.sp), + )?; + self.set_one_reg(Arm64CoreRegs::KvmSpEl1.into(), u128::from(core_regs.sp_el1))?; self.set_one_reg( Arm64CoreRegs::UserPTRegPState.into(), - core_regs.regs.pstate as u128, + u128::from(core_regs.regs.pstate), + )?; + self.set_one_reg( + Arm64CoreRegs::UserPTRegPc.into(), + u128::from(core_regs.regs.pc), + )?; + self.set_one_reg( + Arm64CoreRegs::KvmElrEl1.into(), + u128::from(core_regs.elr_el1), )?; - self.set_one_reg(Arm64CoreRegs::UserPTRegPc.into(), core_regs.regs.pc as u128)?; - self.set_one_reg(Arm64CoreRegs::KvmElrEl1.into(), core_regs.elr_el1 as u128)?; - for i in 0..KVM_NR_REGS as usize { + for i in 0..usize::try_from(KVM_NR_REGS)? { self.set_one_reg( Arm64CoreRegs::UserPTRegRegs(i).into(), - core_regs.regs.regs[i] as u128, + u128::from(core_regs.regs.regs[i]), )?; } - for i in 0..KVM_NR_SPSR as usize { - self.set_one_reg(Arm64CoreRegs::KvmSpsr(i).into(), core_regs.spsr[i] as u128)?; + for i in 0..usize::try_from(KVM_NR_SPSR)? { + self.set_one_reg( + Arm64CoreRegs::KvmSpsr(i).into(), + u128::from(core_regs.spsr[i]), + )?; } // State save and restore is not supported for SVE for now, so we just skip it. if self.kvi.lock().unwrap().features[0] & (1 << kvm_bindings::KVM_ARM_VCPU_SVE) == 0 { - for i in 0..KVM_NR_FP_REGS as usize { + for i in 0..usize::try_from(KVM_NR_FP_REGS)? { self.set_one_reg( Arm64CoreRegs::UserFPSIMDStateVregs(i).into(), core_regs.fp_regs.vregs[i], @@ -390,11 +411,11 @@ impl KvmCpu { self.set_one_reg( Arm64CoreRegs::UserFPSIMDStateFpsr.into(), - core_regs.fp_regs.fpsr as u128, + u128::from(core_regs.fp_regs.fpsr), )?; self.set_one_reg( Arm64CoreRegs::UserFPSIMDStateFpcr.into(), - core_regs.fp_regs.fpcr as u128, + u128::from(core_regs.fp_regs.fpcr), )?; Ok(()) diff --git a/hypervisor/src/kvm/interrupt.rs b/hypervisor/src/kvm/interrupt.rs index ea9e7790c78d217585d4a5e117a03f570165670c..b067b17c3d598abe3b0a427c6a94854448e2bbdb 100644 --- a/hypervisor/src/kvm/interrupt.rs +++ b/hypervisor/src/kvm/interrupt.rs @@ -46,7 +46,7 @@ fn get_maximum_gsi_cnt(kvmfd: &Kvm) -> u32 { gsi_count = 0; } - gsi_count as u32 + u32::try_from(gsi_count).unwrap_or_default() } /// Return `IrqRouteEntry` according to gsi, irqchip kind and pin. @@ -180,7 +180,7 @@ impl IrqRouteTable { .find_next_zero(0) .with_context(|| "Failed to get new free gsi")?; self.gsi_bitmap.set(free_gsi)?; - Ok(free_gsi as u32) + Ok(u32::try_from(free_gsi)?) } /// Release gsi number to free. @@ -208,11 +208,12 @@ impl IrqRouteTable { trace::kvm_commit_irq_routing(); // SAFETY: data in `routes` is reliable. unsafe { + // layout is aligned, so casting of ptr is allowed. let irq_routing = std::alloc::alloc(layout) as *mut IrqRoute; if irq_routing.is_null() { bail!("Failed to alloc irq routing"); } - (*irq_routing).nr = routes.len() as u32; + (*irq_routing).nr = u32::try_from(routes.len())?; (*irq_routing).flags = 0; let entries: &mut [IrqRouteEntry] = (*irq_routing).entries.as_mut_slice(routes.len()); entries.copy_from_slice(&routes); @@ -236,7 +237,7 @@ mod tests { #[test] fn test_get_maximum_gsi_cnt() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } @@ -245,14 +246,14 @@ mod tests { #[test] fn test_alloc_and_release_gsi() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } - let irq_route_table = Mutex::new(IrqRouteTable::new(&kvm_hyp.fd.as_ref().unwrap())); + let irq_route_table = Mutex::new(IrqRouteTable::new(kvm_hyp.fd.as_ref().unwrap())); let irq_manager = Arc::new(KVMInterruptManager::new( true, - kvm_hyp.vm_fd.clone().unwrap(), + kvm_hyp.vm_fd.unwrap(), irq_route_table, )); let mut irq_route_table = irq_manager.irq_route_table.lock().unwrap(); diff --git a/hypervisor/src/kvm/listener.rs b/hypervisor/src/kvm/listener.rs index 713b72384605398c76beac6f5b3ca0fcdde660c8..1a6943ffb2e802e8bed7855d3cdf760d8f63af9d 100644 --- a/hypervisor/src/kvm/listener.rs +++ b/hypervisor/src/kvm/listener.rs @@ -22,7 +22,7 @@ use log::{debug, warn}; use crate::HypervisorError; use address_space::{ - AddressRange, AddressSpaceError, FlatRange, Listener, ListenerReqType, MemSlot, + AddressAttr, AddressRange, AddressSpaceError, FlatRange, Listener, ListenerReqType, MemSlot, RegionIoEventFd, RegionType, }; use util::{num_ops::round_down, unix::host_page_size}; @@ -93,7 +93,7 @@ impl KvmMemoryListener { for (index, slot) in slots.iter_mut().enumerate() { if slot.size == 0 { - slot.index = index as u32; + slot.index = u32::try_from(index)?; slot.guest_addr = guest_addr; slot.size = size; slot.host_addr = host_addr; @@ -179,12 +179,12 @@ impl KvmMemoryListener { return Ok(()); } - if flat_range.owner.region_type() != RegionType::Ram - && flat_range.owner.region_type() != RegionType::RomDevice - && flat_range.owner.region_type() != RegionType::RamDevice - { - return Ok(()); - } + let attr = match flat_range.owner.region_type() { + address_space::RegionType::Ram => AddressAttr::Ram, + address_space::RegionType::RamDevice => AddressAttr::RamDevice, + address_space::RegionType::RomDevice => AddressAttr::RomDevice, + _ => return Ok(()), + }; let (aligned_addr, aligned_size) = Self::align_mem_slot(flat_range.addr_range, host_page_size()) @@ -193,7 +193,8 @@ impl KvmMemoryListener { let align_adjust = aligned_addr.raw_value() - flat_range.addr_range.base.raw_value(); // `unwrap()` won't fail because Ram-type Region definitely has hva - let aligned_hva = flat_range.owner.get_host_address().unwrap() + // SAFETY: size has been checked. + let aligned_hva = unsafe { flat_range.owner.get_host_address(attr).unwrap() } + flat_range.offset_in_region + align_adjust; @@ -316,8 +317,8 @@ impl KvmMemoryListener { let ioctl_ret = if ioevtfd.data_match { let length = ioevtfd.addr_range.size; match length { - 2 => vm_fd.register_ioevent(&ioevtfd.fd, &io_addr, ioevtfd.data as u16), - 4 => vm_fd.register_ioevent(&ioevtfd.fd, &io_addr, ioevtfd.data as u32), + 2 => vm_fd.register_ioevent(&ioevtfd.fd, &io_addr, u16::try_from(ioevtfd.data)?), + 4 => vm_fd.register_ioevent(&ioevtfd.fd, &io_addr, u32::try_from(ioevtfd.data)?), 8 => vm_fd.register_ioevent(&ioevtfd.fd, &io_addr, ioevtfd.data), _ => bail!("Unexpected ioeventfd data length {}", length), } @@ -357,8 +358,8 @@ impl KvmMemoryListener { let ioctl_ret = if ioevtfd.data_match { let length = ioevtfd.addr_range.size; match length { - 2 => vm_fd.unregister_ioevent(&ioevtfd.fd, &io_addr, ioevtfd.data as u16), - 4 => vm_fd.unregister_ioevent(&ioevtfd.fd, &io_addr, ioevtfd.data as u32), + 2 => vm_fd.unregister_ioevent(&ioevtfd.fd, &io_addr, u16::try_from(ioevtfd.data)?), + 4 => vm_fd.unregister_ioevent(&ioevtfd.fd, &io_addr, u32::try_from(ioevtfd.data)?), 8 => vm_fd.unregister_ioevent(&ioevtfd.fd, &io_addr, ioevtfd.data), _ => bail!("Unexpected ioeventfd data length {}", length), } @@ -628,12 +629,12 @@ mod test { #[test] fn test_alloc_slot() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } - let kml = KvmMemoryListener::new(4, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots.clone()); + let kml = KvmMemoryListener::new(4, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots); let host_addr = 0u64; assert_eq!(kml.get_free_slot(0, 100, host_addr).unwrap(), 0); @@ -652,12 +653,12 @@ mod test { #[test] fn test_add_del_ram_region() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } - let kml = KvmMemoryListener::new(34, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots.clone()); + let kml = KvmMemoryListener::new(34, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots); let ram_size = host_page_size(); let ram_fr1 = create_ram_range(0, ram_size, 0); @@ -678,12 +679,12 @@ mod test { #[test] fn test_add_region_align() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } - let kml = KvmMemoryListener::new(34, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots.clone()); + let kml = KvmMemoryListener::new(34, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots); // flat-range not aligned let page_size = host_page_size(); let ram_fr2 = create_ram_range(page_size, 2 * page_size, 1000); @@ -700,12 +701,12 @@ mod test { #[test] fn test_add_del_ioeventfd() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } - let kml = KvmMemoryListener::new(34, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots.clone()); + let kml = KvmMemoryListener::new(34, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots); let evtfd = generate_region_ioeventfd(4, NoDatamatch); assert!(kml .handle_request(None, Some(&evtfd), ListenerReqType::AddIoeventfd) @@ -746,12 +747,12 @@ mod test { #[test] fn test_ioeventfd_with_data_match() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } - let kml = KvmMemoryListener::new(34, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots.clone()); + let kml = KvmMemoryListener::new(34, kvm_hyp.vm_fd.clone(), kvm_hyp.mem_slots); let evtfd_addr = 0x1000_u64; let mut evtfd = generate_region_ioeventfd(evtfd_addr, 64_u32); evtfd.addr_range.size = 3_u64; @@ -767,7 +768,7 @@ mod test { // Delete ioeventfd with wrong address will cause an error. let mut evtfd_to_del = evtfd.clone(); - evtfd_to_del.addr_range.base.0 = evtfd_to_del.addr_range.base.0 - 2; + evtfd_to_del.addr_range.base.0 -= 2; assert!(kml .handle_request(None, Some(&evtfd_to_del), ListenerReqType::DeleteIoeventfd) .is_err()); @@ -800,12 +801,12 @@ mod test { #[test] #[cfg(target_arch = "x86_64")] fn test_kvm_io_listener() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } - let iol = KvmIoListener::new(kvm_hyp.vm_fd.clone()); + let iol = KvmIoListener::new(kvm_hyp.vm_fd); let evtfd = generate_region_ioeventfd(4, NoDatamatch); assert!(iol .handle_request(None, Some(&evtfd), ListenerReqType::AddIoeventfd) diff --git a/hypervisor/src/kvm/mod.rs b/hypervisor/src/kvm/mod.rs index 8648b9c219dc60b663aa223fbbff8ddc9b5319f9..16e7a479a23b9247d86db2291cc27a7f95d764b6 100644 --- a/hypervisor/src/kvm/mod.rs +++ b/hypervisor/src/kvm/mod.rs @@ -36,9 +36,11 @@ use anyhow::anyhow; use anyhow::{bail, Context, Result}; use kvm_bindings::kvm_userspace_memory_region as KvmMemSlot; use kvm_bindings::*; +#[cfg(feature = "vfio_device")] +use kvm_ioctls::DeviceFd; #[cfg(not(test))] use kvm_ioctls::VcpuExit; -use kvm_ioctls::{Cap, DeviceFd, Kvm, VcpuFd, VmFd}; +use kvm_ioctls::{Cap, Kvm, VcpuFd, VmFd}; use libc::{c_int, c_void, siginfo_t}; use log::{error, info, warn}; use vmm_sys_util::{ @@ -54,8 +56,6 @@ use crate::HypervisorError; #[cfg(target_arch = "aarch64")] use aarch64::cpu_caps::ArmCPUCaps as CPUCaps; use address_space::{AddressSpace, Listener}; -#[cfg(feature = "boot_time")] -use cpu::capture_boot_signal; #[cfg(target_arch = "aarch64")] use cpu::CPUFeatures; use cpu::{ @@ -135,6 +135,7 @@ impl KvmHypervisor { } fn create_memory_listener(&self) -> Arc> { + // Memslot will not exceed u32::MAX, so use as translate data type. Arc::new(Mutex::new(KvmMemoryListener::new( self.fd.as_ref().unwrap().get_nr_memslots() as u32, self.vm_fd.clone(), @@ -229,7 +230,7 @@ impl HypervisorOps for KvmHypervisor { .vm_fd .as_ref() .unwrap() - .create_vcpu(vcpu_id as u64) + .create_vcpu(u64::from(vcpu_id)) .with_context(|| "Create vcpu failed")?; Ok(Arc::new(KvmCpu::new( vcpu_id, @@ -259,6 +260,7 @@ impl HypervisorOps for KvmHypervisor { }) } + #[cfg(feature = "vfio_device")] fn create_vfio_device(&self) -> Option { let mut device = kvm_create_device { type_: kvm_device_type_KVM_DEV_TYPE_VFIO, @@ -297,7 +299,7 @@ impl MigrateOps for KvmHypervisor { self.vm_fd .as_ref() .unwrap() - .get_dirty_log(slot, mem_size as usize) + .get_dirty_log(slot, usize::try_from(mem_size)?) .with_context(|| { format!( "Failed to get dirty log, error is {}", @@ -424,7 +426,7 @@ impl KvmCpu { #[cfg(target_arch = "x86_64")] VcpuExit::IoOut(addr, data) => { #[cfg(feature = "boot_time")] - capture_boot_signal(addr as u64, data); + cpu::capture_boot_signal(u64::from(addr), data); vm.lock().unwrap().pio_out(u64::from(addr), data); } @@ -433,7 +435,7 @@ impl KvmCpu { } VcpuExit::MmioWrite(addr, data) => { #[cfg(all(target_arch = "aarch64", feature = "boot_time"))] - capture_boot_signal(addr, data); + cpu::capture_boot_signal(addr, data); vm.lock().unwrap().mmio_write(addr, data); } @@ -466,7 +468,7 @@ impl KvmCpu { return Ok(true); } else { error!( - "Vcpu{} received unexpected system event with type 0x{:x}, flags 0x{:x}", + "Vcpu{} received unexpected system event with type 0x{:x}, flags {:#x?}", cpu.id(), event, flags @@ -645,7 +647,9 @@ impl CPUHypervisorOps for KvmCpu { if *cpu_state.lock().unwrap() == CpuLifecycleState::Running { *cpu_state.lock().unwrap() = CpuLifecycleState::Paused; cvar.notify_one() - } else if *cpu_state.lock().unwrap() == CpuLifecycleState::Paused { + } else if *cpu_state.lock().unwrap() == CpuLifecycleState::Paused + && pause_signal.load(Ordering::SeqCst) + { return Ok(()); } @@ -662,10 +666,13 @@ impl CPUHypervisorOps for KvmCpu { } // It shall wait for the vCPU pause state from hypervisor exits. - loop { - if pause_signal.load(Ordering::SeqCst) { - break; + let mut sleep_times = 0u32; + while !pause_signal.load(Ordering::SeqCst) { + if sleep_times >= 5 { + bail!(CpuError::StopVcpu("timeout".to_string())); } + thread::sleep(Duration::from_millis(5)); + sleep_times += 1; } Ok(()) @@ -816,6 +823,10 @@ impl LineIrqManager for KVMInterruptManager { } impl MsiIrqManager for KVMInterruptManager { + fn irqfd_enable(&self) -> bool { + self.irqfd_cap + } + fn allocate_irq(&self, vector: MsiVector) -> Result { let mut locked_irq_route_table = self.irq_route_table.lock().unwrap(); let gsi = locked_irq_route_table.allocate_gsi().map_err(|e| { @@ -993,7 +1004,7 @@ mod test { #[cfg(target_arch = "x86_64")] #[test] fn test_x86_64_kvm_cpu() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } @@ -1057,7 +1068,7 @@ mod test { vcpu_fd, )); let x86_cpu = Arc::new(Mutex::new(ArchCPU::new(0, 1))); - let cpu = CPU::new(hypervisor_cpu.clone(), 0, x86_cpu, vm.clone()); + let cpu = CPU::new(hypervisor_cpu.clone(), 0, x86_cpu, vm); // test `set_boot_config` function assert!(hypervisor_cpu .set_boot_config(cpu.arch().clone(), &cpu_config) @@ -1101,7 +1112,7 @@ mod test { #[test] #[allow(unused)] fn test_cpu_lifecycle_with_kvm() { - let kvm_hyp = KvmHypervisor::new().unwrap_or(KvmHypervisor::default()); + let kvm_hyp = KvmHypervisor::new().unwrap_or_default(); if kvm_hyp.vm_fd.is_none() { return; } @@ -1119,7 +1130,7 @@ mod test { hypervisor_cpu.clone(), 0, Arc::new(Mutex::new(ArchCPU::default())), - vm.clone(), + vm, ); let (cpu_state, _) = &*cpu.state; assert_eq!(*cpu_state.lock().unwrap(), CpuLifecycleState::Created); diff --git a/hypervisor/src/lib.rs b/hypervisor/src/lib.rs index 25fc90ef67066a466a1ed8269babfedf8c9bd104..a156d9af65dc93274c7c0b30f1bd020ca5c818b9 100644 --- a/hypervisor/src/lib.rs +++ b/hypervisor/src/lib.rs @@ -22,6 +22,7 @@ use std::any::Any; use std::sync::Arc; use anyhow::Result; +#[cfg(feature = "vfio_device")] use kvm_ioctls::DeviceFd; use address_space::AddressSpace; @@ -56,5 +57,6 @@ pub trait HypervisorOps: Send + Sync + Any { fn create_irq_manager(&mut self) -> Result; + #[cfg(feature = "vfio_device")] fn create_vfio_device(&self) -> Option; } diff --git a/hypervisor/src/test/mod.rs b/hypervisor/src/test/mod.rs index 8fce37ee16c9d9b09d40edd555f54b935d7a441b..99abb06ed1cb403f80adf2c01da64a9b880fa6a0 100644 --- a/hypervisor/src/test/mod.rs +++ b/hypervisor/src/test/mod.rs @@ -21,6 +21,7 @@ use std::thread; use std::time::Duration; use anyhow::{anyhow, Context, Result}; +#[cfg(feature = "vfio_device")] use kvm_ioctls::DeviceFd; use log::info; use vmm_sys_util::eventfd::EventFd; @@ -41,7 +42,7 @@ use devices::{pci::MsiVector, IrqManager, LineIrqManager, MsiIrqManager, Trigger use devices::{GICVersion, GICv3, ICGICConfig, InterruptController, GIC_IRQ_INTERNAL}; use machine_manager::machine::HypervisorType; use migration::{MigrateMemSlot, MigrateOps}; -use util::test_helper::{IntxInfo, MsixMsg, TEST_INTX_LIST, TEST_MSIX_LIST}; +use util::test_helper::{add_msix_msg, IntxInfo, TEST_INTX_LIST}; pub struct TestHypervisor {} @@ -115,6 +116,7 @@ impl HypervisorOps for TestHypervisor { }) } + #[cfg(feature = "vfio_device")] fn create_vfio_device(&self) -> Option { None } @@ -153,7 +155,7 @@ impl CPUHypervisorOps for TestCpu { ) -> Result<()> { #[cfg(target_arch = "aarch64")] { - arch_cpu.lock().unwrap().mpidr = self.id as u64; + arch_cpu.lock().unwrap().mpidr = u64::from(self.id); arch_cpu.lock().unwrap().set_core_reg(boot_config); } Ok(()) @@ -301,19 +303,6 @@ impl TestInterruptManager { pub fn arch_map_irq(&self, gsi: u32) -> u32 { gsi + GIC_IRQ_INTERNAL } - - pub fn add_msix_msg(addr: u64, data: u32) { - let new_msg = MsixMsg::new(addr, data); - let mut msix_list_lock = TEST_MSIX_LIST.lock().unwrap(); - - for msg in msix_list_lock.iter() { - if new_msg.addr == msg.addr && new_msg.data == msg.data { - return; - } - } - - msix_list_lock.push(new_msg); - } } impl LineIrqManager for TestInterruptManager { @@ -368,6 +357,10 @@ impl LineIrqManager for TestInterruptManager { } impl MsiIrqManager for TestInterruptManager { + fn irqfd_enable(&self) -> bool { + false + } + fn allocate_irq(&self, _vector: MsiVector) -> Result { Err(anyhow!( "Failed to allocate irq, mst doesn't support irq routing feature." @@ -399,9 +392,9 @@ impl MsiIrqManager for TestInterruptManager { _dev_id: u32, ) -> Result<()> { let data = vector.msg_data; - let mut addr: u64 = vector.msg_addr_hi as u64; - addr = (addr << 32) + vector.msg_addr_lo as u64; - TestInterruptManager::add_msix_msg(addr, data); + let mut addr: u64 = u64::from(vector.msg_addr_hi); + addr = (addr << 32) + u64::from(vector.msg_addr_lo); + add_msix_msg(addr, data); Ok(()) } diff --git a/image/src/cmdline.rs b/image/src/cmdline.rs index a2f6079bfdc20e097e4b07e0c85dedd3a60e3d34..c991366c65333f9a382f435461e944413b452111 100644 --- a/image/src/cmdline.rs +++ b/image/src/cmdline.rs @@ -70,16 +70,16 @@ impl ArgsParse { let mut pre_opt = (0, "".to_string()); for idx in 0..len { - let str = args[idx as usize].clone(); - if str.starts_with("-") && str.len() > 1 { - if pre_opt.1.len() != 0 { + let str = args[idx].clone(); + if str.starts_with('-') && str.len() > 1 { + if !pre_opt.1.is_empty() { bail!("missing argument for option '{}'", pre_opt.1); } let name = if str.starts_with("--") && str.len() > 2 { - (&str[2..]).to_string() - } else if str.starts_with("-") && str.len() > 1 { - (&str[1..]).to_string() + str[2..].to_string() + } else if str.starts_with('-') && str.len() > 1 { + str[1..].to_string() } else { bail!("unrecognized option '{}'", str); }; @@ -100,7 +100,7 @@ impl ArgsParse { continue; } - if pre_opt.0 + 1 == idx && pre_opt.1.len() != 0 { + if pre_opt.0 + 1 == idx && !pre_opt.1.is_empty() { let name = pre_opt.1.to_string(); let value = str.to_string(); if let Some(arg) = self.args.get_mut(&name) { @@ -117,14 +117,14 @@ impl ArgsParse { } } pre_opt = (0, "".to_string()); - } else if pre_opt.1.len() == 0 { + } else if pre_opt.1.is_empty() { self.free.push(str.to_string()); } else { bail!("unrecognized option '{}'", pre_opt.1); } } - if pre_opt.0 == 0 && pre_opt.1.len() != 0 { + if pre_opt.0 == 0 && !pre_opt.1.is_empty() { bail!("unrecognized option '{}'", pre_opt.1); } @@ -162,20 +162,16 @@ mod test { fn test_arg_parse() { let mut arg_parser = ArgsParse::create(vec!["q", "h", "help"], vec!["f"], vec!["o"]); let cmd_line = "-f qcow2 -q -h --help -o cluster_size=512 -o refcount_bits=16 img_path +1G"; - let cmd_args: Vec = cmd_line - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let cmd_args: Vec = cmd_line.split(' ').map(|str| str.to_string()).collect(); let ret = arg_parser.parse(cmd_args); println!("{:?}", ret); assert!(ret.is_ok()); - assert_eq!(arg_parser.opt_present("f"), true); - assert_eq!(arg_parser.opt_present("q"), true); - assert_eq!(arg_parser.opt_present("h"), true); - assert_eq!(arg_parser.opt_present("help"), true); + assert!(arg_parser.opt_present("f")); + assert!(arg_parser.opt_present("q")); + assert!(arg_parser.opt_present("h")); + assert!(arg_parser.opt_present("help")); let values = arg_parser.opt_strs("o"); assert!(values.contains(&"cluster_size=512".to_string())); diff --git a/image/src/img.rs b/image/src/img.rs index 0e54d0b90c5e3c4f90e49a8012f14bdeed7c1779..bc5b58ebcad1b0ea233410fe49a344d209610145 100644 --- a/image/src/img.rs +++ b/image/src/img.rs @@ -23,12 +23,12 @@ use crate::{cmdline::ArgsParse, BINARY_NAME}; use block_backend::{ qcow2::{header::QcowHeader, InternalSnapshotOps, Qcow2Driver, SyncAioInfo}, raw::RawDriver, - BlockDriverOps, BlockProperty, CheckResult, CreateOptions, FIX_ERRORS, FIX_LEAKS, NO_FIX, - SECTOR_SIZE, + BlockAllocStatus, BlockDriverOps, BlockProperty, CheckResult, CreateOptions, ImageInfo, + FIX_ERRORS, FIX_LEAKS, NO_FIX, SECTOR_SIZE, }; use machine_manager::config::{memory_unit_conversion, DiskFormat}; use util::{ - aio::{Aio, AioEngine, WriteZeroesState}, + aio::{buffer_is_zero, Aio, AioEngine, Iovec, WriteZeroesState}, file::{lock_file, open_file, unlock_file}, }; @@ -37,10 +37,11 @@ enum SnapshotOperation { Delete, Apply, List, + Rename, } pub struct ImageFile { - file: File, + file: Arc, path: String, } @@ -58,7 +59,7 @@ impl ImageFile { })?; Ok(Self { - file, + file: Arc::new(file), path: path.to_string(), }) } @@ -86,7 +87,7 @@ impl ImageFile { detect_fmt: DiskFormat, ) -> Result { let real_fmt = match input_fmt { - Some(fmt) if fmt == DiskFormat::Raw => DiskFormat::Raw, + Some(DiskFormat::Raw) => DiskFormat::Raw, Some(fmt) => { if fmt != detect_fmt { bail!( @@ -112,20 +113,44 @@ impl Drop for ImageFile { } } +fn image_do_create(create_options: &CreateOptions, print_info: bool) -> Result<()> { + let path = create_options.path.clone(); + let file = Arc::new( + std::fs::OpenOptions::new() + .read(true) + .write(true) + .custom_flags(libc::O_CREAT | libc::O_TRUNC) + .mode(0o660) + .open(path)?, + ); + + let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None)?; + + let mut driver: Box> = match create_options.conf.format { + DiskFormat::Raw => Box::new(RawDriver::new(file, aio, create_options.conf.clone())), + DiskFormat::Qcow2 => Box::new(Qcow2Driver::new(file, aio, create_options.conf.clone())?), + }; + let image_info = driver.as_mut().create_image(create_options)?; + + if print_info { + println!("Stratovirt-img: {}", image_info); + } + + Ok(()) +} + pub(crate) fn image_create(args: Vec) -> Result<()> { let mut create_options = CreateOptions::default(); let mut arg_parser = ArgsParse::create(vec!["h", "help"], vec!["f"], vec!["o"]); - arg_parser.parse(args.clone())?; + arg_parser.parse(args)?; if arg_parser.opt_present("h") || arg_parser.opt_present("help") { print_help(); return Ok(()); } - let mut disk_fmt = DiskFormat::Raw; - if let Some(fmt) = arg_parser.opt_str("f") { - disk_fmt = DiskFormat::from_str(&fmt)?; - }; + let fmt = arg_parser.opt_str("f").unwrap_or_else(|| "raw".to_string()); + create_options.conf.format = DiskFormat::from_str(&fmt)?; let extra_options = arg_parser.opt_strs("o"); for option in extra_options { @@ -164,29 +189,56 @@ pub(crate) fn image_create(args: Vec) -> Result<()> { } } - let path = create_options.path.clone(); - let file = std::fs::OpenOptions::new() - .read(true) - .write(true) - .custom_flags(libc::O_CREAT | libc::O_TRUNC) - .mode(0o660) - .open(path.clone())?; + image_do_create(&create_options, true)?; - let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None)?; - let image_info = match disk_fmt { - DiskFormat::Raw => { - create_options.conf.format = DiskFormat::Raw; - let mut raw_driver = RawDriver::new(file, aio, create_options.conf.clone()); - raw_driver.create_image(&create_options)? + Ok(()) +} + +pub(crate) fn image_info(args: Vec) -> Result<()> { + if args.is_empty() { + bail!("Not enough arguments"); + } + let mut arg_parser = ArgsParse::create(vec!["h", "help"], vec![], vec![]); + arg_parser.parse(args)?; + + if arg_parser.opt_present("h") || arg_parser.opt_present("help") { + print_help(); + return Ok(()); + } + + // Parse the image path. + let len = arg_parser.free.len(); + let img_path = match len { + 0 => bail!("Image path is needed"), + 1 => arg_parser.free[0].clone(), + _ => { + let param = arg_parser.free[1].clone(); + bail!("Unexpected argument: {}", param); } + }; + + let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None)?; + let image_file = ImageFile::create(&img_path, false)?; + let detect_fmt = image_file.detect_img_format()?; + let conf = BlockProperty { + format: detect_fmt, + ..Default::default() + }; + let mut driver: Box> = match detect_fmt { + DiskFormat::Raw => Box::new(RawDriver::new(image_file.file.clone(), aio, conf)), DiskFormat::Qcow2 => { - create_options.conf.format = DiskFormat::Qcow2; - let mut qcow2_driver = Qcow2Driver::new(file, aio, create_options.conf.clone())?; - qcow2_driver.create_image(&create_options)? + let mut qocw2_driver = Qcow2Driver::new(image_file.file.clone(), aio, conf.clone())?; + qocw2_driver.load_metadata(conf)?; + Box::new(qocw2_driver) } }; - println!("Stratovirt-img: {}", image_info); + let mut image_info = ImageInfo { + path: img_path, + ..Default::default() + }; + driver.query_image(&mut image_info)?; + print!("{}", image_info); Ok(()) } @@ -212,9 +264,9 @@ pub(crate) fn image_check(args: Vec) -> Result<()> { } if let Some(kind) = arg_parser.opt_str("r") { - if kind == "leaks".to_string() { + if kind == *"leaks" { fix |= FIX_LEAKS; - } else if kind == "all".to_string() { + } else if kind == *"all" { fix |= FIX_LEAKS; fix |= FIX_ERRORS; } else { @@ -242,14 +294,16 @@ pub(crate) fn image_check(args: Vec) -> Result<()> { let real_fmt = image_file.check_img_format(disk_fmt, detect_fmt)?; let mut check_res = CheckResult::default(); - let file = image_file.file.try_clone()?; + let file = image_file.file.clone(); match real_fmt { DiskFormat::Raw => { bail!("stratovirt-img: This image format does not support checks"); } DiskFormat::Qcow2 => { - let mut conf = BlockProperty::default(); - conf.format = DiskFormat::Qcow2; + let conf = BlockProperty { + format: DiskFormat::Qcow2, + ..Default::default() + }; let mut qcow2_driver = create_qcow2_driver_for_check(file, conf)?; let ret = qcow2_driver.check_image(&mut check_res, quite, fix); let check_message = check_res.collect_check_message(); @@ -291,14 +345,14 @@ pub(crate) fn image_resize(mut args: Vec) -> Result<()> { let image_file = ImageFile::create(&img_path, false)?; let detect_fmt = image_file.detect_img_format()?; let real_fmt = image_file.check_img_format(disk_fmt, detect_fmt)?; - - let mut conf = BlockProperty::default(); - conf.format = real_fmt; + let conf = BlockProperty { + format: real_fmt, + ..Default::default() + }; let mut driver: Box> = match real_fmt { - DiskFormat::Raw => Box::new(RawDriver::new(image_file.file.try_clone()?, aio, conf)), + DiskFormat::Raw => Box::new(RawDriver::new(image_file.file.clone(), aio, conf)), DiskFormat::Qcow2 => { - let mut qocw2_driver = - Qcow2Driver::new(image_file.file.try_clone()?, aio, conf.clone())?; + let mut qocw2_driver = Qcow2Driver::new(image_file.file.clone(), aio, conf.clone())?; qocw2_driver.load_metadata(conf)?; Box::new(qocw2_driver) } @@ -306,12 +360,12 @@ pub(crate) fn image_resize(mut args: Vec) -> Result<()> { let old_size = driver.disk_size()?; // Only expansion is supported currently. - let new_size = if size_str.starts_with("+") { + let new_size = if size_str.starts_with('+') { let size = memory_unit_conversion(&size_str, 1)?; old_size .checked_add(size) .ok_or_else(|| anyhow!("Disk size is too large for chosen offset"))? - } else if size_str.starts_with("-") { + } else if size_str.starts_with('-') { bail!("The shrink operation is not supported"); } else { let new_size = memory_unit_conversion(&size_str, 1)?; @@ -327,9 +381,256 @@ pub(crate) fn image_resize(mut args: Vec) -> Result<()> { Ok(()) } +// Default 4k(8 sectors) for sparse detection. +const DEFAULT_SPARSE_SIZE: u8 = 8; +// Default 2M buffer size for convert. +const DEFAULT_BUF_SIZE: usize = 1 << 21; + +#[derive(Debug, PartialEq, Eq)] +enum ConvertDataStatus { + Zero = 0, + Data, +} + +impl ConvertDataStatus { + fn from_bool(is_zero: bool) -> Self { + if is_zero { + ConvertDataStatus::Zero + } else { + ConvertDataStatus::Data + } + } +} + +// Describe a continuous address segment which has the same allocation status when converting. +#[derive(PartialEq, Eq)] +struct ConvertSeg { + start: u64, + len: u64, + status: ConvertDataStatus, +} + +impl ConvertSeg { + fn new(start: u64, len: u64, is_zero: bool) -> Self { + Self { + start, + len, + status: ConvertDataStatus::from_bool(is_zero), + } + } +} + +fn convert_do_read( + driver: &mut dyn BlockDriverOps<()>, + buf: &mut [u8], + len: u64, + offset: u64, +) -> Result<()> { + let iov = vec![Iovec { + iov_base: buf.as_ptr() as u64, + iov_len: len, + }]; + driver.read_vectored(iov, offset as usize, ()) +} + +fn convert_do_seg_write( + driver: &mut dyn BlockDriverOps<()>, + buf: &mut [u8], + data_seg: &ConvertSeg, +) -> Result<()> { + match data_seg.status { + ConvertDataStatus::Data => { + let iov = vec![Iovec { + iov_base: buf.as_ptr() as u64, + iov_len: data_seg.len, + }]; + driver.write_vectored(iov, data_seg.start as usize, ()) + } + ConvertDataStatus::Zero => { + // The output image is a newly created sparse file, which defaults to all holes. + // Sequential writing does not require writing zeroes again. Do nothing. + Ok(()) + } + } +} + +fn convert_do_write( + driver: &mut dyn BlockDriverOps<()>, + buf: &mut [u8], + buf_len: u64, + offset: u64, + sparse_size: u64, +) -> Result<()> { + let mut convert_seg = ConvertSeg::new(offset, 0, true); + let mut first_check = true; + + let mut pos = 0_u64; + while pos < buf_len { + let detect_size = std::cmp::min(sparse_size, buf_len - pos); + let detect_buf = &buf[pos as usize..(pos + detect_size) as usize]; + + // SAFETY: Buffer is a local `Vec` variable in `image_convert`, its base and len are both valid values. + let local_zero = unsafe { buffer_is_zero(detect_buf.as_ptr() as u64, detect_size) }; + + if first_check { + convert_seg.status = ConvertDataStatus::from_bool(local_zero); + first_check = false; + } + + if ConvertDataStatus::from_bool(local_zero) != convert_seg.status { + convert_do_seg_write( + driver, + &mut buf[(convert_seg.start - offset) as usize..pos as usize], + &convert_seg, + )?; + + convert_seg.status = ConvertDataStatus::from_bool(local_zero); + convert_seg.start = offset + pos; + convert_seg.len = detect_size; + + pos += detect_size; + continue; + } + + convert_seg.len += detect_size; + pos += detect_size; + } + + if !first_check { + convert_do_seg_write( + driver, + &mut buf[(convert_seg.start - offset) as usize..pos as usize], + &convert_seg, + )?; + } + Ok(()) +} + +fn image_create_blockdriver( + file: Arc, + format: DiskFormat, +) -> Result>> { + let conf = BlockProperty { + format, + ..Default::default() + }; + let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None).unwrap(); + + match format { + DiskFormat::Qcow2 => { + let mut qcow2_driver = Qcow2Driver::new(file, aio, conf.clone())?; + qcow2_driver.load_metadata(conf)?; + Ok(Box::new(qcow2_driver)) + } + DiskFormat::Raw => Ok(Box::new(RawDriver::new(file, aio, conf))), + } +} + +pub(crate) fn image_convert(args: Vec) -> Result<()> { + let mut arg_parser = ArgsParse::create(vec!["h", "help"], vec!["f", "O", "S"], vec![]); + arg_parser.parse(args)?; + + if arg_parser.opt_present("h") || arg_parser.opt_present("help") { + print_help(); + return Ok(()); + } + + // Parse the image path. Command line should have 2 args at the end(input filename and output filename). + if arg_parser.free.len() != 2 { + bail!("Invalid input/output filenames."); + } + let input_path = arg_parser.free[0].clone(); + let output_path = arg_parser.free[1].clone(); + + // Check output image filename: to avoid accidentally deleting existing files, + // it is prohibited to use the name of an existing file as the output image name. + if std::path::Path::new(&output_path).exists() { + bail!( + "The file {} already exists, please use a different name.", + output_path + ); + } + + // Parse the input image format. If not set, it will detect the corresponding image file to determine the image format. + let input_fmt_str = arg_parser.opt_str("f").unwrap_or_default(); + let input_image_file = ImageFile::create(&input_path, true)?; + let detect_input_fmt = input_image_file.detect_img_format()?; + if !input_fmt_str.is_empty() { + let input_fmt = DiskFormat::from_str(&input_fmt_str)?; + if detect_input_fmt != input_fmt { + bail!("'{}' is not a {} file.", input_path, input_fmt_str); + } + } + + // Parse the output image format. Default is "raw". + let output_fmt_str = arg_parser.opt_str("O").unwrap_or_else(|| "raw".to_string()); + let output_fmt = DiskFormat::from_str(&output_fmt_str)?; + + // Parse the min sparse. + let min_sparse_str = arg_parser + .opt_str("S") + .unwrap_or_else(|| DEFAULT_SPARSE_SIZE.to_string()); + let min_sparse = min_sparse_str.parse::()? * SECTOR_SIZE; + + // Create input image driver. + let mut input_driver_box = + image_create_blockdriver(input_image_file.file.clone(), detect_input_fmt)?; + let input_driver = input_driver_box.as_mut(); + let size = input_driver.disk_size()?; + + // Create new output file. + let mut create_options = CreateOptions { + path: output_path.clone(), + img_size: size, + ..Default::default() + }; + create_options.conf.format = output_fmt; + image_do_create(&create_options, false)?; + + // Create output image driver. + let output_image_file = ImageFile::create(&output_path, false)?; + let mut output_driver_box = + image_create_blockdriver(output_image_file.file.clone(), output_fmt)?; + let output_driver = output_driver_box.as_mut(); + + // Convert from src image to dst image. + let mut offset = 0; + let mut buf = vec![0_u8; DEFAULT_BUF_SIZE]; + while offset < size { + // Get address segments with the same allocation status. + let (status, len) = input_driver.get_address_alloc_status(offset as u64, size - offset)?; + match status { + // `DATA` means that these data should be read from the src image and be written to the dst image. + BlockAllocStatus::DATA => { + let mut need_convert = len; + let mut cur_offset = offset; + while need_convert > 0 { + let data_len = std::cmp::min(DEFAULT_BUF_SIZE as u64, need_convert); + convert_do_read(input_driver, &mut buf, data_len, cur_offset)?; + convert_do_write(output_driver, &mut buf, data_len, cur_offset, min_sparse)?; + cur_offset += data_len; + need_convert -= data_len; + } + } + // `ZERO` means that these data can be treated as holes. + BlockAllocStatus::ZERO => { + // The output image is a newly created sparse file, which defaults to all holes. + // Sequential writing does not require writing zeroes again. Do nothing. + } + }; + + offset += len; + } + + Ok(()) +} + pub(crate) fn image_snapshot(args: Vec) -> Result<()> { - let mut arg_parser = - ArgsParse::create(vec!["l", "h", "help"], vec!["f", "c", "d", "a"], vec![]); + let mut arg_parser = ArgsParse::create( + vec!["l", "h", "help", "r"], + vec!["f", "c", "d", "a"], + vec![], + ); arg_parser.parse(args)?; if arg_parser.opt_present("h") || arg_parser.opt_present("help") { @@ -374,11 +675,25 @@ pub(crate) fn image_snapshot(args: Vec) -> Result<()> { snapshot_name = name; } - // Parse image path. let len = arg_parser.free.len(); + + // Rename snapshot name. + let mut old_snapshot_name = String::from(""); + let mut new_snapshot_name = String::from(""); + if arg_parser.opt_present("r") { + if len != 3 { + bail!("Invalid args number."); + } + snapshot_operation = Some(SnapshotOperation::Rename); + old_snapshot_name = arg_parser.free[0].clone(); + new_snapshot_name = arg_parser.free[1].clone(); + } + + // Parse image path. let path = match len { 0 => bail!("Image snapshot requires path"), 1 => arg_parser.free[0].clone(), + 3 => arg_parser.free[2].clone(), _ => { let param = arg_parser.free[1].clone(); bail!("Unexpected argument: {}", param); @@ -397,13 +712,15 @@ pub(crate) fn image_snapshot(args: Vec) -> Result<()> { } // Create qcow2 driver. - let mut qcow2_conf = BlockProperty::default(); - qcow2_conf.format = DiskFormat::Qcow2; - qcow2_conf.discard = true; - qcow2_conf.write_zeroes = WriteZeroesState::Unmap; + let qcow2_conf = BlockProperty { + format: DiskFormat::Qcow2, + discard: true, + write_zeroes: WriteZeroesState::Unmap, + ..Default::default() + }; let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None).unwrap(); - let mut qcow2_driver = Qcow2Driver::new(image_file.file.try_clone()?, aio, qcow2_conf.clone())?; + let mut qcow2_driver = Qcow2Driver::new(image_file.file.clone(), aio, qcow2_conf.clone())?; qcow2_driver.load_metadata(qcow2_conf)?; match snapshot_operation { @@ -421,6 +738,9 @@ pub(crate) fn image_snapshot(args: Vec) -> Result<()> { Some(SnapshotOperation::Apply) => { qcow2_driver.apply_snapshot(snapshot_name)?; } + Some(SnapshotOperation::Rename) => { + qcow2_driver.rename_snapshot(old_snapshot_name, new_snapshot_name)?; + } None => return Ok(()), }; @@ -428,7 +748,7 @@ pub(crate) fn image_snapshot(args: Vec) -> Result<()> { } pub(crate) fn create_qcow2_driver_for_check( - file: File, + file: Arc, conf: BlockProperty, ) -> Result> { let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None).unwrap(); @@ -474,9 +794,11 @@ Stratovirt disk image utility Command syntax: create [-f fmt] [-o options] filename [size] +info filename check [-r [leaks | all]] [-no_print_error] [-f fmt] filename resize [-f fmt] filename [+]size -snapshot [-l | -a snapshot | -c snapshot | -d snapshot] filename +snapshot [-l | -a snapshot | -c snapshot | -d snapshot | -r old_snapshot_name new_snapshot_name] filename +convert [-f input_fmt | -O output_fmt | -S sparse_size ] input_filename output_filename Command parameters: 'filename' is a disk image filename @@ -497,6 +819,14 @@ Parameters to snapshot subcommand: '-c' creates a snapshot '-d' deletes a snapshot '-l' lists all snapshots in the given image + '-r' change the name of a snapshot + +Parameters to convert subcommand: + '-f fmt' is input image format. + '-O output_fmt' is output image format. + '-S sparse_size' is the consecutive number of bytes that must contain only zeroes to create a sparse image during conversion. Unit: sector(512 bytes). Default is 8. + 'input_filename' is the name of the input file using *input_fmt* image format. + 'output_filename' is the name of the output file using *output_fmt* image format. "#, ); } @@ -515,6 +845,7 @@ mod test { }; use util::aio::Iovec; + const K: u64 = 1024; const M: u64 = 1024 * 1024; const G: u64 = 1024 * 1024 * 1024; @@ -522,7 +853,7 @@ mod test { pub header: QcowHeader, pub cluster_bits: u64, pub path: String, - pub file: File, + pub file: Arc, } impl TestQcow2Image { @@ -533,11 +864,8 @@ mod test { "-f qcow2 -o cluster_size={} -o refcount_bits={} {} {}", cluster_size, refcount_bits, path, img_size, ); - let create_args: Vec = create_str - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let create_args: Vec = + create_str.split(' ').map(|str| str.to_string()).collect(); assert!(image_create(create_args).is_ok()); // Read header. @@ -545,32 +873,35 @@ mod test { let mut buf = vec![0; QcowHeader::len()]; assert!(file.read_at(&mut buf, 0).is_ok()); let header = QcowHeader::from_vec(&buf).unwrap(); - assert_eq!(header.cluster_bits as u64, cluster_bits); + assert_eq!(u64::from(header.cluster_bits), cluster_bits); Self { header, cluster_bits, path: path.to_string(), - file, + file: Arc::new(file), } } fn create_driver(&self) -> Qcow2Driver<()> { - let mut conf = BlockProperty::default(); - conf.format = DiskFormat::Qcow2; + let conf = BlockProperty { + format: DiskFormat::Qcow2, + ..Default::default() + }; let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None).unwrap(); - let mut qcow2_driver = - Qcow2Driver::new(self.file.try_clone().unwrap(), aio, conf.clone()).unwrap(); + let mut qcow2_driver = Qcow2Driver::new(self.file.clone(), aio, conf.clone()).unwrap(); qcow2_driver.load_metadata(conf).unwrap(); qcow2_driver } fn create_driver_for_check(&self) -> Qcow2Driver<()> { - let file = self.file.try_clone().unwrap(); - let mut conf = BlockProperty::default(); - conf.format = DiskFormat::Qcow2; - let qcow2_driver = create_qcow2_driver_for_check(file, conf).unwrap(); - qcow2_driver + let file = self.file.clone(); + let conf = BlockProperty { + format: DiskFormat::Qcow2, + ..Default::default() + }; + + create_qcow2_driver_for_check(file, conf).unwrap() } fn read_data(&self, guest_offset: u64, buf: &Vec) -> Result<()> { @@ -610,7 +941,7 @@ mod test { } fn file_len(&mut self) -> u64 { - let file_len = self.file.seek(SeekFrom::End(0)).unwrap(); + let file_len = self.file.as_ref().seek(SeekFrom::End(0)).unwrap(); file_len } @@ -636,25 +967,21 @@ mod test { impl TestRawImage { fn create(path: String, img_size: String) -> Self { let create_str = format!("-f raw {} {}", path, img_size); - let create_args: Vec = create_str - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let create_args: Vec = + create_str.split(' ').map(|str| str.to_string()).collect(); assert!(image_create(create_args).is_ok()); Self { path } } fn create_driver(&mut self) -> RawDriver<()> { - let mut conf = BlockProperty::default(); - conf.format = DiskFormat::Raw; + let conf = BlockProperty::default(); let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None).unwrap(); let file = open_file(&self.path, false, false).unwrap(); - let raw_driver = RawDriver::new(file, aio, conf); - raw_driver + + RawDriver::new(Arc::new(file), aio, conf) } } @@ -744,11 +1071,8 @@ mod test { for case in test_case { let create_str = case.0.replace("img_path", path); println!("Create options: {}", create_str); - let create_args: Vec = create_str - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let create_args: Vec = + create_str.split(' ').map(|str| str.to_string()).collect(); if case.1 { assert!(image_create(create_args).is_ok()); @@ -760,6 +1084,51 @@ mod test { assert!(remove_file(path).is_ok()); } + /// Test the function of query image. + /// TestStep: + /// 2. Query image info with different type. + /// Expect: + /// 1. Ihe invalid args will result in failure. + #[test] + fn test_args_parse_of_image_info() { + let path = "/tmp/test_args_parse_of_image_info.qcow2"; + let test_case = vec![ + ("img_path", true), + ("-f qcow2", false), + ("invalid_args", false), + ("img_path +1G", false), + ("-h", true), + ("--help", true), + ]; + + for case in test_case { + let cmd_str = case.0.replace("img_path", path); + let args: Vec = cmd_str.split(' ').map(|str| str.to_string()).collect(); + + // Query image info with type of qcow2. + assert!(image_create(vec![ + "-f".to_string(), + "qcow2".to_string(), + path.to_string(), + "+10M".to_string() + ]) + .is_ok()); + assert_eq!(image_info(args.clone()).is_ok(), case.1); + + // Query image info with type of raw. + assert!(image_create(vec![ + "-f".to_string(), + "raw".to_string(), + path.to_string(), + "+10M".to_string() + ]) + .is_ok()); + assert_eq!(image_info(args).is_ok(), case.1); + } + + assert!(remove_file(path).is_ok()); + } + /// Test the function of creating image. /// TestStep: /// 1. Create image with different cluster bits, image size and refcount bits. @@ -791,18 +1160,18 @@ mod test { let file_len = test_image.file_len(); let l1_size = test_image.header.l1_size; let reftable_clusters = test_image.header.refcount_table_clusters; - let reftable_size = reftable_clusters as u64 * cluster_size / ENTRY_SIZE; + let reftable_size = u64::from(reftable_clusters) * cluster_size / ENTRY_SIZE; let refblock_size = cluster_size / (refcount_bits / 8); assert_ne!(l1_size, 0); assert_ne!(reftable_clusters, 0); - assert!(l1_size as u64 * cluster_size * cluster_size / ENTRY_SIZE >= image_size); + assert!(u64::from(l1_size) * cluster_size * cluster_size / ENTRY_SIZE >= image_size); assert!(reftable_size * refblock_size * cluster_size >= file_len); - assert_eq!(test_image.header.cluster_bits as u64, cluster_bits); + assert_eq!(u64::from(test_image.header.cluster_bits), cluster_bits); assert_eq!(test_image.header.size, image_size); // Check refcount. - assert_eq!(test_image.check_image(false, 0), true); + assert!(test_image.check_image(false, 0)); } } @@ -822,20 +1191,13 @@ mod test { ("-f qcow2 path +1G", DiskFormat::Qcow2), ]; let check_str = format!("-f raw {}", path); - let check_args: Vec = check_str - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let check_args: Vec = check_str.split(' ').map(|str| str.to_string()).collect(); for case in test_case { let create_str = case.0.replace("path", path); println!("stratovirt-img {}", create_str); - let create_args: Vec = create_str - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let create_args: Vec = + create_str.split(' ').map(|str| str.to_string()).collect(); assert!(image_create(create_args).is_ok()); let image_file = ImageFile::create(path, false).unwrap(); @@ -875,18 +1237,13 @@ mod test { let create_string = create_str.replace("disk_fmt", case.0); let create_args: Vec = create_string .split(' ') - .into_iter() .map(|str| str.to_string()) .collect(); println!("Create args: {}", create_string); assert!(image_create(create_args.clone()).is_ok()); let check_str = case.1.replace("img_path", path); - let check_args: Vec = check_str - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let check_args: Vec = check_str.split(' ').map(|str| str.to_string()).collect(); println!("Check args: {}", check_str); if case.2 { @@ -1148,7 +1505,7 @@ mod test { let l2_idx = qcow2_driver.table.get_l2_table_index(guest_offset) as usize; let cache_entry = qcow2_driver.get_table_cluster(guest_offset).unwrap(); let mut l2_entry = cache_entry.borrow_mut().get_entry_map(l2_idx).unwrap(); - l2_entry = l2_entry & !QCOW2_OFFSET_COPIED; + l2_entry &= !QCOW2_OFFSET_COPIED; assert!(cache_entry .borrow_mut() .set_entry_map(l2_idx, l2_entry) @@ -1488,9 +1845,15 @@ mod test { let test_case = [ ("qcow2", "-c snapshot0 img_path", true), ("qcow2", "-f qcow2 -l img_path", true), + ("qcow2", "-r old_snapshot_name img_path", false), ("qcow2", "-d snapshot0 img_path", false), ("qcow2", "-a snapshot0 img_path", false), ("qcow2", "-c snapshot0 -l img_path", false), + ( + "raw", + "-r old_snapshot_name new_snapshot_name img_path", + false, + ), ("raw", "-f qcow2 -l img_path", false), ("raw", "-l img_path", false), ]; @@ -1499,18 +1862,14 @@ mod test { let create_string = create_str.replace("disk_fmt", case.0); let create_args: Vec = create_string .split(' ') - .into_iter() .map(|str| str.to_string()) .collect(); println!("Create args: {}", create_string); assert!(image_create(create_args).is_ok()); let snapshot_str = case.1.replace("img_path", path); - let snapshot_args: Vec = snapshot_str - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let snapshot_args: Vec = + snapshot_str.split(' ').map(|str| str.to_string()).collect(); let ret = image_snapshot(snapshot_args); if case.2 { assert!(ret.is_ok()); @@ -1542,10 +1901,10 @@ mod test { let quite = false; let fix = FIX_ERRORS | FIX_LEAKS; - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let buf = vec![1_u8; cluster_size as usize]; assert!(test_image.write_data(0, &buf).is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); // Create a snapshot named test_snapshot0 assert!(image_snapshot(vec![ @@ -1555,10 +1914,10 @@ mod test { ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let buf = vec![2_u8; cluster_size as usize]; assert!(test_image.write_data(0, &buf).is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); // Create as snapshot named test_snapshot1. assert!(image_snapshot(vec![ @@ -1568,10 +1927,10 @@ mod test { ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let buf = vec![3_u8; cluster_size as usize]; assert!(test_image.write_data(0, &buf).is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); // Apply snapshot named test_snapshot0. assert!(image_snapshot(vec![ @@ -1581,7 +1940,7 @@ mod test { ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let buf = vec![0_u8; cluster_size as usize]; assert!(test_image.read_data(0, &buf).is_ok()); for elem in buf { @@ -1589,7 +1948,7 @@ mod test { } let buf = vec![4_u8; cluster_size as usize]; assert!(test_image.write_data(0, &buf).is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); // Apply snapshot named test_snapshot1 assert!(image_snapshot(vec![ @@ -1598,7 +1957,127 @@ mod test { path.to_string() ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); + let buf = vec![0_u8; cluster_size as usize]; + assert!(test_image.read_data(0, &buf).is_ok()); + for elem in buf { + assert_eq!(elem, 2); + } + } + + /// Test the function of snapshot rename. + /// + /// TestStep: + /// 1. Create a new image. alloc a new cluster and write 1. + /// 2. Create snapshot named test_snapshot0, write 2 to the cluster. + /// 3. Create snapshot named test_snapshot1, write 3 to the cluster. + /// 4. Rename test_snapshot0 to test_snapshot0-new. + /// 5. Apply snapshot named test_snapshot0. + /// 6. Apply snapshot named test_snapshot0-new. + /// Expect: + /// 1. step 5 is failure and step 1/2/3/4/6 is success. + /// 2. The data read after snapshot apply is 2. + #[test] + fn test_snapshot_rename_basic() { + let path = "/tmp/test_snapshot_rename_basic.qcow2"; + let cluster_bits = 16; + let cluster_size = 1 << cluster_bits; + let refcount_bits = 16; + + // Create a new image. alloc a new cluster and write 1. + let test_image = TestQcow2Image::create(cluster_bits, refcount_bits, path, "+1G"); + let buf = vec![1_u8; cluster_size as usize]; + assert!(test_image.write_data(0, &buf).is_ok()); + + // Create snapshot named test_snapshot0, write 2 to the cluster. + assert!(image_snapshot(vec![ + "-c".to_string(), + "test_snapshot0".to_string(), + path.to_string() + ]) + .is_ok()); + let buf = vec![2_u8; cluster_size as usize]; + assert!(test_image.write_data(0, &buf).is_ok()); + + // Create snapshot named test_snapshot1, write 3 to the cluster. + assert!(image_snapshot(vec![ + "-c".to_string(), + "test_snapshot1".to_string(), + path.to_string() + ]) + .is_ok()); + let buf = vec![3_u8; cluster_size as usize]; + assert!(test_image.write_data(0, &buf).is_ok()); + + // Rename test_snapshot0 to test_snapshot1. + assert!(image_snapshot(vec![ + "-r".to_string(), + "test_snapshot0".to_string(), + "test_snapshot1".to_string(), + path.to_string() + ]) + .is_err()); + + // Rename test_snapshot0 to test_snapshot0-new. + assert!(image_snapshot(vec![ + "-r".to_string(), + "test_snapshot0".to_string(), + "test_snapshot0-new".to_string(), + path.to_string() + ]) + .is_ok()); + + // Apply snapshot named test_snapshot0. + assert!(image_snapshot(vec![ + "-a".to_string(), + "test_snapshot0".to_string(), + path.to_string() + ]) + .is_err()); + + // Apply snapshot named test_snapshot-new. + assert!(image_snapshot(vec![ + "-a".to_string(), + "test_snapshot0-new".to_string(), + path.to_string() + ]) + .is_ok()); + + // The data read after snapshot apply is 2. + let buf = vec![0_u8; cluster_size as usize]; + assert!(test_image.read_data(0, &buf).is_ok()); + for elem in buf { + assert_eq!(elem, 1); + } + + // Rename non-existed snapshot name. + assert!(image_snapshot(vec![ + "-r".to_string(), + "test_snapshot11111".to_string(), + "test_snapshot11111-new".to_string(), + path.to_string() + ]) + .is_err()); + + let buf = vec![4_u8; cluster_size as usize]; + assert!(test_image.write_data(0, &buf).is_ok()); + + // Rename test_snapshot1 to test_snapshot123. + assert!(image_snapshot(vec![ + "-r".to_string(), + "test_snapshot1".to_string(), + "test_snapshot123".to_string(), + path.to_string() + ]) + .is_ok()); + + // Apply snapshot named test_snapshot123 + assert!(image_snapshot(vec![ + "-a".to_string(), + "test_snapshot123".to_string(), + path.to_string() + ]) + .is_ok()); let buf = vec![0_u8; cluster_size as usize]; assert!(test_image.read_data(0, &buf).is_ok()); for elem in buf { @@ -1648,14 +2127,10 @@ mod test { // Apply resize operation. let cmd = cmd.replace("img_path", path); - let args: Vec = cmd - .split(' ') - .into_iter() - .map(|str| str.to_string()) - .collect(); + let args: Vec = cmd.split(' ').map(|str| str.to_string()).collect(); assert_eq!(image_resize(args).is_ok(), res); - assert!(remove_file(path.to_string()).is_ok()); + assert!(remove_file(path).is_ok()); } } @@ -1682,7 +2157,7 @@ mod test { let buf = vec![1; 10240]; let mut driver = test_image.create_driver(); while offset < 10 * M as usize { - assert!(image_write(&mut driver, offset as usize, &buf).is_ok()); + assert!(image_write(&mut driver, offset, &buf).is_ok()); offset += 10240; } drop(driver); @@ -1700,7 +2175,7 @@ mod test { let buf = vec![2; 10240]; let mut driver = test_image.create_driver(); while offset < (10 + 10) * M as usize { - assert!(image_write(&mut driver, offset as usize, &buf).is_ok()); + assert!(image_write(&mut driver, offset, &buf).is_ok()); offset += 10240; } @@ -1797,7 +2272,7 @@ mod test { let mut driver = test_image.create_driver(); let buf = vec![1; 1024 * 1024]; assert!(image_write(&mut driver, 0, &buf).is_ok()); - assert_eq!(driver.header.size, 1 * G); + assert_eq!(driver.header.size, G); drop(driver); let quite = false; let fix = FIX_ERRORS | FIX_LEAKS; @@ -1808,13 +2283,13 @@ mod test { path.to_string() ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let mut driver = test_image.create_driver(); let buf = vec![2; 1024 * 1024]; assert!(image_write(&mut driver, 0, &buf).is_ok()); - assert_eq!(driver.header.size, 1 * G); + assert_eq!(driver.header.size, G); drop(driver); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); assert!(image_resize(vec![ "-f".to_string(), @@ -1823,7 +2298,7 @@ mod test { "+20G".to_string(), ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let mut driver = test_image.create_driver(); let buf = vec![3; 1024 * 1024]; assert!(image_write(&mut driver, 20 * G as usize, &buf).is_ok()); @@ -1832,7 +2307,7 @@ mod test { assert!(vec_is_fill_with(&buf, 2)); assert_eq!(driver.header.size, 21 * G); drop(driver); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); assert!(image_snapshot(vec![ "-c".to_string(), @@ -1840,13 +2315,13 @@ mod test { path.to_string() ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let mut driver = test_image.create_driver(); let buf = vec![4; 1024 * 1024]; assert!(image_write(&mut driver, 20 * G as usize, &buf).is_ok()); assert_eq!(driver.header.size, 21 * G); drop(driver); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); assert!(image_resize(vec![ "-f".to_string(), @@ -1855,7 +2330,7 @@ mod test { "+10G".to_string(), ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let mut driver = test_image.create_driver(); let buf = vec![5; 1024 * 1024]; assert!(image_write(&mut driver, 30 * G as usize, &buf).is_ok()); @@ -1864,7 +2339,7 @@ mod test { assert!(vec_is_fill_with(&buf, 4)); assert_eq!(driver.header.size, 31 * G); drop(driver); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); assert!(image_snapshot(vec![ "-a".to_string(), @@ -1872,14 +2347,14 @@ mod test { path.to_string() ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let mut driver = test_image.create_driver(); let buf = vec![0; 1024 * 1024]; assert!(image_read(&mut driver, 0, &buf).is_ok()); assert!(vec_is_fill_with(&buf, 1)); - assert_eq!(driver.header.size, 1 * G); + assert_eq!(driver.header.size, G); drop(driver); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); assert!(image_snapshot(vec![ "-a".to_string(), @@ -1887,13 +2362,189 @@ mod test { path.to_string() ]) .is_ok()); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); let mut driver = test_image.create_driver(); let buf = vec![0; 1024 * 1024]; assert!(image_read(&mut driver, 20 * G as usize, &buf).is_ok()); assert!(vec_is_fill_with(&buf, 3)); assert_eq!(driver.header.size, 21 * G); drop(driver); - assert_eq!(test_image.check_image(quite, fix), true); + assert!(test_image.check_image(quite, fix)); + } + + /// Test image convert from qcow2 to raw. + /// + /// TestStep: + /// 1. Create a qcow2 image with size of 10G. + /// 2. Write data to this image in different position. + /// 3. Convert this qcow2 to raw. + /// 4. Read data in the same position. + #[test] + fn test_image_convert_from_qcow2_to_raw() { + let src_path = "/tmp/test_image_convert_src.qcow2"; + let dst_path = "/tmp/test_image_convert_dst.raw"; + let _ = remove_file(src_path.clone()); + let _ = remove_file(dst_path.clone()); + + let test_image = TestQcow2Image::create(16, 16, src_path, "10G"); + let mut src_driver = test_image.create_driver(); + + // Write 1M data(number 1) in offset 0. + let buf1 = vec![1_u8; 1 * M as usize]; + assert!(image_write(&mut src_driver, 0, &buf1).is_ok()); + // Write 1M data(number 2) in offset 5G. + let buf2 = vec![2_u8; 1 * M as usize]; + assert!(image_write(&mut src_driver, 5 * G as usize, &buf2).is_ok()); + // Write 1M data(number 3) in last 1M. + let buf3 = vec![3_u8; 1 * M as usize]; + assert!(image_write(&mut src_driver, 10 * G as usize - 1 * M as usize, &buf3).is_ok()); + // Write 1M data(number 0) in random offset (eg: 300M offset). + let buf4 = vec![0_u8; 1 * M as usize]; + assert!(image_write(&mut src_driver, 300 * M as usize, &buf4).is_ok()); + + drop(src_driver); + + assert!(image_convert(vec![ + "-f".to_string(), + "qcow2".to_string(), + "-O".to_string(), + "raw".to_string(), + src_path.to_string(), + dst_path.to_string() + ]) + .is_ok()); + + // Open the converted raw image. + let conf = BlockProperty::default(); + let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None).unwrap(); + let file = open_file(dst_path, true, false).unwrap(); + let mut dst_driver = RawDriver::new(Arc::new(file), aio, conf); + + // Read 1M data in offset 0. + let buf = vec![0; 1 * M as usize]; + assert!(image_read(&mut dst_driver, 0, &buf).is_ok()); + assert_eq!(buf, buf1); + // Read 1M data in offset 5G. + assert!(image_read(&mut dst_driver, 5 * G as usize, &buf).is_ok()); + assert_eq!(buf, buf2); + // Read 1M data in last 1M. + assert!(image_read(&mut dst_driver, 10 * G as usize - 1 * M as usize, &buf).is_ok()); + assert_eq!(buf, buf3); + + let mut img_info = ImageInfo::default(); + assert!(dst_driver.query_image(&mut img_info).is_ok()); + assert_eq!(img_info.virtual_size, 10 * G); + // 1M data(number 0) in offset 300M will not consume space. + assert_eq!(img_info.actual_size, 3 * M); + + // Clean. + assert!(remove_file(dst_path.clone()).is_ok()); + } + + /// Test image convert parameters parsing. + /// + /// TestStep: + /// 1. Create a qcow2 image with size of 1G. + /// 2. Write two continuous `4k 0 buffer + 4k 1 buffer` to this image. + /// 3. Test default parameters. + /// 4. Test existed destination file. + /// 5. Test min sparse. + #[test] + fn test_image_convert_parameters_parsing() { + let src_path = "/tmp/test_image_convert_paring_src.qcow2"; + let dst_path = "/tmp/test_image_convert_paring_dst.raw"; + let _ = remove_file(src_path.clone()); + let _ = remove_file(dst_path.clone()); + + let test_image = TestQcow2Image::create(16, 16, src_path, "1G"); + let mut src_driver = test_image.create_driver(); + + // Write 4k data(number 0) in offset 16k. + assert!(image_write( + &mut src_driver, + 16 * K as usize, + &vec![0_u8; 4 * K as usize] + ) + .is_ok()); + // Write 4k data(number 1) in offset 20k. + assert!(image_write( + &mut src_driver, + 20 * K as usize, + &vec![1_u8; 4 * K as usize] + ) + .is_ok()); + // Write 4k data(number 0) in offset 24k. + assert!(image_write( + &mut src_driver, + 24 * K as usize, + &vec![0_u8; 4 * K as usize] + ) + .is_ok()); + // Write 4k data(number 1) in offset 28k. + assert!(image_write( + &mut src_driver, + 28 * K as usize, + &vec![1_u8; 4 * K as usize] + ) + .is_ok()); + + drop(src_driver); + + // test1: The default value of the parameters. + // `stratovirt-img convert src_path dst_path` + // Eq: `stratovirt-img convert -f qcow2 -O raw -S 8 src_path dst_path` + assert!(image_convert(vec![src_path.to_string(), dst_path.to_string()]).is_ok()); + // Check output format. + let output_image_file = ImageFile::create(&dst_path, true).unwrap(); + let detect_output_fmt = output_image_file.detect_img_format().unwrap(); + assert_eq!(detect_output_fmt, DiskFormat::Raw); + drop(output_image_file); + // Check sparse size by querying output file size. + let conf = BlockProperty::default(); + let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None).unwrap(); + let file = open_file(dst_path, true, false).unwrap(); + let mut img_info = ImageInfo::default(); + let mut dst_driver = RawDriver::new(Arc::new(file), aio, conf); + assert!(dst_driver.query_image(&mut img_info).is_ok()); + // Raw file has allocated filled first part (sized 4k(host page size), see function `alloc_first_block`) for + // detecting the alignment length. + assert_eq!(img_info.actual_size, 4 * K + 8 * K); // 4K allocated filled first part + 8K data. + drop(dst_driver); + assert!(remove_file(dst_path.clone()).is_ok()); + + // test2: The destination file already exists. + let existed_raw = TestRawImage::create(dst_path.to_string(), "1G".to_string()); + assert!(image_convert(vec![ + "-f".to_string(), + "qcow2".to_string(), + "-O".to_string(), + "raw".to_string(), + src_path.to_string(), + dst_path.to_string() + ]) + .is_err()); + drop(existed_raw); + + // test3: Sparse test. + // stratovirt-img convert -f qcow2 -O raw -S 16 src_path dst_path + assert!(image_convert(vec![ + "-S".to_string(), + "16".to_string(), // 16 sectors.(8K) + src_path.to_string(), + dst_path.to_string() + ]) + .is_ok()); + // Query the image size. + let conf = BlockProperty::default(); + let aio = Aio::new(Arc::new(SyncAioInfo::complete_func), AioEngine::Off, None).unwrap(); + let file = open_file(dst_path, true, false).unwrap(); + let mut img_info = ImageInfo::default(); + let mut dst_driver = RawDriver::new(Arc::new(file), aio, conf); + assert!(dst_driver.query_image(&mut img_info).is_ok()); + // min_sparse is 16k. So, these continuous `4K 0 buffer + 4K 1 buffer` will be considered as all data buffer sized 8k. + // Will not create holes here. + assert_eq!(img_info.actual_size, 4 * K + 16 * K); + drop(dst_driver); + assert!(remove_file(dst_path.clone()).is_ok()); } } diff --git a/image/src/main.rs b/image/src/main.rs index 055ac211fb5d7df130ebb096e62244d7bddce4eb..341c0f0e862733352c265f89f1fc6f49c5d4c4b2 100644 --- a/image/src/main.rs +++ b/image/src/main.rs @@ -21,7 +21,8 @@ use std::{ use anyhow::{bail, Result}; use crate::img::{ - image_check, image_create, image_resize, image_snapshot, print_help, print_version, + image_check, image_convert, image_create, image_info, image_resize, image_snapshot, print_help, + print_version, }; const BINARY_NAME: &str = "stratovirt-img"; @@ -83,8 +84,10 @@ fn run(args: Vec) -> Result<()> { image_operation_matches!( opt.as_str(); ("create", image_create, cmd_args), + ("info", image_info, cmd_args), ("check", image_check, cmd_args), ("resize", image_resize, cmd_args), + ("convert", image_convert, cmd_args), ("snapshot", image_snapshot, cmd_args); ("-v" | "--version", print_version), ("-h" | "--help", print_help) diff --git a/license/Third_Party_Open_Source_Software_Notice.md b/license/Third_Party_Open_Source_Software_Notice.md index 71a4f7778231758128488e3db65c97ade14ccff6..721534eafa4b0ec554f308658a69a1bd67bcbe52 100644 --- a/license/Third_Party_Open_Source_Software_Notice.md +++ b/license/Third_Party_Open_Source_Software_Notice.md @@ -308,7 +308,7 @@ Copyright (c) David Tolnay License: MIT or Apache License Version 2.0 Please see above. -Software: vmm-sys-util 0.11.1 +Software: vmm-sys-util 0.12.1 Copyright notice: Copyright 2019 Intel Corporation. All Rights Reserved. Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. @@ -318,7 +318,7 @@ Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. Copyright 2018 The Chromium OS Authors. All rights reserved. Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -License: Apache License Version 2.0 or BSD 3-Clause +License: BSD 3-Clause Please see above. Software: libusb1-sys 0.6.4 @@ -356,7 +356,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -Software: kvm-ioctls 0.13.0 +Software: kvm-ioctls 0.16.0 Copyright notice: Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Portions Copyright 2017 The Chromium OS Authors. All rights reserved. @@ -365,7 +365,7 @@ Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. License: MIT or Apache License Version 2.0 Please see above. -Software: kvm-bindings 0.6.0 +Software: kvm-bindings 0.7.0 Copyright notice: Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. License: The APACHE 2.0 License diff --git a/machine/Cargo.toml b/machine/Cargo.toml index 2f677f8c1a9884124a3dd9ae06be4c3405531133..ecead47e132177d1497ce55cbaa995176b783779 100644 --- a/machine/Cargo.toml +++ b/machine/Cargo.toml @@ -10,7 +10,7 @@ description = "Emulation machines" log = "0.4" libc = "0.2" serde_json = "1.0" -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" thiserror = "1.0" anyhow = "1.0" acpi = { path = "../acpi" } @@ -24,7 +24,7 @@ migration = { path = "../migration" } migration_derive = { path = "../migration/migration_derive" } util = { path = "../util" } virtio = { path = "../virtio" } -vfio = { path = "../vfio" } +vfio = { path = "../vfio" , optional = true } block_backend = { path = "../block_backend" } ui = { path = "../ui" } trace = { path = "../trace" } @@ -32,7 +32,7 @@ clap = { version = "=4.1.4", default-features = false, features = ["std", "deriv [features] default = [] -boot_time = ["cpu/boot_time"] +boot_time = ["cpu/boot_time", "hypervisor/boot_time"] scream = ["devices/scream", "machine_manager/scream"] scream_alsa = ["scream", "devices/scream_alsa", "machine_manager/scream_alsa"] scream_pulseaudio = ["scream", "devices/scream_pulseaudio","machine_manager/scream_pulseaudio"] @@ -50,3 +50,14 @@ vnc_auth = ["vnc"] ohui_srv = ["windows_emu_pid", "ui/ohui_srv", "machine_manager/ohui_srv", "virtio/ohui_srv"] ramfb = ["devices/ramfb", "machine_manager/ramfb"] virtio_gpu = ["virtio/virtio_gpu", "machine_manager/virtio_gpu"] +vfio_device = ["vfio", "hypervisor/vfio_device"] +usb_uas = ["devices/usb_uas"] +virtio_rng = ["virtio/virtio_rng"] +virtio_scsi = ["virtio/virtio_scsi"] +vhost_vsock = ["virtio/vhost_vsock"] +vhostuser_block = ["virtio/vhostuser_block"] +vhostuser_net = ["virtio/vhostuser_net"] +vhost_net = ["virtio/vhost_net"] +trace_to_logger = ["devices/trace_to_logger"] +trace_to_ftrace = ["devices/trace_to_ftrace"] +trace_to_hitrace = ["devices/trace_to_hitrace"] diff --git a/machine/src/aarch64/fdt.rs b/machine/src/aarch64/fdt.rs index 76815b1700c9f40f098c14fee8e4887e7c34b601..f015ad1ff8280a45ccce23119fc2c9e1183d869a 100644 --- a/machine/src/aarch64/fdt.rs +++ b/machine/src/aarch64/fdt.rs @@ -14,7 +14,8 @@ use anyhow::Result; use crate::MachineBase; use cpu::PMU_INTR; -use devices::sysbus::{SysBusDevType, SysRes}; +use devices::sysbus::{to_sysbusdevops, SysBusDevType, SysRes}; +use devices::{Bus, SYS_BUS_DEVICE}; use util::device_tree::{self, FdtBuilder}; /// Function that helps to generate arm pmu in device-tree. @@ -265,23 +266,24 @@ impl CompileFDTHelper for MachineBase { fdt.set_property_string("method", "hvc")?; fdt.end_node(psci_node_dep)?; - for dev in self.sysbus.devices.iter() { - let locked_dev = dev.lock().unwrap(); - match locked_dev.sysbusdev_base().dev_type { + let devices = self.sysbus.lock().unwrap().child_devices(); + for dev in devices.values() { + SYS_BUS_DEVICE!(dev, locked_dev, sysbusdev); + match sysbusdev.sysbusdev_base().dev_type { SysBusDevType::PL011 => { - generate_serial_device_node(fdt, &locked_dev.sysbusdev_base().res)? + generate_serial_device_node(fdt, &sysbusdev.sysbusdev_base().res)? } SysBusDevType::Rtc => { - generate_rtc_device_node(fdt, &locked_dev.sysbusdev_base().res)? + generate_rtc_device_node(fdt, &sysbusdev.sysbusdev_base().res)? } SysBusDevType::VirtioMmio => { - generate_virtio_devices_node(fdt, &locked_dev.sysbusdev_base().res)? + generate_virtio_devices_node(fdt, &sysbusdev.sysbusdev_base().res)? } SysBusDevType::FwCfg => { - generate_fwcfg_device_node(fdt, &locked_dev.sysbusdev_base().res)?; + generate_fwcfg_device_node(fdt, &sysbusdev.sysbusdev_base().res)?; } SysBusDevType::Flash => { - generate_flash_device_node(fdt, &locked_dev.sysbusdev_base().res)?; + generate_flash_device_node(fdt, &sysbusdev.sysbusdev_base().res)?; } _ => (), } @@ -309,7 +311,7 @@ impl CompileFDTHelper for MachineBase { let dist: u32 = if id as u32 == *i { 10 } else if let Some(distance) = distances.get(i) { - *distance as u32 + u32::from(*distance) } else { 20 }; diff --git a/machine/src/aarch64/micro.rs b/machine/src/aarch64/micro.rs index d7e1b1d4c0dd936e66f6fa8351d10fcea1323f02..b79f13b3e99f97c652fef1367fceaf2d47ff3dde 100644 --- a/machine/src/aarch64/micro.rs +++ b/machine/src/aarch64/micro.rs @@ -15,18 +15,18 @@ use std::sync::{Arc, Mutex}; use anyhow::{bail, Context, Result}; use crate::{micro_common::syscall::syscall_whitelist, MachineBase, MachineError}; -use crate::{LightMachine, MachineOps}; -use address_space::{AddressSpace, GuestAddress, Region}; +use crate::{register_shutdown_event, LightMachine, MachineOps}; +use address_space::{AddressAttr, AddressSpace, GuestAddress, Region}; use cpu::CPUTopology; -use devices::{legacy::PL031, ICGICConfig, ICGICv2Config, ICGICv3Config, GIC_IRQ_MAX}; +use devices::legacy::{PL011, PL031}; +use devices::{Device, ICGICConfig, ICGICv2Config, ICGICv3Config, GIC_IRQ_MAX}; use hypervisor::kvm::aarch64::*; -use machine_manager::config::{SerialConfig, VmConfig}; +use machine_manager::config::{MigrateMode, Param, SerialConfig, VmConfig}; use migration::{MigrationManager, MigrationStatus}; -use util::{ - device_tree::{self, CompileFDT, FdtBuilder}, - seccomp::{BpfRule, SeccompCmpOpt}, -}; -use virtio::VirtioMmioDevice; +use util::device_tree::{self, CompileFDT, FdtBuilder}; +use util::gen_base_func; +use util::seccomp::{BpfRule, SeccompCmpOpt}; +use virtio::{VirtioDevice, VirtioMmioDevice}; #[repr(usize)] pub enum LayoutEntryType { @@ -54,13 +54,7 @@ pub const MEM_LAYOUT: &[(u64, u64)] = &[ ]; impl MachineOps for LightMachine { - fn machine_base(&self) -> &MachineBase { - &self.base - } - - fn machine_base_mut(&mut self) -> &mut MachineBase { - &mut self.base - } + gen_base_func!(machine_base, machine_base_mut, MachineBase, base); fn init_machine_ram(&self, sys_mem: &Arc, mem_size: u64) -> Result<()> { let vm_ram = self.get_vm_ram(); @@ -107,41 +101,40 @@ impl MachineOps for LightMachine { self.base.irq_chip.as_ref().unwrap().realize()?; let irq_manager = locked_hypervisor.create_irq_manager()?; - self.base.sysbus.irq_manager = irq_manager.line_irq_manager; + self.base.sysbus.lock().unwrap().irq_manager = irq_manager.line_irq_manager; Ok(()) } fn add_rtc_device(&mut self) -> Result<()> { - PL031::realize( - PL031::default(), - &mut self.base.sysbus, + let pl031 = PL031::new( + &self.base.sysbus, MEM_LAYOUT[LayoutEntryType::Rtc as usize].0, MEM_LAYOUT[LayoutEntryType::Rtc as usize].1, - ) - .with_context(|| "Failed to realize pl031.") + )?; + pl031 + .realize() + .with_context(|| "Failed to realize pl031.")?; + Ok(()) } fn add_serial_device(&mut self, config: &SerialConfig) -> Result<()> { - use devices::legacy::PL011; - let region_base: u64 = MEM_LAYOUT[LayoutEntryType::Uart as usize].0; let region_size: u64 = MEM_LAYOUT[LayoutEntryType::Uart as usize].1; - - let pl011 = PL011::new(config.clone()).with_context(|| "Failed to create PL011")?; - pl011 - .realize( - &mut self.base.sysbus, - region_base, - region_size, - &self.base.boot_source, - ) - .with_context(|| "Failed to realize PL011") + let pl011 = PL011::new(config.clone(), &self.base.sysbus, region_base, region_size) + .with_context(|| "Failed to create PL011")?; + pl011.realize().with_context(|| "Failed to realize PL011")?; + let mut bs = self.base.boot_source.lock().unwrap(); + bs.kernel_cmdline.push(Param { + param_type: "earlycon".to_string(), + value: format!("pl011,mmio,0x{:08x}", region_base), + }); + Ok(()) } fn realize(vm: &Arc>, vm_config: &mut VmConfig) -> Result<()> { let mut locked_vm = vm.lock().unwrap(); - trace::sysbus(&locked_vm.base.sysbus); + trace::sysbus(&locked_vm.base.sysbus.lock().unwrap()); trace::vm_state(&locked_vm.base.vm_state); let topology = CPUTopology::new().set_topology(( @@ -160,8 +153,12 @@ impl MachineOps for LightMachine { vm_config.machine_config.nr_cpus, )?; - let boot_config = - locked_vm.load_boot_source(None, MEM_LAYOUT[LayoutEntryType::Mem as usize].0)?; + let migrate_info = locked_vm.get_migrate_info(); + let boot_config = if migrate_info.0 == MigrateMode::Unknown { + Some(locked_vm.load_boot_source(None, MEM_LAYOUT[LayoutEntryType::Mem as usize].0)?) + } else { + None + }; let cpu_config = locked_vm.load_cpu_features(vm_config)?; let hypervisor = locked_vm.base.hypervisor.clone(); @@ -186,20 +183,25 @@ impl MachineOps for LightMachine { locked_vm.add_devices(vm_config)?; trace::replaceable_info(&locked_vm.replaceable_info); - let mut fdt_helper = FdtBuilder::new(); - locked_vm - .generate_fdt_node(&mut fdt_helper) - .with_context(|| MachineError::GenFdtErr)?; - let fdt_vec = fdt_helper.finish()?; - locked_vm - .base - .sys_mem - .write( - &mut fdt_vec.as_slice(), - GuestAddress(boot_config.fdt_addr), - fdt_vec.len() as u64, - ) - .with_context(|| MachineError::WrtFdtErr(boot_config.fdt_addr, fdt_vec.len()))?; + if let Some(boot_cfg) = boot_config { + let mut fdt_helper = FdtBuilder::new(); + locked_vm + .generate_fdt_node(&mut fdt_helper) + .with_context(|| MachineError::GenFdtErr)?; + let fdt_vec = fdt_helper.finish()?; + locked_vm + .base + .sys_mem + .write( + &mut fdt_vec.as_slice(), + GuestAddress(boot_cfg.fdt_addr), + fdt_vec.len() as u64, + AddressAttr::Ram, + ) + .with_context(|| MachineError::WrtFdtErr(boot_cfg.fdt_addr, fdt_vec.len()))?; + } + register_shutdown_event(locked_vm.shutdown_req.clone(), vm.clone()) + .with_context(|| "Failed to register shutdown event")?; MigrationManager::register_vm_instance(vm.clone()); MigrationManager::register_migration_instance(locked_vm.base.migration_hypervisor.clone()); @@ -218,11 +220,12 @@ impl MachineOps for LightMachine { self.add_virtio_mmio_block(vm_config, cfg_args) } - fn realize_virtio_mmio_device( + fn add_virtio_mmio_device( &mut self, - dev: VirtioMmioDevice, + name: String, + device: Arc>, ) -> Result>> { - self.realize_virtio_mmio_device(dev) + self.add_virtio_mmio_device(name, device) } fn syscall_whitelist(&self) -> Vec { @@ -235,6 +238,7 @@ pub(crate) fn arch_ioctl_allow_list(bpf_rule: BpfRule) -> BpfRule { .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_ONE_REG() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_DEVICE_ATTR() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_REG_LIST() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_ONE_REG() as u32) } pub(crate) fn arch_syscall_whitelist() -> Vec { @@ -257,14 +261,33 @@ trait CompileFDTHelper { impl CompileFDTHelper for LightMachine { fn generate_memory_node(&self, fdt: &mut FdtBuilder) -> Result<()> { - let mem_base = MEM_LAYOUT[LayoutEntryType::Mem as usize].0; - let mem_size = self.base.sys_mem.memory_end_address().raw_value() - - MEM_LAYOUT[LayoutEntryType::Mem as usize].0; - let node = "memory"; - let memory_node_dep = fdt.begin_node(node)?; - fdt.set_property_string("device_type", "memory")?; - fdt.set_property_array_u64("reg", &[mem_base, mem_size])?; - fdt.end_node(memory_node_dep) + if self.base.numa_nodes.is_none() { + let mem_base = MEM_LAYOUT[LayoutEntryType::Mem as usize].0; + let mem_size = self.base.sys_mem.memory_end_address().raw_value() + - MEM_LAYOUT[LayoutEntryType::Mem as usize].0; + let node = "memory"; + let memory_node_dep = fdt.begin_node(node)?; + fdt.set_property_string("device_type", "memory")?; + fdt.set_property_array_u64("reg", &[mem_base, mem_size])?; + fdt.end_node(memory_node_dep)?; + + return Ok(()); + } + + // Set NUMA node information. + let mut mem_base = MEM_LAYOUT[LayoutEntryType::Mem as usize].0; + for (id, node) in self.base.numa_nodes.as_ref().unwrap().iter().enumerate() { + let mem_size = node.1.size; + let node = format!("memory@{:x}", mem_base); + let memory_node_dep = fdt.begin_node(&node)?; + fdt.set_property_string("device_type", "memory")?; + fdt.set_property_array_u64("reg", &[mem_base, mem_size])?; + fdt.set_property_u32("numa-node-id", id as u32)?; + fdt.end_node(memory_node_dep)?; + mem_base += mem_size; + } + + Ok(()) } fn generate_chosen_node(&self, fdt: &mut FdtBuilder) -> Result<()> { @@ -282,7 +305,11 @@ impl CompileFDTHelper for LightMachine { match &boot_source.initrd { Some(initrd) => { fdt.set_property_u64("linux,initrd-start", initrd.initrd_addr)?; - fdt.set_property_u64("linux,initrd-end", initrd.initrd_addr + initrd.initrd_size)?; + let initrd_end = initrd + .initrd_addr + .checked_add(initrd.initrd_size) + .with_context(|| "initrd end overflow")?; + fdt.set_property_u64("linux,initrd-end", initrd_end)?; } None => {} } diff --git a/machine/src/aarch64/mod.rs b/machine/src/aarch64/mod.rs index ee107ad4991c54a546588f1a971063184b0780b1..3ffe949fe294ef86d09c4857269f5187e900a250 100644 --- a/machine/src/aarch64/mod.rs +++ b/machine/src/aarch64/mod.rs @@ -11,7 +11,7 @@ // See the Mulan PSL v2 for more details. pub mod micro; +pub mod pci_host_root; pub mod standard; mod fdt; -mod pci_host_root; diff --git a/machine/src/aarch64/pci_host_root.rs b/machine/src/aarch64/pci_host_root.rs index aec5402338b4cc2a90f7d859141c2fbe89bc5f52..11b9caa76f946fc67684e96b22489c01fc2dd424 100644 --- a/machine/src/aarch64/pci_host_root.rs +++ b/machine/src/aarch64/pci_host_root.rs @@ -10,18 +10,17 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::sync::{Arc, Mutex, Weak}; +use std::sync::{atomic::AtomicBool, Arc, Mutex, Weak}; use anyhow::Result; -use devices::pci::{ - config::{ - PciConfig, CLASS_CODE_HOST_BRIDGE, DEVICE_ID, PCI_CONFIG_SPACE_SIZE, PCI_VENDOR_ID_REDHAT, - REVISION_ID, SUB_CLASS_CODE, VENDOR_ID, - }, - le_write_u16, PciBus, PciDevBase, PciDevOps, +use devices::pci::config::{ + PciConfig, CLASS_CODE_HOST_BRIDGE, DEVICE_ID, PCI_CONFIG_SPACE_SIZE, PCI_VENDOR_ID_REDHAT, + REVISION_ID, SUB_CLASS_CODE, VENDOR_ID, }; -use devices::{Device, DeviceBase}; +use devices::pci::{le_write_u16, PciDevBase, PciDevOps}; +use devices::{Bus, Device, DeviceBase}; +use util::gen_base_func; const DEVICE_ID_PCIE_HOST: u16 = 0x0008; @@ -31,38 +30,22 @@ pub struct PciHostRoot { } impl PciHostRoot { - pub fn new(parent_bus: Weak>) -> Self { + pub fn new(parent_bus: Weak>) -> Self { Self { base: PciDevBase { - base: DeviceBase::new("PCI Host Root".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 0), - parent_bus, + base: DeviceBase::new("PCI Host Root".to_string(), false, Some(parent_bus)), + config: PciConfig::new(0, PCI_CONFIG_SPACE_SIZE, 0), devfn: 0, + bme: Arc::new(AtomicBool::new(false)), }, } } } impl Device for PciHostRoot { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} - -impl PciDevOps for PciHostRoot { - fn pci_base(&self) -> &PciDevBase { - &self.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base - } - - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { self.init_write_mask(false)?; self.init_write_clear_mask(false)?; @@ -83,14 +66,17 @@ impl PciDevOps for PciHostRoot { )?; le_write_u16(&mut self.base.config.config, REVISION_ID, 0)?; - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - parent_bus - .lock() - .unwrap() - .devices - .insert(0, Arc::new(Mutex::new(self))); - Ok(()) + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + let mut locked_bus = parent_bus.lock().unwrap(); + let dev = Arc::new(Mutex::new(self)); + locked_bus.attach_child(0, dev.clone())?; + + Ok(dev) } +} + +impl PciDevOps for PciHostRoot { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); fn write_config(&mut self, offset: usize, data: &[u8]) { self.base.config.write(offset, data, 0, None); diff --git a/machine/src/aarch64/standard.rs b/machine/src/aarch64/standard.rs index 8d38c4440d571b2862a137703b52fd9b979b7b19..8612b3f12cde56f9092047755a34b640f07dedb2 100644 --- a/machine/src/aarch64/standard.rs +++ b/machine/src/aarch64/standard.rs @@ -13,17 +13,18 @@ pub use crate::error::MachineError; use std::mem::size_of; -use std::ops::Deref; #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] use std::sync::RwLock; use std::sync::{Arc, Mutex}; use anyhow::{anyhow, bail, Context, Result}; -use log::{error, info, warn}; -use vmm_sys_util::eventfd::EventFd; +#[cfg(feature = "ramfb")] +use clap::Parser; +use super::pci_host_root::PciHostRoot; +use crate::standard_common::syscall::syscall_whitelist; use crate::standard_common::{AcpiBuilder, StdMachineOps}; -use crate::{MachineBase, MachineOps}; +use crate::{register_shutdown_event, MachineBase, MachineOps, StdMachine}; use acpi::{ processor_append_priv_res, AcpiGicCpu, AcpiGicDistributor, AcpiGicIts, AcpiGicRedistributor, AcpiSratGiccAffinity, AcpiSratMemoryAffinity, AcpiTable, AmlBuilder, AmlDevice, AmlInteger, @@ -38,37 +39,32 @@ use acpi::{ }; #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] use address_space::FileBackend; -use address_space::{AddressSpace, GuestAddress, Region}; +use address_space::{AddressAttr, AddressSpace, GuestAddress, Region}; use cpu::{CPUInterface, CPUTopology, CpuLifecycleState, PMU_INTR, PPI_BASE}; - -use super::pci_host_root::PciHostRoot; -use crate::standard_common::syscall::syscall_whitelist; use devices::acpi::ged::{acpi_dsdt_add_power_button, Ged, GedEvent}; use devices::acpi::power::PowerDev; -#[cfg(feature = "ramfb")] -use devices::legacy::Ramfb; use devices::legacy::{ FwCfgEntryType, FwCfgMem, FwCfgOps, LegacyError as DevErrorKind, PFlash, PL011, PL031, }; -use devices::pci::{PciDevOps, PciHost, PciIntxState}; -use devices::sysbus::SysBusDevType; -use devices::{ICGICConfig, ICGICv3Config, GIC_IRQ_MAX}; +#[cfg(feature = "ramfb")] +use devices::legacy::{Ramfb, RamfbConfig}; +use devices::pci::{PciBus, PciHost, PciIntxState}; +use devices::sysbus::{to_sysbusdevops, SysBusDevType}; +use devices::{ + convert_bus_mut, Device, ICGICConfig, ICGICv3Config, GIC_IRQ_MAX, MUT_PCI_BUS, SYS_BUS_DEVICE, +}; use hypervisor::kvm::aarch64::*; use hypervisor::kvm::*; #[cfg(feature = "ramfb")] -use machine_manager::config::parse_ramfb; -use machine_manager::config::ShutdownAction; +use machine_manager::config::str_slip_to_clap; #[cfg(feature = "gtk")] use machine_manager::config::UiContext; use machine_manager::config::{ - parse_incoming_uri, BootIndexInfo, MigrateMode, NumaNode, PFlashConfig, SerialConfig, VmConfig, + BootIndexInfo, DriveConfig, MigrateMode, NumaNode, Param, SerialConfig, VmConfig, }; use machine_manager::event; -use machine_manager::machine::{ - MachineExternalInterface, MachineInterface, MachineLifecycle, MachineTestInterface, - MigrateInterface, VmState, -}; -use machine_manager::qmp::{qmp_channel::QmpChannel, qmp_response::Response, qmp_schema}; +use machine_manager::machine::{MachineLifecycle, VmState}; +use machine_manager::qmp::{qmp_channel::QmpChannel, qmp_schema}; use migration::{MigrationManager, MigrationStatus}; #[cfg(feature = "gtk")] use ui::gtk::gtk_display_init; @@ -78,9 +74,9 @@ use ui::ohui_srv::{ohui_init, OhUiServer}; use ui::vnc::vnc_init; use util::byte_code::ByteCode; use util::device_tree::{self, CompileFDT, FdtBuilder}; -use util::loop_context::EventLoopManager; +use util::gen_base_func; +use util::loop_context::create_new_eventfd; use util::seccomp::{BpfRule, SeccompCmpOpt}; -use util::set_termi_canon_mode; /// The type of memory layout entry on aarch64 pub enum LayoutEntryType { @@ -134,31 +130,6 @@ const IRQ_MAP: &[(i32, i32)] = &[ (16, 19), // Pcie ]; -/// Standard machine structure. -pub struct StdMachine { - /// Machine base members. - base: MachineBase, - /// PCI/PCIe host bridge. - pci_host: Arc>, - /// VM power button, handle VM `Shutdown` event. - pub power_button: Arc, - /// Shutdown request, handle VM `shutdown` event. - shutdown_req: Arc, - /// Reset request, handle VM `Reset` event. - reset_req: Arc, - /// Pause request, handle VM `Pause` event. - pause_req: Arc, - /// Resume request, handle VM `Resume` event. - resume_req: Arc, - /// Device Tree Blob. - dtb_vec: Vec, - /// List contains the boot order of boot devices. - boot_order_list: Arc>>, - /// OHUI server - #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] - ohui_server: Option>, -} - impl StdMachine { pub fn new(vm_config: &VmConfig) -> Result { let free_irqs = ( @@ -183,23 +154,23 @@ impl StdMachine { IRQ_MAP[IrqEntryType::Pcie as usize].0, ))), power_button: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("power_button".to_string()))?, ), shutdown_req: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("shutdown_req".to_string()))?, ), reset_req: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("reset_req".to_string()))?, ), pause_req: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("pause_req".to_string()))?, ), resume_req: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("resume_req".to_string()))?, ), dtb_vec: Vec::new(), @@ -230,6 +201,7 @@ impl StdMachine { &mut locked_vm.dtb_vec.as_slice(), GuestAddress(fdt_addr), locked_vm.dtb_vec.len() as u64, + AddressAttr::Ram, ) .with_context(|| "Fail to write dtb into sysmem")?; @@ -255,25 +227,6 @@ impl StdMachine { Ok(()) } - pub fn handle_destroy_request(vm: &Arc>) -> Result<()> { - let locked_vm = vm.lock().unwrap(); - let vmstate = { - let state = locked_vm.base.vm_state.deref().0.lock().unwrap(); - *state - }; - - if !locked_vm.notify_lifecycle(vmstate, VmState::Shutdown) { - warn!("Failed to destroy guest, destroy continue."); - if locked_vm.shutdown_req.write(1).is_err() { - error!("Failed to send shutdown request.") - } - } - - info!("vm destroy"); - - Ok(()) - } - fn build_pptt_cores(&self, pptt: &mut AcpiTable, cluster_offset: u32, uid: &mut u32) { for core in 0..self.base.cpu_topo.cores { let mut priv_resources = vec![0; 3]; @@ -290,7 +243,7 @@ impl StdMachine { if self.base.cpu_topo.threads > 1 { let core_offset = pptt.table_len(); let core_hierarchy_node = - ProcessorHierarchyNode::new(0x0, cluster_offset, core as u32, 3); + ProcessorHierarchyNode::new(0x0, cluster_offset, u32::from(core), 3); pptt.append_child(&core_hierarchy_node.aml_bytes()); processor_append_priv_res(pptt, priv_resources); for _thread in 0..self.base.cpu_topo.threads { @@ -312,7 +265,7 @@ impl StdMachine { for cluster in 0..self.base.cpu_topo.clusters { let cluster_offset = pptt.table_len(); let cluster_hierarchy_node = - ProcessorHierarchyNode::new(0x0, socket_offset, cluster as u32, 0); + ProcessorHierarchyNode::new(0x0, socket_offset, u32::from(cluster), 0); pptt.append_child(&cluster_hierarchy_node.aml_bytes()); self.build_pptt_cores(pptt, cluster_offset as u32, uid); } @@ -325,7 +278,7 @@ impl StdMachine { pptt.append_child(&cache_hierarchy_node.aml_bytes()); let socket_offset = pptt.table_len(); - let socket_hierarchy_node = ProcessorHierarchyNode::new(0x1, 0, socket as u32, 1); + let socket_hierarchy_node = ProcessorHierarchyNode::new(0x1, 0, u32::from(socket), 1); pptt.append_child(&socket_hierarchy_node.aml_bytes()); processor_append_priv_res(pptt, priv_resources); @@ -337,8 +290,9 @@ impl StdMachine { if let Some(vcpu) = self.get_cpus().get(vcpu_index) { let (cpu_state, _) = vcpu.state(); let cpu_state = *cpu_state.lock().unwrap(); - if cpu_state != CpuLifecycleState::Paused { - self.pause(); + if cpu_state != CpuLifecycleState::Paused && !self.pause() { + self.notify_lifecycle(VmState::Paused, VmState::Running); + return None; } let value = match vcpu.hypervisor_cpu.get_one_reg(addr) { @@ -357,10 +311,13 @@ impl StdMachine { #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] fn add_ohui_server(&mut self, vm_config: &VmConfig) -> Result<()> { if let Some(dpy) = vm_config.display.as_ref() { - if !dpy.ohui_config.ohui { + if dpy.display_type != "ohui" { return Ok(()); } - self.ohui_server = Some(Arc::new(OhUiServer::new(dpy.get_ui_path())?)); + self.ohui_server = Some(Arc::new(OhUiServer::new( + dpy.get_ui_path(), + dpy.get_sock_path(), + )?)); } Ok(()) } @@ -368,7 +325,7 @@ impl StdMachine { impl StdMachineOps for StdMachine { fn init_pci_host(&self) -> Result<()> { - let root_bus = Arc::downgrade(&self.pci_host.lock().unwrap().root_bus); + let root_bus = Arc::downgrade(&self.pci_host.lock().unwrap().child_bus().unwrap()); let mmconfig_region_ops = PciHost::build_mmconfig_ops(self.pci_host.clone()); let mmconfig_region = Region::init_io_region( MEM_LAYOUT[LayoutEntryType::HighPcieEcam as usize].1, @@ -387,7 +344,8 @@ impl StdMachineOps for StdMachine { let pcihost_root = PciHostRoot::new(root_bus); pcihost_root .realize() - .with_context(|| "Failed to realize pcihost root device.") + .with_context(|| "Failed to realize pcihost root device.")?; + Ok(()) } fn add_fwcfg_device(&mut self, nr_cpus: u8) -> Result>>> { @@ -395,7 +353,12 @@ impl StdMachineOps for StdMachine { return Ok(None); } - let mut fwcfg = FwCfgMem::new(self.base.sys_mem.clone()); + let mut fwcfg = FwCfgMem::new( + self.base.sys_mem.clone(), + &self.base.sysbus, + MEM_LAYOUT[LayoutEntryType::FwCfg as usize].0, + MEM_LAYOUT[LayoutEntryType::FwCfg as usize].1, + )?; fwcfg .add_data_entry(FwCfgEntryType::NbCpus, nr_cpus.as_bytes().to_vec()) .with_context(|| DevErrorKind::AddEntryErr("NbCpus".to_string()))?; @@ -427,13 +390,9 @@ impl StdMachineOps for StdMachine { .add_file_entry("bios-geometry", bios_geometry) .with_context(|| DevErrorKind::AddEntryErr("bios-geometry".to_string()))?; - let fwcfg_dev = FwCfgMem::realize( - fwcfg, - &mut self.base.sysbus, - MEM_LAYOUT[LayoutEntryType::FwCfg as usize].0, - MEM_LAYOUT[LayoutEntryType::FwCfg as usize].1, - ) - .with_context(|| "Failed to realize fwcfg device")?; + let fwcfg_dev = fwcfg + .realize() + .with_context(|| "Failed to realize fwcfg device")?; self.base.fwcfg_dev = Some(fwcfg_dev.clone()); Ok(Some(fwcfg_dev)) @@ -441,13 +400,7 @@ impl StdMachineOps for StdMachine { } impl MachineOps for StdMachine { - fn machine_base(&self) -> &MachineBase { - &self.base - } - - fn machine_base_mut(&mut self) -> &mut MachineBase { - &mut self.base - } + gen_base_func!(machine_base, machine_base_mut, MachineBase, base); fn init_machine_ram(&self, sys_mem: &Arc, mem_size: u64) -> Result<()> { let vm_ram = self.get_vm_ram(); @@ -487,57 +440,56 @@ impl MachineOps for StdMachine { self.base.irq_chip = Some(locked_hypervisor.create_interrupt_controller(&intc_conf)?); self.base.irq_chip.as_ref().unwrap().realize()?; - let root_bus = &self.pci_host.lock().unwrap().root_bus; + let root_bus = &self.pci_host.lock().unwrap().child_bus().unwrap(); + MUT_PCI_BUS!(root_bus, locked_bus, root_pci_bus); let irq_manager = locked_hypervisor.create_irq_manager()?; - root_bus.lock().unwrap().msi_irq_manager = irq_manager.msi_irq_manager; + root_pci_bus.msi_irq_manager = irq_manager.msi_irq_manager; let line_irq_manager = irq_manager.line_irq_manager; if let Some(line_irq_manager) = line_irq_manager.clone() { let irq_state = Some(Arc::new(Mutex::new(PciIntxState::new( IRQ_MAP[IrqEntryType::Pcie as usize].0 as u32, line_irq_manager.clone(), )))); - root_bus.lock().unwrap().intx_state = irq_state; + root_pci_bus.intx_state = irq_state; } else { return Err(anyhow!( "Failed to create intx state: legacy irq manager is none." )); } - self.base.sysbus.irq_manager = line_irq_manager; + self.base.sysbus.lock().unwrap().irq_manager = line_irq_manager; Ok(()) } fn add_rtc_device(&mut self) -> Result<()> { - let rtc = PL031::default(); - PL031::realize( - rtc, - &mut self.base.sysbus, + let rtc = PL031::new( + &self.base.sysbus, MEM_LAYOUT[LayoutEntryType::Rtc as usize].0, MEM_LAYOUT[LayoutEntryType::Rtc as usize].1, - ) - .with_context(|| "Failed to realize PL031") + )?; + rtc.realize().with_context(|| "Failed to realize PL031")?; + Ok(()) } fn add_ged_device(&mut self) -> Result<()> { let battery_present = self.base.vm_config.lock().unwrap().machine_config.battery; - let ged = Ged::default(); - let ged_dev = ged - .realize( - &mut self.base.sysbus, - GedEvent::new(self.power_button.clone()), - battery_present, - MEM_LAYOUT[LayoutEntryType::Ged as usize].0, - MEM_LAYOUT[LayoutEntryType::Ged as usize].1, - ) - .with_context(|| "Failed to realize Ged")?; + let ged = Ged::new( + battery_present, + &self.base.sysbus, + MEM_LAYOUT[LayoutEntryType::Ged as usize].0, + MEM_LAYOUT[LayoutEntryType::Ged as usize].1, + GedEvent::new(self.power_button.clone()), + )?; + let ged_dev = ged.realize().with_context(|| "Failed to realize Ged")?; if battery_present { - let pdev = PowerDev::new(ged_dev); - pdev.realize( - &mut self.base.sysbus, + let pdev = PowerDev::new( + ged_dev, + &self.base.sysbus, MEM_LAYOUT[LayoutEntryType::PowerDev as usize].0, MEM_LAYOUT[LayoutEntryType::PowerDev as usize].1, - ) - .with_context(|| "Failed to realize PowerDev")?; + )?; + pdev.realize() + .with_context(|| "Failed to realize PowerDev")?; } Ok(()) } @@ -545,16 +497,15 @@ impl MachineOps for StdMachine { fn add_serial_device(&mut self, config: &SerialConfig) -> Result<()> { let region_base: u64 = MEM_LAYOUT[LayoutEntryType::Uart as usize].0; let region_size: u64 = MEM_LAYOUT[LayoutEntryType::Uart as usize].1; - - let pl011 = PL011::new(config.clone()).with_context(|| "Failed to create PL011")?; - pl011 - .realize( - &mut self.base.sysbus, - region_base, - region_size, - &self.base.boot_source, - ) - .with_context(|| "Failed to realize PL011") + let pl011 = PL011::new(config.clone(), &self.base.sysbus, region_base, region_size) + .with_context(|| "Failed to create PL011")?; + pl011.realize().with_context(|| "Failed to realize PL011")?; + let mut bs = self.base.boot_source.lock().unwrap(); + bs.kernel_cmdline.push(Param { + param_type: "earlycon".to_string(), + value: format!("pl011,mmio,0x{:08x}", region_base), + }); + Ok(()) } fn syscall_whitelist(&self) -> Vec { @@ -578,9 +529,8 @@ impl MachineOps for StdMachine { let nr_cpus = vm_config.machine_config.nr_cpus; let mut locked_vm = vm.lock().unwrap(); locked_vm.init_global_config(vm_config)?; - locked_vm - .register_shutdown_event(locked_vm.shutdown_req.clone(), vm.clone()) - .with_context(|| "Fail to register shutdown event")?; + register_shutdown_event(locked_vm.shutdown_req.clone(), vm.clone()) + .with_context(|| "Failed to register shutdown event")?; locked_vm .register_reset_event(locked_vm.reset_req.clone(), vm.clone()) .with_context(|| "Fail to register reset event")?; @@ -606,8 +556,16 @@ impl MachineOps for StdMachine { .with_context(|| MachineError::InitPCIeHostErr)?; let fwcfg = locked_vm.add_fwcfg_device(nr_cpus)?; - let boot_config = locked_vm - .load_boot_source(fwcfg.as_ref(), MEM_LAYOUT[LayoutEntryType::Mem as usize].0)?; + let migrate = locked_vm.get_migrate_info(); + let boot_config = + if migrate.0 == MigrateMode::Unknown { + Some(locked_vm.load_boot_source( + fwcfg.as_ref(), + MEM_LAYOUT[LayoutEntryType::Mem as usize].0, + )?) + } else { + None + }; let cpu_config = locked_vm.load_cpu_features(vm_config)?; let hypervisor = locked_vm.base.hypervisor.clone(); @@ -632,21 +590,24 @@ impl MachineOps for StdMachine { .add_devices(vm_config) .with_context(|| "Failed to add devices")?; - let mut fdt_helper = FdtBuilder::new(); - locked_vm - .generate_fdt_node(&mut fdt_helper) - .with_context(|| MachineError::GenFdtErr)?; - let fdt_vec = fdt_helper.finish()?; - locked_vm.dtb_vec = fdt_vec.clone(); - locked_vm - .base - .sys_mem - .write( - &mut fdt_vec.as_slice(), - GuestAddress(boot_config.fdt_addr), - fdt_vec.len() as u64, - ) - .with_context(|| MachineError::WrtFdtErr(boot_config.fdt_addr, fdt_vec.len()))?; + if let Some(boot_cfg) = boot_config { + let mut fdt_helper = FdtBuilder::new(); + locked_vm + .generate_fdt_node(&mut fdt_helper) + .with_context(|| MachineError::GenFdtErr)?; + let fdt_vec = fdt_helper.finish()?; + locked_vm.dtb_vec = fdt_vec.clone(); + locked_vm + .base + .sys_mem + .write( + &mut fdt_vec.as_slice(), + GuestAddress(boot_cfg.fdt_addr), + fdt_vec.len() as u64, + AddressAttr::Ram, + ) + .with_context(|| MachineError::WrtFdtErr(boot_cfg.fdt_addr, fdt_vec.len()))?; + } // If it is direct kernel boot mode, the ACPI can not be enabled. if let Some(fw_cfg) = fwcfg { @@ -670,10 +631,11 @@ impl MachineOps for StdMachine { .with_context(|| "Fail to init display")?; #[cfg(feature = "windows_emu_pid")] - locked_vm.watch_windows_emu_pid( + crate::watch_windows_emu_pid( vm_config, locked_vm.power_button.clone(), locked_vm.shutdown_req.clone(), + vm.clone(), ); MigrationManager::register_vm_config(locked_vm.get_vm_config()); @@ -685,25 +647,35 @@ impl MachineOps for StdMachine { Ok(()) } - fn add_pflash_device(&mut self, configs: &[PFlashConfig]) -> Result<()> { + fn add_pflash_device(&mut self, configs: &[DriveConfig]) -> Result<()> { let mut configs_vec = configs.to_vec(); - configs_vec.sort_by_key(|c| c.unit); + configs_vec.sort_by_key(|c| c.unit.unwrap()); let sector_len: u32 = 1024 * 256; let mut flash_base: u64 = MEM_LAYOUT[LayoutEntryType::Flash as usize].0; let flash_size: u64 = MEM_LAYOUT[LayoutEntryType::Flash as usize].1 / 2; for i in 0..=1 { let (fd, read_only) = if i < configs_vec.len() { let path = &configs_vec[i].path_on_host; - let read_only = configs_vec[i].read_only; + let read_only = configs_vec[i].readonly; let fd = self.fetch_drive_file(path)?; (Some(fd), read_only) } else { (None, false) }; - let pflash = PFlash::new(flash_size, &fd, sector_len, 4, 2, read_only) - .with_context(|| MachineError::InitPflashErr)?; - PFlash::realize(pflash, &mut self.base.sysbus, flash_base, flash_size, fd) + let pflash = PFlash::new( + flash_size, + fd, + sector_len, + 4, + 2, + read_only, + &self.base.sysbus, + flash_base, + ) + .with_context(|| MachineError::InitPflashErr)?; + pflash + .realize() .with_context(|| MachineError::RlzPflashErr)?; flash_base += flash_size; } @@ -718,7 +690,7 @@ impl MachineOps for StdMachine { #[cfg(any(feature = "gtk", all(target_env = "ohos", feature = "ohui_srv")))] match vm_config.display { #[cfg(feature = "gtk")] - Some(ref ds_cfg) if ds_cfg.gtk => { + Some(ref ds_cfg) if ds_cfg.display_type == "gtk" => { let ui_context = UiContext { vm_name: vm_config.guest_name.clone(), power_button: Some(self.power_button.clone()), @@ -731,7 +703,7 @@ impl MachineOps for StdMachine { } // OHUI server init. #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] - Some(ref ds_cfg) if ds_cfg.ohui_config.ohui => { + Some(ref ds_cfg) if ds_cfg.display_type == "ohui" => { ohui_init(self.ohui_server.as_ref().unwrap().clone(), ds_cfg) .with_context(|| "Failed to init OH UI server!")?; } @@ -747,15 +719,16 @@ impl MachineOps for StdMachine { #[cfg(feature = "ramfb")] fn add_ramfb(&mut self, cfg_args: &str) -> Result<()> { - let install = parse_ramfb(cfg_args)?; + let config = RamfbConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; let fwcfg_dev = self .get_fwcfg_dev() .with_context(|| "Ramfb device must be used UEFI to boot, please add pflash devices")?; let sys_mem = self.get_sys_mem(); - let mut ramfb = Ramfb::new(sys_mem.clone(), install); + let mut ramfb = Ramfb::new(sys_mem.clone(), &self.base.sysbus, config.install); ramfb.ramfb_state.setup(&fwcfg_dev)?; - ramfb.realize(&mut self.base.sysbus) + ramfb.realize()?; + Ok(()) } fn get_pci_host(&mut self) -> Result<&Arc>> { @@ -874,7 +847,7 @@ impl AcpiBuilder for StdMachine { dbg2.set_field(40, 1_u32); // Table 2. Debug Device Information structure format - let offset = 44; + let offset = 44_usize; // Revision dbg2.set_field(offset, 0_u8); // Length @@ -962,7 +935,7 @@ impl AcpiBuilder for StdMachine { // Mapping counts of Root Complex Node iort.set_field(80, 1_u32); // Mapping offset of Root Complex Node - iort.set_field(84, ROOT_COMPLEX_ENTRY_SIZE as u32); + iort.set_field(84, u32::from(ROOT_COMPLEX_ENTRY_SIZE)); // Cache of coherent device iort.set_field(88, 1_u32); // Memory flags of coherent device @@ -999,10 +972,11 @@ impl AcpiBuilder for StdMachine { spcr.set_field(52, 1_u8 << 3); // Irq number used by the UART let mut uart_irq: u32 = 0; - for dev in self.base.sysbus.devices.iter() { - let locked_dev = dev.lock().unwrap(); - if locked_dev.sysbusdev_base().dev_type == SysBusDevType::PL011 { - uart_irq = locked_dev.sysbusdev_base().irq_state.irq as _; + let devices = self.get_sysbus_devices(); + for dev in devices.values() { + SYS_BUS_DEVICE!(dev, locked_dev, sysbusdev); + if sysbusdev.sysbusdev_base().dev_type == SysBusDevType::PL011 { + uart_irq = sysbusdev.sysbusdev_base().irq_state.irq as _; break; } } @@ -1048,7 +1022,7 @@ impl AcpiBuilder for StdMachine { dsdt.append_child(sb_scope.aml_bytes().as_slice()); // 3. Info of devices attached to system bus. - dsdt.append_child(self.base.sysbus.aml_bytes().as_slice()); + dsdt.append_child(self.base.sysbus.lock().unwrap().aml_bytes().as_slice()); let dsdt_begin = StdMachine::add_table_to_loader(acpi_data, loader, &dsdt) .with_context(|| "Fail to add DSDT table to loader")?; @@ -1125,7 +1099,7 @@ impl AcpiBuilder for StdMachine { type_id: 3_u8, length: size_of::() as u8, proximity_domain, - process_uid: *cpu as u32, + process_uid: u32::from(*cpu), flags: 1, clock_domain: 0_u32, } @@ -1185,116 +1159,16 @@ impl AcpiBuilder for StdMachine { loader: &mut TableLoader, ) -> Result { let mut pptt = AcpiTable::new(*b"PPTT", 2, *b"STRATO", *b"VIRTPPTT", 1); - let mut uid = 0; + let mut uid = 0_u32; self.build_pptt_sockets(&mut pptt, &mut uid); let pptt_begin = StdMachine::add_table_to_loader(acpi_data, loader, &pptt) .with_context(|| "Fail to add PPTT table to loader")?; Ok(pptt_begin) } -} - -impl MachineLifecycle for StdMachine { - fn pause(&self) -> bool { - if self.notify_lifecycle(VmState::Running, VmState::Paused) { - event!(Stop); - true - } else { - false - } - } - - fn resume(&self) -> bool { - if !self.notify_lifecycle(VmState::Paused, VmState::Running) { - return false; - } - event!(Resume); - true - } - - fn destroy(&self) -> bool { - if self.shutdown_req.write(1).is_err() { - error!("Failed to send shutdown request."); - return false; - } - - true - } - - fn powerdown(&self) -> bool { - if self.power_button.write(1).is_err() { - error!("ARM standard vm write power button failed"); - return false; - } - true - } - - fn get_shutdown_action(&self) -> ShutdownAction { - self.base - .vm_config - .lock() - .unwrap() - .machine_config - .shutdown_action - } - - fn reset(&mut self) -> bool { - if self.reset_req.write(1).is_err() { - error!("ARM standard vm write reset req failed"); - return false; - } - true - } - - fn notify_lifecycle(&self, old: VmState, new: VmState) -> bool { - if let Err(e) = self.vm_state_transfer( - &self.base.cpus, - &self.base.irq_chip, - &mut self.base.vm_state.0.lock().unwrap(), - old, - new, - ) { - error!("VM state transfer failed: {:?}", e); - return false; - } - true - } -} - -impl MigrateInterface for StdMachine { - fn migrate(&self, uri: String) -> Response { - match parse_incoming_uri(&uri) { - Ok((MigrateMode::File, path)) => migration::snapshot(path), - Ok((MigrateMode::Unix, path)) => migration::migration_unix_mode(path), - Ok((MigrateMode::Tcp, path)) => migration::migration_tcp_mode(path), - _ => Response::create_error_response( - qmp_schema::QmpErrorClass::GenericError(format!("Invalid uri: {}", uri)), - None, - ), - } - } - - fn query_migrate(&self) -> Response { - migration::query_migrate() - } - - fn cancel_migrate(&self) -> Response { - migration::cancel_migrate() - } -} - -impl MachineInterface for StdMachine {} -impl MachineExternalInterface for StdMachine {} -impl MachineTestInterface for StdMachine {} - -impl EventLoopManager for StdMachine { - fn loop_should_exit(&self) -> bool { - let vmstate = self.base.vm_state.deref().0.lock().unwrap(); - *vmstate == VmState::Shutdown - } - fn loop_cleanup(&self) -> Result<()> { - set_termi_canon_mode().with_context(|| "Failed to set terminal to canonical mode")?; - Ok(()) + fn get_hardware_signature(&self) -> Option { + let vm_config = self.machine_base().vm_config.lock().unwrap(); + vm_config.hardware_signature } } @@ -1428,7 +1302,11 @@ impl CompileFDTHelper for StdMachine { match &boot_source.initrd { Some(initrd) => { fdt.set_property_u64("linux,initrd-start", initrd.initrd_addr)?; - fdt.set_property_u64("linux,initrd-end", initrd.initrd_addr + initrd.initrd_size)?; + let initrd_end = initrd + .initrd_addr + .checked_add(initrd.initrd_size) + .with_context(|| "initrd end overflow")?; + fdt.set_property_u64("linux,initrd-end", initrd_end)?; } None => {} } diff --git a/machine/src/lib.rs b/machine/src/lib.rs index ceeb39fb8de8fcc36cdcde4963523c9219b0d66d..bda1c9cc8af0e52ba9df84b19fb5772bebd3fa7e 100644 --- a/machine/src/lib.rs +++ b/machine/src/lib.rs @@ -20,6 +20,11 @@ pub mod x86_64; mod micro_common; pub use crate::error::MachineError; +#[cfg(feature = "usb_host")] +use machine_manager::{ + event, + qmp::{qmp_channel::QmpChannel, qmp_schema::UsbHostAddRes}, +}; pub use micro_common::LightMachine; pub use standard_common::StdMachine; @@ -27,87 +32,117 @@ use std::collections::{BTreeMap, HashMap}; use std::fs::{remove_file, File}; use std::net::TcpListener; use std::ops::Deref; +use std::os::unix::io::AsRawFd; use std::os::unix::net::UnixListener; +#[cfg(any(feature = "windows_emu_pid", feature = "vfio_device"))] use std::path::Path; +use std::rc::Rc; use std::sync::{Arc, Barrier, Condvar, Mutex, RwLock, Weak}; +#[cfg(feature = "usb_host")] +use std::thread; #[cfg(feature = "windows_emu_pid")] use std::time::Duration; +use std::u64; use anyhow::{anyhow, bail, Context, Result}; use clap::Parser; -use log::warn; -#[cfg(feature = "windows_emu_pid")] +use log::{error, info, warn}; +use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::EventFd; #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] use address_space::FileBackend; -use address_space::{create_backend_mem, create_default_mem, AddressSpace, GuestAddress, Region}; +use address_space::{ + create_backend_mem, create_default_mem, AddressAttr, AddressSpace, GuestAddress, Region, +}; #[cfg(target_arch = "aarch64")] use cpu::CPUFeatures; use cpu::{ArchCPU, CPUBootConfig, CPUHypervisorOps, CPUInterface, CPUTopology, CpuTopology, CPU}; use devices::legacy::FwCfgOps; #[cfg(feature = "pvpanic")] -use devices::misc::pvpanic::PvPanicPci; +use devices::misc::pvpanic::{PvPanicPci, PvpanicDevConfig}; #[cfg(feature = "scream")] use devices::misc::scream::{Scream, ScreamConfig}; #[cfg(feature = "demo_device")] -use devices::pci::demo_device::DemoDev; -use devices::pci::{PciBus, PciDevOps, PciHost, RootPort}; +use devices::pci::demo_device::{DemoDev, DemoDevConfig}; +use devices::pci::{ + devices_register_pcidevops_type, register_pcidevops_type, PciBus, PciDevOps, PciHost, RootPort, + RootPortConfig, +}; use devices::smbios::smbios_table::{build_smbios_ep30, SmbiosTable}; use devices::smbios::{SMBIOS_ANCHOR_FILE, SMBIOS_TABLE_FILE}; -use devices::sysbus::{SysBus, SysBusDevOps, SysBusDevType}; +use devices::sysbus::{devices_register_sysbusdevops_type, to_sysbusdevops, SysBus, SysBusDevType}; #[cfg(feature = "usb_camera")] use devices::usb::camera::{UsbCamera, UsbCameraConfig}; use devices::usb::keyboard::{UsbKeyboard, UsbKeyboardConfig}; +use devices::usb::storage::{UsbStorage, UsbStorageConfig}; use devices::usb::tablet::{UsbTablet, UsbTabletConfig}; +#[cfg(feature = "usb_uas")] +use devices::usb::uas::{UsbUas, UsbUasConfig}; #[cfg(feature = "usb_host")] use devices::usb::usbhost::{UsbHost, UsbHostConfig}; use devices::usb::xhci::xhci_pci::{XhciConfig, XhciPciDevice}; -use devices::usb::{storage::UsbStorage, UsbDevice}; +use devices::usb::UsbDevice; #[cfg(target_arch = "aarch64")] use devices::InterruptController; -use devices::ScsiDisk::{ScsiDevice, SCSI_TYPE_DISK, SCSI_TYPE_ROM}; +use devices::{convert_bus_ref, Bus, Device, PCI_BUS, SYS_BUS_DEVICE}; +#[cfg(feature = "virtio_scsi")] +use devices::{ + ScsiBus::get_scsi_key, + ScsiDisk::{ScsiDevConfig, ScsiDevice}, +}; use hypervisor::{kvm::KvmHypervisor, test::TestHypervisor, HypervisorOps}; #[cfg(feature = "usb_camera")] use machine_manager::config::get_cameradev_by_id; -#[cfg(feature = "demo_device")] -use machine_manager::config::parse_demo_dev; -#[cfg(feature = "virtio_gpu")] -use machine_manager::config::parse_gpu; -#[cfg(feature = "pvpanic")] -use machine_manager::config::parse_pvpanic; -use machine_manager::config::parse_usb_storage; +#[cfg(feature = "vhostuser_net")] +use machine_manager::config::get_chardev_socket_path; use machine_manager::config::{ - complete_numa_node, get_multi_function, get_pci_bdf, parse_blk, parse_device_id, - parse_device_type, parse_fs, parse_net, parse_numa_distance, parse_numa_mem, parse_rng_dev, - parse_root_port, parse_scsi_controller, parse_scsi_device, parse_vfio, parse_vhost_user_blk, - parse_virtio_serial, parse_virtserialport, parse_vsock, str_slip_to_clap, BootIndexInfo, - BootSource, DriveFile, Incoming, MachineMemConfig, MigrateMode, NumaConfig, NumaDistance, - NumaNode, NumaNodes, PFlashConfig, PciBdf, SerialConfig, VfioConfig, VmConfig, FAST_UNPLUG_ON, - MAX_VIRTIO_QUEUE, + complete_numa_node, get_class_type, get_pci_bdf, get_value_of_parameter, parse_numa_distance, + parse_numa_mem, str_slip_to_clap, BootIndexInfo, BootSource, ConfigCheck, DriveConfig, + DriveFile, Incoming, MachineMemConfig, MigrateMode, NetworkInterfaceConfig, NumaNode, + NumaNodes, PciBdf, SerialConfig, VirtioSerialInfo, VirtioSerialPortCfg, VmConfig, + FAST_UNPLUG_ON, MAX_VIRTIO_QUEUE, }; use machine_manager::event_loop::EventLoop; -use machine_manager::machine::{HypervisorType, MachineInterface, VmState}; +use machine_manager::machine::{HypervisorType, MachineInterface, MachineLifecycle, VmState}; +use machine_manager::notifier::pause_notify; +use machine_manager::{check_arg_exist, check_arg_nonexist}; use migration::{MigrateOps, MigrationManager}; #[cfg(feature = "windows_emu_pid")] use ui::console::{get_run_stage, VmRunningStage}; +use util::arg_parser; use util::file::{clear_file, lock_file, unlock_file}; -use util::{ - arg_parser, - seccomp::{BpfRule, SeccompOpt, SyscallFilter}, +use util::loop_context::{ + gen_delete_notifiers, EventNotifier, NotifierCallback, NotifierOperation, }; -use vfio::{VfioDevice, VfioPciDevice, KVM_DEVICE_FD}; -#[cfg(feature = "virtio_gpu")] -use virtio::Gpu; +use util::seccomp::{BpfRule, SeccompOpt, SyscallFilter}; +#[cfg(feature = "vfio_device")] +use vfio::{vfio_register_pcidevops_type, VfioConfig, VfioDevice, VfioPciDevice, KVM_DEVICE_FD}; +#[cfg(feature = "virtio_scsi")] +use virtio::ScsiCntlr::{scsi_cntlr_create_scsi_bus, ScsiCntlr, ScsiCntlrConfig}; +#[cfg(any(feature = "vhost_vsock", feature = "vhost_net"))] +use virtio::VhostKern; +#[cfg(any(feature = "vhostuser_block", feature = "vhostuser_net"))] +use virtio::VhostUser; #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] use virtio::VirtioDeviceQuirk; use virtio::{ - balloon_allow_list, find_port_by_nr, get_max_nr, vhost, Balloon, BalloonConfig, Block, - BlockState, Rng, RngState, - ScsiCntlr::{scsi_cntlr_create_scsi_bus, ScsiCntlr}, - Serial, SerialPort, VhostKern, VhostUser, VirtioDevice, VirtioMmioDevice, VirtioMmioState, - VirtioNetState, VirtioPciDevice, VirtioSerialState, VIRTIO_TYPE_CONSOLE, + balloon_allow_list, find_port_by_nr, get_max_nr, vhost, virtio_register_pcidevops_type, + virtio_register_sysbusdevops_type, Balloon, BalloonConfig, Block, BlockState, Input, + InputConfig, Serial, SerialPort, VirtioBlkDevConfig, VirtioDevice, VirtioMmioDevice, + VirtioMmioState, VirtioNetState, VirtioPciDevice, VirtioSerialState, VIRTIO_TYPE_CONSOLE, }; +#[cfg(feature = "virtio_gpu")] +use virtio::{Gpu, GpuDevConfig}; +#[cfg(feature = "virtio_rng")] +use virtio::{Rng, RngConfig, RngState}; + +#[cfg(feature = "windows_emu_pid")] +const WINDOWS_EMU_PID_DEFAULT_INTERVAL: u64 = 4000; +#[cfg(feature = "windows_emu_pid")] +const WINDOWS_EMU_PID_SHUTDOWN_INTERVAL: u64 = 1000; +#[cfg(feature = "windows_emu_pid")] +const WINDOWS_EMU_PID_POWERDOWN_INTERVAL: u64 = 30000; /// Machine structure include base members. pub struct MachineBase { @@ -124,7 +159,7 @@ pub struct MachineBase { #[cfg(target_arch = "x86_64")] sys_io: Arc, /// System bus. - sysbus: SysBus, + sysbus: Arc>, /// VM running state. vm_state: Arc<(Mutex, Condvar)>, /// Vm boot_source config. @@ -160,12 +195,9 @@ impl MachineBase { vm_config.machine_config.nr_threads, vm_config.machine_config.max_cpus, ); - let machine_ram = Arc::new(Region::init_container_region( - u64::max_value(), - "MachineRam", - )); + let machine_ram = Arc::new(Region::init_container_region(u64::MAX, "MachineRam")); let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "SysMem"), + Region::init_container_region(u64::MAX, "SysMem"), "sys_mem", Some(machine_ram.clone()), ) @@ -209,7 +241,7 @@ impl MachineBase { sys_mem, #[cfg(target_arch = "x86_64")] sys_io, - sysbus, + sysbus: Arc::new(Mutex::new(sysbus)), vm_state: Arc::new((Mutex::new(VmState::Created), Condvar::new())), boot_source: Arc::new(Mutex::new(vm_config.clone().boot_source)), vm_config: Arc::new(Mutex::new(vm_config.clone())), @@ -239,7 +271,7 @@ impl MachineBase { let length = data.len() as u64; self.sys_io - .read(&mut data, GuestAddress(addr), length) + .read(&mut data, GuestAddress(addr), length, AddressAttr::MMIO) .is_ok() } @@ -248,27 +280,27 @@ impl MachineBase { use crate::x86_64::ich9_lpc::SLEEP_CTRL_OFFSET; let count = data.len() as u64; - if addr == SLEEP_CTRL_OFFSET as u64 { + if addr == u64::from(SLEEP_CTRL_OFFSET) { if let Err(e) = self.cpus[0].pause() { log::error!("Fail to pause bsp, {:?}", e); } } self.sys_io - .write(&mut data, GuestAddress(addr), count) + .write(&mut data, GuestAddress(addr), count, AddressAttr::MMIO) .is_ok() } fn mmio_read(&self, addr: u64, mut data: &mut [u8]) -> bool { let length = data.len() as u64; self.sys_mem - .read(&mut data, GuestAddress(addr), length) + .read(&mut data, GuestAddress(addr), length, AddressAttr::MMIO) .is_ok() } fn mmio_write(&self, addr: u64, mut data: &[u8]) -> bool { let count = data.len() as u64; self.sys_mem - .write(&mut data, GuestAddress(addr), count) + .write(&mut data, GuestAddress(addr), count, AddressAttr::MMIO) .is_ok() } } @@ -277,18 +309,18 @@ macro_rules! create_device_add_matches { ( $command:expr; $controller: expr; $(($($driver_name:tt)|+, $function_name:tt, $($arg:tt),*)),*; $(#[cfg($($features: tt)*)] - ($driver_name1:tt, $function_name1:tt, $($arg1:tt),*)),* + ($($driver_name1:tt)|+, $function_name1:tt, $($arg1:tt),*)),* ) => { match $command { $( $($driver_name)|+ => { - $controller.$function_name($($arg),*)?; + $controller.$function_name($($arg),*).with_context(|| format!("add {} fail.", $command))?; }, )* $( #[cfg($($features)*)] - $driver_name1 => { - $controller.$function_name1($($arg1),*)?; + $($driver_name1)|+ => { + $controller.$function_name1($($arg1),*).with_context(|| format!("add {} fail.", $command))?; }, )* _ => bail!("Unsupported device: {:?}", $command), @@ -296,7 +328,7 @@ macro_rules! create_device_add_matches { }; } -pub trait MachineOps { +pub trait MachineOps: MachineLifecycle { fn machine_base(&self) -> &MachineBase; fn machine_base_mut(&mut self) -> &mut MachineBase; @@ -383,12 +415,14 @@ pub trait MachineOps { } let zones = mem_config.mem_zones.as_ref().unwrap(); let mut offset = 0_u64; - for (_, node) in numa_nodes.as_ref().unwrap().iter().enumerate() { + for node in numa_nodes.as_ref().unwrap().iter() { for zone in zones.iter() { if zone.id.eq(&node.1.mem_dev) { let ram = create_backend_mem(zone, thread_num)?; root.add_subregion_not_update(ram, offset)?; - offset += zone.size; + offset = offset + .checked_add(zone.size) + .with_context(|| "total zone size overflow")?; break; } } @@ -409,6 +443,7 @@ pub trait MachineOps { sys_mem: &Arc, nr_cpus: u8, ) -> Result<()> { + trace::trace_scope_start!(init_memory); let migrate_info = self.get_migrate_info(); if migrate_info.0 != MigrateMode::File { self.create_machine_ram(mem_config, nr_cpus)?; @@ -453,7 +488,7 @@ pub trait MachineOps { #[cfg(target_arch = "aarch64")] let arch_cpu = ArchCPU::new(u32::from(vcpu_id)); #[cfg(target_arch = "x86_64")] - let arch_cpu = ArchCPU::new(u32::from(vcpu_id), u32::from(max_cpus)); + let arch_cpu = ArchCPU::new(u32::from(vcpu_id), max_cpus); let cpu = Arc::new(CPU::new( hypervisor_cpu, @@ -478,7 +513,7 @@ pub trait MachineOps { nr_cpus: u8, #[cfg(target_arch = "x86_64")] max_cpus: u8, topology: &CPUTopology, - boot_cfg: &CPUBootConfig, + boot_cfg: &Option, #[cfg(target_arch = "aarch64")] vcpu_cfg: &CPUFeatures, ) -> Result>> where @@ -571,23 +606,32 @@ pub trait MachineOps { /// # Arguments /// /// * `cfg_args` - Device configuration. + #[cfg(feature = "vhost_vsock")] fn add_virtio_vsock(&mut self, cfg_args: &str) -> Result<()> { - let device_cfg = parse_vsock(cfg_args)?; + let device_cfg = + VhostKern::VsockConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; let sys_mem = self.get_sys_mem().clone(); let vsock = Arc::new(Mutex::new(VhostKern::Vsock::new(&device_cfg, &sys_mem))); - match parse_device_type(cfg_args)?.as_str() { + match device_cfg.classtype.as_str() { "vhost-vsock-device" => { - let device = VirtioMmioDevice::new(&sys_mem, vsock.clone()); + check_arg_nonexist!( + ("bus", device_cfg.bus), + ("addr", device_cfg.addr), + ("multifunction", device_cfg.multifunction) + ); + let device = self + .add_virtio_mmio_device(device_cfg.id.clone(), vsock.clone()) + .with_context(|| MachineError::RlzVirtioMmioErr)?; MigrationManager::register_device_instance( VirtioMmioState::descriptor(), - self.realize_virtio_mmio_device(device) - .with_context(|| MachineError::RlzVirtioMmioErr)?, + device, &device_cfg.id, ); } _ => { - let bdf = get_pci_bdf(cfg_args)?; - let multi_func = get_multi_function(cfg_args)?; + check_arg_exist!(("bus", device_cfg.bus), ("addr", device_cfg.addr)); + let bdf = PciBdf::new(device_cfg.bus.clone().unwrap(), device_cfg.addr.unwrap()); + let multi_func = device_cfg.multifunction.unwrap_or_default(); self.add_virtio_pci_device(&device_cfg.id, &bdf, vsock.clone(), multi_func, true) .with_context(|| "Failed to add virtio pci vsock device")?; } @@ -602,9 +646,10 @@ pub trait MachineOps { Ok(()) } - fn realize_virtio_mmio_device( + fn add_virtio_mmio_device( &mut self, - _dev: VirtioMmioDevice, + _name: String, + _device: Arc>, ) -> Result>> { bail!("Virtio mmio devices not supported"); } @@ -662,27 +707,26 @@ pub trait MachineOps { } fn add_virtio_balloon(&mut self, vm_config: &mut VmConfig, cfg_args: &str) -> Result<()> { - if vm_config.dev_name.get("balloon").is_some() { + if vm_config.dev_name.contains_key("balloon") { bail!("Only one balloon device is supported for each vm."); } - let config = BalloonConfig::try_parse_from(str_slip_to_clap(cfg_args))?; + let config = BalloonConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; vm_config.dev_name.insert("balloon".to_string(), 1); let sys_mem = self.get_sys_mem(); let balloon = Arc::new(Mutex::new(Balloon::new(config.clone(), sys_mem.clone()))); Balloon::object_init(balloon.clone()); - match parse_device_type(cfg_args)?.as_str() { + match config.classtype.as_str() { "virtio-balloon-device" => { - if config.addr.is_some() || config.bus.is_some() || config.multifunction.is_some() { - bail!("virtio balloon device config is error!"); - } - let device = VirtioMmioDevice::new(sys_mem, balloon); - self.realize_virtio_mmio_device(device)?; + check_arg_nonexist!( + ("bus", config.bus), + ("addr", config.addr), + ("multifunction", config.multifunction) + ); + self.add_virtio_mmio_device(config.id.clone(), balloon)?; } _ => { - if config.addr.is_none() || config.bus.is_none() { - bail!("virtio balloon pci config is error!"); - } + check_arg_exist!(("bus", config.bus), ("addr", config.addr)); let bdf = PciBdf::new(config.bus.unwrap(), config.addr.unwrap()); let multi_func = config.multifunction.unwrap_or_default(); self.add_virtio_pci_device(&config.id, &bdf, balloon, multi_func, false) @@ -700,23 +744,34 @@ pub trait MachineOps { /// * `vm_config` - VM configuration. /// * `cfg_args` - Device configuration args. fn add_virtio_serial(&mut self, vm_config: &mut VmConfig, cfg_args: &str) -> Result<()> { - let serial_cfg = parse_virtio_serial(vm_config, cfg_args)?; - let sys_mem = self.get_sys_mem().clone(); + if vm_config.virtio_serial.is_some() { + bail!("Only one virtio serial device is supported"); + } + let mut serial_cfg = + VirtioSerialInfo::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + serial_cfg.auto_max_ports(); let serial = Arc::new(Mutex::new(Serial::new(serial_cfg.clone()))); - match parse_device_type(cfg_args)?.as_str() { + match serial_cfg.classtype.as_str() { "virtio-serial-device" => { - let device = VirtioMmioDevice::new(&sys_mem, serial.clone()); + check_arg_nonexist!( + ("bus", serial_cfg.bus), + ("addr", serial_cfg.addr), + ("multifunction", serial_cfg.multifunction) + ); + let device = self + .add_virtio_mmio_device(serial_cfg.id.clone(), serial.clone()) + .with_context(|| MachineError::RlzVirtioMmioErr)?; MigrationManager::register_device_instance( VirtioMmioState::descriptor(), - self.realize_virtio_mmio_device(device) - .with_context(|| MachineError::RlzVirtioMmioErr)?, + device, &serial_cfg.id, ); } _ => { - let bdf = serial_cfg.pci_bdf.unwrap(); - let multi_func = serial_cfg.multifunction; + check_arg_exist!(("bus", serial_cfg.bus), ("addr", serial_cfg.addr)); + let bdf = PciBdf::new(serial_cfg.bus.clone().unwrap(), serial_cfg.addr.unwrap()); + let multi_func = serial_cfg.multifunction.unwrap_or_default(); self.add_virtio_pci_device(&serial_cfg.id, &bdf, serial.clone(), multi_func, false) .with_context(|| "Failed to add virtio pci serial device")?; } @@ -728,6 +783,7 @@ pub trait MachineOps { &serial_cfg.id, ); + vm_config.virtio_serial = Some(serial_cfg); Ok(()) } @@ -744,11 +800,11 @@ pub trait MachineOps { .with_context(|| "No virtio serial device specified")?; let mut virtio_device = None; - if serial_cfg.pci_bdf.is_none() { + if serial_cfg.bus.is_none() { // Micro_vm. - for dev in self.get_sys_bus().devices.iter() { - let locked_busdev = dev.lock().unwrap(); - if locked_busdev.sysbusdev_base().dev_type == SysBusDevType::VirtioMmio { + for dev in self.get_sysbus_devices().values() { + SYS_BUS_DEVICE!(dev, locked_busdev, sysbusdev); + if sysbusdev.sysbusdev_base().dev_type == SysBusDevType::VirtioMmio { let virtio_mmio_dev = locked_busdev .as_any() .downcast_ref::() @@ -781,24 +837,31 @@ pub trait MachineOps { let mut virtio_dev_h = virtio_dev.lock().unwrap(); let serial = virtio_dev_h.as_any_mut().downcast_mut::().unwrap(); - let is_console = matches!(parse_device_type(cfg_args)?.as_str(), "virtconsole"); + let mut serialport_cfg = + VirtioSerialPortCfg::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; let free_port0 = find_port_by_nr(&serial.ports, 0).is_none(); // Note: port 0 is reserved for a virtconsole. let free_nr = get_max_nr(&serial.ports) + 1; - let serialport_cfg = - parse_virtserialport(vm_config, cfg_args, is_console, free_nr, free_port0)?; - if serialport_cfg.nr >= serial.max_nr_ports { + serialport_cfg.auto_nr(free_port0, free_nr, serial.max_nr_ports)?; + serialport_cfg.check()?; + if find_port_by_nr(&serial.ports, serialport_cfg.nr.unwrap()).is_some() { bail!( - "virtio serial port nr {} should be less than virtio serial's max_nr_ports {}", - serialport_cfg.nr, - serial.max_nr_ports + "Repetitive virtio serial port nr {}.", + serialport_cfg.nr.unwrap() ); } - if find_port_by_nr(&serial.ports, serialport_cfg.nr).is_some() { - bail!("Repetitive virtio serial port nr {}.", serialport_cfg.nr,); - } + let is_console = matches!(serialport_cfg.classtype.as_str(), "virtconsole"); + let chardev_cfg = vm_config + .chardev + .remove(&serialport_cfg.chardev) + .with_context(|| { + format!( + "Chardev {:?} not found or is in use", + &serialport_cfg.chardev + ) + })?; - let mut serial_port = SerialPort::new(serialport_cfg); + let mut serial_port = SerialPort::new(serialport_cfg, chardev_cfg); let port = Arc::new(Mutex::new(serial_port.clone())); serial_port.realize()?; if !is_console { @@ -815,20 +878,31 @@ pub trait MachineOps { /// /// * `vm_config` - VM configuration. /// * `cfg_args` - Device configuration arguments. + #[cfg(feature = "virtio_rng")] fn add_virtio_rng(&mut self, vm_config: &mut VmConfig, cfg_args: &str) -> Result<()> { - let rng_cfg = parse_rng_dev(vm_config, cfg_args)?; - let sys_mem = self.get_sys_mem(); - let rng_dev = Arc::new(Mutex::new(Rng::new(rng_cfg.clone()))); + let rng_cfg = RngConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + rng_cfg.bytes_per_sec()?; + let rngobj_cfg = vm_config + .object + .rng_object + .remove(&rng_cfg.rng) + .with_context(|| "Object for rng-random device not found")?; + let rng_dev = Arc::new(Mutex::new(Rng::new(rng_cfg.clone(), rngobj_cfg))); - match parse_device_type(cfg_args)?.as_str() { + match rng_cfg.classtype.as_str() { "virtio-rng-device" => { - let device = VirtioMmioDevice::new(sys_mem, rng_dev.clone()); - self.realize_virtio_mmio_device(device) + check_arg_nonexist!( + ("bus", rng_cfg.bus), + ("addr", rng_cfg.addr), + ("multifunction", rng_cfg.multifunction) + ); + self.add_virtio_mmio_device(rng_cfg.id.clone(), rng_dev.clone()) .with_context(|| "Failed to add virtio mmio rng device")?; } _ => { - let bdf = get_pci_bdf(cfg_args)?; - let multi_func = get_multi_function(cfg_args)?; + check_arg_exist!(("bus", rng_cfg.bus), ("addr", rng_cfg.addr)); + let bdf = PciBdf::new(rng_cfg.bus.clone().unwrap(), rng_cfg.addr.unwrap()); + let multi_func = rng_cfg.multifunction.unwrap_or_default(); self.add_virtio_pci_device(&rng_cfg.id, &bdf, rng_dev.clone(), multi_func, false) .with_context(|| "Failed to add pci rng device")?; } @@ -842,6 +916,35 @@ pub trait MachineOps { bail!("No pci host found"); } + /// Add virtio-input device + /// + /// # Arguments + /// + /// * `cfg_args` - Device configuration arguments. + fn add_virtio_input(&mut self, cfg_args: &str) -> Result<()> { + let cfg = InputConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let dev = Arc::new(Mutex::new(Input::new(cfg.clone())?)); + match cfg.classtype.as_str() { + "virtio-input-device" => { + check_arg_nonexist!( + ("bus", cfg.bus), + ("addr", cfg.addr), + ("multifunction", cfg.multifunction) + ); + self.add_virtio_mmio_device(cfg.id.clone(), dev) + .with_context(|| "Failed to add virtio mmio input device")?; + } + _ => { + check_arg_exist!(("bus", cfg.bus), ("addr", cfg.addr)); + let bdf = PciBdf::new(cfg.bus.clone().unwrap(), cfg.addr.unwrap()); + let multi_func = cfg.multifunction.unwrap_or_default(); + self.add_virtio_pci_device(&cfg.id, &bdf, dev, multi_func, false) + .with_context(|| "Failed to add virtio pci input device")?; + } + } + Ok(()) + } + /// Add virtioFs device. /// /// # Arguments @@ -849,26 +952,43 @@ pub trait MachineOps { /// * 'vm_config' - VM configuration. /// * 'cfg_args' - Device configuration arguments. fn add_virtio_fs(&mut self, vm_config: &mut VmConfig, cfg_args: &str) -> Result<()> { - let dev_cfg = parse_fs(vm_config, cfg_args)?; - let id_clone = dev_cfg.id.clone(); + let dev_cfg = + vhost::user::FsConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let char_dev = vm_config + .chardev + .remove(&dev_cfg.chardev) + .with_context(|| format!("Chardev {:?} not found or is in use", &dev_cfg.chardev))?; let sys_mem = self.get_sys_mem().clone(); if !vm_config.machine_config.mem_config.mem_share { bail!("When configuring the vhost-user-fs-device or vhost-user-fs-pci device, the memory must be shared."); } - match parse_device_type(cfg_args)?.as_str() { + let device = Arc::new(Mutex::new(vhost::user::Fs::new( + dev_cfg.clone(), + char_dev, + sys_mem, + ))); + match dev_cfg.classtype.as_str() { "vhost-user-fs-device" => { - let device = Arc::new(Mutex::new(vhost::user::Fs::new(dev_cfg, sys_mem.clone()))); - let virtio_mmio_device = VirtioMmioDevice::new(&sys_mem, device); - self.realize_virtio_mmio_device(virtio_mmio_device) + check_arg_nonexist!( + ("bus", dev_cfg.bus), + ("addr", dev_cfg.addr), + ("multifunction", dev_cfg.multifunction) + ); + self.add_virtio_mmio_device(dev_cfg.id.clone(), device) .with_context(|| "Failed to add vhost user fs device")?; } _ => { - let device = Arc::new(Mutex::new(vhost::user::Fs::new(dev_cfg, sys_mem))); - let bdf = get_pci_bdf(cfg_args)?; - let multi_func = get_multi_function(cfg_args)?; - self.add_virtio_pci_device(&id_clone, &bdf, device, multi_func, true) + check_arg_exist!(("bus", dev_cfg.bus), ("addr", dev_cfg.addr)); + let bdf = PciBdf::new(dev_cfg.bus.clone().unwrap(), dev_cfg.addr.unwrap()); + let multi_func = dev_cfg.multifunction.unwrap_or_default(); + let root_bus = self.get_pci_host()?.lock().unwrap().child_bus().unwrap(); + PCI_BUS!(root_bus, locked_bus, root_pci_bus); + let msi_irq_manager = root_pci_bus.msi_irq_manager.clone(); + drop(locked_bus); + let need_irqfd = msi_irq_manager.as_ref().unwrap().irqfd_enable(); + self.add_virtio_pci_device(&dev_cfg.id, &bdf, device, multi_func, need_irqfd) .with_context(|| "Failed to add pci fs device")?; } } @@ -876,8 +996,8 @@ pub trait MachineOps { Ok(()) } - fn get_sys_bus(&mut self) -> &SysBus { - &self.machine_base().sysbus + fn get_sysbus_devices(&self) -> BTreeMap>> { + self.machine_base().sysbus.lock().unwrap().child_devices() } fn get_fwcfg_dev(&mut self) -> Option>> { @@ -889,19 +1009,15 @@ pub trait MachineOps { } fn reset_all_devices(&mut self) -> Result<()> { - let sysbus = self.get_sys_bus(); - for dev in sysbus.devices.iter() { - dev.lock() - .unwrap() - .reset() - .with_context(|| "Fail to reset sysbus device")?; - } + let sysbus = self.machine_base().sysbus.clone(); + sysbus.lock().unwrap().reset()?; + // Todo: this logic will be deleted after deleting pci_host in machine struct. if let Ok(pci_host) = self.get_pci_host() { pci_host .lock() .unwrap() - .reset() + .reset(true) .with_context(|| "Fail to reset pci host")?; } @@ -932,7 +1048,9 @@ pub trait MachineOps { if name.is_empty() { bail!("Device id is empty"); } - if PciBus::find_attached_bus(&pci_host.lock().unwrap().root_bus, name).is_some() { + if PciBus::find_attached_bus(&pci_host.lock().unwrap().child_bus().unwrap(), name) + .is_some() + { bail!("Device id {} existed", name); } if self.check_id_existed_in_xhci(name).unwrap_or_default() { @@ -1025,11 +1143,10 @@ pub trait MachineOps { #[cfg(feature = "pvpanic")] fn add_pvpanic(&mut self, cfg_args: &str) -> Result<()> { - let bdf = get_pci_bdf(cfg_args)?; - let device_cfg = parse_pvpanic(cfg_args)?; - + let config = PvpanicDevConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let bdf = PciBdf::new(config.bus.clone(), config.addr); let (devfn, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; - let pcidev = PvPanicPci::new(&device_cfg, devfn, parent_bus); + let pcidev = PvPanicPci::new(&config, devfn, parent_bus); pcidev .realize() .with_context(|| "Failed to realize pvpanic device")?; @@ -1043,26 +1160,38 @@ pub trait MachineOps { cfg_args: &str, hotplug: bool, ) -> Result<()> { - let bdf = get_pci_bdf(cfg_args)?; - let multi_func = get_multi_function(cfg_args)?; - let queues_auto = Some(VirtioPciDevice::virtio_pci_auto_queues_num( - 0, - vm_config.machine_config.nr_cpus, - MAX_VIRTIO_QUEUE, - )); - let device_cfg = parse_blk(vm_config, cfg_args, queues_auto)?; - if let Some(bootindex) = device_cfg.boot_index { + let mut device_cfg = + VirtioBlkDevConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + check_arg_exist!(("bus", device_cfg.bus), ("addr", device_cfg.addr)); + let bdf = PciBdf::new(device_cfg.bus.clone().unwrap(), device_cfg.addr.unwrap()); + let multi_func = device_cfg.multifunction.unwrap_or_default(); + if device_cfg.num_queues.is_none() { + let queues_auto = VirtioPciDevice::virtio_pci_auto_queues_num( + 0, + vm_config.machine_config.nr_cpus, + MAX_VIRTIO_QUEUE, + ); + device_cfg.num_queues = Some(queues_auto); + } + if let Some(bootindex) = device_cfg.bootindex { self.check_bootindex(bootindex) .with_context(|| "Fail to add virtio pci blk device for invalid bootindex")?; } + + let drive_cfg = vm_config + .drives + .remove(&device_cfg.drive) + .with_context(|| "No drive configured matched for blk device")?; + let device = Arc::new(Mutex::new(Block::new( device_cfg.clone(), + drive_cfg, self.get_drive_files(), ))); let pci_dev = self .add_virtio_pci_device(&device_cfg.id, &bdf, device.clone(), multi_func, false) .with_context(|| "Failed to add virtio pci device")?; - if let Some(bootindex) = device_cfg.boot_index { + if let Some(bootindex) = device_cfg.bootindex { // Eg: OpenFirmware device path(virtio-blk disk): // /pci@i0cf8/scsi@6[,3]/disk@0,0 // | | | | | @@ -1085,85 +1214,80 @@ pub trait MachineOps { Ok(()) } + #[cfg(feature = "virtio_scsi")] fn add_virtio_pci_scsi( &mut self, vm_config: &mut VmConfig, cfg_args: &str, hotplug: bool, ) -> Result<()> { - let bdf = get_pci_bdf(cfg_args)?; - let multi_func = get_multi_function(cfg_args)?; - let queues_auto = Some(VirtioPciDevice::virtio_pci_auto_queues_num( - 0, - vm_config.machine_config.nr_cpus, - MAX_VIRTIO_QUEUE, - )); - let device_cfg = parse_scsi_controller(cfg_args, queues_auto)?; + let mut device_cfg = + ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let bdf = PciBdf::new(device_cfg.bus.clone(), device_cfg.addr); + let multi_func = device_cfg.multifunction.unwrap_or_default(); + if device_cfg.num_queues.is_none() { + let queues_auto = VirtioPciDevice::virtio_pci_auto_queues_num( + 0, + vm_config.machine_config.nr_cpus, + MAX_VIRTIO_QUEUE, + ); + device_cfg.num_queues = Some(u32::from(queues_auto)); + } let device = Arc::new(Mutex::new(ScsiCntlr::new(device_cfg.clone()))); let bus_name = format!("{}.0", device_cfg.id); scsi_cntlr_create_scsi_bus(&bus_name, &device)?; - let pci_dev = self - .add_virtio_pci_device(&device_cfg.id, &bdf, device.clone(), multi_func, false) + self.add_virtio_pci_device(&device_cfg.id, &bdf, device, multi_func, false) .with_context(|| "Failed to add virtio scsi controller")?; if !hotplug { self.reset_bus(&device_cfg.id)?; } - device.lock().unwrap().config.boot_prefix = pci_dev.lock().unwrap().get_dev_path(); Ok(()) } + #[cfg(feature = "virtio_scsi")] fn add_scsi_device(&mut self, vm_config: &mut VmConfig, cfg_args: &str) -> Result<()> { - let device_cfg = parse_scsi_device(vm_config, cfg_args)?; - let scsi_type = match parse_device_type(cfg_args)?.as_str() { - "scsi-hd" => SCSI_TYPE_DISK, - _ => SCSI_TYPE_ROM, - }; - if let Some(bootindex) = device_cfg.boot_index { + let device_cfg = ScsiDevConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let drive_arg = vm_config + .drives + .remove(&device_cfg.drive) + .with_context(|| "No drive configured matched for scsi device")?; + if let Some(bootindex) = device_cfg.bootindex { self.check_bootindex(bootindex) .with_context(|| "Failed to add scsi device for invalid bootindex")?; } - let device = Arc::new(Mutex::new(ScsiDevice::new( - device_cfg.clone(), - scsi_type, - self.get_drive_files(), - ))); + // Bus name `$parent_cntlr_name.0` is checked when parsing by clap. + let cntlr = device_cfg.bus.split('.').collect::>()[0].to_string(); let pci_dev = self - .get_pci_dev_by_id_and_type(vm_config, Some(&device_cfg.cntlr), "virtio-scsi-pci") - .with_context(|| { - format!( - "Can not find scsi controller from pci bus {}", - device_cfg.cntlr - ) - })?; + .get_pci_dev_by_id_and_type(vm_config, Some(&cntlr), "virtio-scsi-pci") + .with_context(|| format!("Can not find scsi controller from pci bus {}", cntlr))?; let locked_pcidev = pci_dev.lock().unwrap(); let virtio_pcidev = locked_pcidev .as_any() .downcast_ref::() .unwrap(); + let prefix = virtio_pcidev.get_dev_path().unwrap(); let virtio_device = virtio_pcidev.get_virtio_device().lock().unwrap(); let cntlr = virtio_device.as_any().downcast_ref::().unwrap(); - let bus = cntlr.bus.as_ref().unwrap(); - if bus - .lock() - .unwrap() - .devices - .contains_key(&(device_cfg.target, device_cfg.lun)) - { - bail!("Wrong! Two scsi devices have the same scsi-id and lun"); + let key = get_scsi_key(device_cfg.target, device_cfg.lun); + if bus.lock().unwrap().child_dev(key).is_some() { + bail!("Wrong! Two scsi devices have the same scsi-id and lun!"); } let iothread = cntlr.config.iothread.clone(); - device.lock().unwrap().realize(iothread)?; - bus.lock() - .unwrap() - .devices - .insert((device_cfg.target, device_cfg.lun), device.clone()); - device.lock().unwrap().parent_bus = Arc::downgrade(bus); + let scsi_device = ScsiDevice::new( + device_cfg.clone(), + drive_arg, + self.get_drive_files(), + iothread, + bus.clone(), + ); + let device = scsi_device.realize()?; + bus.lock().unwrap().attach_child(key, device)?; - if let Some(bootindex) = device_cfg.boot_index { + if let Some(bootindex) = device_cfg.bootindex { // Eg: OpenFirmware device path(virtio-scsi disk): // /pci@i0cf8/scsi@7[,3]/channel@0/disk@2,3 // | | | | | | @@ -1171,7 +1295,6 @@ pub trait MachineOps { // | | | channel(unused, fixed 0). // | PCI slot,[function] holding SCSI controller. // PCI root as system bus port. - let prefix = cntlr.config.boot_prefix.as_ref().unwrap(); let dev_path = format! {"{}/channel@0/disk@{:x},{:x}", prefix, device_cfg.target, device_cfg.lun}; self.add_bootindex_devices(bootindex, &dev_path, &device_cfg.id); @@ -1185,66 +1308,118 @@ pub trait MachineOps { cfg_args: &str, hotplug: bool, ) -> Result<()> { - let bdf = get_pci_bdf(cfg_args)?; - let multi_func = get_multi_function(cfg_args)?; - let device_cfg = parse_net(vm_config, cfg_args)?; + let mut net_cfg = + NetworkInterfaceConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + net_cfg.auto_iothread(); + let netdev_cfg = vm_config + .netdevs + .remove(&net_cfg.netdev) + .with_context(|| format!("Netdev: {:?} not found for net device", &net_cfg.netdev))?; + check_arg_exist!(("bus", net_cfg.bus), ("addr", net_cfg.addr)); + let bdf = PciBdf::new(net_cfg.bus.clone().unwrap(), net_cfg.addr.unwrap()); + let multi_func = net_cfg.multifunction.unwrap_or_default(); + + #[cfg(all(not(feature = "vhost_net"), not(feature = "vhostuser_net")))] + let need_irqfd = false; + #[cfg(any(feature = "vhost_net", feature = "vhostuser_net"))] let mut need_irqfd = false; - let device: Arc> = if device_cfg.vhost_type.is_some() { - need_irqfd = true; - if device_cfg.vhost_type == Some(String::from("vhost-kernel")) { - Arc::new(Mutex::new(VhostKern::Net::new( - &device_cfg, - self.get_sys_mem(), - ))) + let device: Arc> = if netdev_cfg.vhost_type().is_some() { + if netdev_cfg.vhost_type().unwrap() == "vhost-kernel" { + #[cfg(not(feature = "vhost_net"))] + bail!("Unsupported Vhost_net"); + + #[cfg(feature = "vhost_net")] + { + need_irqfd = true; + Arc::new(Mutex::new(VhostKern::Net::new( + &net_cfg, + netdev_cfg, + self.get_sys_mem(), + ))) + } } else { - Arc::new(Mutex::new(VhostUser::Net::new( - &device_cfg, - self.get_sys_mem(), - ))) + #[cfg(not(feature = "vhostuser_net"))] + bail!("Unsupported Vhostuser_net"); + + #[cfg(feature = "vhostuser_net")] + { + need_irqfd = true; + let chardev = netdev_cfg.chardev.clone().with_context(|| { + format!("Chardev not configured for netdev {:?}", netdev_cfg.id) + })?; + let chardev_cfg = vm_config + .chardev + .remove(&chardev) + .with_context(|| format!("Chardev: {:?} not found for netdev", chardev))?; + let sock_path = get_chardev_socket_path(chardev_cfg)?; + Arc::new(Mutex::new(VhostUser::Net::new( + &net_cfg, + netdev_cfg, + sock_path, + self.get_sys_mem(), + ))) + } } } else { - let device = Arc::new(Mutex::new(virtio::Net::new(device_cfg.clone()))); + let device = Arc::new(Mutex::new(virtio::Net::new(net_cfg.clone(), netdev_cfg))); MigrationManager::register_device_instance( VirtioNetState::descriptor(), device.clone(), - &device_cfg.id, + &net_cfg.id, ); device }; - self.add_virtio_pci_device(&device_cfg.id, &bdf, device, multi_func, need_irqfd)?; + self.add_virtio_pci_device(&net_cfg.id, &bdf, device, multi_func, need_irqfd)?; if !hotplug { - self.reset_bus(&device_cfg.id)?; + self.reset_bus(&net_cfg.id)?; } Ok(()) } + #[cfg(feature = "vhostuser_block")] fn add_vhost_user_blk_pci( &mut self, vm_config: &mut VmConfig, cfg_args: &str, hotplug: bool, ) -> Result<()> { - let bdf = get_pci_bdf(cfg_args)?; - let multi_func = get_multi_function(cfg_args)?; - let queues_auto = Some(VirtioPciDevice::virtio_pci_auto_queues_num( - 0, - vm_config.machine_config.nr_cpus, - MAX_VIRTIO_QUEUE, - )); - let device_cfg = parse_vhost_user_blk(vm_config, cfg_args, queues_auto)?; + let mut device_cfg = VhostUser::VhostUserBlkDevConfig::try_parse_from(str_slip_to_clap( + cfg_args, true, false, + ))?; + check_arg_exist!(("bus", device_cfg.bus), ("addr", device_cfg.addr)); + let bdf = PciBdf::new(device_cfg.bus.clone().unwrap(), device_cfg.addr.unwrap()); + if device_cfg.num_queues.is_none() { + let queues_auto = VirtioPciDevice::virtio_pci_auto_queues_num( + 0, + vm_config.machine_config.nr_cpus, + MAX_VIRTIO_QUEUE, + ); + device_cfg.num_queues = Some(queues_auto); + } + let chardev_cfg = vm_config + .chardev + .remove(&device_cfg.chardev) + .with_context(|| { + format!( + "Chardev: {:?} not found for vhost user blk", + &device_cfg.chardev + ) + })?; + let device: Arc> = Arc::new(Mutex::new(VhostUser::Block::new( &device_cfg, + chardev_cfg, self.get_sys_mem(), ))); let pci_dev = self - .add_virtio_pci_device(&device_cfg.id, &bdf, device.clone(), multi_func, true) + .add_virtio_pci_device(&device_cfg.id, &bdf, device.clone(), false, true) .with_context(|| { format!( "Failed to add virtio pci device, device id: {}", &device_cfg.id ) })?; - if let Some(bootindex) = device_cfg.boot_index { + if let Some(bootindex) = device_cfg.bootindex { if let Some(dev_path) = pci_dev.lock().unwrap().get_dev_path() { self.add_bootindex_devices(bootindex, &dev_path, &device_cfg.id); } @@ -1255,65 +1430,61 @@ pub trait MachineOps { Ok(()) } + #[cfg(feature = "vhostuser_block")] fn add_vhost_user_blk_device( &mut self, vm_config: &mut VmConfig, cfg_args: &str, ) -> Result<()> { - let device_cfg = parse_vhost_user_blk(vm_config, cfg_args, None)?; + let device_cfg = VhostUser::VhostUserBlkDevConfig::try_parse_from(str_slip_to_clap( + cfg_args, true, false, + ))?; + check_arg_nonexist!(("bus", device_cfg.bus), ("addr", device_cfg.addr)); + let chardev_cfg = vm_config + .chardev + .remove(&device_cfg.chardev) + .with_context(|| { + format!( + "Chardev: {:?} not found for vhost user blk", + &device_cfg.chardev + ) + })?; let device: Arc> = Arc::new(Mutex::new(VhostUser::Block::new( &device_cfg, + chardev_cfg, self.get_sys_mem(), ))); - let virtio_mmio_device = VirtioMmioDevice::new(self.get_sys_mem(), device); - self.realize_virtio_mmio_device(virtio_mmio_device) + self.add_virtio_mmio_device(device_cfg.id.clone(), device) .with_context(|| "Failed to add vhost user block device")?; Ok(()) } + #[cfg(feature = "vfio_device")] + fn add_vfio_device(&mut self, cfg_args: &str, hotplug: bool) -> Result<()> { + let hypervisor = self.get_hypervisor(); + let locked_hypervisor = hypervisor.lock().unwrap(); + *KVM_DEVICE_FD.lock().unwrap() = locked_hypervisor.create_vfio_device(); - fn create_vfio_pci_device( - &mut self, - id: &str, - bdf: &PciBdf, - host: &str, - sysfsdev: &str, - multifunc: bool, - ) -> Result<()> { - let (devfn, parent_bus) = self.get_devfn_and_parent_bus(bdf)?; - let path = if !host.is_empty() { - format!("/sys/bus/pci/devices/{}", host) + let device_cfg = VfioConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let bdf = PciBdf::new(device_cfg.bus.clone(), device_cfg.addr); + let multi_func = device_cfg.multifunction.unwrap_or_default(); + let (devfn, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; + let path = if device_cfg.host.is_some() { + format!("/sys/bus/pci/devices/{}", device_cfg.host.unwrap()) } else { - sysfsdev.to_string() + device_cfg.sysfsdev.unwrap() }; let device = VfioDevice::new(Path::new(&path), self.get_sys_mem()) .with_context(|| "Failed to create vfio device.")?; let vfio_pci = VfioPciDevice::new( device, devfn, - id.to_string(), + device_cfg.id.to_string(), parent_bus, - multifunc, + multi_func, self.get_sys_mem().clone(), ); VfioPciDevice::realize(vfio_pci).with_context(|| "Failed to realize vfio-pci device.")?; - Ok(()) - } - - fn add_vfio_device(&mut self, cfg_args: &str, hotplug: bool) -> Result<()> { - let hypervisor = self.get_hypervisor(); - let locked_hypervisor = hypervisor.lock().unwrap(); - *KVM_DEVICE_FD.lock().unwrap() = locked_hypervisor.create_vfio_device(); - let device_cfg: VfioConfig = parse_vfio(cfg_args)?; - let bdf = get_pci_bdf(cfg_args)?; - let multifunc = get_multi_function(cfg_args)?; - self.create_vfio_pci_device( - &device_cfg.id, - &bdf, - &device_cfg.host, - &device_cfg.sysfsdev, - multifunc, - )?; if !hotplug { self.reset_bus(&device_cfg.id)?; } @@ -1330,10 +1501,10 @@ pub trait MachineOps { #[cfg(feature = "virtio_gpu")] fn add_virtio_pci_gpu(&mut self, cfg_args: &str) -> Result<()> { - let bdf = get_pci_bdf(cfg_args)?; - let multi_func = get_multi_function(cfg_args)?; - let device_cfg = parse_gpu(cfg_args)?; - let device = Arc::new(Mutex::new(Gpu::new(device_cfg.clone()))); + let config = GpuDevConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + config.check(); + let bdf = PciBdf::new(config.bus.clone(), config.addr); + let device = Arc::new(Mutex::new(Gpu::new(config.clone()))); #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] if device.lock().unwrap().device_quirk() == Some(VirtioDeviceQuirk::VirtioGpuEnableBar0) @@ -1343,13 +1514,13 @@ pub trait MachineOps { device.lock().unwrap().set_bar0_fb(self.get_ohui_fb()); } - self.add_virtio_pci_device(&device_cfg.id, &bdf, device, multi_func, false)?; + self.add_virtio_pci_device(&config.id, &bdf, device, false, false)?; Ok(()) } - fn get_devfn_and_parent_bus(&mut self, bdf: &PciBdf) -> Result<(u8, Weak>)> { + fn get_devfn_and_parent_bus(&mut self, bdf: &PciBdf) -> Result<(u8, Weak>)> { let pci_host = self.get_pci_host()?; - let bus = pci_host.lock().unwrap().root_bus.clone(); + let bus = pci_host.lock().unwrap().child_bus().unwrap().clone(); let pci_bus = PciBus::find_bus_by_name(&bus, &bdf.bus); if pci_bus.is_none() { bail!("Parent bus :{} not found", &bdf.bus); @@ -1360,21 +1531,15 @@ pub trait MachineOps { } fn add_pci_root_port(&mut self, cfg_args: &str) -> Result<()> { - let bdf = get_pci_bdf(cfg_args)?; - let (devfn, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; - let device_cfg = parse_root_port(cfg_args)?; + let dev_cfg = RootPortConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let bdf = PciBdf::new(dev_cfg.bus.clone(), dev_cfg.addr); + let (_, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; let pci_host = self.get_pci_host()?; - let bus = pci_host.lock().unwrap().root_bus.clone(); - if PciBus::find_bus_by_name(&bus, &device_cfg.id).is_some() { - bail!("ID {} already exists.", &device_cfg.id); + let bus = pci_host.lock().unwrap().child_bus().unwrap().clone(); + if PciBus::find_bus_by_name(&bus, &dev_cfg.id).is_some() { + bail!("ID {} already exists.", &dev_cfg.id); } - let rootport = RootPort::new( - device_cfg.id, - devfn, - device_cfg.port, - parent_bus, - device_cfg.multifunction, - ); + let rootport = RootPort::new(dev_cfg, parent_bus); rootport .realize() .with_context(|| "Failed to add pci root port")?; @@ -1411,25 +1576,25 @@ pub trait MachineOps { fn reset_bus(&mut self, dev_id: &str) -> Result<()> { let pci_host = self.get_pci_host()?; let locked_pci_host = pci_host.lock().unwrap(); - let bus = PciBus::find_attached_bus(&locked_pci_host.root_bus, dev_id) + let bus = PciBus::find_attached_bus(&locked_pci_host.child_bus().unwrap(), dev_id) .with_context(|| format!("Bus not found, dev id {}", dev_id))? .0; let locked_bus = bus.lock().unwrap(); - if locked_bus.name == "pcie.0" { + if locked_bus.name() == "pcie.0" { // No need to reset root bus return Ok(()); } let parent_bridge = locked_bus - .parent_bridge - .as_ref() + .parent_device() .with_context(|| format!("Parent bridge does not exist, dev id {}", dev_id))?; let dev = parent_bridge.upgrade().unwrap(); let locked_dev = dev.lock().unwrap(); let name = locked_dev.name(); drop(locked_dev); let mut devfn = None; - let locked_bus = locked_pci_host.root_bus.lock().unwrap(); - for (id, dev) in &locked_bus.devices { + let bus = locked_pci_host.child_bus().unwrap(); + let locked_bus = bus.lock().unwrap(); + for (id, dev) in &locked_bus.child_devices() { if dev.lock().unwrap().name() == name { devfn = Some(*id); break; @@ -1437,7 +1602,7 @@ pub trait MachineOps { } drop(locked_bus); // It's safe to call devfn.unwrap(), because the bus exists. - match locked_pci_host.find_device(0, devfn.unwrap()) { + match locked_pci_host.find_device(0, u8::try_from(devfn.unwrap())?) { Some(dev) => dev .lock() .unwrap() @@ -1477,46 +1642,47 @@ pub trait MachineOps { for numa in vm_config.numa_nodes.iter() { match numa.0.as_str() { "node" => { - let numa_config: NumaConfig = parse_numa_mem(numa.1.as_str())?; - if numa_nodes.contains_key(&numa_config.numa_id) { - bail!("Numa node id is repeated {}", numa_config.numa_id); + let node_config = parse_numa_mem(numa.1.as_str())?; + if numa_nodes.contains_key(&node_config.numa_id) { + bail!("Numa node id is repeated {}", node_config.numa_id); } let mut numa_node = NumaNode { - cpus: numa_config.cpus, - mem_dev: numa_config.mem_dev.clone(), + cpus: node_config.cpus, + mem_dev: node_config.mem_dev.clone(), ..Default::default() }; numa_node.size = vm_config .object .mem_object - .remove(&numa_config.mem_dev) + .remove(&node_config.mem_dev) .map(|mem_conf| mem_conf.size) .with_context(|| { format!( "Object for memory-backend {} config not found", - numa_config.mem_dev + node_config.mem_dev ) })?; - numa_nodes.insert(numa_config.numa_id, numa_node); + numa_nodes.insert(node_config.numa_id, numa_node); } "dist" => { - let dist: (u32, NumaDistance) = parse_numa_distance(numa.1.as_str())?; - if !numa_nodes.contains_key(&dist.0) { - bail!("Numa node id is not found {}", dist.0); + let dist_config = parse_numa_distance(numa.1.as_str())?; + if !numa_nodes.contains_key(&dist_config.numa_id) { + bail!("Numa node id is not found {}", dist_config.numa_id); } - if !numa_nodes.contains_key(&dist.1.destination) { - bail!("Numa node id is not found {}", dist.1.destination); + if !numa_nodes.contains_key(&dist_config.destination) { + bail!("Numa node id is not found {}", dist_config.destination); } - if let Some(n) = numa_nodes.get_mut(&dist.0) { - if n.distances.contains_key(&dist.1.destination) { + if let Some(n) = numa_nodes.get_mut(&dist_config.numa_id) { + if n.distances.contains_key(&dist_config.destination) { bail!( "Numa destination info {} repeat settings", - dist.1.destination + dist_config.destination ); } - n.distances.insert(dist.1.destination, dist.1.distance); + n.distances + .insert(dist_config.destination, dist_config.distance); } } _ => { @@ -1541,7 +1707,7 @@ pub trait MachineOps { /// /// * `cfg_args` - XHCI Configuration. fn add_usb_xhci(&mut self, cfg_args: &str) -> Result<()> { - let device_cfg = XhciConfig::try_parse_from(str_slip_to_clap(cfg_args))?; + let device_cfg = XhciConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; let bdf = PciBdf::new(device_cfg.bus.clone(), device_cfg.addr); let (devfn, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; @@ -1565,12 +1731,9 @@ pub trait MachineOps { cfg_args: &str, token_id: Option>>, ) -> Result<()> { - let config = ScreamConfig::try_parse_from(str_slip_to_clap(cfg_args))?; - let bdf = PciBdf { - bus: config.bus.clone(), - addr: config.addr, - }; - let (devfn, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; + let config = ScreamConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let bdf = PciBdf::new(config.bus.clone(), config.addr); + let (_, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; let mem_cfg = vm_config .object @@ -1587,9 +1750,9 @@ pub trait MachineOps { bail!("Object for share config is not on"); } - let scream = Scream::new(mem_cfg.size, config, token_id); + let mut scream = Scream::new(mem_cfg.size, config, token_id)?; scream - .realize(devfn, parent_bus) + .realize(parent_bus) .with_context(|| "Failed to realize scream device") } @@ -1605,7 +1768,7 @@ pub trait MachineOps { vm_config: &VmConfig, id: Option<&str>, dev_type: &str, - ) -> Option>> { + ) -> Option>> { let (id_check, id_str) = if id.is_some() { (true, format! {"id={}", id.unwrap()}) } else { @@ -1621,9 +1784,10 @@ pub trait MachineOps { let bdf = get_pci_bdf(cfg_args).ok()?; let devfn = (bdf.addr.0 << 3) + bdf.addr.1; let pci_host = self.get_pci_host().ok()?; - let root_bus = pci_host.lock().unwrap().root_bus.clone(); - if let Some(pci_bus) = PciBus::find_bus_by_name(&root_bus, &bdf.bus) { - return pci_bus.lock().unwrap().get_device(0, devfn); + let root_bus = pci_host.lock().unwrap().child_bus().unwrap().clone(); + if let Some(bus) = PciBus::find_bus_by_name(&root_bus, &bdf.bus) { + PCI_BUS!(bus, locked_bus, pci_bus); + return pci_bus.get_device(0, devfn); } else { return None; } @@ -1686,16 +1850,18 @@ pub trait MachineOps { /// * `driver` - USB device class. /// * `cfg_args` - USB device Configuration. fn add_usb_device(&mut self, vm_config: &mut VmConfig, cfg_args: &str) -> Result<()> { - let usb_device = match parse_device_type(cfg_args)?.as_str() { + let usb_device = match get_class_type(cfg_args)?.as_str() { "usb-kbd" => { - let config = UsbKeyboardConfig::try_parse_from(str_slip_to_clap(cfg_args))?; + let config = + UsbKeyboardConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; let keyboard = UsbKeyboard::new(config); keyboard .realize() .with_context(|| "Failed to realize usb keyboard device")? } "usb-tablet" => { - let config = UsbTabletConfig::try_parse_from(str_slip_to_clap(cfg_args))?; + let config = + UsbTabletConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; let tablet = UsbTablet::new(config); tablet .realize() @@ -1703,7 +1869,12 @@ pub trait MachineOps { } #[cfg(feature = "usb_camera")] "usb-camera" => { - let config = UsbCameraConfig::try_parse_from(str_slip_to_clap(cfg_args))?; + let token_id = match self.get_token_id() { + Some(id) => *id.read().unwrap(), + None => 0, + }; + let config = + UsbCameraConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; let cameradev = get_cameradev_by_id(vm_config, config.cameradev.clone()) .with_context(|| { format!( @@ -1712,25 +1883,79 @@ pub trait MachineOps { ) })?; - let camera = UsbCamera::new(config, cameradev)?; + let camera = UsbCamera::new(config, cameradev, token_id)?; camera .realize() .with_context(|| "Failed to realize usb camera device")? } "usb-storage" => { - let device_cfg = parse_usb_storage(vm_config, cfg_args)?; - let storage = UsbStorage::new(device_cfg, self.get_drive_files()); + let device_cfg = + UsbStorageConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let drive_cfg = vm_config + .drives + .remove(&device_cfg.drive) + .with_context(|| "No drive configured matched for usb storage device.")?; + let storage = UsbStorage::new(device_cfg, drive_cfg, self.get_drive_files())?; storage .realize() .with_context(|| "Failed to realize usb storage device")? } + #[cfg(feature = "usb_uas")] + "usb-uas" => { + let device_cfg = + UsbUasConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let drive_cfg = vm_config + .drives + .remove(&device_cfg.drive) + .with_context(|| "No drive configured matched for usb uas device.")?; + let uas = UsbUas::new(device_cfg, drive_cfg, self.get_drive_files())?; + uas.realize() + .with_context(|| "Failed to realize usb uas device")? + } #[cfg(feature = "usb_host")] "usb-host" => { - let config = UsbHostConfig::try_parse_from(str_slip_to_clap(cfg_args))?; - let usbhost = UsbHost::new(config)?; - usbhost - .realize() - .with_context(|| "Failed to realize usb host device")? + let config = + UsbHostConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + let parent_dev = self + .get_pci_dev_by_id_and_type(vm_config, None, "nec-usb-xhci") + .with_context(|| "No nec-usb-xhci device found")?; + + let update_vm_config = self.get_vm_config(); + + thread::Builder::new() + .name("usb host initialization".to_string()) + .spawn(move || { + let dev_id = config.id.clone(); + match initialize_usb_host(config, parent_dev) { + Ok(_) => { + if QmpChannel::is_connected() { + let success_msg = UsbHostAddRes { + device: Some(dev_id), + state_msg: Some("Add usb host device success".to_string()), + }; + event!(UsbHostAddRes; success_msg); + } + } + Err(e) => { + error!("Usb host device initialization failed: {:?}", e); + let mut locked_vm_config = update_vm_config.lock().unwrap(); + locked_vm_config.del_device_by_id(dev_id.clone()); + if QmpChannel::is_connected() { + let fail_msg = UsbHostAddRes { + device: Some(dev_id), + state_msg: Some(format!( + "Usb host device initialization failed: {:?}", + e + )), + }; + event!(UsbHostAddRes; fail_msg); + } + } + } + }) + .with_context(|| "Failed to spawn usb host initializer thread")?; + + return Ok(()); } _ => bail!("Unknown usb device classes."), }; @@ -1768,7 +1993,7 @@ pub trait MachineOps { for dev in &cloned_vm_config.devices { let cfg_args = dev.1.as_str(); // Check whether the device id exists to ensure device uniqueness. - let id = parse_device_id(cfg_args)?; + let id = get_value_of_parameter("id", cfg_args)?; self.check_device_id_existed(&id) .with_context(|| format!("Failed to check device id: config {}", cfg_args))?; #[cfg(feature = "scream")] @@ -1778,28 +2003,36 @@ pub trait MachineOps { dev.0.as_str(); self; ("virtio-blk-device", add_virtio_mmio_block, vm_config, cfg_args), ("virtio-blk-pci", add_virtio_pci_blk, vm_config, cfg_args, false), - ("virtio-scsi-pci", add_virtio_pci_scsi, vm_config, cfg_args, false), - ("scsi-hd" | "scsi-cd", add_scsi_device, vm_config, cfg_args), ("virtio-net-device", add_virtio_mmio_net, vm_config, cfg_args), ("virtio-net-pci", add_virtio_pci_net, vm_config, cfg_args, false), ("pcie-root-port", add_pci_root_port, cfg_args), - ("vhost-vsock-pci" | "vhost-vsock-device", add_virtio_vsock, cfg_args), ("virtio-balloon-device" | "virtio-balloon-pci", add_virtio_balloon, vm_config, cfg_args), + ("virtio-input-device" | "virtio-input-pci", add_virtio_input, cfg_args), ("virtio-serial-device" | "virtio-serial-pci", add_virtio_serial, vm_config, cfg_args), ("virtconsole" | "virtserialport", add_virtio_serial_port, vm_config, cfg_args), - ("virtio-rng-device" | "virtio-rng-pci", add_virtio_rng, vm_config, cfg_args), - ("vfio-pci", add_vfio_device, cfg_args, false), - ("vhost-user-blk-device",add_vhost_user_blk_device, vm_config, cfg_args), - ("vhost-user-blk-pci",add_vhost_user_blk_pci, vm_config, cfg_args, false), ("vhost-user-fs-pci" | "vhost-user-fs-device", add_virtio_fs, vm_config, cfg_args), ("nec-usb-xhci", add_usb_xhci, cfg_args), - ("usb-kbd" | "usb-storage" | "usb-tablet" | "usb-camera" | "usb-host", add_usb_device, vm_config, cfg_args); + ("usb-kbd" | "usb-storage" | "usb-uas" | "usb-tablet" | "usb-camera" | "usb-host", add_usb_device, vm_config, cfg_args); + #[cfg(feature = "vhostuser_block")] + ("vhost-user-blk-device",add_vhost_user_blk_device, vm_config, cfg_args), + #[cfg(feature = "vhostuser_block")] + ("vhost-user-blk-pci",add_vhost_user_blk_pci, vm_config, cfg_args, false), + #[cfg(feature = "vhost_vsock")] + ("vhost-vsock-pci" | "vhost-vsock-device", add_virtio_vsock, cfg_args), + #[cfg(feature = "virtio_rng")] + ("virtio-rng-device" | "virtio-rng-pci", add_virtio_rng, vm_config, cfg_args), + #[cfg(feature = "vfio_device")] + ("vfio-pci", add_vfio_device, cfg_args, false), #[cfg(feature = "virtio_gpu")] ("virtio-gpu-pci", add_virtio_pci_gpu, cfg_args), + #[cfg(feature = "virtio_scsi")] + ("virtio-scsi-pci", add_virtio_pci_scsi, vm_config, cfg_args, false), + #[cfg(feature = "virtio_scsi")] + ("scsi-hd" | "scsi-cd", add_scsi_device, vm_config, cfg_args), #[cfg(feature = "ramfb")] ("ramfb", add_ramfb, cfg_args), #[cfg(feature = "demo_device")] - ("pcie-demo-dev", add_demo_dev, vm_config, cfg_args), + ("pcie-demo-dev", add_demo_dev, cfg_args), #[cfg(feature = "scream")] ("ivshmem-scream", add_ivshmem_scream, vm_config, cfg_args, token_id), #[cfg(feature = "pvpanic")] @@ -1814,7 +2047,7 @@ pub trait MachineOps { None } - fn add_pflash_device(&mut self, _configs: &[PFlashConfig]) -> Result<()> { + fn add_pflash_device(&mut self, _configs: &[DriveConfig]) -> Result<()> { bail!("Pflash device is not supported!"); } @@ -1827,17 +2060,16 @@ pub trait MachineOps { } #[cfg(feature = "demo_device")] - fn add_demo_dev(&mut self, vm_config: &mut VmConfig, cfg_args: &str) -> Result<()> { - let bdf = get_pci_bdf(cfg_args)?; - let (devfn, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; - - let demo_cfg = parse_demo_dev(vm_config, cfg_args.to_string()) + fn add_demo_dev(&mut self, cfg_args: &str) -> Result<()> { + let config = DemoDevConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false)) .with_context(|| "failed to parse cmdline for demo dev.")?; - + let bdf = PciBdf::new(config.bus.clone(), config.addr); + let (devfn, parent_bus) = self.get_devfn_and_parent_bus(&bdf)?; let sys_mem = self.get_sys_mem().clone(); - let demo_dev = DemoDev::new(demo_cfg, devfn, sys_mem, parent_bus); + let demo_dev = DemoDev::new(config, devfn, sys_mem, parent_bus); - demo_dev.realize() + demo_dev.realize()?; + Ok(()) } /// Return the syscall whitelist for seccomp. @@ -1872,7 +2104,7 @@ pub trait MachineOps { } /// Fetch a cloned file from drive backend files. - fn fetch_drive_file(&self, path: &str) -> Result { + fn fetch_drive_file(&self, path: &str) -> Result> { let files = self.get_drive_files(); let drive_files = files.lock().unwrap(); VmConfig::fetch_drive_file(&drive_files, path) @@ -2004,11 +2236,15 @@ pub trait MachineOps { ) -> Result<()> { EventLoop::get_ctx(None).unwrap().disable_clock(); - self.deactive_drive_files()?; - + // Deactive files so that VM on the other end can active files. + if MigrationManager::is_active() { + self.deactive_drive_files()?; + } for (cpu_index, cpu) in cpus.iter().enumerate() { if let Err(e) = cpu.pause() { - self.active_drive_files()?; + if MigrationManager::is_active() { + self.active_drive_files()?; + } return Err(anyhow!("Failed to pause vcpu{}, {:?}", cpu_index, e)); } } @@ -2019,6 +2255,9 @@ pub trait MachineOps { *vm_state = VmState::Paused; + // Notify VM paused. + pause_notify(true); + Ok(()) } @@ -2033,6 +2272,9 @@ pub trait MachineOps { self.active_drive_files()?; + // Notify VM resumed. + pause_notify(false); + for (cpu_index, cpu) in cpus.iter().enumerate() { if let Err(e) = cpu.resume() { self.deactive_drive_files()?; @@ -2119,6 +2361,51 @@ pub trait MachineOps { } } +fn register_shutdown_event( + shutdown_req: Arc, + vm: Arc>, +) -> Result<()> { + let shutdown_req_fd = shutdown_req.as_raw_fd(); + let shutdown_req_handler: Rc = Rc::new(move |_, _| { + let _ret = shutdown_req.read(); + if handle_destroy_request(&vm) { + Some(gen_delete_notifiers(&[shutdown_req_fd])) + } else { + warn!("Fail to shutdown VM, try again"); + if shutdown_req.write(1).is_err() { + error!("Failed to send shutdown request"); + } + None + } + }); + let notifier = EventNotifier::new( + NotifierOperation::AddShared, + shutdown_req_fd, + None, + EventSet::IN, + vec![shutdown_req_handler], + ); + EventLoop::update_event(vec![notifier], None) + .with_context(|| "Failed to register event notifier.") +} + +fn handle_destroy_request(vm: &Arc>) -> bool { + let locked_vm = vm.lock().unwrap(); + let vmstate: VmState = { + let state = locked_vm.machine_base().vm_state.deref().0.lock().unwrap(); + *state + }; + + if !locked_vm.notify_lifecycle(vmstate, VmState::Shutdown) { + return false; + } + + info!("vm destroy"); + EventLoop::get_ctx(None).unwrap().kick(); + + true +} + /// Normal run or resume virtual machine from migration/snapshot. /// /// # Arguments @@ -2209,19 +2496,34 @@ fn check_windows_emu_pid( pid_path: String, powerdown_req: Arc, shutdown_req: Arc, + vm: Arc>, ) { - let mut check_delay = Duration::from_millis(4000); + let mut check_delay = Duration::from_millis(WINDOWS_EMU_PID_DEFAULT_INTERVAL); if !Path::new(&pid_path).exists() { - log::info!("Detect windows emu exited, let VM exits now"); + info!("Detect emulator exited, let VM exits now"); + let locked_vm = vm.lock().unwrap(); + let mut vm_state = locked_vm.get_vm_state().deref().0.lock().unwrap(); + if *vm_state == VmState::Paused { + info!("VM state is paused, resume VM before exit"); + if let Err(e) = locked_vm.vm_resume(&locked_vm.machine_base().cpus, &mut vm_state) { + log::error!("Failed to resume VM when check windows emu pid: {:?}", e); + } + } + drop(vm_state); + drop(locked_vm); if get_run_stage() == VmRunningStage::Os { + // Wait 30s for windows normal exit. + check_delay = Duration::from_millis(WINDOWS_EMU_PID_POWERDOWN_INTERVAL); if let Err(e) = powerdown_req.write(1) { log::error!("Failed to send powerdown request after emu exits: {:?}", e); } - } else if let Err(e) = shutdown_req.write(1) { - log::error!("Failed to send shutdown request after emu exits: {:?}", e); + } else { + // Wait 1s for windows shutdown. + check_delay = Duration::from_millis(WINDOWS_EMU_PID_SHUTDOWN_INTERVAL); + if let Err(e) = shutdown_req.write(1) { + log::error!("Failed to send shutdown request after emu exits: {:?}", e); + } } - // Continue checking to prevent exit failed. - check_delay = Duration::from_millis(1000); } let check_emu_alive = Box::new(move || { @@ -2229,9 +2531,88 @@ fn check_windows_emu_pid( pid_path.clone(), powerdown_req.clone(), shutdown_req.clone(), + vm.clone(), ); }); EventLoop::get_ctx(None) .unwrap() .timer_add(check_emu_alive, check_delay); } + +/// When windows emu exits, stratovirt should exits too. +#[cfg(feature = "windows_emu_pid")] +pub(crate) fn watch_windows_emu_pid( + vm_config: &VmConfig, + power_button: Arc, + shutdown_req: Arc, + vm: Arc>, +) { + let emu_pid = vm_config.emulator_pid.as_ref(); + if emu_pid.is_none() { + return; + } + info!("Watching on emulator lifetime"); + let pid_path = "/proc/".to_owned() + emu_pid.unwrap(); + let check_delay = Duration::from_millis(WINDOWS_EMU_PID_DEFAULT_INTERVAL); + let check_emu_alive = Box::new(move || { + check_windows_emu_pid( + pid_path.clone(), + power_button.clone(), + shutdown_req.clone(), + vm.clone(), + ); + }); + EventLoop::get_ctx(None) + .unwrap() + .timer_add(check_emu_alive, check_delay); +} + +fn machine_register_pcidevops_type() -> Result<()> { + #[cfg(target_arch = "x86_64")] + { + register_pcidevops_type::()?; + register_pcidevops_type::()?; + } + #[cfg(target_arch = "aarch64")] + { + register_pcidevops_type::()?; + } + + Ok(()) +} + +pub fn type_init() -> Result<()> { + // Register all sysbus devices type. + virtio_register_sysbusdevops_type()?; + devices_register_sysbusdevops_type()?; + + // Register all pci devices type. + machine_register_pcidevops_type()?; + #[cfg(feature = "vfio_device")] + vfio_register_pcidevops_type()?; + virtio_register_pcidevops_type()?; + devices_register_pcidevops_type()?; + + Ok(()) +} + +#[cfg(feature = "usb_host")] +fn initialize_usb_host(config: UsbHostConfig, parent_dev: Arc>) -> Result<()> { + let usbhost = UsbHost::new(config).with_context(|| "Failed to create usb host device")?; + + let usbhost = usbhost + .realize() + .with_context(|| "Failed to realize usb host device")?; + + let parent = parent_dev.lock().unwrap(); + let xhci_pci = parent + .as_any() + .downcast_ref::() + .ok_or_else(|| anyhow!("Failed to downcast PciDevOps to XhciPciDevice"))?; + + xhci_pci + .attach_device(&usbhost) + .with_context(|| "Failed to attach usb host device to xhci controller")?; + + Ok(()) +} diff --git a/machine/src/micro_common/mod.rs b/machine/src/micro_common/mod.rs index b97f046ca35f0f9bd8aab02c07468c7548e83897..200028d568d0b0f827435933e7fc2315d0603953 100644 --- a/machine/src/micro_common/mod.rs +++ b/machine/src/micro_common/mod.rs @@ -38,7 +38,9 @@ use std::sync::{Arc, Mutex}; use std::vec::Vec; use anyhow::{anyhow, bail, Context, Result}; -use log::{error, info}; +use clap::Parser; +use log::error; +use vmm_sys_util::eventfd::EventFd; #[cfg(target_arch = "aarch64")] use crate::aarch64::micro::{LayoutEntryType, MEM_LAYOUT}; @@ -46,13 +48,18 @@ use crate::aarch64::micro::{LayoutEntryType, MEM_LAYOUT}; use crate::x86_64::micro::{LayoutEntryType, MEM_LAYOUT}; use crate::{MachineBase, MachineError, MachineOps}; use cpu::CpuLifecycleState; +#[cfg(target_arch = "x86_64")] +use devices::sysbus::SysBusDevOps; use devices::sysbus::{IRQ_BASE, IRQ_MAX}; +use devices::Device; +#[cfg(feature = "vhostuser_net")] +use machine_manager::config::get_chardev_socket_path; +#[cfg(target_arch = "x86_64")] +use machine_manager::config::Param; use machine_manager::config::{ - parse_blk, parse_incoming_uri, parse_net, BlkDevConfig, ConfigCheck, DiskFormat, MigrateMode, - NetworkInterfaceConfig, VmConfig, DEFAULT_VIRTQUEUE_SIZE, + parse_incoming_uri, str_slip_to_clap, ConfigCheck, DriveConfig, MigrateMode, NetDevcfg, + NetworkInterfaceConfig, VmConfig, }; -use machine_manager::event; -use machine_manager::event_loop::EventLoop; use machine_manager::machine::{ DeviceInterface, MachineAddressInterface, MachineExternalInterface, MachineInterface, MachineLifecycle, MigrateInterface, VmState, @@ -60,12 +67,18 @@ use machine_manager::machine::{ use machine_manager::qmp::{ qmp_channel::QmpChannel, qmp_response::Response, qmp_schema, qmp_schema::UpdateRegionArgument, }; +use machine_manager::{check_arg_nonexist, event}; use migration::MigrationManager; -use util::aio::WriteZeroesState; -use util::{loop_context::EventLoopManager, num_ops::str_to_num, set_termi_canon_mode}; +use util::loop_context::{create_new_eventfd, EventLoopManager}; +use util::{num_ops::str_to_num, set_termi_canon_mode}; +use virtio::device::block::VirtioBlkDevConfig; +#[cfg(feature = "vhost_net")] +use virtio::VhostKern; +#[cfg(feature = "vhostuser_net")] +use virtio::VhostUser; use virtio::{ - create_tap, qmp_balloon, qmp_query_balloon, Block, BlockState, Net, VhostKern, VhostUser, - VirtioDevice, VirtioMmioDevice, VirtioMmioState, VirtioNetState, + create_tap, qmp_balloon, qmp_query_balloon, Block, BlockState, Net, VirtioDevice, + VirtioMmioDevice, VirtioMmioState, VirtioNetState, }; // The replaceable block device maximum count. @@ -78,8 +91,9 @@ const MMIO_REPLACEABLE_NET_NR: usize = 2; struct MmioReplaceableConfig { // Device id. id: String, - // The dev_config of the related backend device. - dev_config: Arc, + // The config of the related backend device. + // Eg: Drive config of virtio mmio block. Netdev config of virtio mmio net. + back_config: Arc, } // The device information of replaceable device. @@ -132,6 +146,8 @@ pub struct LightMachine { pub(crate) base: MachineBase, // All replaceable device information. pub(crate) replaceable_info: MmioReplaceableInfo, + /// Shutdown request, handle VM `shutdown` event. + pub(crate) shutdown_req: Arc, } impl LightMachine { @@ -151,73 +167,72 @@ impl LightMachine { Ok(LightMachine { base, replaceable_info: MmioReplaceableInfo::new(), + shutdown_req: Arc::new( + create_new_eventfd() + .with_context(|| MachineError::InitEventFdErr("shutdown_req".to_string()))?, + ), }) } pub(crate) fn create_replaceable_devices(&mut self) -> Result<()> { - let mut rpl_devs: Vec = Vec::new(); for id in 0..MMIO_REPLACEABLE_BLK_NR { let block = Arc::new(Mutex::new(Block::new( - BlkDevConfig::default(), + VirtioBlkDevConfig::default(), + DriveConfig::default(), self.get_drive_files(), ))); - let virtio_mmio = VirtioMmioDevice::new(&self.base.sys_mem, block.clone()); - rpl_devs.push(virtio_mmio); - MigrationManager::register_device_instance( BlockState::descriptor(), - block, + block.clone(), + &id.to_string(), + ); + + let blk_mmio = self.add_virtio_mmio_device(id.to_string(), block.clone())?; + let info = MmioReplaceableDevInfo { + device: block, + id: id.to_string(), + used: false, + }; + self.replaceable_info.devices.lock().unwrap().push(info); + MigrationManager::register_transport_instance( + VirtioMmioState::descriptor(), + blk_mmio, &id.to_string(), ); } for id in 0..MMIO_REPLACEABLE_NET_NR { - let net = Arc::new(Mutex::new(Net::new(NetworkInterfaceConfig::default()))); - let virtio_mmio = VirtioMmioDevice::new(&self.base.sys_mem, net.clone()); - rpl_devs.push(virtio_mmio); - + let total_id = id + MMIO_REPLACEABLE_BLK_NR; + let net = Arc::new(Mutex::new(Net::new( + NetworkInterfaceConfig::default(), + NetDevcfg::default(), + ))); MigrationManager::register_device_instance( VirtioNetState::descriptor(), - net, - &id.to_string(), + net.clone(), + &total_id.to_string(), ); - } - - let mut region_base = self.base.sysbus.min_free_base; - let region_size = MEM_LAYOUT[LayoutEntryType::Mmio as usize].1; - for (id, dev) in rpl_devs.into_iter().enumerate() { - self.replaceable_info - .devices - .lock() - .unwrap() - .push(MmioReplaceableDevInfo { - device: dev.device.clone(), - id: id.to_string(), - used: false, - }); + let net_mmio = self.add_virtio_mmio_device(total_id.to_string(), net.clone())?; + let info = MmioReplaceableDevInfo { + device: net, + id: total_id.to_string(), + used: false, + }; + self.replaceable_info.devices.lock().unwrap().push(info); MigrationManager::register_transport_instance( VirtioMmioState::descriptor(), - VirtioMmioDevice::realize( - dev, - &mut self.base.sysbus, - region_base, - MEM_LAYOUT[LayoutEntryType::Mmio as usize].1, - #[cfg(target_arch = "x86_64")] - &self.base.boot_source, - ) - .with_context(|| MachineError::RlzVirtioMmioErr)?, - &id.to_string(), + net_mmio, + &total_id.to_string(), ); - region_base += region_size; } - self.base.sysbus.min_free_base = region_base; + Ok(()) } pub(crate) fn fill_replaceable_device( &mut self, id: &str, - dev_config: Arc, + dev_config: Vec>, index: usize, ) -> Result<()> { let mut replaceable_devices = self.replaceable_info.devices.lock().unwrap(); @@ -232,14 +247,14 @@ impl LightMachine { .device .lock() .unwrap() - .update_config(Some(dev_config.clone())) + .update_config(dev_config.clone()) .with_context(|| MachineError::UpdCfgErr(id.to_string()))?; } - self.add_replaceable_config(id, dev_config) + self.add_replaceable_config(id, dev_config[0].clone()) } - fn add_replaceable_config(&self, id: &str, dev_config: Arc) -> Result<()> { + fn add_replaceable_config(&self, id: &str, back_config: Arc) -> Result<()> { let mut configs_lock = self.replaceable_info.configs.lock().unwrap(); let limit = MMIO_REPLACEABLE_BLK_NR + MMIO_REPLACEABLE_NET_NR; if configs_lock.len() >= limit { @@ -254,7 +269,7 @@ impl LightMachine { let config = MmioReplaceableConfig { id: id.to_string(), - dev_config, + back_config, }; trace::mmio_replaceable_config(&config); @@ -262,21 +277,28 @@ impl LightMachine { Ok(()) } - fn add_replaceable_device(&self, id: &str, driver: &str, slot: usize) -> Result<()> { + fn add_replaceable_device( + &self, + args: Box, + slot: usize, + ) -> Result<()> { + let id = args.id; + let driver = args.driver; + // Find the configuration by id. let configs_lock = self.replaceable_info.configs.lock().unwrap(); - let mut dev_config = None; + let mut configs = Vec::new(); for config in configs_lock.iter() { if config.id == id { - dev_config = Some(config.dev_config.clone()); + configs.push(config.back_config.clone()); } } - if dev_config.is_none() { + if configs.is_empty() { bail!("Failed to find device configuration."); } // Sanity check for config, driver and slot. - let cfg_any = dev_config.as_ref().unwrap().as_any(); + let cfg_any = configs[0].as_any(); let index = if driver.contains("net") { if slot >= MMIO_REPLACEABLE_NET_NR { return Err(anyhow!(MachineError::RplDevLmtErr( @@ -284,9 +306,19 @@ impl LightMachine { MMIO_REPLACEABLE_NET_NR ))); } - if cfg_any.downcast_ref::().is_none() { + if cfg_any.downcast_ref::().is_none() { return Err(anyhow!(MachineError::DevTypeErr("net".to_string()))); } + let mut net_config = NetworkInterfaceConfig { + classtype: driver, + id: id.clone(), + netdev: args.chardev.with_context(|| "No chardev set")?, + mac: args.mac, + iothread: args.iothread, + ..Default::default() + }; + net_config.auto_iothread(); + configs.push(Arc::new(net_config)); slot + MMIO_REPLACEABLE_BLK_NR } else if driver.contains("blk") { if slot >= MMIO_REPLACEABLE_BLK_NR { @@ -295,9 +327,19 @@ impl LightMachine { MMIO_REPLACEABLE_BLK_NR ))); } - if cfg_any.downcast_ref::().is_none() { + if cfg_any.downcast_ref::().is_none() { return Err(anyhow!(MachineError::DevTypeErr("blk".to_string()))); } + let dev_config = VirtioBlkDevConfig { + classtype: driver, + id: id.clone(), + drive: args.drive.with_context(|| "No drive set")?, + bootindex: args.boot_index, + iothread: args.iothread, + serial: args.serial_num, + ..Default::default() + }; + configs.push(Arc::new(dev_config)); slot } else { bail!("Unsupported replaceable device type."); @@ -316,7 +358,7 @@ impl LightMachine { .device .lock() .unwrap() - .update_config(dev_config) + .update_config(configs) .with_context(|| MachineError::UpdCfgErr(id.to_string()))?; } Ok(()) @@ -328,8 +370,10 @@ impl LightMachine { let mut configs_lock = self.replaceable_info.configs.lock().unwrap(); for (index, config) in configs_lock.iter().enumerate() { if config.id == id { - if let Some(blkconf) = config.dev_config.as_any().downcast_ref::() { - self.unregister_drive_file(&blkconf.path_on_host)?; + if let Some(drive_config) = + config.back_config.as_any().downcast_ref::() + { + self.unregister_drive_file(&drive_config.path_on_host)?; } configs_lock.remove(index); is_exist = true; @@ -347,7 +391,7 @@ impl LightMachine { .device .lock() .unwrap() - .update_config(None) + .update_config(Vec::new()) .with_context(|| MachineError::UpdCfgErr(id.to_string()))?; } } @@ -363,22 +407,55 @@ impl LightMachine { vm_config: &mut VmConfig, cfg_args: &str, ) -> Result<()> { - let device_cfg = parse_net(vm_config, cfg_args)?; - if device_cfg.vhost_type.is_some() { - let device = if device_cfg.vhost_type == Some(String::from("vhost-kernel")) { - let net = Arc::new(Mutex::new(VhostKern::Net::new( - &device_cfg, - &self.base.sys_mem, - ))); - VirtioMmioDevice::new(&self.base.sys_mem, net) + let mut net_cfg = + NetworkInterfaceConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + net_cfg.auto_iothread(); + check_arg_nonexist!( + ("bus", net_cfg.bus), + ("addr", net_cfg.addr), + ("multifunction", net_cfg.multifunction) + ); + let netdev_cfg = vm_config + .netdevs + .remove(&net_cfg.netdev) + .with_context(|| format!("Netdev: {:?} not found for net device", &net_cfg.netdev))?; + if netdev_cfg.vhost_type().is_some() { + if netdev_cfg.vhost_type().unwrap() == "vhost-kernel" { + #[cfg(not(feature = "vhost_net"))] + bail!("Unsupported Vhost_Net"); + + #[cfg(feature = "vhost_net")] + { + let net = Arc::new(Mutex::new(VhostKern::Net::new( + &net_cfg, + netdev_cfg, + &self.base.sys_mem, + ))); + self.add_virtio_mmio_device(net_cfg.id.clone(), net)?; + } } else { - let net = Arc::new(Mutex::new(VhostUser::Net::new( - &device_cfg, - &self.base.sys_mem, - ))); - VirtioMmioDevice::new(&self.base.sys_mem, net) + #[cfg(not(feature = "vhostuser_net"))] + bail!("Unsupported Vhostuser_Net"); + + #[cfg(feature = "vhostuser_net")] + { + let chardev = netdev_cfg.chardev.clone().with_context(|| { + format!("Chardev not configured for netdev {:?}", netdev_cfg.id) + })?; + let chardev_cfg = vm_config + .chardev + .remove(&chardev) + .with_context(|| format!("Chardev: {:?} not found for netdev", chardev))?; + let sock_path = get_chardev_socket_path(chardev_cfg)?; + let net = Arc::new(Mutex::new(VhostUser::Net::new( + &net_cfg, + netdev_cfg, + sock_path, + &self.base.sys_mem, + ))); + self.add_virtio_mmio_device(net_cfg.id.clone(), net)?; + } }; - self.realize_virtio_mmio_device(device)?; } else { let index = MMIO_REPLACEABLE_BLK_NR + self.replaceable_info.net_count; if index >= MMIO_REPLACEABLE_BLK_NR + MMIO_REPLACEABLE_NET_NR { @@ -387,7 +464,9 @@ impl LightMachine { MMIO_REPLACEABLE_NET_NR ); } - self.fill_replaceable_device(&device_cfg.id, Arc::new(device_cfg.clone()), index)?; + let configs: Vec> = + vec![Arc::new(netdev_cfg), Arc::new(net_cfg.clone())]; + self.fill_replaceable_device(&net_cfg.id, configs, index)?; self.replaceable_info.net_count += 1; } Ok(()) @@ -398,7 +477,17 @@ impl LightMachine { vm_config: &mut VmConfig, cfg_args: &str, ) -> Result<()> { - let device_cfg = parse_blk(vm_config, cfg_args, None)?; + let device_cfg = + VirtioBlkDevConfig::try_parse_from(str_slip_to_clap(cfg_args, true, false))?; + check_arg_nonexist!( + ("bus", device_cfg.bus), + ("addr", device_cfg.addr), + ("multifunction", device_cfg.multifunction) + ); + let drive_cfg = vm_config + .drives + .remove(&device_cfg.drive) + .with_context(|| "No drive configured matched for blk device")?; if self.replaceable_info.block_count >= MMIO_REPLACEABLE_BLK_NR { bail!( "A maximum of {} block replaceable devices are supported.", @@ -406,28 +495,43 @@ impl LightMachine { ); } let index = self.replaceable_info.block_count; - self.fill_replaceable_device(&device_cfg.id, Arc::new(device_cfg.clone()), index)?; + let configs: Vec> = + vec![Arc::new(drive_cfg), Arc::new(device_cfg.clone())]; + self.fill_replaceable_device(&device_cfg.id, configs, index)?; self.replaceable_info.block_count += 1; Ok(()) } - pub(crate) fn realize_virtio_mmio_device( + pub(crate) fn add_virtio_mmio_device( &mut self, - dev: VirtioMmioDevice, + name: String, + device: Arc>, ) -> Result>> { - let region_base = self.base.sysbus.min_free_base; + let sys_mem = self.get_sys_mem().clone(); + let region_base = self.base.sysbus.lock().unwrap().min_free_base; let region_size = MEM_LAYOUT[LayoutEntryType::Mmio as usize].1; - let realized_virtio_mmio_device = VirtioMmioDevice::realize( - dev, - &mut self.base.sysbus, + let dev = VirtioMmioDevice::new( + &sys_mem, + name, + device, + &self.base.sysbus, region_base, region_size, - #[cfg(target_arch = "x86_64")] - &self.base.boot_source, - ) - .with_context(|| MachineError::RlzVirtioMmioErr)?; - self.base.sysbus.min_free_base += region_size; - Ok(realized_virtio_mmio_device) + )?; + let mmio_device = dev + .realize() + .with_context(|| MachineError::RlzVirtioMmioErr)?; + #[cfg(target_arch = "x86_64")] + { + let res = mmio_device.lock().unwrap().get_sys_resource().clone(); + let mut bs = self.base.boot_source.lock().unwrap(); + bs.kernel_cmdline.push(Param { + param_type: "virtio_mmio.device".to_string(), + value: format!("{}@0x{:08x}:{}", res.region_size, res.region_base, res.irq), + }); + } + self.base.sysbus.lock().unwrap().min_free_base += region_size; + Ok(mmio_device) } } @@ -451,18 +555,11 @@ impl MachineLifecycle for LightMachine { } fn destroy(&self) -> bool { - let vmstate = { - let state = self.base.vm_state.deref().0.lock().unwrap(); - *state - }; - - if !self.notify_lifecycle(vmstate, VmState::Shutdown) { + if self.shutdown_req.write(1).is_err() { + error!("Failed to send shutdown request."); return false; } - info!("vm destroy"); - EventLoop::get_ctx(None).unwrap().kick(); - true } @@ -500,12 +597,13 @@ impl MachineAddressInterface for LightMachine { #[cfg(target_arch = "x86_64")] fn pio_out(&self, addr: u64, mut data: &[u8]) -> bool { + use address_space::AddressAttr; use address_space::GuestAddress; let count = data.len() as u64; self.base .sys_io - .write(&mut data, GuestAddress(addr), count) + .write(&mut data, GuestAddress(addr), count, AddressAttr::MMIO) .is_ok() } @@ -545,7 +643,7 @@ impl DeviceInterface for LightMachine { for cpu_index in 0..cpu_topo.max_cpus { if cpu_topo.get_mask(cpu_index as usize) == 1 { let thread_id = cpus[cpu_index as usize].tid(); - let cpu_instance = cpu_topo.get_topo_instance_for_qmp(cpu_index as usize); + let cpu_instance = cpu_topo.get_topo_instance_for_qmp(cpu_index); let cpu_common = qmp_schema::CpuInfoCommon { current: true, qom_path: String::from("/machine/unattached/device[") @@ -586,10 +684,7 @@ impl DeviceInterface for LightMachine { for cpu_index in 0..self.base.cpu_topo.max_cpus { if self.base.cpu_topo.get_mask(cpu_index as usize) == 0 { - let cpu_instance = self - .base - .cpu_topo - .get_topo_instance_for_qmp(cpu_index as usize); + let cpu_instance = self.base.cpu_topo.get_topo_instance_for_qmp(cpu_index); let hotpluggable_cpu = qmp_schema::HotpluggableCPU { type_: cpu_type.clone(), vcpus_count: 1, @@ -598,10 +693,7 @@ impl DeviceInterface for LightMachine { }; hotplug_vec.push(serde_json::to_value(hotpluggable_cpu).unwrap()); } else { - let cpu_instance = self - .base - .cpu_topo - .get_topo_instance_for_qmp(cpu_index as usize); + let cpu_instance = self.base.cpu_topo.get_topo_instance_for_qmp(cpu_index); let hotpluggable_cpu = qmp_schema::HotpluggableCPU { type_: cpu_type.clone(), vcpus_count: 1, @@ -667,8 +759,8 @@ impl DeviceInterface for LightMachine { fn device_add(&mut self, args: Box) -> Response { // get slot of bus by addr or lun - let mut slot = 0; - if let Some(addr) = args.addr { + let mut slot = 0_usize; + if let Some(addr) = args.addr.clone() { if let Ok(num) = str_to_num::(&addr) { slot = num; } else { @@ -684,7 +776,7 @@ impl DeviceInterface for LightMachine { slot = lun + 1; } - match self.add_replaceable_device(&args.id, &args.driver, slot) { + match self.add_replaceable_device(args.clone(), slot) { Ok(()) => Response::create_empty_response(), Err(ref e) => { error!("{:?}", e); @@ -719,32 +811,22 @@ impl DeviceInterface for LightMachine { } fn blockdev_add(&self, args: Box) -> Response { - let read_only = args.read_only.unwrap_or(false); + let readonly = args.read_only.unwrap_or(false); let mut direct = true; if args.cache.is_some() && !args.cache.unwrap().direct.unwrap_or(true) { direct = false; } - let config = BlkDevConfig { + let config = DriveConfig { id: args.node_name.clone(), + drive_type: "none".to_string(), path_on_host: args.file.filename.clone(), - read_only, + readonly, direct, - serial_num: None, - iothread: None, - iops: None, - queues: 1, - boot_index: None, - chardev: None, - socket_path: None, aio: args.file.aio, - queue_size: DEFAULT_VIRTQUEUE_SIZE, - discard: false, - write_zeroes: WriteZeroesState::Off, - format: DiskFormat::Raw, - l2_cache_size: None, - refcount_cache_size: None, + ..Default::default() }; + if let Err(e) = config.check() { error!("{:?}", e); return Response::create_error_response( @@ -753,7 +835,7 @@ impl DeviceInterface for LightMachine { ); } // Register drive backend file for hotplugged drive. - if let Err(e) = self.register_drive_file(&config.id, &args.file.filename, read_only, direct) + if let Err(e) = self.register_drive_file(&config.id, &args.file.filename, readonly, direct) { error!("{:?}", e); return Response::create_error_response( @@ -783,18 +865,9 @@ impl DeviceInterface for LightMachine { } fn netdev_add(&mut self, args: Box) -> Response { - let mut config = NetworkInterfaceConfig { + let mut netdev_cfg = NetDevcfg { id: args.id.clone(), - host_dev_name: "".to_string(), - mac: None, - tap_fds: None, - vhost_type: None, - vhost_fds: None, - iothread: None, - queues: 2, - mq: false, - socket_path: None, - queue_size: DEFAULT_VIRTQUEUE_SIZE, + ..Default::default() }; if let Some(fds) = args.fds { @@ -806,7 +879,7 @@ impl DeviceInterface for LightMachine { }; if let Some(fd_num) = QmpChannel::get_fd(&netdev_fd) { - config.tap_fds = Some(vec![fd_num]); + netdev_cfg.tap_fds = Some(vec![fd_num]); } else { // try to convert string to RawFd let fd_num = match netdev_fd.parse::() { @@ -824,10 +897,10 @@ impl DeviceInterface for LightMachine { ); } }; - config.tap_fds = Some(vec![fd_num]); + netdev_cfg.tap_fds = Some(vec![fd_num]); } } else if let Some(if_name) = args.if_name { - config.host_dev_name = if_name.clone(); + netdev_cfg.ifname = if_name.clone(); if create_tap(None, Some(&if_name), 1).is_err() { return Response::create_error_response( qmp_schema::QmpErrorClass::GenericError( @@ -838,7 +911,7 @@ impl DeviceInterface for LightMachine { } } - match self.add_replaceable_config(&args.id, Arc::new(config)) { + match self.add_replaceable_config(&args.id, Arc::new(netdev_cfg)) { Ok(()) => Response::create_empty_response(), Err(ref e) => { error!("{:?}", e); diff --git a/machine/src/micro_common/syscall.rs b/machine/src/micro_common/syscall.rs index f3acec191bdaa0fcae4e023fbb81dbce762cea1e..6ae9a56a28ef7957c0dab9f1e2dc010c7aa73967 100644 --- a/machine/src/micro_common/syscall.rs +++ b/machine/src/micro_common/syscall.rs @@ -159,6 +159,7 @@ fn ioctl_allow_list() -> BpfRule { .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_API_VERSION() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_MP_STATE() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_MP_STATE() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_VCPU_EVENTS() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_VCPU_EVENTS() as u32); arch_ioctl_allow_list(bpf_rule) } diff --git a/machine/src/standard_common/mod.rs b/machine/src/standard_common/mod.rs index 1ab0be07c6c74093538fed98b3c5296deb349ba7..7ac6bcbbcbe9273984930cdc707aff0285ea523a 100644 --- a/machine/src/standard_common/mod.rs +++ b/machine/src/standard_common/mod.rs @@ -12,11 +12,7 @@ pub mod syscall; -#[cfg(target_arch = "aarch64")] -pub use crate::aarch64::standard::StdMachine; pub use crate::error::MachineError; -#[cfg(target_arch = "x86_64")] -pub use crate::x86_64::standard::StdMachine; use std::mem::size_of; use std::ops::Deref; @@ -28,7 +24,9 @@ use std::sync::{Arc, Mutex}; use std::u64; use anyhow::{bail, Context, Result}; -use log::error; +use log::{error, warn}; +use serde_json::json; +use util::set_termi_canon_mode; use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::EventFd; @@ -40,7 +38,7 @@ use crate::x86_64::ich9_lpc::{ }; #[cfg(target_arch = "x86_64")] use crate::x86_64::standard::{LayoutEntryType, MEM_LAYOUT}; -use crate::MachineOps; +use crate::{MachineBase, MachineOps}; #[cfg(target_arch = "x86_64")] use acpi::AcpiGenericAddress; use acpi::{ @@ -48,7 +46,8 @@ use acpi::{ ACPI_TABLE_LOADER_FILE, TABLE_CHECKSUM_OFFSET, }; use address_space::{ - AddressRange, FileBackend, GuestAddress, HostMemMapping, Region, RegionIoEventFd, RegionOps, + AddressAttr, AddressRange, FileBackend, GuestAddress, HostMemMapping, Region, RegionIoEventFd, + RegionOps, }; use block_backend::{qcow2::QCOW2_LIST, BlockStatus}; #[cfg(target_arch = "x86_64")] @@ -57,33 +56,77 @@ use devices::legacy::FwCfgOps; #[cfg(feature = "scream")] use devices::misc::scream::set_record_authority; use devices::pci::hotplug::{handle_plug, handle_unplug_pci_request}; -use devices::pci::PciBus; +use devices::pci::{PciBus, PciHost}; +use devices::Device; #[cfg(feature = "usb_camera")] use machine_manager::config::get_cameradev_config; -#[cfg(feature = "windows_emu_pid")] -use machine_manager::config::VmConfig; +#[cfg(target_arch = "aarch64")] +use machine_manager::config::ShutdownAction; use machine_manager::config::{ - get_chardev_config, get_netdev_config, memory_unit_conversion, ConfigCheck, DiskFormat, - DriveConfig, ExBool, NumaNode, NumaNodes, M, + get_chardev_config, get_netdev_config, memory_unit_conversion, parse_incoming_uri, + BootIndexInfo, ConfigCheck, DiskFormat, DriveConfig, ExBool, MigrateMode, NumaNode, NumaNodes, + M, }; +use machine_manager::event; use machine_manager::event_loop::EventLoop; use machine_manager::machine::{ - DeviceInterface, MachineAddressInterface, MachineLifecycle, VmState, + DeviceInterface, MachineAddressInterface, MachineExternalInterface, MachineInterface, + MachineLifecycle, MachineTestInterface, MigrateInterface, VmState, }; use machine_manager::qmp::qmp_schema::{BlockDevAddArgument, UpdateRegionArgument}; use machine_manager::qmp::{qmp_channel::QmpChannel, qmp_response::Response, qmp_schema}; +use machine_manager::state_query::query_workloads; #[cfg(feature = "gtk")] use ui::gtk::qmp_query_display_image; use ui::input::{input_button, input_move_abs, input_point_sync, key_event, Axis}; +#[cfg(all(target_env = "ohos", feature = "ohui_srv"))] +use ui::ohui_srv::OhUiServer; #[cfg(feature = "vnc")] use ui::vnc::qmp_query_vnc; use util::aio::{AioEngine, WriteZeroesState}; use util::byte_code::ByteCode; -use util::loop_context::{read_fd, EventNotifier, NotifierCallback, NotifierOperation}; +use util::loop_context::{ + create_new_eventfd, read_fd, EventLoopManager, EventNotifier, NotifierCallback, + NotifierOperation, +}; use virtio::{qmp_balloon, qmp_query_balloon}; const MAX_REGION_SIZE: u64 = 65536; +/// Standard machine structure. +pub struct StdMachine { + /// Machine base members. + pub(crate) base: MachineBase, + /// PCI/PCIe host bridge. + pub(crate) pci_host: Arc>, + /// Reset request, handle VM `Reset` event. + pub(crate) reset_req: Arc, + /// Shutdown request, handle VM `shutdown` event. + pub(crate) shutdown_req: Arc, + /// VM power button, handle VM `Shutdown` event. + pub(crate) power_button: Arc, + /// List contains the boot order of boot devices. + pub(crate) boot_order_list: Arc>>, + /// CPU Resize request, handle vm cpu hot(un)plug event. + #[cfg(target_arch = "x86_64")] + pub(crate) cpu_resize_req: Arc, + /// Cpu Controller. + #[cfg(target_arch = "x86_64")] + pub(crate) cpu_controller: Option>>, + /// Pause request, handle VM `Pause` event. + #[cfg(target_arch = "aarch64")] + pub(crate) pause_req: Arc, + /// Resume request, handle VM `Resume` event. + #[cfg(target_arch = "aarch64")] + pub(crate) resume_req: Arc, + /// Device Tree Blob. + #[cfg(target_arch = "aarch64")] + pub(crate) dtb_vec: Vec, + /// OHUI server + #[cfg(all(target_arch = "aarch64", target_env = "ohos", feature = "ohui_srv"))] + pub(crate) ohui_server: Option>, +} + pub(crate) trait StdMachineOps: AcpiBuilder + MachineOps { fn init_pci_host(&self) -> Result<()>; @@ -102,7 +145,8 @@ pub(crate) trait StdMachineOps: AcpiBuilder + MachineOps { let mut xsdt_entries = Vec::new(); - let facs_addr = Self::build_facs_table(&acpi_tables) + let facs_addr = self + .build_facs_table(&acpi_tables) .with_context(|| "Failed to build ACPI FACS table")?; let dsdt_addr = self @@ -265,7 +309,10 @@ pub(crate) trait StdMachineOps: AcpiBuilder + MachineOps { let reset_req_handler: Rc = Rc::new(move |_, _| { read_fd(reset_req_fd); if let Err(e) = StdMachine::handle_reset_request(&clone_vm) { - error!("Fail to reboot standard VM, {:?}", e); + warn!("Fail to reboot standard VM, {:?}, try again", e); + if reset_req.write(1).is_err() { + error!("Failed to send VM reset request"); + } } None @@ -281,6 +328,7 @@ pub(crate) trait StdMachineOps: AcpiBuilder + MachineOps { .with_context(|| "Failed to register event notifier.") } + #[cfg(target_arch = "aarch64")] fn register_pause_event( &self, pause_req: Arc, @@ -290,7 +338,10 @@ pub(crate) trait StdMachineOps: AcpiBuilder + MachineOps { let pause_req_handler: Rc = Rc::new(move |_, _| { let _ret = pause_req.read(); if !clone_vm.lock().unwrap().pause() { - error!("VM pause failed"); + warn!("VM pause failed, try again"); + if pause_req.write(1).is_err() { + error!("Failed to send VM pause request"); + } } None }); @@ -306,6 +357,7 @@ pub(crate) trait StdMachineOps: AcpiBuilder + MachineOps { .with_context(|| "Failed to register event notifier.") } + #[cfg(target_arch = "aarch64")] fn register_resume_event( &self, resume_req: Arc, @@ -330,33 +382,6 @@ pub(crate) trait StdMachineOps: AcpiBuilder + MachineOps { EventLoop::update_event(vec![notifier], None) .with_context(|| "Failed to register event notifier.") } - - fn register_shutdown_event( - &self, - shutdown_req: Arc, - clone_vm: Arc>, - ) -> Result<()> { - use util::loop_context::gen_delete_notifiers; - - let shutdown_req_fd = shutdown_req.as_raw_fd(); - let shutdown_req_handler: Rc = Rc::new(move |_, _| { - let _ret = shutdown_req.read(); - if StdMachine::handle_destroy_request(&clone_vm).is_ok() { - Some(gen_delete_notifiers(&[shutdown_req_fd])) - } else { - None - } - }); - let notifier = EventNotifier::new( - NotifierOperation::AddShared, - shutdown_req_fd, - None, - EventSet::IN, - vec![shutdown_req_handler], - ); - EventLoop::update_event(vec![notifier], None) - .with_context(|| "Failed to register event notifier.") - } } /// Trait that helps to build ACPI tables. @@ -385,12 +410,13 @@ pub(crate) trait AcpiBuilder { loader.add_cksum_entry( ACPI_TABLE_FILE, + // table_begin is much less than u32::MAX, will not overflow. table_begin + TABLE_CHECKSUM_OFFSET, table_begin, table_end - table_begin, )?; - Ok(table_begin as u64) + Ok(u64::from(table_begin)) } /// Build ACPI DSDT table, returns the offset of ACPI DSDT table in `acpi_data`. @@ -523,7 +549,7 @@ pub(crate) trait AcpiBuilder { { let mut mcfg = AcpiTable::new(*b"MCFG", 1, *b"STRATO", *b"VIRTMCFG", 1); // Bits 20~28 (totally 9 bits) in PCIE ECAM represents bus number. - let bus_number_mask = (1 << 9) - 1; + let bus_number_mask = (1u64 << 9) - 1; let ecam_addr: u64; let max_nr_bus: u64; #[cfg(target_arch = "x86_64")] @@ -544,8 +570,8 @@ pub(crate) trait AcpiBuilder { mcfg.append_child(ecam_addr.as_bytes()); // PCI Segment Group Number mcfg.append_child(0_u16.as_bytes()); - // Start Bus Number and End Bus Number - mcfg.append_child(&[0_u8, (max_nr_bus - 1) as u8]); + // Start Bus Number and End Bus Number. max_nr_bus is no less than 1. + mcfg.append_child(&[0_u8, u8::try_from(max_nr_bus - 1)?]); // Reserved mcfg.append_child(&[0_u8; 4]); @@ -557,11 +583,12 @@ pub(crate) trait AcpiBuilder { loader.add_cksum_entry( ACPI_TABLE_FILE, + // mcfg_begin is much less than u32::MAX, will not overflow. mcfg_begin + TABLE_CHECKSUM_OFFSET, mcfg_begin, mcfg_end - mcfg_begin, )?; - Ok(mcfg_begin as u64) + Ok(u64::from(mcfg_begin)) } /// Build ACPI FADT table, returns the offset of ACPI FADT table in `acpi_data`. @@ -586,31 +613,31 @@ pub(crate) trait AcpiBuilder { fadt.set_table_len(208_usize); // PM1A_EVENT bit, offset is 56. #[cfg(target_arch = "x86_64")] - fadt.set_field(56, 0x600); + fadt.set_field(56, 0x600_u32); // PM1A_CONTROL bit, offset is 64. #[cfg(target_arch = "x86_64")] - fadt.set_field(64, 0x604); + fadt.set_field(64, 0x604_u32); // PM_TMR_BLK bit, offset is 76. #[cfg(target_arch = "x86_64")] - fadt.set_field(76, 0x608); + fadt.set_field(76, 0x608_u32); // PM1_EVT_LEN, offset is 88. #[cfg(target_arch = "x86_64")] - fadt.set_field(88, 4); + fadt.set_field(88, 4_u8); // PM1_CNT_LEN, offset is 89. #[cfg(target_arch = "x86_64")] - fadt.set_field(89, 2); + fadt.set_field(89, 2_u8); // PM_TMR_LEN, offset is 91. #[cfg(target_arch = "x86_64")] - fadt.set_field(91, 4); + fadt.set_field(91, 4_u8); #[cfg(target_arch = "aarch64")] { - // FADT flag: enable HW_REDUCED_ACPI and LOW_POWER_S0_IDLE_CAPABLE bit on aarch64 plantform. - fadt.set_field(112, 1 << 21 | 1 << 20 | 1 << 10 | 1 << 8); + // FADT flag: enable HW_REDUCED_ACPI bit on aarch64 plantform. + fadt.set_field(112, 1_u32 << 20 | 1_u32 << 10 | 1_u32 << 8); // ARM Boot Architecture Flags fadt.set_field(129, 0x3_u16); } // FADT minor revision - fadt.set_field(131, 3); + fadt.set_field(131, 3_u8); // X_PM_TMR_BLK bit, offset is 208. #[cfg(target_arch = "x86_64")] fadt.append_child(&AcpiGenericAddress::new_io_address(0x608_u32).aml_bytes()); @@ -620,28 +647,28 @@ pub(crate) trait AcpiBuilder { #[cfg(target_arch = "x86_64")] { // FADT flag: disable HW_REDUCED_ACPI bit on x86 plantform. - fadt.set_field(112, 1 << 10 | 1 << 8); + fadt.set_field(112, 1_u32 << 10 | 1_u32 << 8); // Reset Register bit, offset is 116. fadt.set_field(116, 0x01_u8); fadt.set_field(117, 0x08_u8); - fadt.set_field(120, RST_CTRL_OFFSET as u64); + fadt.set_field(120, u64::from(RST_CTRL_OFFSET)); fadt.set_field(128, 0x0F_u8); // PM1a event register bit, offset is 148. fadt.set_field(148, 0x01_u8); fadt.set_field(149, 0x20_u8); - fadt.set_field(152, PM_EVENT_OFFSET as u64); + fadt.set_field(152, u64::from(PM_EVENT_OFFSET)); // PM1a control register bit, offset is 172. fadt.set_field(172, 0x01_u8); fadt.set_field(173, 0x10_u8); - fadt.set_field(176, PM_CTRL_OFFSET as u64); + fadt.set_field(176, u64::from(PM_CTRL_OFFSET)); // Sleep control register, offset is 244. fadt.set_field(244, 0x01_u8); fadt.set_field(245, 0x08_u8); - fadt.set_field(248, SLEEP_CTRL_OFFSET as u64); + fadt.set_field(248, u64::from(SLEEP_CTRL_OFFSET)); // Sleep status tegister, offset is 256. fadt.set_field(256, 0x01_u8); fadt.set_field(257, 0x08_u8); - fadt.set_field(260, SLEEP_CTRL_OFFSET as u64); + fadt.set_field(260, u64::from(SLEEP_CTRL_OFFSET)); } let mut locked_acpi_data = acpi_data.lock().unwrap(); @@ -656,10 +683,11 @@ pub(crate) trait AcpiBuilder { let facs_size = 4_u8; loader.add_pointer_entry( ACPI_TABLE_FILE, + // fadt_begin is much less than u32::MAX, will not overflow. fadt_begin + facs_offset, facs_size, ACPI_TABLE_FILE, - facs_addr as u32, + u32::try_from(facs_addr)?, )?; // xDSDT address field's offset in FADT. @@ -668,28 +696,33 @@ pub(crate) trait AcpiBuilder { let xdsdt_size = 8_u8; loader.add_pointer_entry( ACPI_TABLE_FILE, + // fadt_begin is much less than u32::MAX, will not overflow. fadt_begin + xdsdt_offset, xdsdt_size, ACPI_TABLE_FILE, - dsdt_addr as u32, + u32::try_from(dsdt_addr)?, )?; loader.add_cksum_entry( ACPI_TABLE_FILE, + // fadt_begin is much less than u32::MAX, will not overflow. fadt_begin + TABLE_CHECKSUM_OFFSET, fadt_begin, fadt_end - fadt_begin, )?; - Ok(fadt_begin as u64) + Ok(u64::from(fadt_begin)) } + /// Get the Hardware Signature used to build FACS table. + fn get_hardware_signature(&self) -> Option; + /// Build ACPI FACS table, returns the offset of ACPI FACS table in `acpi_data`. /// /// # Arguments /// /// `acpi_data` - Bytes streams that ACPI tables converts to. - fn build_facs_table(acpi_data: &Arc>>) -> Result + fn build_facs_table(&self, acpi_data: &Arc>>) -> Result where Self: Sized, { @@ -702,12 +735,21 @@ pub(crate) trait AcpiBuilder { // FACS table length. facs_data[4] = 0x40; + // FACS table Hardware Signature. + if let Some(signature) = self.get_hardware_signature() { + let signature = signature.as_bytes(); + facs_data[8] = signature[0]; + facs_data[9] = signature[1]; + facs_data[10] = signature[2]; + facs_data[11] = signature[3]; + } + let mut locked_acpi_data = acpi_data.lock().unwrap(); let facs_begin = locked_acpi_data.len() as u32; locked_acpi_data.extend(facs_data); drop(locked_acpi_data); - Ok(facs_begin as u64) + Ok(u64::from(facs_begin)) } /// Build ACPI SRAT CPU table. @@ -797,6 +839,7 @@ pub(crate) trait AcpiBuilder { { let mut xsdt = AcpiTable::new(*b"XSDT", 1, *b"STRATO", *b"VIRTXSDT", 1); + // usize is enough for storing table len. xsdt.set_table_len(xsdt.table_len() + size_of::() * xsdt_entries.len()); let mut locked_acpi_data = acpi_data.lock().unwrap(); @@ -812,22 +855,25 @@ pub(crate) trait AcpiBuilder { for entry in xsdt_entries { loader.add_pointer_entry( ACPI_TABLE_FILE, + // xsdt_begin is much less than u32::MAX, will not overflow. xsdt_begin + entry_offset, entry_size, ACPI_TABLE_FILE, - entry as u32, + u32::try_from(entry)?, )?; + // u32 is enough for storing offset. entry_offset += u32::from(entry_size); } loader.add_cksum_entry( ACPI_TABLE_FILE, + // xsdt_begin is much less than u32::MAX, will not overflow. xsdt_begin + TABLE_CHECKSUM_OFFSET, xsdt_begin, xsdt_end - xsdt_begin, )?; - Ok(xsdt_begin as u64) + Ok(u64::from(xsdt_begin)) } /// Build ACPI RSDP and add it to FwCfg as file-entry. @@ -853,7 +899,7 @@ pub(crate) trait AcpiBuilder { xsdt_offset, xsdt_size, ACPI_TABLE_FILE, - xsdt_addr as u32, + u32::try_from(xsdt_addr)?, )?; let cksum_offset = 8_u32; @@ -874,26 +920,6 @@ impl StdMachine { self.detach_usb_from_xhci_controller(&mut locked_vmconfig, id) } - /// When windows emu exits, stratovirt should exits too. - #[cfg(feature = "windows_emu_pid")] - pub(crate) fn watch_windows_emu_pid( - &self, - vm_config: &VmConfig, - power_button: Arc, - shutdown_req: Arc, - ) { - let emu_pid = vm_config.windows_emu_pid.as_ref(); - if emu_pid.is_none() { - return; - } - log::info!("Watching on windows emu lifetime"); - crate::check_windows_emu_pid( - "/proc/".to_owned() + emu_pid.unwrap(), - power_button, - shutdown_req, - ); - } - #[cfg(target_arch = "x86_64")] fn plug_cpu_device(&mut self, args: &qmp_schema::DeviceAddArgument) -> Result<()> { if self.get_numa_nodes().is_some() { @@ -915,6 +941,7 @@ impl StdMachine { bail!("Cpu-id {} already exist.", cpu_id) } if cpu_id >= max_cpus { + // max_cpus is no less than 1. bail!("Max cpu-id is {}", max_cpus - 1) } @@ -928,6 +955,114 @@ impl StdMachine { } } +impl MachineLifecycle for StdMachine { + fn pause(&self) -> bool { + if self.notify_lifecycle(VmState::Running, VmState::Paused) { + event!(Stop); + true + } else { + false + } + } + + fn resume(&self) -> bool { + if !self.notify_lifecycle(VmState::Paused, VmState::Running) { + return false; + } + event!(Resume); + true + } + + fn destroy(&self) -> bool { + if self.shutdown_req.write(1).is_err() { + error!("Failed to send shutdown request."); + return false; + } + + true + } + + #[cfg(target_arch = "aarch64")] + fn powerdown(&self) -> bool { + if self.power_button.write(1).is_err() { + error!("Standard vm write power button failed"); + return false; + } + true + } + + #[cfg(target_arch = "aarch64")] + fn get_shutdown_action(&self) -> ShutdownAction { + self.base + .vm_config + .lock() + .unwrap() + .machine_config + .shutdown_action + } + + fn reset(&mut self) -> bool { + if self.reset_req.write(1).is_err() { + error!("Standard vm write reset request failed"); + return false; + } + true + } + + fn notify_lifecycle(&self, old: VmState, new: VmState) -> bool { + if let Err(e) = self.vm_state_transfer( + &self.base.cpus, + #[cfg(target_arch = "aarch64")] + &self.base.irq_chip, + &mut self.base.vm_state.0.lock().unwrap(), + old, + new, + ) { + error!("VM state transfer failed: {:?}", e); + return false; + } + true + } +} + +impl MigrateInterface for StdMachine { + fn migrate(&self, uri: String) -> Response { + match parse_incoming_uri(&uri) { + Ok((MigrateMode::File, path)) => migration::snapshot(path), + Ok((MigrateMode::Unix, path)) => migration::migration_unix_mode(path), + Ok((MigrateMode::Tcp, path)) => migration::migration_tcp_mode(path), + _ => Response::create_error_response( + qmp_schema::QmpErrorClass::GenericError(format!("Invalid uri: {}", uri)), + None, + ), + } + } + + fn query_migrate(&self) -> Response { + migration::query_migrate() + } + + fn cancel_migrate(&self) -> Response { + migration::cancel_migrate() + } +} + +impl MachineInterface for StdMachine {} +impl MachineExternalInterface for StdMachine {} +impl MachineTestInterface for StdMachine {} + +impl EventLoopManager for StdMachine { + fn loop_should_exit(&self) -> bool { + let vmstate = self.base.vm_state.deref().0.lock().unwrap(); + *vmstate == VmState::Shutdown + } + + fn loop_cleanup(&self) -> Result<()> { + set_termi_canon_mode().with_context(|| "Failed to set terminal to canonical mode")?; + Ok(()) + } +} + impl MachineAddressInterface for StdMachine { #[cfg(target_arch = "x86_64")] fn pio_in(&self, addr: u64, data: &mut [u8]) -> bool { @@ -976,7 +1111,7 @@ impl DeviceInterface for StdMachine { for cpu_index in 0..cpu_topo.max_cpus { if cpu_topo.get_mask(cpu_index as usize) == 1 { let thread_id = cpus[cpu_index as usize].tid(); - let cpu_instance = cpu_topo.get_topo_instance_for_qmp(cpu_index as usize); + let cpu_instance = cpu_topo.get_topo_instance_for_qmp(cpu_index); let cpu_common = qmp_schema::CpuInfoCommon { current: true, qom_path: String::from("/machine/unattached/device[") @@ -1094,6 +1229,7 @@ impl DeviceInterface for StdMachine { ); } } + #[cfg(feature = "virtio_scsi")] "virtio-scsi-pci" => { let cfg_args = locked_vmconfig.add_device_config(args.as_ref()); if let Err(e) = self.add_virtio_pci_scsi(&mut vm_config_clone, &cfg_args, true) { @@ -1106,6 +1242,7 @@ impl DeviceInterface for StdMachine { ); } } + #[cfg(feature = "vhostuser_block")] "vhost-user-blk-pci" => { let cfg_args = locked_vmconfig.add_device_config(args.as_ref()); if let Err(e) = self.add_vhost_user_blk_pci(&mut vm_config_clone, &cfg_args, true) { @@ -1130,6 +1267,7 @@ impl DeviceInterface for StdMachine { ); } } + #[cfg(feature = "vfio_device")] "vfio-pci" => { let cfg_args = locked_vmconfig.add_device_config(args.as_ref()); if let Err(e) = self.add_vfio_device(&cfg_args, true) { @@ -1141,7 +1279,7 @@ impl DeviceInterface for StdMachine { ); } } - "usb-kbd" | "usb-tablet" | "usb-camera" | "usb-host" | "usb-storage" => { + "usb-kbd" | "usb-tablet" | "usb-camera" | "usb-host" | "usb-storage" | "usb-uas" => { let cfg_args = locked_vmconfig.add_device_config(args.as_ref()); if let Err(e) = self.add_usb_device(&mut vm_config_clone, &cfg_args) { error!("{:?}", e); @@ -1176,7 +1314,9 @@ impl DeviceInterface for StdMachine { // It's safe to call get_pci_host().unwrap() because it has been checked before. let locked_pci_host = self.get_pci_host().unwrap().lock().unwrap(); - if let Some((bus, dev)) = PciBus::find_attached_bus(&locked_pci_host.root_bus, &args.id) { + if let Some((bus, dev)) = + PciBus::find_attached_bus(&locked_pci_host.child_bus().unwrap(), &args.id) + { match handle_plug(&bus, &dev) { Ok(()) => Response::create_empty_response(), Err(e) => { @@ -1213,7 +1353,9 @@ impl DeviceInterface for StdMachine { }; let locked_pci_host = pci_host.lock().unwrap(); - if let Some((bus, dev)) = PciBus::find_attached_bus(&locked_pci_host.root_bus, &device_id) { + if let Some((bus, dev)) = + PciBus::find_attached_bus(&locked_pci_host.child_bus().unwrap(), &device_id) + { return match handle_unplug_pci_request(&bus, &dev) { Ok(()) => { let locked_dev = dev.lock().unwrap(); @@ -1272,7 +1414,7 @@ impl DeviceInterface for StdMachine { if let Err(e) = self.register_drive_file( &config.id, &args.file.filename, - config.read_only, + config.readonly, config.direct, ) { error!("{:?}", e); @@ -1485,6 +1627,7 @@ impl DeviceInterface for StdMachine { } for (i, data) in data.iter_mut().enumerate().take(std::mem::size_of::()) { + // i is less than 8, multiply will not overflow. *data = (self.head >> (8 * i)) as u8; } true @@ -1533,13 +1676,13 @@ impl DeviceInterface for StdMachine { let mut fd = None; if args.region_type.eq("rom_device_region") || args.region_type.eq("ram_device_region") { if let Some(file_name) = args.device_fd_path { - fd = Some( + fd = Some(Arc::new( std::fs::OpenOptions::new() .read(true) .write(true) .open(file_name) .unwrap(), - ); + )); } } @@ -1549,7 +1692,7 @@ impl DeviceInterface for StdMachine { region = Region::init_io_region(args.size, dummy_dev_ops, "UpdateRegionTest"); if args.ioeventfd.is_some() && args.ioeventfd.unwrap() { let ioeventfds = vec![RegionIoEventFd { - fd: Arc::new(EventFd::new(libc::EFD_NONBLOCK).unwrap()), + fd: Arc::new(create_new_eventfd().unwrap()), addr_range: AddressRange::from(( 0, args.ioeventfd_size.unwrap_or_default(), @@ -1603,6 +1746,12 @@ impl DeviceInterface for StdMachine { } }; + if i32::try_from(args.priority).is_err() { + return Response::create_error_response( + qmp_schema::QmpErrorClass::GenericError("priority illegal".to_string()), + None, + ); + } region.set_priority(args.priority as i32); if let Some(read_only) = args.romd { if region.set_rom_device_romd(read_only).is_err() { @@ -1672,12 +1821,7 @@ impl DeviceInterface for StdMachine { None, ); } - let drive_cfg = match self - .get_vm_config() - .lock() - .unwrap() - .add_block_drive(cmd_args[2]) - { + let drive_cfg = match self.get_vm_config().lock().unwrap().add_drive(cmd_args[2]) { Ok(cfg) => cfg, Err(ref e) => { return Response::create_error_response( @@ -1689,7 +1833,7 @@ impl DeviceInterface for StdMachine { if let Err(e) = self.register_drive_file( &drive_cfg.id, &drive_cfg.path_on_host, - drive_cfg.read_only, + drive_cfg.readonly, drive_cfg.direct, ) { error!("{:?}", e); @@ -1734,7 +1878,7 @@ impl DeviceInterface for StdMachine { } let qcow2_list = QCOW2_LIST.lock().unwrap(); - if qcow2_list.len() == 0 { + if qcow2_list.is_empty() { return Response::create_response( serde_json::to_value("There is no snapshot available.\r\n").unwrap(), None, @@ -1897,7 +2041,7 @@ impl DeviceInterface for StdMachine { match self .machine_base() .sys_mem - .read_object::(GuestAddress(gpa)) + .read_object::(GuestAddress(gpa), AddressAttr::Ram) { Ok(val) => { Response::create_response(serde_json::to_value(format!("{:X}", val)).unwrap(), None) @@ -1910,13 +2054,30 @@ impl DeviceInterface for StdMachine { ), } } + + fn query_workloads(&self) -> Response { + let workloads = query_workloads(); + + if !workloads.is_empty() { + let status = workloads + .iter() + .map(|(module, state)| json!({ "module": module, "state": state })) + .collect(); + + Response::create_response(serde_json::Value::Array(status), None) + } else { + Response::create_empty_response() + } + } } fn parse_blockdev(args: &BlockDevAddArgument) -> Result { let mut config = DriveConfig { id: args.node_name.clone(), + drive_type: "none".to_string(), + unit: None, path_on_host: args.file.filename.clone(), - read_only: args.read_only.unwrap_or(false), + readonly: args.read_only.unwrap_or(false), direct: true, iops: args.iops, aio: args.file.aio, diff --git a/machine/src/standard_common/syscall.rs b/machine/src/standard_common/syscall.rs index d665f314e4fbfaf8b48c992dec6ff79b5513d9e7..30e1998ea25c40bb9052044fe6abea0b05fc01ce 100644 --- a/machine/src/standard_common/syscall.rs +++ b/machine/src/standard_common/syscall.rs @@ -23,6 +23,7 @@ use util::v4l2::{ VIDIOC_G_FMT, VIDIOC_QBUF, VIDIOC_QUERYBUF, VIDIOC_QUERYCAP, VIDIOC_REQBUFS, VIDIOC_STREAMOFF, VIDIOC_STREAMON, VIDIOC_S_FMT, VIDIOC_S_PARM, }; +#[cfg(feature = "vfio_device")] use vfio::{ VFIO_CHECK_EXTENSION, VFIO_DEVICE_GET_INFO, VFIO_DEVICE_GET_REGION_INFO, VFIO_DEVICE_RESET, VFIO_DEVICE_SET_IRQS, VFIO_GET_API_VERSION, VFIO_GROUP_GET_DEVICE_FD, VFIO_GROUP_GET_STATUS, @@ -226,6 +227,18 @@ fn ioctl_allow_list() -> BpfRule { .add_constraint(SeccompCmpOpt::Eq, 1, TUNSETOFFLOAD() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, TUNSETVNETHDRSZ() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, TUNSETQUEUE() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_GSI_ROUTING() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_IRQFD() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_CREATE_DEVICE() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_API_VERSION() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_MP_STATE() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_VCPU_EVENTS() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_DIRTY_LOG() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_MP_STATE() as u32) + .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_VCPU_EVENTS() as u32); + + #[cfg(feature = "vfio_device")] + let bpf_rule = bpf_rule .add_constraint(SeccompCmpOpt::Eq, 1, VFIO_DEVICE_SET_IRQS() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, VFIO_GROUP_GET_STATUS() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, VFIO_GET_API_VERSION() as u32) @@ -237,16 +250,7 @@ fn ioctl_allow_list() -> BpfRule { .add_constraint(SeccompCmpOpt::Eq, 1, VFIO_GROUP_GET_DEVICE_FD() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, VFIO_DEVICE_GET_INFO() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, VFIO_DEVICE_RESET() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, VFIO_DEVICE_GET_REGION_INFO() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_GSI_ROUTING() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_IRQFD() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_CREATE_DEVICE() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_API_VERSION() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_MP_STATE() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_VCPU_EVENTS() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_DIRTY_LOG() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_MP_STATE() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_VCPU_EVENTS() as u32); + .add_constraint(SeccompCmpOpt::Eq, 1, VFIO_DEVICE_GET_REGION_INFO() as u32); #[cfg(feature = "usb_camera_v4l2")] let bpf_rule = bpf_rule diff --git a/machine/src/x86_64/ich9_lpc.rs b/machine/src/x86_64/ich9_lpc.rs index 0a93312ef6028cd96f2556e586ea145660b19828..4ce47680aefaeaf47d226b122703e4519ecdad0b 100644 --- a/machine/src/x86_64/ich9_lpc.rs +++ b/machine/src/x86_64/ich9_lpc.rs @@ -11,7 +11,7 @@ // See the Mulan PSL v2 for more details. use std::sync::{ - atomic::{AtomicU8, Ordering}, + atomic::{AtomicBool, AtomicU8, Ordering}, Arc, Mutex, Weak, }; @@ -26,9 +26,10 @@ use devices::pci::config::{ PciConfig, CLASS_CODE_ISA_BRIDGE, DEVICE_ID, HEADER_TYPE, HEADER_TYPE_BRIDGE, HEADER_TYPE_MULTIFUNC, PCI_CONFIG_SPACE_SIZE, SUB_CLASS_CODE, VENDOR_ID, }; -use devices::pci::{le_write_u16, le_write_u32, PciBus, PciDevBase, PciDevOps}; -use devices::{Device, DeviceBase}; +use devices::pci::{le_write_u16, le_write_u32, PciDevBase, PciDevOps}; +use devices::{Bus, Device, DeviceBase}; use util::byte_code::ByteCode; +use util::gen_base_func; use util::num_ops::ranges_overlap; const DEVICE_ID_INTEL_ICH9: u16 = 0x2918; @@ -56,17 +57,17 @@ pub struct LPCBridge { impl LPCBridge { pub fn new( - parent_bus: Weak>, + parent_bus: Weak>, sys_io: Arc, reset_req: Arc, shutdown_req: Arc, ) -> Result { Ok(Self { base: PciDevBase { - base: DeviceBase::new("ICH9 LPC bridge".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 0), + base: DeviceBase::new("ICH9 LPC bridge".to_string(), false, Some(parent_bus)), + config: PciConfig::new(0x1F << 3, PCI_CONFIG_SPACE_SIZE, 0), devfn: 0x1F << 3, - parent_bus, + bme: Arc::new(AtomicBool::new(false)), }, sys_io, pm_timer: Arc::new(Mutex::new(AcpiPMTimer::new())), @@ -94,9 +95,10 @@ impl LPCBridge { self.base .config .read(PM_BASE_OFFSET as usize, pm_base_addr.as_mut_bytes()); - self.sys_io - .root() - .add_subregion(pmtmr_region, pm_base_addr as u64 + PM_TIMER_OFFSET as u64)?; + self.sys_io.root().add_subregion( + pmtmr_region, + u64::from(pm_base_addr) + u64::from(PM_TIMER_OFFSET), + )?; Ok(()) } @@ -143,7 +145,7 @@ impl LPCBridge { let rst_ctrl_region = Region::init_io_region(0x1, ops, "RstCtrlRegion"); self.sys_io .root() - .add_subregion(rst_ctrl_region, RST_CTRL_OFFSET as u64)?; + .add_subregion(rst_ctrl_region, u64::from(RST_CTRL_OFFSET))?; Ok(()) } @@ -170,7 +172,7 @@ impl LPCBridge { let sleep_reg_region = Region::init_io_region(0x1, ops, "SleepReg"); self.sys_io .root() - .add_subregion(sleep_reg_region, SLEEP_CTRL_OFFSET as u64)?; + .add_subregion(sleep_reg_region, u64::from(SLEEP_CTRL_OFFSET))?; Ok(()) } @@ -192,7 +194,7 @@ impl LPCBridge { let pm_evt_region = Region::init_io_region(0x4, ops, "PmEvtRegion"); self.sys_io .root() - .add_subregion(pm_evt_region, PM_EVENT_OFFSET as u64)?; + .add_subregion(pm_evt_region, u64::from(PM_EVENT_OFFSET))?; Ok(()) } @@ -222,32 +224,16 @@ impl LPCBridge { let pm_ctrl_region = Region::init_io_region(0x4, ops, "PmCtrl"); self.sys_io .root() - .add_subregion(pm_ctrl_region, PM_CTRL_OFFSET as u64)?; + .add_subregion(pm_ctrl_region, u64::from(PM_CTRL_OFFSET))?; Ok(()) } } impl Device for LPCBridge { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} - -impl PciDevOps for LPCBridge { - fn pci_base(&self) -> &PciDevBase { - &self.base - } - - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { self.init_write_mask(false)?; self.init_write_clear_mask(false)?; @@ -274,7 +260,7 @@ impl PciDevOps for LPCBridge { le_write_u16( &mut self.base.config.config, HEADER_TYPE as usize, - (HEADER_TYPE_BRIDGE | HEADER_TYPE_MULTIFUNC) as u16, + u16::from(HEADER_TYPE_BRIDGE | HEADER_TYPE_MULTIFUNC), )?; self.init_sleep_reg() @@ -288,16 +274,16 @@ impl PciDevOps for LPCBridge { self.init_pm_ctrl_reg() .with_context(|| "Fail to init IO region for PM control register")?; - let parent_bus = self.base.parent_bus.clone(); - parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .devices - .insert(0x1F << 3, Arc::new(Mutex::new(self))); - Ok(()) + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + let mut locked_bus = parent_bus.lock().unwrap(); + let dev = Arc::new(Mutex::new(self)); + locked_bus.attach_child(0x1F << 3, dev.clone())?; + Ok(dev) } +} + +impl PciDevOps for LPCBridge { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); fn write_config(&mut self, offset: usize, data: &[u8]) { self.base.config.write(offset, data, 0, None, None); diff --git a/machine/src/x86_64/mch.rs b/machine/src/x86_64/mch.rs index 5c6f593ecb4195044cf582fa1b49a2750a62abca..570cdcb3e3be552bc87e88ccbdb221742a1eca63 100644 --- a/machine/src/x86_64/mch.rs +++ b/machine/src/x86_64/mch.rs @@ -10,7 +10,7 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::sync::{Arc, Mutex, Weak}; +use std::sync::{atomic::AtomicBool, Arc, Mutex, Weak}; use anyhow::{bail, Result}; use log::error; @@ -24,7 +24,8 @@ use devices::pci::{ }, le_read_u64, le_write_u16, PciBus, PciDevBase, PciDevOps, }; -use devices::{Device, DeviceBase}; +use devices::{convert_bus_ref, Bus, Device, DeviceBase, PCI_BUS}; +use util::gen_base_func; use util::num_ops::ranges_overlap; const DEVICE_ID_INTEL_Q35_MCH: u16 = 0x29c0; @@ -50,16 +51,16 @@ pub struct Mch { impl Mch { pub fn new( - parent_bus: Weak>, + parent_bus: Weak>, mmconfig_region: Region, mmconfig_ops: RegionOps, ) -> Self { Self { base: PciDevBase { - base: DeviceBase::new("Memory Controller Hub".to_string(), false), - config: PciConfig::new(PCI_CONFIG_SPACE_SIZE, 0), + base: DeviceBase::new("Memory Controller Hub".to_string(), false, Some(parent_bus)), + config: PciConfig::new(0, PCI_CONFIG_SPACE_SIZE, 0), devfn: 0, - parent_bus, + bme: Arc::new(AtomicBool::new(false)), }, mmconfig_region: Some(mmconfig_region), mmconfig_ops, @@ -85,27 +86,17 @@ impl Mch { } if let Some(region) = self.mmconfig_region.as_ref() { - self.base - .parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .mem_region - .delete_subregion(region)?; + let bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); + pci_bus.mem_region.delete_subregion(region)?; self.mmconfig_region = None; } if enable == 0x1 { let region = Region::init_io_region(length, self.mmconfig_ops.clone(), "PcieXBar"); let base_addr: u64 = pciexbar & addr_mask; - self.base - .parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .mem_region - .add_subregion(region, base_addr)?; + let bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); + pci_bus.mem_region.add_subregion(region, base_addr)?; } Ok(()) } @@ -121,25 +112,9 @@ impl Mch { } impl Device for Mch { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} - -impl PciDevOps for Mch { - fn pci_base(&self) -> &PciDevBase { - &self.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base - } - - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { self.init_write_mask(false)?; self.init_write_clear_mask(false)?; @@ -159,16 +134,16 @@ impl PciDevOps for Mch { CLASS_CODE_HOST_BRIDGE, )?; - let parent_bus = self.base.parent_bus.clone(); - parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .devices - .insert(0, Arc::new(Mutex::new(self))); - Ok(()) + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + let mut locked_bus = parent_bus.lock().unwrap(); + let dev = Arc::new(Mutex::new(self)); + locked_bus.attach_child(0, dev.clone())?; + Ok(dev) } +} + +impl PciDevOps for Mch { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); fn write_config(&mut self, offset: usize, data: &[u8]) { let old_pciexbar: u64 = le_read_u64(&self.base.config.config, PCIEXBAR as usize).unwrap(); diff --git a/machine/src/x86_64/micro.rs b/machine/src/x86_64/micro.rs index d8fb92e6167d69f4d81212865b3bf4c82df94881..ba7c8fe37e172bddbc3f2422440283efae44539b 100644 --- a/machine/src/x86_64/micro.rs +++ b/machine/src/x86_64/micro.rs @@ -14,18 +14,19 @@ use std::sync::{Arc, Mutex}; use anyhow::{bail, Context, Result}; -use crate::{ - micro_common::syscall::syscall_whitelist, LightMachine, MachineBase, MachineError, MachineOps, -}; +use crate::micro_common::syscall::syscall_whitelist; +use crate::{register_shutdown_event, LightMachine, MachineBase, MachineError, MachineOps}; use address_space::{AddressSpace, Region}; use cpu::{CPUBootConfig, CPUTopology}; -use devices::legacy::FwCfgOps; +use devices::legacy::{FwCfgOps, Serial, SERIAL_ADDR}; +use devices::Device; use hypervisor::kvm::x86_64::*; use hypervisor::kvm::*; -use machine_manager::config::{SerialConfig, VmConfig}; +use machine_manager::config::{MigrateMode, SerialConfig, VmConfig}; use migration::{MigrationManager, MigrationStatus}; +use util::gen_base_func; use util::seccomp::{BpfRule, SeccompCmpOpt}; -use virtio::VirtioMmioDevice; +use virtio::{VirtioDevice, VirtioMmioDevice}; #[repr(usize)] pub enum LayoutEntryType { @@ -47,13 +48,7 @@ pub const MEM_LAYOUT: &[(u64, u64)] = &[ ]; impl MachineOps for LightMachine { - fn machine_base(&self) -> &MachineBase { - &self.base - } - - fn machine_base_mut(&mut self) -> &mut MachineBase { - &mut self.base - } + gen_base_func!(machine_base, machine_base_mut, MachineBase, base); fn init_machine_ram(&self, sys_mem: &Arc, mem_size: u64) -> Result<()> { let vm_ram = self.get_vm_ram(); @@ -89,7 +84,7 @@ impl MachineOps for LightMachine { locked_hypervisor.create_interrupt_controller()?; let irq_manager = locked_hypervisor.create_irq_manager()?; - self.base.sysbus.irq_manager = irq_manager.line_irq_manager; + self.base.sysbus.lock().unwrap().irq_manager = irq_manager.line_irq_manager; Ok(()) } @@ -100,6 +95,7 @@ impl MachineOps for LightMachine { let boot_source = self.base.boot_source.lock().unwrap(); let initrd = boot_source.initrd.as_ref().map(|b| b.initrd_file.clone()); + // MEM_LAYOUT is defined statically, will not overflow. let gap_start = MEM_LAYOUT[LayoutEntryType::MemBelow4g as usize].0 + MEM_LAYOUT[LayoutEntryType::MemBelow4g as usize].1; let gap_end = MEM_LAYOUT[LayoutEntryType::MemAbove4g as usize].0; @@ -108,6 +104,7 @@ impl MachineOps for LightMachine { initrd, kernel_cmdline: boot_source.kernel_cmdline.to_string(), cpu_count: self.base.cpu_topo.nrcpus, + // gap_end is bigger than gap_start, as MEM_LAYOUT is defined statically. gap_range: (gap_start, gap_end - gap_start), ioapic_addr: MEM_LAYOUT[LayoutEntryType::IoApic as usize].0 as u32, lapic_addr: MEM_LAYOUT[LayoutEntryType::LocalApic as usize].0 as u32, @@ -134,14 +131,13 @@ impl MachineOps for LightMachine { } fn add_serial_device(&mut self, config: &SerialConfig) -> Result<()> { - use devices::legacy::{Serial, SERIAL_ADDR}; - let region_base: u64 = SERIAL_ADDR; let region_size: u64 = 8; - let serial = Serial::new(config.clone()); + let serial = Serial::new(config.clone(), &self.base.sysbus, region_base, region_size)?; serial - .realize(&mut self.base.sysbus, region_base, region_size) - .with_context(|| "Failed to realize serial device.") + .realize() + .with_context(|| "Failed to realize serial device.")?; + Ok(()) } fn realize(vm: &Arc>, vm_config: &mut VmConfig) -> Result<()> { @@ -174,7 +170,12 @@ impl MachineOps for LightMachine { locked_vm.add_devices(vm_config)?; trace::replaceable_info(&locked_vm.replaceable_info); - let boot_config = locked_vm.load_boot_source(None)?; + let migrate_info = locked_vm.get_migrate_info(); + let boot_config = if migrate_info.0 == MigrateMode::Unknown { + Some(locked_vm.load_boot_source(None)?) + } else { + None + }; let hypervisor = locked_vm.base.hypervisor.clone(); locked_vm.base.cpus.extend(::init_vcpu( vm.clone(), @@ -184,6 +185,8 @@ impl MachineOps for LightMachine { &topology, &boot_config, )?); + register_shutdown_event(locked_vm.shutdown_req.clone(), vm.clone()) + .with_context(|| "Failed to register shutdown event")?; MigrationManager::register_vm_instance(vm.clone()); let migration_hyp = locked_vm.base.migration_hypervisor.clone(); @@ -204,11 +207,12 @@ impl MachineOps for LightMachine { self.add_virtio_mmio_block(vm_config, cfg_args) } - fn realize_virtio_mmio_device( + fn add_virtio_mmio_device( &mut self, - dev: VirtioMmioDevice, + name: String, + device: Arc>, ) -> Result>> { - self.realize_virtio_mmio_device(dev) + self.add_virtio_mmio_device(name, device) } fn syscall_whitelist(&self) -> Vec { @@ -238,7 +242,6 @@ pub(crate) fn arch_ioctl_allow_list(bpf_rule: BpfRule) -> BpfRule { .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_LAPIC() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, KVM_GET_MSRS() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_MSRS() as u32) - .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_VCPU_EVENTS() as u32) .add_constraint(SeccompCmpOpt::Eq, 1, KVM_SET_CPUID2() as u32) } diff --git a/machine/src/x86_64/mod.rs b/machine/src/x86_64/mod.rs index 47b4ecbe16395502a3619eaca576f769d090fe61..b3227f997a235f184f9d6cdfab8cab7505de5d0e 100644 --- a/machine/src/x86_64/mod.rs +++ b/machine/src/x86_64/mod.rs @@ -11,7 +11,6 @@ // See the Mulan PSL v2 for more details. pub mod ich9_lpc; +pub mod mch; pub mod micro; pub mod standard; - -mod mch; diff --git a/machine/src/x86_64/standard.rs b/machine/src/x86_64/standard.rs index 6499151bb1042dbc07f82e10bd846b92dd935394..e85042bf2f52da0e8521b17b20b2dc4d4838629d 100644 --- a/machine/src/x86_64/standard.rs +++ b/machine/src/x86_64/standard.rs @@ -12,19 +12,16 @@ use std::io::{Seek, SeekFrom}; use std::mem::size_of; -use std::ops::Deref; use std::sync::{Arc, Barrier, Mutex}; use anyhow::{bail, Context, Result}; -use log::{error, info, warn}; -use vmm_sys_util::eventfd::EventFd; use super::ich9_lpc; use super::mch::Mch; use crate::error::MachineError; use crate::standard_common::syscall::syscall_whitelist; use crate::standard_common::{AcpiBuilder, StdMachineOps}; -use crate::{MachineBase, MachineOps}; +use crate::{register_shutdown_event, MachineBase, MachineOps, StdMachine}; use acpi::{ AcpiIoApic, AcpiLocalApic, AcpiSratMemoryAffinity, AcpiSratProcessorAffinity, AcpiTable, AmlBuilder, AmlInteger, AmlNameDecl, AmlPackage, AmlScope, AmlScopeBuilder, TableLoader, @@ -39,29 +36,27 @@ use devices::legacy::{ error::LegacyError as DevErrorKind, FwCfgEntryType, FwCfgIO, FwCfgOps, PFlash, Serial, RTC, SERIAL_ADDR, }; -use devices::pci::{PciDevOps, PciHost}; +use devices::pci::{PciBus, PciHost}; +use devices::{convert_bus_mut, Device, MUT_PCI_BUS}; use hypervisor::kvm::x86_64::*; use hypervisor::kvm::*; #[cfg(feature = "gtk")] use machine_manager::config::UiContext; use machine_manager::config::{ - parse_incoming_uri, BootIndexInfo, MigrateMode, NumaNode, PFlashConfig, SerialConfig, VmConfig, + BootIndexInfo, DriveConfig, MigrateMode, NumaNode, SerialConfig, VmConfig, }; use machine_manager::event; -use machine_manager::machine::{ - MachineExternalInterface, MachineInterface, MachineLifecycle, MachineTestInterface, - MigrateInterface, VmState, -}; -use machine_manager::qmp::{qmp_channel::QmpChannel, qmp_response::Response, qmp_schema}; +use machine_manager::qmp::{qmp_channel::QmpChannel, qmp_schema}; use migration::{MigrationManager, MigrationStatus}; #[cfg(feature = "gtk")] use ui::gtk::gtk_display_init; #[cfg(feature = "vnc")] use ui::vnc::vnc_init; +use util::byte_code::ByteCode; +use util::gen_base_func; +use util::loop_context::create_new_eventfd; +use util::seccomp::BpfRule; use util::seccomp::SeccompCmpOpt; -use util::{ - byte_code::ByteCode, loop_context::EventLoopManager, seccomp::BpfRule, set_termi_canon_mode, -}; pub(crate) const VENDOR_ID_INTEL: u16 = 0x8086; const HOLE_640K_START: u64 = 0x000A_0000; @@ -111,26 +106,6 @@ const IRQ_MAP: &[(i32, i32)] = &[ (16, 19), // Pcie ]; -/// Standard machine structure. -pub struct StdMachine { - // Machine base members. - base: MachineBase, - /// PCI/PCIe host bridge. - pci_host: Arc>, - /// Reset request, handle VM `Reset` event. - reset_req: Arc, - /// Shutdown_req, handle VM 'ShutDown' event. - shutdown_req: Arc, - /// VM power button, handle VM `Powerdown` event. - power_button: Arc, - /// CPU Resize request, handle vm cpu hot(un)plug event. - cpu_resize_req: Arc, - /// List contains the boot order of boot devices. - boot_order_list: Arc>>, - /// Cpu Controller. - cpu_controller: Option>>, -} - impl StdMachine { pub fn new(vm_config: &VmConfig) -> Result { let free_irqs = ( @@ -155,20 +130,20 @@ impl StdMachine { IRQ_MAP[IrqEntryType::Pcie as usize].0, ))), reset_req: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("reset request".to_string()))?, ), shutdown_req: Arc::new( - EventFd::new(libc::EFD_NONBLOCK).with_context(|| { + create_new_eventfd().with_context(|| { MachineError::InitEventFdErr("shutdown request".to_string()) })?, ), power_button: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("power button".to_string()))?, ), cpu_resize_req: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("cpu resize".to_string()))?, ), boot_order_list: Arc::new(Mutex::new(Vec::new())), @@ -206,39 +181,20 @@ impl StdMachine { Ok(()) } - pub fn handle_destroy_request(vm: &Arc>) -> Result<()> { - let locked_vm = vm.lock().unwrap(); - let vmstate = { - let state = locked_vm.base.vm_state.deref().0.lock().unwrap(); - *state - }; - - if !locked_vm.notify_lifecycle(vmstate, VmState::Shutdown) { - warn!("Failed to destroy guest, destroy continue."); - if locked_vm.shutdown_req.write(1).is_err() { - error!("Failed to send shutdown request.") - } - } - - info!("vm destroy"); - - Ok(()) - } - fn init_ich9_lpc(&self, vm: Arc>) -> Result<()> { - let clone_vm = vm.clone(); - let root_bus = Arc::downgrade(&self.pci_host.lock().unwrap().root_bus); + let root_bus = Arc::downgrade(&self.pci_host.lock().unwrap().child_bus().unwrap()); let ich = ich9_lpc::LPCBridge::new( root_bus, self.base.sys_io.clone(), self.reset_req.clone(), self.shutdown_req.clone(), )?; - self.register_reset_event(self.reset_req.clone(), vm) + self.register_reset_event(self.reset_req.clone(), vm.clone()) .with_context(|| "Fail to register reset event in LPC")?; - self.register_shutdown_event(ich.shutdown_req.clone(), clone_vm) + register_shutdown_event(ich.shutdown_req.clone(), vm) .with_context(|| "Fail to register shutdown event in LPC")?; - ich.realize() + ich.realize()?; + Ok(()) } pub fn get_vcpu_reg_val(&self, _addr: u64, _vcpu: usize) -> Option { @@ -256,31 +212,25 @@ impl StdMachine { cpu_topology: CPUTopology, vm: Arc>, ) -> Result<()> { - let cpu_controller: CpuController = Default::default(); - let region_base: u64 = MEM_LAYOUT[LayoutEntryType::CpuController as usize].0; let region_size: u64 = MEM_LAYOUT[LayoutEntryType::CpuController as usize].1; let cpu_config = CpuConfig::new(boot_config, cpu_topology); let hotplug_cpu_req = Arc::new( - EventFd::new(libc::EFD_NONBLOCK) + create_new_eventfd() .with_context(|| MachineError::InitEventFdErr("hotplug cpu".to_string()))?, ); - + let cpu_controller = CpuController::new( + self.base.cpu_topo.max_cpus, + &self.base.sysbus, + region_base, + region_size, + cpu_config, + hotplug_cpu_req.clone(), + self.base.cpus.clone(), + )?; let realize_controller = cpu_controller - .realize( - &mut self.base.sysbus, - self.base.cpu_topo.max_cpus, - region_base, - region_size, - cpu_config, - hotplug_cpu_req.clone(), - ) + .realize() .with_context(|| "Failed to realize Cpu Controller")?; - - let mut lock_controller = realize_controller.lock().unwrap(); - lock_controller.set_boot_vcpu(self.base.cpus.clone())?; - drop(lock_controller); - self.register_hotplug_vcpu_event(hotplug_cpu_req, vm)?; self.cpu_controller = Some(realize_controller); Ok(()) @@ -289,7 +239,7 @@ impl StdMachine { impl StdMachineOps for StdMachine { fn init_pci_host(&self) -> Result<()> { - let root_bus = Arc::downgrade(&self.pci_host.lock().unwrap().root_bus); + let root_bus = Arc::downgrade(&self.pci_host.lock().unwrap().child_bus().unwrap()); let mmconfig_region_ops = PciHost::build_mmconfig_ops(self.pci_host.clone()); let mmconfig_region = Region::init_io_region( MEM_LAYOUT[LayoutEntryType::PcieEcam as usize].1, @@ -321,7 +271,8 @@ impl StdMachineOps for StdMachine { .with_context(|| "Failed to register CONFIG_DATA port in I/O space.")?; let mch = Mch::new(root_bus, mmconfig_region, mmconfig_region_ops); - mch.realize() + mch.realize()?; + Ok(()) } fn add_fwcfg_device( @@ -329,7 +280,7 @@ impl StdMachineOps for StdMachine { nr_cpus: u8, max_cpus: u8, ) -> Result>>> { - let mut fwcfg = FwCfgIO::new(self.base.sys_mem.clone()); + let mut fwcfg = FwCfgIO::new(self.base.sys_mem.clone(), &self.base.sysbus)?; fwcfg.add_data_entry(FwCfgEntryType::NbCpus, nr_cpus.as_bytes().to_vec())?; fwcfg.add_data_entry(FwCfgEntryType::MaxCpus, max_cpus.as_bytes().to_vec())?; fwcfg.add_data_entry(FwCfgEntryType::Irq0Override, 1_u32.as_bytes().to_vec())?; @@ -339,7 +290,8 @@ impl StdMachineOps for StdMachine { .add_file_entry("bootorder", boot_order) .with_context(|| DevErrorKind::AddEntryErr("bootorder".to_string()))?; - let fwcfg_dev = FwCfgIO::realize(fwcfg, &mut self.base.sysbus) + let fwcfg_dev = fwcfg + .realize() .with_context(|| "Failed to realize fwcfg device")?; self.base.fwcfg_dev = Some(fwcfg_dev.clone()); @@ -371,7 +323,7 @@ impl StdMachineOps for StdMachine { hypervisor, self.base.cpu_topo.max_cpus, )?; - vcpu.realize(boot_cfg, topology).with_context(|| { + vcpu.realize(&Some(boot_cfg), topology).with_context(|| { format!( "Failed to realize arch cpu register/features for CPU {}", vcpu_id @@ -413,13 +365,7 @@ impl StdMachineOps for StdMachine { } impl MachineOps for StdMachine { - fn machine_base(&self) -> &MachineBase { - &self.base - } - - fn machine_base_mut(&mut self) -> &mut MachineBase { - &mut self.base - } + gen_base_func!(machine_base, machine_base_mut, MachineBase, base); fn init_machine_ram(&self, sys_mem: &Arc, mem_size: u64) -> Result<()> { let ram = self.get_vm_ram(); @@ -454,10 +400,11 @@ impl MachineOps for StdMachine { let mut locked_hypervisor = hypervisor.lock().unwrap(); locked_hypervisor.create_interrupt_controller()?; - let root_bus = &self.pci_host.lock().unwrap().root_bus; + let child_bus = self.pci_host.lock().unwrap().child_bus().unwrap(); + MUT_PCI_BUS!(child_bus, locked_bus, pci_bus); let irq_manager = locked_hypervisor.create_irq_manager()?; - root_bus.lock().unwrap().msi_irq_manager = irq_manager.msi_irq_manager; - self.base.sysbus.irq_manager = irq_manager.line_irq_manager; + pci_bus.msi_irq_manager = irq_manager.msi_irq_manager; + self.base.sysbus.lock().unwrap().irq_manager = irq_manager.line_irq_manager; Ok(()) } @@ -466,6 +413,7 @@ impl MachineOps for StdMachine { let boot_source = self.base.boot_source.lock().unwrap(); let initrd = boot_source.initrd.as_ref().map(|b| b.initrd_file.clone()); + // MEM_LAYOUT is defined statically, will not overflow. let gap_start = MEM_LAYOUT[LayoutEntryType::MemBelow4g as usize].0 + MEM_LAYOUT[LayoutEntryType::MemBelow4g as usize].1; let gap_end = MEM_LAYOUT[LayoutEntryType::MemAbove4g as usize].0; @@ -474,6 +422,7 @@ impl MachineOps for StdMachine { initrd, kernel_cmdline: boot_source.kernel_cmdline.to_string(), cpu_count: self.base.cpu_topo.nrcpus, + // gap_end is bigger than gap_start, as MEM_LAYOUT is defined statically. gap_range: (gap_start, gap_end - gap_start), ioapic_addr: MEM_LAYOUT[LayoutEntryType::IoApic as usize].0 as u32, lapic_addr: MEM_LAYOUT[LayoutEntryType::LocalApic as usize].0 as u32, @@ -493,38 +442,40 @@ impl MachineOps for StdMachine { } fn add_rtc_device(&mut self, mem_size: u64) -> Result<()> { - let mut rtc = RTC::new().with_context(|| "Failed to create RTC device")?; + let mut rtc = RTC::new(&self.base.sysbus).with_context(|| "Failed to create RTC device")?; rtc.set_memory( mem_size, + // MEM_LAYOUT is defined statically, will not overflow. MEM_LAYOUT[LayoutEntryType::MemBelow4g as usize].0 + MEM_LAYOUT[LayoutEntryType::MemBelow4g as usize].1, ); - RTC::realize(rtc, &mut self.base.sysbus).with_context(|| "Failed to realize RTC device") + rtc.realize() + .with_context(|| "Failed to realize RTC device")?; + Ok(()) } fn add_ged_device(&mut self) -> Result<()> { - let ged = Ged::default(); let region_base: u64 = MEM_LAYOUT[LayoutEntryType::GedMmio as usize].0; let region_size: u64 = MEM_LAYOUT[LayoutEntryType::GedMmio as usize].1; - let ged_event = GedEvent::new(self.power_button.clone(), self.cpu_resize_req.clone()); - ged.realize( - &mut self.base.sysbus, - ged_event, + let ged = Ged::new( false, + &self.base.sysbus, region_base, region_size, - ) - .with_context(|| "Failed to realize Ged")?; + ged_event, + )?; + + ged.realize().with_context(|| "Failed to realize Ged")?; Ok(()) } fn add_serial_device(&mut self, config: &SerialConfig) -> Result<()> { let region_base: u64 = SERIAL_ADDR; let region_size: u64 = 8; - let serial = Serial::new(config.clone()); + let serial = Serial::new(config.clone(), &self.base.sysbus, region_base, region_size)?; serial - .realize(&mut self.base.sysbus, region_base, region_size) + .realize() .with_context(|| "Failed to realize serial device.")?; Ok(()) } @@ -536,7 +487,6 @@ impl MachineOps for StdMachine { fn realize(vm: &Arc>, vm_config: &mut VmConfig) -> Result<()> { let nr_cpus = vm_config.machine_config.nr_cpus; let max_cpus = vm_config.machine_config.max_cpus; - let clone_vm = vm.clone(); let mut locked_vm = vm.lock().unwrap(); locked_vm.init_global_config(vm_config)?; locked_vm.base.numa_nodes = locked_vm.add_numa_nodes(vm_config)?; @@ -554,12 +504,17 @@ impl MachineOps for StdMachine { .init_pci_host() .with_context(|| MachineError::InitPCIeHostErr)?; locked_vm - .init_ich9_lpc(clone_vm) + .init_ich9_lpc(vm.clone()) .with_context(|| "Fail to init LPC bridge")?; locked_vm.add_devices(vm_config)?; let fwcfg = locked_vm.add_fwcfg_device(nr_cpus, max_cpus)?; - let boot_config = locked_vm.load_boot_source(fwcfg.as_ref())?; + let migrate = locked_vm.get_migrate_info(); + let boot_config = if migrate.0 == MigrateMode::Unknown { + Some(locked_vm.load_boot_source(fwcfg.as_ref())?) + } else { + None + }; let topology = CPUTopology::new().set_topology(( vm_config.machine_config.nr_threads, vm_config.machine_config.nr_cores, @@ -575,7 +530,9 @@ impl MachineOps for StdMachine { &boot_config, )?); - locked_vm.init_cpu_controller(boot_config, topology, vm.clone())?; + if migrate.0 == MigrateMode::Unknown { + locked_vm.init_cpu_controller(boot_config.unwrap(), topology, vm.clone())?; + } if let Some(fw_cfg) = fwcfg { locked_vm @@ -610,10 +567,11 @@ impl MachineOps for StdMachine { .with_context(|| "Fail to init display")?; #[cfg(feature = "windows_emu_pid")] - locked_vm.watch_windows_emu_pid( + crate::watch_windows_emu_pid( vm_config, locked_vm.shutdown_req.clone(), locked_vm.shutdown_req.clone(), + vm.clone(), ); MigrationManager::register_vm_config(locked_vm.get_vm_config()); @@ -628,23 +586,26 @@ impl MachineOps for StdMachine { Ok(()) } - fn add_pflash_device(&mut self, configs: &[PFlashConfig]) -> Result<()> { + fn add_pflash_device(&mut self, configs: &[DriveConfig]) -> Result<()> { let mut configs_vec = configs.to_vec(); - configs_vec.sort_by_key(|c| c.unit); + configs_vec.sort_by_key(|c| c.unit.unwrap()); // The two PFlash devices locates below 4GB, this variable represents the end address // of current PFlash device. let mut flash_end: u64 = MEM_LAYOUT[LayoutEntryType::MemAbove4g as usize].0; for config in configs_vec { - let mut fd = self.fetch_drive_file(&config.path_on_host)?; - let pfl_size = fd.metadata().unwrap().len(); + let file = self.fetch_drive_file(&config.path_on_host)?; + let pfl_size = file.as_ref().metadata()?.len(); - if config.unit == 0 { + if config.unit.unwrap() == 0 { // According to the Linux/x86 boot protocol, the memory region of // 0x000000 - 0x100000 (1 MiB) is for BIOS usage. And the top 128 // KiB is for BIOS code which is stored in the first PFlash. let rom_base = 0xe0000; let rom_size = 0x20000; - fd.seek(SeekFrom::Start(pfl_size - rom_size))?; + let seek_start = pfl_size + .checked_sub(rom_size) + .with_context(|| "pflash file size less than rom size")?; + file.as_ref().seek(SeekFrom::Start(seek_start))?; let ram1 = Arc::new(HostMemMapping::new( GuestAddress(rom_base), @@ -656,35 +617,36 @@ impl MachineOps for StdMachine { false, )?); let rom_region = Region::init_ram_region(ram1, "PflashRam"); - rom_region.write(&mut fd, GuestAddress(rom_base), 0, rom_size)?; + rom_region.write(&mut file.as_ref(), GuestAddress(rom_base), 0, rom_size)?; rom_region.set_priority(10); self.base .sys_mem .root() .add_subregion(rom_region, rom_base)?; - fd.rewind()? + file.as_ref().rewind()? } let sector_len: u32 = 1024 * 4; - let backend = Some(fd); + let backend = Some(file); + let region_base = flash_end + .checked_sub(pfl_size) + .with_context(|| "flash end is less than flash size")?; let pflash = PFlash::new( pfl_size, - &backend, + backend, sector_len, 4_u32, 1_u32, - config.read_only, + config.readonly, + &self.base.sysbus, + region_base, ) .with_context(|| MachineError::InitPflashErr)?; - PFlash::realize( - pflash, - &mut self.base.sysbus, - flash_end - pfl_size, - pfl_size, - backend, - ) - .with_context(|| MachineError::RlzPflashErr)?; + pflash + .realize() + .with_context(|| MachineError::RlzPflashErr)?; + // sub has been checked above. flash_end -= pfl_size; } @@ -697,7 +659,7 @@ impl MachineOps for StdMachine { // GTK display init. #[cfg(feature = "gtk")] match vm_config.display { - Some(ref ds_cfg) if ds_cfg.gtk => { + Some(ref ds_cfg) if ds_cfg.display_type == "gtk" => { let ui_context = UiContext { vm_name: vm_config.guest_name.clone(), power_button: None, @@ -793,7 +755,7 @@ impl AcpiBuilder for StdMachine { dsdt.append_child(sb_scope.aml_bytes().as_slice()); // 2. Info of devices attached to system bus. - dsdt.append_child(self.base.sysbus.aml_bytes().as_slice()); + dsdt.append_child(self.base.sysbus.lock().unwrap().aml_bytes().as_slice()); // 3. Add _S5 sleep state. let mut package = AmlPackage::new(4); @@ -879,6 +841,7 @@ impl AcpiBuilder for StdMachine { node: &NumaNode, srat: &mut AcpiTable, ) -> u64 { + // MEM_LAYOUT is defined statically, will not overflow. let mem_below_4g = MEM_LAYOUT[LayoutEntryType::MemBelow4g as usize].0 + MEM_LAYOUT[LayoutEntryType::MemBelow4g as usize].1; let mem_above_4g = MEM_LAYOUT[LayoutEntryType::MemAbove4g as usize].0; @@ -971,91 +934,9 @@ impl AcpiBuilder for StdMachine { .with_context(|| "Fail to add SRAT table to loader")?; Ok(srat_begin) } -} - -impl MachineLifecycle for StdMachine { - fn pause(&self) -> bool { - if self.notify_lifecycle(VmState::Running, VmState::Paused) { - event!(Stop); - true - } else { - false - } - } - fn resume(&self) -> bool { - if !self.notify_lifecycle(VmState::Paused, VmState::Running) { - return false; - } - event!(Resume); - true - } - - fn destroy(&self) -> bool { - if self.shutdown_req.write(1).is_err() { - error!("Failed to send shutdown request."); - return false; - } - - true - } - - fn reset(&mut self) -> bool { - if self.reset_req.write(1).is_err() { - error!("X86 standard vm write reset request failed"); - return false; - } - true - } - - fn notify_lifecycle(&self, old: VmState, new: VmState) -> bool { - if let Err(e) = self.vm_state_transfer( - &self.base.cpus, - &mut self.base.vm_state.0.lock().unwrap(), - old, - new, - ) { - error!("VM state transfer failed: {:?}", e); - return false; - } - true - } -} - -impl MigrateInterface for StdMachine { - fn migrate(&self, uri: String) -> Response { - match parse_incoming_uri(&uri) { - Ok((MigrateMode::File, path)) => migration::snapshot(path), - Ok((MigrateMode::Unix, path)) => migration::migration_unix_mode(path), - Ok((MigrateMode::Tcp, path)) => migration::migration_tcp_mode(path), - _ => Response::create_error_response( - qmp_schema::QmpErrorClass::GenericError(format!("Invalid uri: {}", uri)), - None, - ), - } - } - - fn query_migrate(&self) -> Response { - migration::query_migrate() - } - - fn cancel_migrate(&self) -> Response { - migration::cancel_migrate() - } -} - -impl MachineInterface for StdMachine {} -impl MachineExternalInterface for StdMachine {} -impl MachineTestInterface for StdMachine {} - -impl EventLoopManager for StdMachine { - fn loop_should_exit(&self) -> bool { - let vmstate = self.base.vm_state.deref().0.lock().unwrap(); - *vmstate == VmState::Shutdown - } - - fn loop_cleanup(&self) -> Result<()> { - set_termi_canon_mode().with_context(|| "Failed to set terminal to canonical mode")?; - Ok(()) + fn get_hardware_signature(&self) -> Option { + let vm_config = self.machine_base().vm_config.lock().unwrap(); + vm_config.hardware_signature } } diff --git a/machine_manager/Cargo.toml b/machine_manager/Cargo.toml index 29e151bf7650f9aa532b362ee78015a2febbf3ed..ec787bc510c1f4eb388d618034d58f700841d431 100644 --- a/machine_manager/Cargo.toml +++ b/machine_manager/Cargo.toml @@ -12,7 +12,7 @@ regex = "1" log = "0.4" libc = "0.2" serde_json = "1.0" -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" hex = "0.4.3" serde = { version = "1.0", features = ["derive"] } strum = "0.24.1" diff --git a/machine_manager/src/cmdline.rs b/machine_manager/src/cmdline.rs index 5b0d2751ee271880de21dcb980ba594b231f0523..b056b571b680ca2eb03c175dcb3034c5264c2005 100644 --- a/machine_manager/src/cmdline.rs +++ b/machine_manager/src/cmdline.rs @@ -11,9 +11,10 @@ // See the Mulan PSL v2 for more details. use anyhow::{bail, Context, Result}; +use clap::{ArgAction, Parser}; use crate::{ - config::{parse_trace_options, ChardevType, CmdParser, MachineType, VmConfig}, + config::{add_trace, str_slip_to_clap, ChardevType, MachineType, SocketType, VmConfig}, qmp::qmp_socket::QmpSocketPath, temp_cleaner::TempCleaner, }; @@ -37,7 +38,8 @@ use util::{ macro_rules! add_args_to_config { ( $x:tt, $z:expr, $s:tt ) => { if let Some(temp) = &$x { - $z.$s(temp)?; + $z.$s(temp) + .with_context(|| format!("Add args {:?} error.", temp))?; } }; ( $x:tt, $z:expr, $s:tt, vec ) => { @@ -64,7 +66,8 @@ macro_rules! add_args_to_config_multi { ( $x:tt, $z:expr, $s:tt ) => { if let Some(temps) = &$x { for temp in temps { - $z.$s(temp)?; + $z.$s(temp) + .with_context(|| format!("Add args {:?} error.", temp))?; } } }; @@ -246,6 +249,8 @@ pub fn create_args_parser<'a>() -> ArgParser<'a> { \n\t\tadd virtio pci balloon: -device virtio-balloon-pci,id=,bus=,addr=<0x4>[,deflate-on-oom=true|false][,free-page-reporting=true|false][,multifunction=on|off]; \ \n\t\tadd virtio mmio rng: -device virtio-rng-device,rng=,max-bytes=<1234>,period=<1000>; \ \n\t\tadd virtio pci rng: -device virtio-rng-pci,id=,rng=,max-bytes=<1234>,period=<1000>,bus=,addr=<0x1>[,multifunction=on|off]; \ + \n\t\tadd virtio mmio input: -device virtio-input-device,id=,evdev=; \ + \n\t\tadd virtio pci input: -device virtio-input-pci,id=,evdev=,bus=,addr=<0x1>[,multifunction=on|off]; \ \n\t\tadd pcie root port: -device pcie-root-port,id=,port=<0x1>,bus=,addr=<0x1>[,multifunction=on|off]; \ \n\t\tadd vfio pci: -device vfio-pci,id=,host=<0000:1a:00.3>,bus=,addr=<0x03>[,multifunction=on|off]; \ \n\t\tadd usb controller: -device nec-usb-xhci,id=,bus=,addr=<0xa>; \ @@ -446,8 +451,8 @@ pub fn create_args_parser<'a>() -> ArgParser<'a> { Arg::with_name("trace") .multiple(false) .long("trace") - .value_name("file=") - .help("specify the file lists trace state to enable") + .value_name("file=|type=") + .help("specify the trace state to enable") .takes_value(true), ) .arg( @@ -459,6 +464,15 @@ pub fn create_args_parser<'a>() -> ArgParser<'a> { .takes_values(true) .required(false), ) + .arg( + Arg::with_name("hardware-signature") + .multiple(false) + .long("hardware-signature") + .value_name("<32bit integer>") + .help("set ACPI Hardware Signature") + .takes_value(true) + .required(false), + ) .arg( Arg::with_name("smbios") .multiple(true) @@ -509,7 +523,7 @@ pub fn create_args_parser<'a>() -> ArgParser<'a> { .multiple(false) .long("windows_emu_pid") .value_name("pid") - .help("watch on the external windows emu pid") + .help("watch on the external emulator pid") .takes_value(true), ); @@ -576,6 +590,11 @@ pub fn create_vmconfig(args: &ArgMatches) -> Result { add_kernel_cmdline, vec ); + add_args_to_config!( + (args.value_of("hardware-signature")), + vm_cfg, + add_hw_signature + ); add_args_to_config_multi!((args.values_of("drive")), vm_cfg, add_drive); add_args_to_config_multi!((args.values_of("object")), vm_cfg, add_object); add_args_to_config_multi!((args.values_of("netdev")), vm_cfg, add_netdev); @@ -587,7 +606,7 @@ pub fn create_vmconfig(args: &ArgMatches) -> Result { add_args_to_config_multi!((args.values_of("cameradev")), vm_cfg, add_camera_backend); add_args_to_config_multi!((args.values_of("smbios")), vm_cfg, add_smbios); if let Some(opt) = args.value_of("trace") { - parse_trace_options(&opt)?; + add_trace(&opt)?; } // Check the mini-set for Vm to start is ok @@ -599,6 +618,28 @@ pub fn create_vmconfig(args: &ArgMatches) -> Result { Ok(vm_cfg) } +#[derive(Parser)] +#[command(no_binary_name(true))] +struct QmpConfig { + #[arg(long, alias = "classtype")] + uri: String, + #[arg(long, action = ArgAction::SetTrue, required = true)] + server: bool, + #[arg(long, action = ArgAction::SetTrue, required = true)] + nowait: bool, +} + +#[derive(Parser)] +#[command(no_binary_name(true))] +struct MonConfig { + #[arg(long, default_value = "")] + id: String, + #[arg(long, value_parser = ["control"])] + mode: String, + #[arg(long)] + chardev: String, +} + /// This function is to parse qmp socket path and type. /// /// # Arguments @@ -613,75 +654,34 @@ pub fn check_api_channel( vm_config: &mut VmConfig, ) -> Result> { let mut sock_paths = Vec::new(); - if let Some(qmp_config) = args.value_of("qmp") { - let mut cmd_parser = CmdParser::new("qmp"); - cmd_parser.push("").push("server").push("nowait"); - - cmd_parser.parse(&qmp_config)?; - if let Some(uri) = cmd_parser.get_value::("")? { - let sock_path = - QmpSocketPath::new(uri).with_context(|| "Failed to parse qmp socket path")?; - sock_paths.push(sock_path); - } else { - bail!("No uri found for qmp"); - } - if cmd_parser.get_value::("server")?.is_none() { - bail!("Argument \'server\' is needed for qmp"); - } - if cmd_parser.get_value::("nowait")?.is_none() { - bail!("Argument \'nowait\' is needed for qmp"); - } + if let Some(qmp_args) = args.value_of("qmp") { + let qmp_cfg = QmpConfig::try_parse_from(str_slip_to_clap(&qmp_args, true, false))?; + let sock_path = + QmpSocketPath::new(qmp_cfg.uri).with_context(|| "Failed to parse qmp socket path")?; + sock_paths.push(sock_path); } - if let Some(mon_config) = args.value_of("mon") { - let mut cmd_parser = CmdParser::new("monitor"); - cmd_parser.push("id").push("mode").push("chardev"); - cmd_parser.parse(&mon_config)?; - - let chardev = cmd_parser - .get_value::("chardev")? - .with_context(|| "Argument \'chardev\' is missing for \'mon\'")?; - - if let Some(mode) = cmd_parser.get_value::("mode")? { - if mode != *"control" { - bail!("Invalid \'mode\' parameter: {:?} for monitor", &mode); + if let Some(mon_args) = args.value_of("mon") { + let mon_cfg = MonConfig::try_parse_from(str_slip_to_clap(&mon_args, false, false))?; + let cfg = vm_config + .chardev + .remove(&mon_cfg.chardev) + .with_context(|| format!("No chardev found: {}", &mon_cfg.chardev))?; + let socket = cfg + .classtype + .socket_type() + .with_context(|| "Only chardev of unix-socket type can be used for monitor")?; + if let ChardevType::Socket { server, nowait, .. } = cfg.classtype { + if !server || !nowait { + bail!( + "Argument \'server\' and \'nowait\' are both required for chardev \'{}\'", + cfg.id() + ); } - } else { - bail!("Argument \'mode\' of \'mon\' should be set to \'control\'."); } - - if let Some(cfg) = vm_config.chardev.remove(&chardev) { - if let ChardevType::UnixSocket { - path, - server, - nowait, - } = cfg.backend - { - if !server || !nowait { - bail!( - "Argument \'server\' and \'nowait\' are both required for chardev \'{}\'", - path - ); - } - sock_paths.push(QmpSocketPath::Unix { path }); - } else if let ChardevType::TcpSocket { - host, - port, - server, - nowait, - } = cfg.backend - { - if !server || !nowait { - bail!( - "Argument \'server\' and \'nowait\' are both required for chardev \'{}:{}\'", - host, port - ); - } - sock_paths.push(QmpSocketPath::Tcp { host, port }); - } else { - bail!("Only chardev of unix-socket type can be used for monitor"); - } - } else { - bail!("No chardev found: {}", &chardev); + if let SocketType::Tcp { host, port } = socket { + sock_paths.push(QmpSocketPath::Tcp { host, port }); + } else if let SocketType::Unix { path } = socket { + sock_paths.push(QmpSocketPath::Unix { path }); } } diff --git a/machine_manager/src/config/boot_source.rs b/machine_manager/src/config/boot_source.rs index 6f7c227a44e834c6248a48aa53e71df3ef71420a..997b841044396e43edcf5a75ba67c1c263ab1095 100644 --- a/machine_manager/src/config/boot_source.rs +++ b/machine_manager/src/config/boot_source.rs @@ -239,8 +239,8 @@ mod tests { test_kernel_param.push(Param::from_str("maxcpus=8")); assert_eq!(test_kernel_param.params.len(), 6); - assert_eq!(test_kernel_param.contains("maxcpus"), true); - assert_eq!(test_kernel_param.contains("cpus"), false); + assert!(test_kernel_param.contains("maxcpus")); + assert!(!test_kernel_param.contains("cpus")); assert_eq!( test_kernel_param.to_string(), "reboot=k panic=1 pci=off nomodules 8250.nr_uarts=0 maxcpus=8" @@ -257,7 +257,7 @@ mod tests { initrd_file.set_len(100_u64).unwrap(); let mut vm_config = VmConfig::default(); assert!(vm_config.add_kernel(&kernel_path).is_ok()); - vm_config.add_kernel_cmdline(&vec![ + vm_config.add_kernel_cmdline(&[ String::from("console=ttyS0"), String::from("reboot=k"), String::from("panic=1"), diff --git a/machine_manager/src/config/camera.rs b/machine_manager/src/config/camera.rs index a5ed0702a23a59f16e10b2c57a425393f20898f8..b3964f4f2f3cd65e67af0958b5bffe42595851ad 100644 --- a/machine_manager/src/config/camera.rs +++ b/machine_manager/src/config/camera.rs @@ -22,8 +22,10 @@ use crate::{ }; #[derive(Parser, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[command(name = "camera device")] +#[command(no_binary_name(true))] pub struct CameraDevConfig { + #[arg(long)] + pub classtype: String, #[arg(long, value_parser = valid_id)] pub id: String, #[arg(long)] @@ -59,7 +61,7 @@ impl FromStr for CamBackendType { impl VmConfig { pub fn add_camera_backend(&mut self, camera_config: &str) -> Result<()> { let cfg = format!("cameradev,backend={}", camera_config); - let config = CameraDevConfig::try_parse_from(str_slip_to_clap(&cfg))?; + let config = CameraDevConfig::try_parse_from(str_slip_to_clap(&cfg, true, false))?; self.add_cameradev_with_config(config) } @@ -91,10 +93,10 @@ impl VmConfig { } pub fn del_cameradev_by_id(&mut self, id: &str) -> Result<()> { - if self.camera_backend.get(&id.to_string()).is_none() { + if !self.camera_backend.contains_key(id) { bail!("no cameradev with id {}", id); } - self.camera_backend.remove(&id.to_string()); + self.camera_backend.remove(id); Ok(()) } @@ -103,6 +105,7 @@ impl VmConfig { pub fn get_cameradev_config(args: qmp_schema::CameraDevAddArgument) -> Result { let path = args.path.with_context(|| "cameradev config path is null")?; let config = CameraDevConfig { + classtype: "cameradev".to_string(), id: args.id, path, backend: CamBackendType::from_str(&args.driver)?, diff --git a/machine_manager/src/config/chardev.rs b/machine_manager/src/config/chardev.rs index 943de721b009e601053827f3f361ab7a583b4faf..1eb0d489c67483c727c416a141072447a6197a32 100644 --- a/machine_manager/src/config/chardev.rs +++ b/machine_manager/src/config/chardev.rs @@ -14,248 +14,168 @@ use std::net::IpAddr; use std::str::FromStr; use anyhow::{anyhow, bail, Context, Result}; +use clap::{ArgAction, Parser, Subcommand}; use log::error; use serde::{Deserialize, Serialize}; -use super::{error::ConfigError, get_pci_bdf, pci_args_check, PciBdf}; -use crate::config::{ - check_arg_too_long, CmdParser, ConfigCheck, ExBool, VmConfig, MAX_PATH_LENGTH, -}; +use super::{error::ConfigError, str_slip_to_clap}; +use super::{get_pci_df, parse_bool}; +use crate::config::{valid_id, valid_path, valid_socket_path, ConfigCheck, VmConfig}; use crate::qmp::qmp_schema; -const MAX_GUEST_CID: u64 = 4_294_967_295; -const MIN_GUEST_CID: u64 = 3; - /// Default value of max ports for virtio-serial. const DEFAULT_SERIAL_PORTS_NUMBER: u32 = 31; -/// Character device options. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum ChardevType { - Stdio, - Pty, - UnixSocket { - path: String, - server: bool, - nowait: bool, - }, - TcpSocket { - host: String, - port: u16, - server: bool, - nowait: bool, - }, - File(String), -} - /// Config structure for virtio-serial-port. -#[derive(Debug, Clone)] -pub struct VirtioSerialPort { +#[derive(Parser, Debug, Clone)] +#[command(no_binary_name(true))] +pub struct VirtioSerialPortCfg { + #[arg(long, value_parser = ["virtconsole", "virtserialport"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] pub id: String, - pub chardev: ChardevConfig, - pub nr: u32, - pub is_console: bool, + #[arg(long)] + pub chardev: String, + #[arg(long)] + pub nr: Option, } -impl ConfigCheck for VirtioSerialPort { +impl ConfigCheck for VirtioSerialPortCfg { fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "chardev id") - } -} - -/// Config structure for character device. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChardevConfig { - pub id: String, - pub backend: ChardevType, -} - -impl ConfigCheck for ChardevConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "chardev id")?; - match &self.backend { - ChardevType::UnixSocket { path, .. } => { - if path.len() > MAX_PATH_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "unix-socket path".to_string(), - MAX_PATH_LENGTH - ))); - } - Ok(()) - } - ChardevType::TcpSocket { host, port, .. } => { - if *port == 0u16 { - return Err(anyhow!(ConfigError::InvalidParam( - "port".to_string(), - "tcp-socket".to_string() - ))); - } - let ip_address = IpAddr::from_str(host); - if ip_address.is_err() { - return Err(anyhow!(ConfigError::InvalidParam( - "host".to_string(), - "tcp-socket".to_string() - ))); - } - Ok(()) - } - ChardevType::File(path) => { - if path.len() > MAX_PATH_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "file path".to_string(), - MAX_PATH_LENGTH - ))); - } - Ok(()) - } - _ => Ok(()), + if self.classtype != "virtconsole" && self.nr.unwrap() == 0 { + bail!("Port number 0 on virtio-serial devices reserved for virtconsole device."); } + + Ok(()) } } -fn check_chardev_fields( - dev_type: &str, - cmd_parser: &CmdParser, - supported_fields: &[&str], -) -> Result<()> { - for (field, value) in &cmd_parser.params { - let supported_field = supported_fields.contains(&field.as_str()); - if !supported_field && value.is_some() { +impl VirtioSerialPortCfg { + /// If nr is not set in command line. Configure incremental maximum value for virtconsole. + /// Configure incremental maximum value(except 0) for virtserialport. + pub fn auto_nr(&mut self, free_port0: bool, free_nr: u32, max_nr_ports: u32) -> Result<()> { + let free_console_nr = if free_port0 { 0 } else { free_nr }; + let auto_nr = match self.classtype.as_str() { + "virtconsole" => free_console_nr, + "virtserialport" => free_nr, + _ => bail!("Invalid classtype."), + }; + let nr = self.nr.unwrap_or(auto_nr); + if nr >= max_nr_ports { bail!( - "Chardev of type {} does not support \'{}\' argument", - dev_type, - field + "virtio serial port nr {} should be less than virtio serial's max_nr_ports {}", + nr, + max_nr_ports ); } - } - Ok(()) -} - -fn parse_stdio_chardev(chardev_id: String, cmd_parser: CmdParser) -> Result { - let supported_fields = ["", "id"]; - check_chardev_fields("stdio", &cmd_parser, &supported_fields)?; - Ok(ChardevConfig { - id: chardev_id, - backend: ChardevType::Stdio, - }) -} -fn parse_pty_chardev(chardev_id: String, cmd_parser: CmdParser) -> Result { - let supported_fields = ["", "id"]; - check_chardev_fields("pty", &cmd_parser, &supported_fields)?; - Ok(ChardevConfig { - id: chardev_id, - backend: ChardevType::Pty, - }) + self.nr = Some(nr); + Ok(()) + } } -fn parse_file_chardev(chardev_id: String, cmd_parser: CmdParser) -> Result { - let supported_fields = ["", "id", "path"]; - check_chardev_fields("file", &cmd_parser, &supported_fields)?; - - let path = cmd_parser - .get_value::("path")? - .with_context(|| ConfigError::FieldIsMissing("path".to_string(), "chardev".to_string()))?; - - let default_value = path.clone(); - let file_path = std::fs::canonicalize(path).map_or(default_value, |canonical_path| { - String::from(canonical_path.to_str().unwrap()) - }); - - Ok(ChardevConfig { - id: chardev_id, - backend: ChardevType::File(file_path), - }) +/// Config structure for character device. +#[derive(Parser, Debug, Clone, Serialize, Deserialize)] +#[command(no_binary_name(true))] +pub struct ChardevConfig { + #[command(subcommand)] + pub classtype: ChardevType, } -fn parse_socket_chardev(chardev_id: String, cmd_parser: CmdParser) -> Result { - let mut server_enabled = false; - let server = cmd_parser.get_value::("server")?; - if let Some(server) = server { - if server.ne("") { - bail!("No parameter needed for server"); +impl ChardevConfig { + pub fn id(&self) -> String { + match &self.classtype { + ChardevType::Stdio { id } => id, + ChardevType::Pty { id } => id, + ChardevType::Socket { id, .. } => id, + ChardevType::File { id, .. } => id, } - server_enabled = true; + .clone() } +} - let mut nowait_enabled = false; - let nowait = cmd_parser.get_value::("nowait")?; - if let Some(nowait) = nowait { - if nowait.ne("") { - bail!("No parameter needed for nowait"); +impl ConfigCheck for ChardevConfig { + fn check(&self) -> Result<()> { + if let ChardevType::Socket { .. } = self.classtype { + self.classtype.socket_type()?; } - nowait_enabled = true; - } - - let path = cmd_parser.get_value::("path")?; - if let Some(path) = path { - let supported_fields = ["", "id", "path", "server", "nowait"]; - check_chardev_fields("unix-socket", &cmd_parser, &supported_fields)?; - - let default_value = path.clone(); - let socket_path = std::fs::canonicalize(path).map_or(default_value, |canonical_path| { - String::from(canonical_path.to_str().unwrap()) - }); - - return Ok(ChardevConfig { - id: chardev_id, - backend: ChardevType::UnixSocket { - path: socket_path, - server: server_enabled, - nowait: nowait_enabled, - }, - }); - } - let port = cmd_parser.get_value::("port")?; - if let Some(port) = port { - let supported_fields = ["", "id", "host", "port", "server", "nowait"]; - check_chardev_fields("tcp-socket", &cmd_parser, &supported_fields)?; - - let host = cmd_parser.get_value::("host")?; - return Ok(ChardevConfig { - id: chardev_id, - backend: ChardevType::TcpSocket { - host: host.unwrap_or_else(|| String::from("0.0.0.0")), - port, - server: server_enabled, - nowait: nowait_enabled, - }, - }); + Ok(()) } +} - Err(anyhow!(ConfigError::InvalidParam( - "backend".to_string(), - "chardev".to_string() - ))) +/// Character device options. +#[derive(Subcommand, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ChardevType { + Stdio { + #[arg(long, value_parser = valid_id)] + id: String, + }, + Pty { + #[arg(long, value_parser = valid_id)] + id: String, + }, + // Unix Socket: use `path`. + // Tcp Socket: use `host` and `port`. + #[clap(group = clap::ArgGroup::new("unix-socket").args(&["host", "port"]).requires("port").multiple(true).conflicts_with("tcp-socket"))] + #[clap(group = clap::ArgGroup::new("tcp-socket").arg("path").conflicts_with("unix-socket"))] + Socket { + #[arg(long, value_parser = valid_id)] + id: String, + #[arg(long, value_parser = valid_socket_path)] + path: Option, + #[arg(long, value_parser = valid_host, default_value = "0.0.0.0")] + host: String, + #[arg(long, value_parser = clap::value_parser!(u16).range(1..))] + port: Option, + #[arg(long, action = ArgAction::SetTrue)] + server: bool, + #[arg(long, action = ArgAction::SetTrue)] + nowait: bool, + }, + File { + #[arg(long, value_parser = valid_id)] + id: String, + #[arg(long, value_parser = valid_path)] + path: String, + }, } -pub fn parse_chardev(chardev_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("chardev"); - for field in ["", "id", "path", "host", "port", "server", "nowait"] { - cmd_parser.push(field); +impl ChardevType { + pub fn socket_type(&self) -> Result { + if let ChardevType::Socket { + path, host, port, .. + } = self + { + if path.is_some() && port.is_none() { + return Ok(SocketType::Unix { + path: path.clone().unwrap(), + }); + } else if port.is_some() && path.is_none() { + return Ok(SocketType::Tcp { + host: host.clone(), + port: (*port).unwrap(), + }); + } + } + bail!("Not socket type or invalid socket type"); } +} - cmd_parser.parse(chardev_config)?; - - let chardev_id = cmd_parser - .get_value::("id")? - .with_context(|| ConfigError::FieldIsMissing("id".to_string(), "chardev".to_string()))?; - - let backend = cmd_parser - .get_value::("")? - .with_context(|| ConfigError::InvalidParam("backend".to_string(), "chardev".to_string()))?; - - match backend.as_str() { - "stdio" => parse_stdio_chardev(chardev_id, cmd_parser), - "pty" => parse_pty_chardev(chardev_id, cmd_parser), - "file" => parse_file_chardev(chardev_id, cmd_parser), - "socket" => parse_socket_chardev(chardev_id, cmd_parser), - _ => Err(anyhow!(ConfigError::InvalidParam( - backend, - "chardev".to_string() - ))), +pub enum SocketType { + Unix { path: String }, + Tcp { host: String, port: u16 }, +} + +fn valid_host(host: &str) -> Result { + let ip_address = IpAddr::from_str(host); + if ip_address.is_err() { + return Err(anyhow!(ConfigError::InvalidParam( + "host".to_string(), + "tcp-socket".to_string() + ))); } + Ok(host.to_string()) } /// Get chardev config from qmp arguments. @@ -291,9 +211,11 @@ pub fn get_chardev_config(args: qmp_schema::CharDevAddArgument) -> Result Result Result { - if let Some(char_dev) = vm_config.chardev.remove(chardev) { - match char_dev.backend.clone() { - ChardevType::UnixSocket { - path, - server, - nowait, - } => { - if server || nowait { - bail!( - "Argument \'server\' or \'nowait\' is not need for chardev \'{}\'", - path - ); - } - Ok(path) - } - _ => { - bail!( - "Chardev {:?} backend should be unix-socket type.", - &char_dev.id - ); - } +pub fn get_chardev_socket_path(chardev: ChardevConfig) -> Result { + let id = chardev.id(); + if let ChardevType::Socket { + path, + server, + nowait, + .. + } = chardev.classtype + { + path.clone() + .with_context(|| format!("Chardev {:?} backend should be unix-socket type.", id))?; + if server || nowait { + bail!( + "Argument \'server\' or \'nowait\' is not need for chardev \'{}\'", + path.unwrap() + ); } - } else { - bail!("Chardev: {:?} not found for character device", &chardev); + return Ok(path.unwrap()); } -} - -pub fn parse_virtserialport( - vm_config: &mut VmConfig, - config_args: &str, - is_console: bool, - free_nr: u32, - free_port0: bool, -) -> Result { - let mut cmd_parser = CmdParser::new("virtserialport"); - cmd_parser.push("").push("id").push("chardev").push("nr"); - cmd_parser.parse(config_args)?; - - let chardev_name = cmd_parser - .get_value::("chardev")? - .with_context(|| { - ConfigError::FieldIsMissing("chardev".to_string(), "virtserialport".to_string()) - })?; - let id = cmd_parser.get_value::("id")?.with_context(|| { - ConfigError::FieldIsMissing("id".to_string(), "virtserialport".to_string()) - })?; - - let nr = cmd_parser - .get_value::("nr")? - .unwrap_or(if is_console && free_port0 { 0 } else { free_nr }); - - if nr == 0 && !is_console { - bail!("Port number 0 on virtio-serial devices reserved for virtconsole device."); - } - - if let Some(chardev) = vm_config.chardev.remove(&chardev_name) { - let port_cfg = VirtioSerialPort { - id, - chardev, - nr, - is_console, - }; - port_cfg.check()?; - return Ok(port_cfg); - } - bail!("Chardev {:?} not found or is in use", &chardev_name); + bail!("Chardev {:?} backend should be unix-socket type.", id); } impl VmConfig { /// Add chardev config to `VmConfig`. pub fn add_chardev(&mut self, chardev_config: &str) -> Result<()> { - let chardev = parse_chardev(chardev_config)?; + let chardev = ChardevConfig::try_parse_from(str_slip_to_clap(chardev_config, true, true))?; chardev.check()?; - let chardev_id = chardev.id.clone(); - if self.chardev.get(&chardev_id).is_none() { - self.chardev.insert(chardev_id, chardev); - } else { - bail!("Chardev {:?} has been added", &chardev_id); - } + self.add_chardev_with_config(chardev)?; Ok(()) } @@ -395,16 +264,11 @@ impl VmConfig { /// /// * `conf` - The chardev config to be added to the vm. pub fn add_chardev_with_config(&mut self, conf: ChardevConfig) -> Result<()> { - if let Err(e) = conf.check() { - bail!("Chardev config checking failed, {}", e.to_string()); - } - - let chardev_id = conf.id.clone(); - if self.chardev.get(&chardev_id).is_none() { - self.chardev.insert(chardev_id, conf); - } else { + let chardev_id = conf.id(); + if self.chardev.contains_key(&chardev_id) { bail!("Chardev {:?} has been added", chardev_id); } + self.chardev.insert(chardev_id, conf); Ok(()) } @@ -414,11 +278,9 @@ impl VmConfig { /// /// * `id` - The chardev id which is used to delete chardev config. pub fn del_chardev_by_id(&mut self, id: &str) -> Result<()> { - if self.chardev.get(id).is_some() { - self.chardev.remove(id); - } else { - bail!("Chardev {} not found", id); - } + self.chardev + .remove(id) + .with_context(|| format!("Chardev {} not found", id))?; Ok(()) } } @@ -458,189 +320,74 @@ impl VmConfig { } } -/// Config structure for virtio-vsock. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct VsockConfig { - pub id: String, - pub guest_cid: u64, - pub vhost_fd: Option, -} - -impl ConfigCheck for VsockConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "vsock id")?; - - if self.guest_cid < MIN_GUEST_CID || self.guest_cid >= MAX_GUEST_CID { - return Err(anyhow!(ConfigError::IllegalValue( - "Vsock guest-cid".to_string(), - MIN_GUEST_CID, - true, - MAX_GUEST_CID, - false, - ))); - } - - Ok(()) - } -} - -pub fn parse_vsock(vsock_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("vhost-vsock"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("addr") - .push("multifunction") - .push("guest-cid") - .push("vhostfd"); - cmd_parser.parse(vsock_config)?; - pci_args_check(&cmd_parser)?; - let id = cmd_parser - .get_value::("id")? - .with_context(|| ConfigError::FieldIsMissing("id".to_string(), "vsock".to_string()))?; - - let guest_cid = cmd_parser.get_value::("guest-cid")?.with_context(|| { - ConfigError::FieldIsMissing("guest-cid".to_string(), "vsock".to_string()) - })?; - - let vhost_fd = cmd_parser.get_value::("vhostfd")?; - let vsock = VsockConfig { - id, - guest_cid, - vhost_fd, - }; - Ok(vsock) -} - -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Parser, Clone, Debug, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct VirtioSerialInfo { + #[arg(long, value_parser = ["virtio-serial-pci", "virtio-serial-device"])] + pub classtype: String, + #[arg(long, default_value = "", value_parser = valid_id)] pub id: String, - pub pci_bdf: Option, - pub multifunction: bool, + #[arg(long)] + pub bus: Option, + #[arg(long, value_parser = get_pci_df)] + pub addr: Option<(u8, u8)>, + #[arg(long, value_parser = parse_bool, action = ArgAction::Append)] + pub multifunction: Option, + #[arg(long, default_value = "31", value_parser = clap::value_parser!(u32).range(1..=DEFAULT_SERIAL_PORTS_NUMBER as i64))] pub max_ports: u32, } -impl ConfigCheck for VirtioSerialInfo { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "virtio-serial id")?; - - if self.max_ports < 1 || self.max_ports > DEFAULT_SERIAL_PORTS_NUMBER { - return Err(anyhow!(ConfigError::IllegalValue( - "Virtio-serial max_ports".to_string(), - 1, - true, - DEFAULT_SERIAL_PORTS_NUMBER as u64, - true - ))); - } - - Ok(()) - } -} - -pub fn parse_virtio_serial( - vm_config: &mut VmConfig, - serial_config: &str, -) -> Result { - let mut cmd_parser = CmdParser::new("virtio-serial"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("addr") - .push("multifunction") - .push("max_ports"); - cmd_parser.parse(serial_config)?; - pci_args_check(&cmd_parser)?; - - if vm_config.virtio_serial.is_some() { - bail!("Only one virtio serial device is supported"); - } - - let id = cmd_parser.get_value::("id")?.unwrap_or_default(); - let multifunction = cmd_parser - .get_value::("multifunction")? - .map_or(false, |switch| switch.into()); - let max_ports = cmd_parser - .get_value::("max_ports")? - .unwrap_or(DEFAULT_SERIAL_PORTS_NUMBER); - let virtio_serial = if serial_config.contains("-pci") { - let pci_bdf = get_pci_bdf(serial_config)?; - VirtioSerialInfo { - id, - pci_bdf: Some(pci_bdf), - multifunction, - max_ports, - } - } else { - VirtioSerialInfo { - id, - pci_bdf: None, - multifunction, +impl VirtioSerialInfo { + pub fn auto_max_ports(&mut self) { + if self.classtype == "virtio-serial-device" { // Micro_vm does not support multi-ports in virtio-serial-device. - max_ports: 1, + self.max_ports = 1; } - }; - virtio_serial.check()?; - vm_config.virtio_serial = Some(virtio_serial.clone()); - - Ok(virtio_serial) + } } #[cfg(test)] mod tests { use super::*; - use crate::config::parse_virtio_serial; fn test_mmio_console_config_cmdline_parser(chardev_cfg: &str, expected_chardev: ChardevType) { let mut vm_config = VmConfig::default(); - assert!(parse_virtio_serial(&mut vm_config, "virtio-serial-device").is_ok()); + let serial_cmd = "virtio-serial-device"; + let mut serial_cfg = + VirtioSerialInfo::try_parse_from(str_slip_to_clap(serial_cmd, true, false)).unwrap(); + serial_cfg.auto_max_ports(); + vm_config.virtio_serial = Some(serial_cfg.clone()); assert!(vm_config.add_chardev(chardev_cfg).is_ok()); - let virt_console = parse_virtserialport( - &mut vm_config, - "virtconsole,chardev=test_console,id=console1,nr=1", - true, - 0, - true, - ); - assert!(virt_console.is_ok()); - - let console_cfg = virt_console.unwrap(); - assert_eq!(console_cfg.id, "console1"); - assert_eq!(console_cfg.chardev.backend, expected_chardev); + let port_cmd = "virtconsole,chardev=test_console,id=console1,nr=0"; + let mut port_cfg = + VirtioSerialPortCfg::try_parse_from(str_slip_to_clap(port_cmd, true, false)).unwrap(); + assert!(port_cfg.auto_nr(true, 0, serial_cfg.max_ports).is_ok()); + let chardev = vm_config.chardev.remove(&port_cfg.chardev).unwrap(); + assert_eq!(port_cfg.id, "console1"); + assert_eq!(port_cfg.nr.unwrap(), 0); + assert_eq!(chardev.classtype, expected_chardev); + + // Error: VirtioSerialPortCfg.nr >= VirtioSerialInfo.max_nr_ports. + let port_cmd = "virtconsole,chardev=test_console,id=console1,nr=1"; + let mut port_cfg = + VirtioSerialPortCfg::try_parse_from(str_slip_to_clap(port_cmd, true, false)).unwrap(); + assert!(port_cfg.auto_nr(true, 0, serial_cfg.max_ports).is_err()); let mut vm_config = VmConfig::default(); - assert!( - parse_virtio_serial(&mut vm_config, "virtio-serial-device,bus=pcie.0,addr=0x1") - .is_err() - ); assert!(vm_config .add_chardev("sock,id=test_console,path=/path/to/socket") .is_err()); - - let mut vm_config = VmConfig::default(); - assert!(parse_virtio_serial(&mut vm_config, "virtio-serial-device").is_ok()); - assert!(vm_config - .add_chardev("socket,id=test_console,path=/path/to/socket,server,nowait") - .is_ok()); - let virt_console = parse_virtserialport( - &mut vm_config, - "virtconsole,chardev=test_console1,id=console1,nr=1", - true, - 0, - true, - ); - // test_console1 does not exist. - assert!(virt_console.is_err()); } #[test] fn test_mmio_console_config_cmdline_parser_1() { let chardev_cfg = "socket,id=test_console,path=/path/to/socket,server,nowait"; - let expected_chardev = ChardevType::UnixSocket { - path: "/path/to/socket".to_string(), + let expected_chardev = ChardevType::Socket { + id: "test_console".to_string(), + path: Some("/path/to/socket".to_string()), + host: "0.0.0.0".to_string(), + port: None, server: true, nowait: true, }; @@ -650,9 +397,11 @@ mod tests { #[test] fn test_mmio_console_config_cmdline_parser_2() { let chardev_cfg = "socket,id=test_console,host=127.0.0.1,port=9090,server,nowait"; - let expected_chardev = ChardevType::TcpSocket { + let expected_chardev = ChardevType::Socket { + id: "test_console".to_string(), + path: None, host: "127.0.0.1".to_string(), - port: 9090, + port: Some(9090), server: true, nowait: true, }; @@ -661,41 +410,34 @@ mod tests { fn test_pci_console_config_cmdline_parser(chardev_cfg: &str, expected_chardev: ChardevType) { let mut vm_config = VmConfig::default(); - let virtio_arg = "virtio-serial-pci,bus=pcie.0,addr=0x1.0x2"; - assert!(parse_virtio_serial(&mut vm_config, virtio_arg).is_ok()); + let serial_cmd = "virtio-serial-pci,bus=pcie.0,addr=0x1.0x2,multifunction=on"; + let mut serial_cfg = + VirtioSerialInfo::try_parse_from(str_slip_to_clap(serial_cmd, true, false)).unwrap(); + serial_cfg.auto_max_ports(); + vm_config.virtio_serial = Some(serial_cfg.clone()); assert!(vm_config.add_chardev(chardev_cfg).is_ok()); - let virt_console = parse_virtserialport( - &mut vm_config, - "virtconsole,chardev=test_console,id=console1,nr=1", - true, - 0, - true, - ); - assert!(virt_console.is_ok()); - let console_cfg = virt_console.unwrap(); - + let console_cmd = "virtconsole,chardev=test_console,id=console1,nr=1"; + let mut console_cfg = + VirtioSerialPortCfg::try_parse_from(str_slip_to_clap(console_cmd, true, false)) + .unwrap(); + assert!(console_cfg.auto_nr(true, 0, serial_cfg.max_ports).is_ok()); + let chardev = vm_config.chardev.remove(&console_cfg.chardev).unwrap(); assert_eq!(console_cfg.id, "console1"); let serial_info = vm_config.virtio_serial.clone().unwrap(); - assert!(serial_info.pci_bdf.is_some()); - let bdf = serial_info.pci_bdf.unwrap(); - assert_eq!(bdf.bus, "pcie.0"); - assert_eq!(bdf.addr, (1, 2)); - assert_eq!(console_cfg.chardev.backend, expected_chardev); - - let mut vm_config = VmConfig::default(); - assert!(parse_virtio_serial( - &mut vm_config, - "virtio-serial-pci,bus=pcie.0,addr=0x1.0x2,multifunction=on" - ) - .is_ok()); + assert_eq!(serial_info.bus.unwrap(), "pcie.0"); + assert_eq!(serial_info.addr.unwrap(), (1, 2)); + assert_eq!(chardev.classtype, expected_chardev); } #[test] fn test_pci_console_config_cmdline_parser_1() { let chardev_cfg = "socket,id=test_console,path=/path/to/socket,server,nowait"; - let expected_chardev = ChardevType::UnixSocket { - path: "/path/to/socket".to_string(), + let expected_chardev = ChardevType::Socket { + id: "test_console".to_string(), + path: Some("/path/to/socket".to_string()), + host: "0.0.0.0".to_string(), + port: None, server: true, nowait: true, }; @@ -705,36 +447,17 @@ mod tests { #[test] fn test_pci_console_config_cmdline_parser_2() { let chardev_cfg = "socket,id=test_console,host=127.0.0.1,port=9090,server,nowait"; - let expected_chardev = ChardevType::TcpSocket { + let expected_chardev = ChardevType::Socket { + id: "test_console".to_string(), + path: None, host: "127.0.0.1".to_string(), - port: 9090, + port: Some(9090), server: true, nowait: true, }; test_pci_console_config_cmdline_parser(chardev_cfg, expected_chardev) } - #[test] - fn test_vsock_config_cmdline_parser() { - let vsock_cfg_op = parse_vsock("vhost-vsock-device,id=test_vsock,guest-cid=3"); - assert!(vsock_cfg_op.is_ok()); - - let vsock_config = vsock_cfg_op.unwrap(); - assert_eq!(vsock_config.id, "test_vsock"); - assert_eq!(vsock_config.guest_cid, 3); - assert_eq!(vsock_config.vhost_fd, None); - assert!(vsock_config.check().is_ok()); - - let vsock_cfg_op = parse_vsock("vhost-vsock-device,id=test_vsock,guest-cid=3,vhostfd=4"); - assert!(vsock_cfg_op.is_ok()); - - let vsock_config = vsock_cfg_op.unwrap(); - assert_eq!(vsock_config.id, "test_vsock"); - assert_eq!(vsock_config.guest_cid, 3); - assert_eq!(vsock_config.vhost_fd, Some(4)); - assert!(vsock_config.check().is_ok()); - } - #[test] fn test_chardev_config_cmdline_parser() { let check_argument = |arg: String, expect: ChardevType| { @@ -744,17 +467,30 @@ mod tests { let device_id = "test_id"; if let Some(char_dev) = vm_config.chardev.remove(device_id) { - assert_eq!(char_dev.backend, expect); + assert_eq!(char_dev.classtype, expect); } else { assert!(false); } }; - check_argument("stdio,id=test_id".to_string(), ChardevType::Stdio); - check_argument("pty,id=test_id".to_string(), ChardevType::Pty); + check_argument( + "stdio,id=test_id".to_string(), + ChardevType::Stdio { + id: "test_id".to_string(), + }, + ); + check_argument( + "pty,id=test_id".to_string(), + ChardevType::Pty { + id: "test_id".to_string(), + }, + ); check_argument( "file,id=test_id,path=/some/file".to_string(), - ChardevType::File("/some/file".to_string()), + ChardevType::File { + id: "test_id".to_string(), + path: "/some/file".to_string(), + }, ); let extra_params = [ @@ -767,17 +503,22 @@ mod tests { for (param, server_state, nowait_state) in extra_params { check_argument( format!("{}{}", "socket,id=test_id,path=/path/to/socket", param), - ChardevType::UnixSocket { - path: "/path/to/socket".to_string(), + ChardevType::Socket { + id: "test_id".to_string(), + path: Some("/path/to/socket".to_string()), + host: "0.0.0.0".to_string(), + port: None, server: server_state, nowait: nowait_state, }, ); check_argument( format!("{}{}", "socket,id=test_id,port=9090", param), - ChardevType::TcpSocket { + ChardevType::Socket { + id: "test_id".to_string(), + path: None, host: "0.0.0.0".to_string(), - port: 9090, + port: Some(9090), server: server_state, nowait: nowait_state, }, @@ -787,9 +528,11 @@ mod tests { "{}{}", "socket,id=test_id,host=172.56.16.12,port=7070", param ), - ChardevType::TcpSocket { + ChardevType::Socket { + id: "test_id".to_string(), + path: None, host: "172.56.16.12".to_string(), - port: 7070, + port: Some(7070), server: server_state, nowait: nowait_state, }, diff --git a/machine_manager/src/config/demo_dev.rs b/machine_manager/src/config/demo_dev.rs deleted file mode 100644 index 10d21994e3c3c89e42e3ca7cc48449f66d975574..0000000000000000000000000000000000000000 --- a/machine_manager/src/config/demo_dev.rs +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2023 Huawei Technologies Co.,Ltd. All rights reserved. -// -// StratoVirt is licensed under Mulan PSL v2. -// You can use this software according to the terms and conditions of the Mulan -// PSL v2. -// You may obtain a copy of Mulan PSL v2 at: -// http://license.coscl.org.cn/MulanPSL2 -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -// See the Mulan PSL v2 for more details. - -use anyhow::{bail, Result}; - -use super::{pci_args_check, CmdParser, VmConfig}; - -/// Config struct for `demo_dev`. -/// Contains demo_dev device's attr. -#[derive(Debug, Clone)] -pub struct DemoDevConfig { - pub id: String, - // Different device implementations can be configured based on this parameter - pub device_type: String, - pub bar_num: u8, - // Every bar has the same size just for simplification. - pub bar_size: u64, -} - -impl DemoDevConfig { - pub fn new() -> Self { - Self { - id: "".to_string(), - device_type: "".to_string(), - bar_num: 0, - bar_size: 0, - } - } -} - -impl Default for DemoDevConfig { - fn default() -> Self { - Self::new() - } -} - -pub fn parse_demo_dev(_vm_config: &mut VmConfig, args_str: String) -> Result { - let mut cmd_parser = CmdParser::new("demo-dev"); - cmd_parser - .push("") - .push("id") - .push("addr") - .push("device_type") - .push("bus") - .push("bar_num") - .push("bar_size"); - cmd_parser.parse(&args_str)?; - - pci_args_check(&cmd_parser)?; - - let mut demo_dev_cfg = DemoDevConfig::new(); - - if let Some(id) = cmd_parser.get_value::("id")? { - demo_dev_cfg.id = id; - } else { - bail!("No id configured for demo device"); - } - - if let Some(device_type) = cmd_parser.get_value::("device_type")? { - demo_dev_cfg.device_type = device_type; - } - - if let Some(bar_num) = cmd_parser.get_value::("bar_num")? { - demo_dev_cfg.bar_num = bar_num; - } - - // todo: support parsing hex num "0x**". It just supports decimal number now. - if let Some(bar_size) = cmd_parser.get_value::("bar_size")? { - demo_dev_cfg.bar_size = bar_size; - } - - Ok(demo_dev_cfg) -} - -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn test_parse_demo_dev() { - let mut vm_config = VmConfig::default(); - let config_line = "-device pcie-demo-dev,bus=pcie.0,addr=4.0,id=test_0,device_type=demo-gpu,bar_num=3,bar_size=4096"; - let demo_cfg = parse_demo_dev(&mut vm_config, config_line.to_string()).unwrap(); - assert_eq!(demo_cfg.id, "test_0".to_string()); - assert_eq!(demo_cfg.device_type, "demo-gpu".to_string()); - assert_eq!(demo_cfg.bar_num, 3); - assert_eq!(demo_cfg.bar_size, 4096); - } -} diff --git a/machine_manager/src/config/devices.rs b/machine_manager/src/config/devices.rs index cf42739b46f044df8ef0474f1e1e10ff8b0747d2..e355b88f285739bbefee70fdf5b97885d7e600df 100644 --- a/machine_manager/src/config/devices.rs +++ b/machine_manager/src/config/devices.rs @@ -13,7 +13,7 @@ use anyhow::{Context, Result}; use regex::Regex; -use super::{CmdParser, VmConfig}; +use super::{get_class_type, VmConfig}; use crate::qmp::qmp_schema; impl VmConfig { @@ -117,7 +117,7 @@ impl VmConfig { } pub fn add_device(&mut self, device_config: &str) -> Result<()> { - let device_type = parse_device_type(device_config)?; + let device_type = get_class_type(device_config).with_context(|| "Missing driver field.")?; self.devices.push((device_type, device_config.to_string())); Ok(()) @@ -135,44 +135,3 @@ impl VmConfig { } } } - -pub fn parse_device_type(device_config: &str) -> Result { - let mut cmd_params = CmdParser::new("device"); - cmd_params.push(""); - cmd_params.get_parameters(device_config)?; - cmd_params - .get_value::("")? - .with_context(|| "Missing driver field.") -} - -pub fn parse_device_id(device_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("device"); - cmd_parser.push("id"); - - cmd_parser.get_parameters(device_config)?; - if let Some(id) = cmd_parser.get_value::("id")? { - Ok(id) - } else { - Ok(String::new()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_device_id() { - let test_conf = "virtio-blk-device,drive=rootfs,id=blkid"; - let ret = parse_device_id(test_conf); - assert!(ret.is_ok()); - let id = ret.unwrap(); - assert_eq!("blkid", id); - - let test_conf = "virtio-blk-device,drive=rootfs"; - let ret = parse_device_id(test_conf); - assert!(ret.is_ok()); - let id = ret.unwrap(); - assert_eq!("", id); - } -} diff --git a/machine_manager/src/config/display.rs b/machine_manager/src/config/display.rs index 8f1f2e088d11d1089017b69c216ec96ae20ffc34..7369e4298794ee8fb10381a8f6451aaa04db4028 100644 --- a/machine_manager/src/config/display.rs +++ b/machine_manager/src/config/display.rs @@ -13,17 +13,15 @@ #[cfg(feature = "gtk")] use std::sync::Arc; +use anyhow::Result; #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] -use anyhow::Context; -use anyhow::{bail, Result}; +use anyhow::{bail, Context}; +use clap::{ArgAction, Parser}; use serde::{Deserialize, Serialize}; #[cfg(feature = "gtk")] use vmm_sys_util::eventfd::EventFd; -use crate::config::{CmdParser, ExBool, VmConfig}; - -#[cfg(all(target_env = "ohos", feature = "ohui_srv"))] -static DEFAULT_UI_PATH: &str = "/tmp/"; +use crate::config::{parse_bool, str_slip_to_clap, VmConfig}; /// Event fd related to power button in gtk. #[cfg(feature = "gtk")] @@ -41,88 +39,71 @@ pub struct UiContext { } #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct OhuiConfig { - /// Use OHUI. - pub ohui: bool, - /// Create the OHUI thread. - pub iothread: Option, - /// Confirm related files' path. - pub path: String, +fn get_dir_path(p: &str) -> Result { + if cfg!(debug_assertions) { + return Ok(p.to_string()); + } + + let path = std::fs::canonicalize(p) + .with_context(|| format!("Failed to get real directory path: {:?}", p))?; + if !path.exists() { + bail!( + "The defined directory {:?} path doesn't exist", + path.as_os_str() + ); + } + if !path.is_dir() { + bail!( + "The defined socks-path {:?} is not directory", + path.as_os_str() + ); + } + + Ok(path.to_str().unwrap().to_string()) } /// GTK and OHUI related configuration. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +#[command(no_binary_name(true))] + pub struct DisplayConfig { - /// Create the GTK thread. - pub gtk: bool, + #[arg(long, alias = "classtype", value_parser = ["gtk", "ohui"])] + pub display_type: String, /// App name if configured. + #[arg(long)] pub app_name: Option, /// Keep the window fill the desktop. + #[arg(long, default_value = "off", action = ArgAction::Append, value_parser = parse_bool)] pub full_screen: bool, - /// Used for OHUI + /// Create the OHUI thread. + #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] + #[arg(long)] + pub iothread: Option, + /// Confirm socket path. Default socket path is "/tmp". #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] - pub ohui_config: OhuiConfig, + #[arg(long, alias = "socks-path", default_value = "/tmp/", value_parser = get_dir_path)] + pub sock_path: String, + /// Define the directory path for OHUI framebuffer and cursor. + #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] + #[arg(long, alias = "ui-path", default_value_if("display_type", "ohui", "/dev/shm/hwf/"), default_value = "/tmp/", value_parser = get_dir_path)] + pub ui_path: String, } #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] impl DisplayConfig { + pub fn get_sock_path(&self) -> String { + self.sock_path.clone() + } + pub fn get_ui_path(&self) -> String { - self.ohui_config.path.clone() + self.ui_path.clone() } } impl VmConfig { pub fn add_display(&mut self, vm_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("display"); - cmd_parser.push("").push("full-screen").push("app-name"); - #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] - cmd_parser.push("iothread").push("socks-path"); - cmd_parser.parse(vm_config)?; - let mut display_config = DisplayConfig::default(); - if let Some(str) = cmd_parser.get_value::("")? { - match str.as_str() { - "gtk" => display_config.gtk = true, - #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] - "ohui" => display_config.ohui_config.ohui = true, - _ => bail!("Unsupported device: {}", str), - } - } - if let Some(name) = cmd_parser.get_value::("app-name")? { - display_config.app_name = Some(name); - } - if let Some(default) = cmd_parser.get_value::("full-screen")? { - display_config.full_screen = default.into(); - } - - #[cfg(all(target_env = "ohos", feature = "ohui_srv"))] - if display_config.ohui_config.ohui { - if let Some(iothread) = cmd_parser.get_value::("iothread")? { - display_config.ohui_config.iothread = Some(iothread); - } - - if let Some(path) = cmd_parser.get_value::("socks-path")? { - let path = std::fs::canonicalize(path.clone()).with_context(|| { - format!("Failed to get real directory path: {:?}", path.clone()) - })?; - if !path.exists() { - bail!( - "The defined directory {:?} path doesn't exist", - path.as_os_str() - ); - } - if !path.is_dir() { - bail!( - "The defined socks-path {:?} is not directory", - path.as_os_str() - ); - } - display_config.ohui_config.path = path.to_str().unwrap().to_string(); - } else { - display_config.ohui_config.path = DEFAULT_UI_PATH.to_string(); - } - } - + let display_config = + DisplayConfig::try_parse_from(str_slip_to_clap(vm_config, true, false))?; self.display = Some(display_config); Ok(()) } @@ -141,29 +122,29 @@ mod tests { let config_line = "gtk"; assert!(vm_config.add_display(config_line).is_ok()); let display_config = vm_config.display.unwrap(); - assert_eq!(display_config.gtk, true); - assert_eq!(display_config.full_screen, false); + assert_eq!(display_config.display_type, "gtk"); + assert!(!display_config.full_screen); let mut vm_config = VmConfig::default(); let config_line = "gtk,full-screen=on"; assert!(vm_config.add_display(config_line).is_ok()); let display_config = vm_config.display.unwrap(); - assert_eq!(display_config.gtk, true); - assert_eq!(display_config.full_screen, true); + assert_eq!(display_config.display_type, "gtk"); + assert!(display_config.full_screen); let mut vm_config = VmConfig::default(); let config_line = "gtk,full-screen=off"; assert!(vm_config.add_display(config_line).is_ok()); let display_config = vm_config.display.unwrap(); - assert_eq!(display_config.gtk, true); - assert_eq!(display_config.full_screen, false); + assert_eq!(display_config.display_type, "gtk"); + assert!(!display_config.full_screen); let mut vm_config = VmConfig::default(); let config_line = "gtk,app-name=desktopappengine"; assert!(vm_config.add_display(config_line).is_ok()); let display_config = vm_config.display.unwrap(); - assert_eq!(display_config.gtk, true); - assert_eq!(display_config.full_screen, false); + assert_eq!(display_config.display_type, "gtk"); + assert!(!display_config.full_screen); assert_eq!( display_config.app_name, Some("desktopappengine".to_string()) diff --git a/machine_manager/src/config/drive.rs b/machine_manager/src/config/drive.rs index e88826a5342df292dc56c9f1f1740b547f2a4a70..a9a132eb2621b83f4f40ea588c16133f886fbcc3 100644 --- a/machine_manager/src/config/drive.rs +++ b/machine_manager/src/config/drive.rs @@ -10,31 +10,25 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::fmt::Display; use std::fs::{metadata, File}; use std::os::linux::fs::MetadataExt; use std::path::Path; use std::str::FromStr; +use std::sync::Arc; use anyhow::{anyhow, bail, Context, Result}; +use clap::{ArgAction, Parser}; use log::error; use serde::{Deserialize, Serialize}; -use super::{error::ConfigError, pci_args_check, M}; -use crate::config::{ - check_arg_too_long, get_chardev_socket_path, memory_unit_conversion, CmdParser, ConfigCheck, - ExBool, VmConfig, DEFAULT_VIRTQUEUE_SIZE, MAX_PATH_LENGTH, MAX_STRING_LENGTH, MAX_VIRTIO_QUEUE, -}; +use super::{error::ConfigError, parse_size, valid_id, valid_path}; +use crate::config::{parse_bool, str_slip_to_clap, ConfigCheck, VmConfig, MAX_STRING_LENGTH}; use util::aio::{aio_probe, AioEngine, WriteZeroesState}; -const MAX_SERIAL_NUM: usize = 20; const MAX_IOPS: u64 = 1_000_000; const MAX_UNIT_ID: usize = 2; -// Seg_max = queue_size - 2. So, size of each virtqueue for virtio-blk should be larger than 2. -const MIN_QUEUE_SIZE_BLK: u16 = 2; -// Max size of each virtqueue for virtio-blk. -const MAX_QUEUE_SIZE_BLK: u16 = 1024; - // L2 Cache max size is 32M. pub const MAX_L2_CACHE_SIZE: u64 = 32 * (1 << 20); // Refcount table cache max size is 32M. @@ -45,7 +39,7 @@ pub struct DriveFile { /// Drive id. pub id: String, /// The opened file. - pub file: File, + pub file: Arc, /// The num of drives share same file. pub count: u32, /// File path. @@ -60,29 +54,6 @@ pub struct DriveFile { pub buf_align: u32, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct BlkDevConfig { - pub id: String, - pub path_on_host: String, - pub read_only: bool, - pub direct: bool, - pub serial_num: Option, - pub iothread: Option, - pub iops: Option, - pub queues: u16, - pub boot_index: Option, - pub chardev: Option, - pub socket_path: Option, - pub aio: AioEngine, - pub queue_size: u16, - pub discard: bool, - pub write_zeroes: WriteZeroesState, - pub format: DiskFormat, - pub l2_cache_size: Option, - pub refcount_cache_size: Option, -} - #[derive(Debug, Clone)] pub struct BootIndexInfo { pub boot_index: u8, @@ -90,33 +61,9 @@ pub struct BootIndexInfo { pub dev_path: String, } -impl Default for BlkDevConfig { - fn default() -> Self { - BlkDevConfig { - id: "".to_string(), - path_on_host: "".to_string(), - read_only: false, - direct: true, - serial_num: None, - iothread: None, - iops: None, - queues: 1, - boot_index: None, - chardev: None, - socket_path: None, - aio: AioEngine::Native, - queue_size: DEFAULT_VIRTQUEUE_SIZE, - discard: false, - write_zeroes: WriteZeroesState::Off, - format: DiskFormat::Raw, - l2_cache_size: None, - refcount_cache_size: None, - } - } -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub enum DiskFormat { + #[default] Raw, Qcow2, } @@ -133,53 +80,82 @@ impl FromStr for DiskFormat { } } -impl ToString for DiskFormat { - fn to_string(&self) -> String { - match *self { - DiskFormat::Raw => "raw".to_string(), - DiskFormat::Qcow2 => "qcow2".to_string(), +impl Display for DiskFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DiskFormat::Raw => write!(f, "raw"), + DiskFormat::Qcow2 => write!(f, "qcow2"), } } } -/// Config struct for `drive`. -/// Contains block device's attr. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] +fn valid_l2_cache_size(s: &str) -> Result { + let size = parse_size(s)?; + if size > MAX_L2_CACHE_SIZE { + return Err(anyhow!(ConfigError::IllegalValue( + "l2-cache-size".to_string(), + 0, + true, + MAX_L2_CACHE_SIZE, + true + ))); + } + Ok(size) +} + +fn valid_refcount_cache_size(s: &str) -> Result { + let size = parse_size(s)?; + if size > MAX_REFTABLE_CACHE_SIZE { + return Err(anyhow!(ConfigError::IllegalValue( + "refcount-cache-size".to_string(), + 0, + true, + MAX_REFTABLE_CACHE_SIZE, + true + ))); + } + Ok(size) +} + +/// Config struct for `drive`, including `block drive` and `pflash drive`. +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct DriveConfig { + #[arg(long, default_value = "")] pub id: String, + #[arg(long, alias = "if", default_value = "none", value_parser = ["none", "pflash"])] + pub drive_type: String, + #[arg(long, value_parser = clap::value_parser!(u8).range(..MAX_UNIT_ID as i64))] + pub unit: Option, + #[arg(long, alias = "file", value_parser = valid_path)] pub path_on_host: String, - pub read_only: bool, + #[arg(long, default_value = "off", value_parser = parse_bool, action = ArgAction::Append)] + pub readonly: bool, + #[arg(long, default_value = "true", value_parser = parse_bool, action = ArgAction::Append)] pub direct: bool, + #[arg(long, alias = "throttling.iops-total", value_parser = clap::value_parser!(u64).range(..=MAX_IOPS as u64))] pub iops: Option, + #[arg( + long, + default_value = "native", + default_value_if("direct", "false", "off"), + default_value_if("direct", "off", "off") + )] pub aio: AioEngine, + #[arg(long, default_value = "disk", value_parser = ["disk", "cdrom"])] pub media: String, + #[arg(long, default_value = "ignore", value_parser = parse_bool, action = ArgAction::Append)] pub discard: bool, + #[arg(long, alias = "detect-zeroes", default_value = "off")] pub write_zeroes: WriteZeroesState, + #[arg(long, default_value = "raw")] pub format: DiskFormat, + #[arg(long, value_parser = valid_l2_cache_size)] pub l2_cache_size: Option, + #[arg(long, value_parser = valid_refcount_cache_size)] pub refcount_cache_size: Option, } -impl Default for DriveConfig { - fn default() -> Self { - DriveConfig { - id: "".to_string(), - path_on_host: "".to_string(), - read_only: false, - direct: true, - iops: None, - aio: AioEngine::Native, - media: "disk".to_string(), - discard: false, - write_zeroes: WriteZeroesState::Off, - format: DiskFormat::Raw, - l2_cache_size: None, - refcount_cache_size: None, - } - } -} - impl DriveConfig { /// Check whether the drive file path on the host is valid. pub fn check_path(&self) -> Result<()> { @@ -222,383 +198,90 @@ impl DriveConfig { impl ConfigCheck for DriveConfig { fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "Drive id")?; + if self.drive_type == "pflash" { + self.unit.with_context(|| { + ConfigError::FieldIsMissing("unit".to_string(), "pflash".to_string()) + })?; + if self.format.to_string() != "raw" { + bail!("Only \'raw\' type of pflash is supported"); + } + } else { + if self.id.is_empty() { + return Err(anyhow!(ConfigError::FieldIsMissing( + "id".to_string(), + "blk".to_string() + ))); + } + valid_id(&self.id)?; + valid_path(&self.path_on_host)?; + if self.iops > Some(MAX_IOPS) { + return Err(anyhow!(ConfigError::IllegalValue( + "iops of block device".to_string(), + 0, + true, + MAX_IOPS, + true, + ))); + } + if self.l2_cache_size > Some(MAX_L2_CACHE_SIZE) { + return Err(anyhow!(ConfigError::IllegalValue( + "l2-cache-size".to_string(), + 0, + true, + MAX_L2_CACHE_SIZE, + true + ))); + } + if self.refcount_cache_size > Some(MAX_REFTABLE_CACHE_SIZE) { + return Err(anyhow!(ConfigError::IllegalValue( + "refcount-cache-size".to_string(), + 0, + true, + MAX_REFTABLE_CACHE_SIZE, + true + ))); + } - if self.path_on_host.len() > MAX_PATH_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "Drive device path".to_string(), - MAX_PATH_LENGTH, - ))); - } - if self.iops > Some(MAX_IOPS) { - return Err(anyhow!(ConfigError::IllegalValue( - "iops of block device".to_string(), - 0, - true, - MAX_IOPS, - true, - ))); - } - if self.aio != AioEngine::Off { - if self.aio == AioEngine::Native && !self.direct { + if self.aio != AioEngine::Off { + if self.aio == AioEngine::Native && !self.direct { + return Err(anyhow!(ConfigError::InvalidParam( + "aio".to_string(), + "native aio type should be used with \"direct\" on".to_string(), + ))); + } + aio_probe(self.aio)?; + } else if self.direct { return Err(anyhow!(ConfigError::InvalidParam( "aio".to_string(), - "native aio type should be used with \"direct\" on".to_string(), + "low performance expected when use sync io with \"direct\" on".to_string(), ))); } - aio_probe(self.aio)?; - } else if self.direct { - return Err(anyhow!(ConfigError::InvalidParam( - "aio".to_string(), - "low performance expected when use sync io with \"direct\" on".to_string(), - ))); - } - - if !["disk", "cdrom"].contains(&self.media.as_str()) { - return Err(anyhow!(ConfigError::InvalidParam( - "media".to_string(), - "media should be \"disk\" or \"cdrom\"".to_string(), - ))); - } - - if self.l2_cache_size > Some(MAX_L2_CACHE_SIZE) { - return Err(anyhow!(ConfigError::IllegalValue( - "l2-cache-size".to_string(), - 0, - true, - MAX_L2_CACHE_SIZE, - true - ))); - } - if self.refcount_cache_size > Some(MAX_REFTABLE_CACHE_SIZE) { - return Err(anyhow!(ConfigError::IllegalValue( - "refcount-cache-size".to_string(), - 0, - true, - MAX_REFTABLE_CACHE_SIZE, - true - ))); } - Ok(()) - } -} - -impl ConfigCheck for BlkDevConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "drive device id")?; - if self.serial_num.is_some() && self.serial_num.as_ref().unwrap().len() > MAX_SERIAL_NUM { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "drive serial number".to_string(), - MAX_SERIAL_NUM, - ))); - } - - if self.iothread.is_some() && self.iothread.as_ref().unwrap().len() > MAX_STRING_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "iothread name".to_string(), - MAX_STRING_LENGTH, - ))); - } - - if self.queues < 1 || self.queues > MAX_VIRTIO_QUEUE as u16 { - return Err(anyhow!(ConfigError::IllegalValue( - "number queues of block device".to_string(), - 1, - true, - MAX_VIRTIO_QUEUE as u64, - true, - ))); - } - - if self.queue_size <= MIN_QUEUE_SIZE_BLK || self.queue_size > MAX_QUEUE_SIZE_BLK { - return Err(anyhow!(ConfigError::IllegalValue( - "queue size of block device".to_string(), - MIN_QUEUE_SIZE_BLK as u64, - false, - MAX_QUEUE_SIZE_BLK as u64, - true - ))); - } - - if self.queue_size & (self.queue_size - 1) != 0 { - bail!("Queue size should be power of 2!"); - } - - let fake_drive = DriveConfig { - path_on_host: self.path_on_host.clone(), - direct: self.direct, - iops: self.iops, - aio: self.aio, - ..Default::default() - }; - fake_drive.check()?; #[cfg(not(test))] - if self.chardev.is_none() { - fake_drive.check_path()?; - } - - Ok(()) - } -} - -fn parse_drive(cmd_parser: CmdParser) -> Result { - let mut drive = DriveConfig::default(); - if let Some(fmt) = cmd_parser.get_value::("format")? { - drive.format = fmt; - } - - drive.id = cmd_parser - .get_value::("id")? - .with_context(|| ConfigError::FieldIsMissing("id".to_string(), "blk".to_string()))?; - drive.path_on_host = cmd_parser - .get_value::("file")? - .with_context(|| ConfigError::FieldIsMissing("file".to_string(), "blk".to_string()))?; - - if let Some(read_only) = cmd_parser.get_value::("readonly")? { - drive.read_only = read_only.into(); - } - if let Some(direct) = cmd_parser.get_value::("direct")? { - drive.direct = direct.into(); - } - drive.iops = cmd_parser.get_value::("throttling.iops-total")?; - drive.aio = cmd_parser.get_value::("aio")?.unwrap_or({ - if drive.direct { - AioEngine::Native - } else { - AioEngine::Off - } - }); - drive.media = cmd_parser - .get_value::("media")? - .unwrap_or_else(|| "disk".to_string()); - if let Some(discard) = cmd_parser.get_value::("discard")? { - drive.discard = discard.into(); - } - drive.write_zeroes = cmd_parser - .get_value::("detect-zeroes")? - .unwrap_or(WriteZeroesState::Off); - - if let Some(l2_cache) = cmd_parser.get_value::("l2-cache-size")? { - let sz = memory_unit_conversion(&l2_cache, M) - .with_context(|| format!("Invalid l2 cache size: {}", l2_cache))?; - drive.l2_cache_size = Some(sz); - } - if let Some(rc_cache) = cmd_parser.get_value::("refcount-cache-size")? { - let sz = memory_unit_conversion(&rc_cache, M) - .with_context(|| format!("Invalid refcount cache size: {}", rc_cache))?; - drive.refcount_cache_size = Some(sz); - } - - drive.check()?; - #[cfg(not(test))] - drive.check_path()?; - Ok(drive) -} - -pub fn parse_blk( - vm_config: &mut VmConfig, - drive_config: &str, - queues_auto: Option, -) -> Result { - let mut cmd_parser = CmdParser::new("virtio-blk"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("addr") - .push("multifunction") - .push("drive") - .push("bootindex") - .push("serial") - .push("iothread") - .push("num-queues") - .push("queue-size"); - - cmd_parser.parse(drive_config)?; - - pci_args_check(&cmd_parser)?; - - let mut blkdevcfg = BlkDevConfig::default(); - if let Some(boot_index) = cmd_parser.get_value::("bootindex")? { - blkdevcfg.boot_index = Some(boot_index); - } - - let blkdrive = cmd_parser - .get_value::("drive")? - .with_context(|| ConfigError::FieldIsMissing("drive".to_string(), "blk".to_string()))?; - - if let Some(iothread) = cmd_parser.get_value::("iothread")? { - blkdevcfg.iothread = Some(iothread); - } - - if let Some(serial) = cmd_parser.get_value::("serial")? { - blkdevcfg.serial_num = Some(serial); - } - - blkdevcfg.id = cmd_parser - .get_value::("id")? - .with_context(|| "No id configured for blk device")?; - - if let Some(queues) = cmd_parser.get_value::("num-queues")? { - blkdevcfg.queues = queues; - } else if let Some(queues) = queues_auto { - blkdevcfg.queues = queues; - } - - if let Some(queue_size) = cmd_parser.get_value::("queue-size")? { - blkdevcfg.queue_size = queue_size; - } - - let drive_arg = &vm_config - .drives - .remove(&blkdrive) - .with_context(|| "No drive configured matched for blk device")?; - blkdevcfg.path_on_host = drive_arg.path_on_host.clone(); - blkdevcfg.read_only = drive_arg.read_only; - blkdevcfg.direct = drive_arg.direct; - blkdevcfg.iops = drive_arg.iops; - blkdevcfg.aio = drive_arg.aio; - blkdevcfg.discard = drive_arg.discard; - blkdevcfg.write_zeroes = drive_arg.write_zeroes; - blkdevcfg.format = drive_arg.format; - blkdevcfg.l2_cache_size = drive_arg.l2_cache_size; - blkdevcfg.refcount_cache_size = drive_arg.refcount_cache_size; - blkdevcfg.check()?; - Ok(blkdevcfg) -} - -pub fn parse_vhost_user_blk( - vm_config: &mut VmConfig, - drive_config: &str, - queues_auto: Option, -) -> Result { - let mut cmd_parser = CmdParser::new("vhost-user-blk-pci"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("addr") - .push("num-queues") - .push("chardev") - .push("queue-size") - .push("bootindex"); - - cmd_parser.parse(drive_config)?; - - pci_args_check(&cmd_parser)?; - - let mut blkdevcfg = BlkDevConfig::default(); - - if let Some(boot_index) = cmd_parser.get_value::("bootindex")? { - blkdevcfg.boot_index = Some(boot_index); - } - - blkdevcfg.chardev = cmd_parser - .get_value::("chardev")? - .map(Some) - .with_context(|| { - ConfigError::FieldIsMissing("chardev".to_string(), "vhost-user-blk-pci".to_string()) - })?; - - blkdevcfg.id = cmd_parser - .get_value::("id")? - .with_context(|| "No id configured for blk device")?; - - if let Some(queues) = cmd_parser.get_value::("num-queues")? { - blkdevcfg.queues = queues; - } else if let Some(queues) = queues_auto { - blkdevcfg.queues = queues; - } + self.check_path()?; - if let Some(size) = cmd_parser.get_value::("queue-size")? { - blkdevcfg.queue_size = size; - } - - if let Some(chardev) = &blkdevcfg.chardev { - blkdevcfg.socket_path = Some(get_chardev_socket_path(chardev, vm_config)?); - } - blkdevcfg.check()?; - Ok(blkdevcfg) -} - -/// Config struct for `pflash`. -/// Contains pflash device's attr. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(deny_unknown_fields)] -pub struct PFlashConfig { - pub path_on_host: String, - pub read_only: bool, - pub unit: usize, -} - -impl ConfigCheck for PFlashConfig { - fn check(&self) -> Result<()> { - if self.path_on_host.len() > MAX_PATH_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "drive device path".to_string(), - MAX_PATH_LENGTH, - ))); - } - - if self.unit >= MAX_UNIT_ID { - return Err(anyhow!(ConfigError::UnitIdError( - "PFlash unit id".to_string(), - self.unit, - MAX_UNIT_ID - 1 - ))); - } Ok(()) } } impl VmConfig { - /// Add '-drive ...' drive config to `VmConfig`. - pub fn add_drive(&mut self, drive_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("drive"); - cmd_parser.push("if"); - - cmd_parser.get_parameters(drive_config)?; - let drive_type = cmd_parser - .get_value::("if")? - .unwrap_or_else(|| "none".to_string()); - match drive_type.as_str() { + /// Add '-drive ...' drive config to `VmConfig`, including `block drive` and `pflash drive`. + pub fn add_drive(&mut self, drive_config: &str) -> Result { + let drive_cfg = DriveConfig::try_parse_from(str_slip_to_clap(drive_config, false, false))?; + drive_cfg.check()?; + match drive_cfg.drive_type.as_str() { "none" => { - self.add_block_drive(drive_config)?; + self.add_drive_with_config(drive_cfg.clone())?; } "pflash" => { - self.add_pflash(drive_config)?; + self.add_flashdev(drive_cfg.clone())?; } _ => { - bail!("Unknow 'if' argument: {:?}", drive_type.as_str()); + bail!("Unknow 'if' argument: {:?}", &drive_cfg.drive_type); } } - Ok(()) - } - - /// Add block drive config to vm and return the added drive config. - pub fn add_block_drive(&mut self, block_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("drive"); - cmd_parser - .push("file") - .push("id") - .push("readonly") - .push("direct") - .push("format") - .push("if") - .push("throttling.iops-total") - .push("aio") - .push("media") - .push("discard") - .push("detect-zeroes") - .push("format") - .push("l2-cache-size") - .push("refcount-cache-size"); - - cmd_parser.parse(block_config)?; - let drive_cfg = parse_drive(cmd_parser)?; - self.add_drive_with_config(drive_cfg.clone())?; Ok(drive_cfg) } @@ -609,11 +292,10 @@ impl VmConfig { /// * `drive_conf` - The drive config to be added to the vm. pub fn add_drive_with_config(&mut self, drive_conf: DriveConfig) -> Result<()> { let drive_id = drive_conf.id.clone(); - if self.drives.get(&drive_id).is_none() { - self.drives.insert(drive_id, drive_conf); - } else { + if self.drives.contains_key(&drive_id) { bail!("Drive {} has been added", drive_id); } + self.drives.insert(drive_id, drive_conf); Ok(()) } @@ -623,7 +305,7 @@ impl VmConfig { /// /// * `drive_id` - Drive id. pub fn del_drive_by_id(&mut self, drive_id: &str) -> Result { - if self.drives.get(drive_id).is_some() { + if self.drives.contains_key(drive_id) { Ok(self.drives.remove(drive_id).unwrap().path_on_host) } else { bail!("Drive {} not found", drive_id); @@ -631,13 +313,13 @@ impl VmConfig { } /// Add new flash device to `VmConfig`. - fn add_flashdev(&mut self, pflash: PFlashConfig) -> Result<()> { + fn add_flashdev(&mut self, pflash: DriveConfig) -> Result<()> { if self.pflashs.is_some() { for pf in self.pflashs.as_ref().unwrap() { - if pf.unit == pflash.unit { + if pf.unit.unwrap() == pflash.unit.unwrap() { return Err(anyhow!(ConfigError::IdRepeat( "pflash".to_string(), - pf.unit.to_string() + pf.unit.unwrap().to_string() ))); } } @@ -647,147 +329,38 @@ impl VmConfig { } Ok(()) } - - /// Add '-pflash ...' pflash config to `VmConfig`. - pub fn add_pflash(&mut self, pflash_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("pflash"); - cmd_parser - .push("if") - .push("file") - .push("format") - .push("readonly") - .push("unit"); - - cmd_parser.parse(pflash_config)?; - - let mut pflash = PFlashConfig::default(); - - if let Some(format) = cmd_parser.get_value::("format")? { - if format.ne("raw") { - bail!("Only \'raw\' type of pflash is supported"); - } - } - pflash.path_on_host = cmd_parser.get_value::("file")?.with_context(|| { - ConfigError::FieldIsMissing("file".to_string(), "pflash".to_string()) - })?; - - if let Some(read_only) = cmd_parser.get_value::("readonly")? { - pflash.read_only = read_only.into(); - } - - pflash.unit = cmd_parser.get_value::("unit")?.with_context(|| { - ConfigError::FieldIsMissing("unit".to_string(), "pflash".to_string()) - })? as usize; - - pflash.check()?; - self.add_flashdev(pflash) - } } #[cfg(test)] mod tests { use super::*; - use crate::config::get_pci_bdf; #[test] - fn test_drive_config_cmdline_parser() { - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_drive( - "id=rootfs,file=/path/to/rootfs,readonly=off,direct=on,throttling.iops-total=200" - ) - .is_ok()); - let blk_cfg_res = parse_blk( - &mut vm_config, - "virtio-blk-device,drive=rootfs,id=rootfs,iothread=iothread1,serial=111111,num-queues=4", - None, - ); - assert!(blk_cfg_res.is_ok()); - let blk_device_config = blk_cfg_res.unwrap(); - assert_eq!(blk_device_config.id, "rootfs"); - assert_eq!(blk_device_config.path_on_host, "/path/to/rootfs"); - assert_eq!(blk_device_config.direct, true); - assert_eq!(blk_device_config.read_only, false); - assert_eq!(blk_device_config.serial_num, Some(String::from("111111"))); - assert_eq!(blk_device_config.queues, 4); - + fn test_pflash_drive_config_cmdline_parser() { + // Test1: Right. let mut vm_config = VmConfig::default(); assert!(vm_config - .add_drive("id=rootfs,file=/path/to/rootfs,readonly=off,direct=on") - .is_ok()); - let blk_cfg_res = parse_blk( - &mut vm_config, - "virtio-blk-device,drive=rootfs1,id=rootfs1,iothread=iothread1,iops=200,serial=111111", - None, - ); - assert!(blk_cfg_res.is_err()); // Can not find drive named "rootfs1". - } - - #[test] - fn test_pci_block_config_cmdline_parser() { - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_drive("id=rootfs,file=/path/to/rootfs,readonly=off,direct=on") - .is_ok()); - let blk_cfg = "virtio-blk-pci,id=rootfs,bus=pcie.0,addr=0x1.0x2,drive=rootfs,serial=111111,num-queues=4"; - let blk_cfg_res = parse_blk(&mut vm_config, blk_cfg, None); - assert!(blk_cfg_res.is_ok()); - let drive_configs = blk_cfg_res.unwrap(); - assert_eq!(drive_configs.id, "rootfs"); - assert_eq!(drive_configs.path_on_host, "/path/to/rootfs"); - assert_eq!(drive_configs.direct, true); - assert_eq!(drive_configs.read_only, false); - assert_eq!(drive_configs.serial_num, Some(String::from("111111"))); - assert_eq!(drive_configs.queues, 4); - - let pci_bdf = get_pci_bdf(blk_cfg); - assert!(pci_bdf.is_ok()); - let pci = pci_bdf.unwrap(); - assert_eq!(pci.bus, "pcie.0".to_string()); - assert_eq!(pci.addr, (1, 2)); - - // drive "rootfs" has been removed. - let blk_cfg_res = parse_blk(&mut vm_config, blk_cfg, None); - assert!(blk_cfg_res.is_err()); - - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_drive("id=rootfs,file=/path/to/rootfs,readonly=off,direct=on") - .is_ok()); - let blk_cfg = - "virtio-blk-pci,id=blk1,bus=pcie.0,addr=0x1.0x2,drive=rootfs,multifunction=on"; - assert!(parse_blk(&mut vm_config, blk_cfg, None).is_ok()); - } - - #[test] - fn test_pflash_config_cmdline_parser() { - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_drive("if=pflash,readonly=on,file=flash0.fd,unit=0") + .add_drive("if=pflash,readonly=on,file=flash0.fd,unit=0,format=raw") .is_ok()); assert!(vm_config.pflashs.is_some()); let pflash = vm_config.pflashs.unwrap(); assert!(pflash.len() == 1); let pflash_cfg = &pflash[0]; - assert_eq!(pflash_cfg.unit, 0); + assert_eq!(pflash_cfg.unit.unwrap(), 0); assert_eq!(pflash_cfg.path_on_host, "flash0.fd".to_string()); - assert_eq!(pflash_cfg.read_only, true); + assert!(pflash_cfg.readonly); + // Test2: Change parameters sequence. let mut vm_config = VmConfig::default(); assert!(vm_config .add_drive("readonly=on,file=flash0.fd,unit=0,if=pflash") .is_ok()); - let mut vm_config = VmConfig::default(); assert!(vm_config .add_drive("readonly=on,if=pflash,file=flash0.fd,unit=0") .is_ok()); - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_drive("if=pflash,readonly=on,file=flash0.fd,unit=2") - .is_err()); - + // Test3: Add duplicate pflash. let mut vm_config = VmConfig::default(); assert!(vm_config .add_drive("if=pflash,readonly=on,file=flash0.fd,unit=0") @@ -795,52 +368,103 @@ mod tests { assert!(vm_config .add_drive("if=pflash,file=flash1.fd,unit=1") .is_ok()); + assert!(vm_config + .add_drive("if=pflash,file=flash1.fd,unit=1") + .is_err()); assert!(vm_config.pflashs.is_some()); let pflash = vm_config.pflashs.unwrap(); assert!(pflash.len() == 2); let pflash_cfg = &pflash[0]; - assert_eq!(pflash_cfg.unit, 0); + assert_eq!(pflash_cfg.unit.unwrap(), 0); assert_eq!(pflash_cfg.path_on_host, "flash0.fd".to_string()); - assert_eq!(pflash_cfg.read_only, true); + assert!(pflash_cfg.readonly); let pflash_cfg = &pflash[1]; - assert_eq!(pflash_cfg.unit, 1); + assert_eq!(pflash_cfg.unit.unwrap(), 1); assert_eq!(pflash_cfg.path_on_host, "flash1.fd".to_string()); - assert_eq!(pflash_cfg.read_only, false); - } + assert!(!pflash_cfg.readonly); - #[test] - fn test_drive_config_check() { - let mut drive_conf = DriveConfig::default(); - for _ in 0..MAX_STRING_LENGTH { - drive_conf.id += "A"; - } - assert!(drive_conf.check().is_ok()); + // Test4: Illegal parameters unit/format. + let mut vm_config = VmConfig::default(); + assert!(vm_config + .add_drive("if=pflash,readonly=on,file=flash0.fd,unit=2") + .is_err()); + assert!(vm_config + .add_drive("if=pflash,readonly=on,file=flash0.fd,unit=0,format=qcow2") + .is_err()); - // Overflow - drive_conf.id += "A"; - assert!(drive_conf.check().is_err()); + // Test5: Missing parameters file/unit. + let mut vm_config = VmConfig::default(); + assert!(vm_config.add_drive("if=pflash,readonly=on,unit=2").is_err()); + assert!(vm_config + .add_drive("if=pflash,readonly=on,file=flash0.fd") + .is_err()); + } - let mut drive_conf = DriveConfig::default(); - for _ in 0..MAX_PATH_LENGTH { - drive_conf.path_on_host += "A"; - } - assert!(drive_conf.check().is_ok()); + #[test] + fn test_block_drive_config_cmdline_parser() { + // Test1: Right. + let mut vm_config = VmConfig::default(); + assert!(vm_config + .add_drive("id=rootfs,file=/path/to/rootfs,format=qcow2,readonly=off,direct=on,throttling.iops-total=200,discard=unmap,detect-zeroes=unmap") + .is_ok()); + assert!(vm_config.drives.len() == 1); + let drive_cfg = &vm_config.drives.remove("rootfs").unwrap(); + + assert_eq!(drive_cfg.id, "rootfs"); + assert_eq!(drive_cfg.path_on_host, "/path/to/rootfs"); + assert_eq!(drive_cfg.format.to_string(), "qcow2"); + assert!(!drive_cfg.readonly); + assert!(drive_cfg.direct); + assert_eq!(drive_cfg.iops.unwrap(), 200); + assert!(drive_cfg.discard); + assert_eq!( + drive_cfg.write_zeroes, + WriteZeroesState::from_str("unmap").unwrap() + ); - // Overflow - drive_conf.path_on_host += "A"; - assert!(drive_conf.check().is_err()); + // Test2: Change parameters sequence. + let mut vm_config = VmConfig::default(); + assert!(vm_config + .add_drive("throttling.iops-total=200,file=/path/to/rootfs,format=qcow2,id=rootfs,readonly=off,direct=on,discard=unmap,detect-zeroes=unmap") + .is_ok()); - let mut drive_conf = DriveConfig::default(); - drive_conf.iops = Some(MAX_IOPS); - assert!(drive_conf.check().is_ok()); + // Test3: Add duplicate block drive config. + let mut vm_config = VmConfig::default(); + assert!(vm_config + .add_drive("id=rootfs,file=/path/to/rootfs,format=qcow2,readonly=off,direct=on") + .is_ok()); + assert!(vm_config + .add_drive("id=rootfs,file=/path/to/rootfs,format=qcow2,readonly=off,direct=on") + .is_err()); + let drive_cfg = &vm_config.drives.remove("rootfs"); + assert!(drive_cfg.is_some()); - let mut drive_conf = DriveConfig::default(); - drive_conf.iops = None; - assert!(drive_conf.check().is_ok()); + // Test4: Illegal parameters. + let mut vm_config = VmConfig::default(); + assert!(vm_config + .add_drive("id=rootfs,file=/path/to/rootfs,format=vhdx") + .is_err()); + assert!(vm_config + .add_drive("id=rootfs,if=illegal,file=/path/to/rootfs,format=vhdx") + .is_err()); + assert!(vm_config + .add_drive("id=rootfs,file=/path/to/rootfs,format=raw,throttling.iops-total=1000001") + .is_err()); + assert!(vm_config + .add_drive("id=rootfs,file=/path/to/rootfs,format=raw,media=illegal") + .is_err()); + assert!(vm_config + .add_drive("id=rootfs,file=/path/to/rootfs,format=raw,detect-zeroes=illegal") + .is_err()); - // Overflow - drive_conf.iops = Some(MAX_IOPS + 1); - assert!(drive_conf.check().is_err()); + // Test5: Missing parameters id/file. + let mut vm_config = VmConfig::default(); + assert!(vm_config + .add_drive("file=/path/to/rootfs,format=qcow2,readonly=off,direct=on,throttling.iops-total=200") + .is_err()); + assert!(vm_config + .add_drive("id=rootfs,format=qcow2,readonly=off,direct=on,throttling.iops-total=200") + .is_err()); } #[test] @@ -879,7 +503,7 @@ mod tests { let mut drive_conf = DriveConfig::default(); drive_conf.id = String::from(*id); assert!(vm_config.drives.get(*id).is_some()); - assert!(vm_config.del_drive_by_id(*id).is_ok()); + assert!(vm_config.del_drive_by_id(id).is_ok()); assert!(vm_config.drives.get(*id).is_none()); } } @@ -888,47 +512,47 @@ mod tests { fn test_drive_config_discard() { let mut vm_config = VmConfig::default(); let drive_conf = vm_config - .add_block_drive("id=rootfs,file=/path/to/rootfs,discard=ignore") + .add_drive("id=rootfs,file=/path/to/rootfs,discard=ignore") .unwrap(); - assert_eq!(drive_conf.discard, false); + assert!(!drive_conf.discard); let mut vm_config = VmConfig::default(); let drive_conf = vm_config - .add_block_drive("id=rootfs,file=/path/to/rootfs,discard=unmap") + .add_drive("id=rootfs,file=/path/to/rootfs,discard=unmap") .unwrap(); - assert_eq!(drive_conf.discard, true); + assert!(drive_conf.discard); let mut vm_config = VmConfig::default(); let ret = vm_config - .add_block_drive("id=rootfs,file=/path/to/rootfs,discard=invalid") + .add_drive("id=rootfs,file=/path/to/rootfs,discard=invalid") .is_err(); - assert_eq!(ret, true); + assert!(ret); } #[test] fn test_drive_config_write_zeroes() { let mut vm_config = VmConfig::default(); let drive_conf = vm_config - .add_block_drive("id=rootfs,file=/path/to/rootfs,detect-zeroes=off") + .add_drive("id=rootfs,file=/path/to/rootfs,detect-zeroes=off") .unwrap(); assert_eq!(drive_conf.write_zeroes, WriteZeroesState::Off); let mut vm_config = VmConfig::default(); let drive_conf = vm_config - .add_block_drive("id=rootfs,file=/path/to/rootfs,detect-zeroes=on") + .add_drive("id=rootfs,file=/path/to/rootfs,detect-zeroes=on") .unwrap(); assert_eq!(drive_conf.write_zeroes, WriteZeroesState::On); let mut vm_config = VmConfig::default(); let drive_conf = vm_config - .add_block_drive("id=rootfs,file=/path/to/rootfs,detect-zeroes=unmap") + .add_drive("id=rootfs,file=/path/to/rootfs,detect-zeroes=unmap") .unwrap(); assert_eq!(drive_conf.write_zeroes, WriteZeroesState::Unmap); let mut vm_config = VmConfig::default(); let ret = vm_config - .add_block_drive("id=rootfs,file=/path/to/rootfs,detect-zeroes=invalid") + .add_drive("id=rootfs,file=/path/to/rootfs,detect-zeroes=invalid") .is_err(); - assert_eq!(ret, true); + assert!(ret); } } diff --git a/machine_manager/src/config/fs.rs b/machine_manager/src/config/fs.rs deleted file mode 100644 index e1a16ab359e5ab664dedd3facae843f18abf76ae..0000000000000000000000000000000000000000 --- a/machine_manager/src/config/fs.rs +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) 2022 Huawei Technologies Co.,Ltd. All rights reserved. -// -// StratoVirt is licensed under Mulan PSL v2. -// You can use this software according to the terms and conditions of the Mulan -// PSL v2. -// You may obtain a copy of Mulan PSL v2 at: -// http://license.coscl.org.cn/MulanPSL2 -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -// See the Mulan PSL v2 for more details. - -use anyhow::{anyhow, bail, Context, Result}; - -use super::error::ConfigError; -use crate::config::{ - pci_args_check, ChardevType, CmdParser, ConfigCheck, VmConfig, MAX_SOCK_PATH_LENGTH, - MAX_STRING_LENGTH, MAX_TAG_LENGTH, -}; - -/// Config struct for `fs`. -/// Contains fs device's attr. -#[derive(Debug, Clone)] -pub struct FsConfig { - /// Device tag. - pub tag: String, - /// Device id. - pub id: String, - /// Char device sock path. - pub sock: String, -} - -impl Default for FsConfig { - fn default() -> Self { - FsConfig { - tag: "".to_string(), - id: "".to_string(), - sock: "".to_string(), - } - } -} - -impl ConfigCheck for FsConfig { - fn check(&self) -> Result<()> { - if self.tag.len() >= MAX_TAG_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "fs device tag".to_string(), - MAX_TAG_LENGTH - 1, - ))); - } - - if self.id.len() >= MAX_STRING_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "fs device id".to_string(), - MAX_STRING_LENGTH - 1, - ))); - } - - if self.sock.len() > MAX_SOCK_PATH_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "fs sock path".to_string(), - MAX_SOCK_PATH_LENGTH, - ))); - } - - Ok(()) - } -} - -pub fn parse_fs(vm_config: &mut VmConfig, fs_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("fs"); - cmd_parser - .push("") - .push("tag") - .push("id") - .push("chardev") - .push("bus") - .push("addr") - .push("multifunction"); - cmd_parser.parse(fs_config)?; - pci_args_check(&cmd_parser)?; - - let mut fs_cfg = FsConfig { - tag: cmd_parser.get_value::("tag")?.with_context(|| { - ConfigError::FieldIsMissing("tag".to_string(), "virtio-fs".to_string()) - })?, - id: cmd_parser.get_value::("id")?.with_context(|| { - ConfigError::FieldIsMissing("id".to_string(), "virtio-fs".to_string()) - })?, - ..Default::default() - }; - - if let Some(name) = cmd_parser.get_value::("chardev")? { - if let Some(char_dev) = vm_config.chardev.remove(&name) { - match &char_dev.backend { - ChardevType::UnixSocket { path, .. } => { - fs_cfg.sock = path.clone(); - } - _ => { - bail!("Chardev {:?} backend should be unix-socket type.", &name); - } - } - } else { - bail!("Chardev {:?} not found or is in use", &name); - } - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "chardev".to_string(), - "virtio-fs".to_string() - ))); - } - fs_cfg.check()?; - - Ok(fs_cfg) -} diff --git a/machine_manager/src/config/gpu.rs b/machine_manager/src/config/gpu.rs deleted file mode 100644 index 56ab2842e9d6f28b80196cceb4a892b3685ecbc0..0000000000000000000000000000000000000000 --- a/machine_manager/src/config/gpu.rs +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright (c) 2022 Huawei Technologies Co.,Ltd. All rights reserved. -// -// StratoVirt is licensed under Mulan PSL v2. -// You can use this software according to the terms and conditions of the Mulan -// PSL v2. -// You may obtain a copy of Mulan PSL v2 at: -// http://license.coscl.org.cn/MulanPSL2 -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -// See the Mulan PSL v2 for more details. - -use anyhow::{anyhow, Result}; -use log::warn; - -use super::{error::ConfigError, M}; -use crate::config::{check_arg_too_long, CmdParser, ConfigCheck}; - -/// The maximum number of outputs. -pub const VIRTIO_GPU_MAX_OUTPUTS: usize = 16; - -pub const VIRTIO_GPU_MAX_HOSTMEM: u64 = 256 * M; - -/// The bar0 size of enable_bar0 features -pub const VIRTIO_GPU_ENABLE_BAR0_SIZE: u64 = 64 * M; - -#[derive(Clone, Debug)] -pub struct GpuDevConfig { - pub id: String, - pub max_outputs: u32, - pub edid: bool, - pub xres: u32, - pub yres: u32, - pub max_hostmem: u64, - pub enable_bar0: bool, -} - -impl Default for GpuDevConfig { - fn default() -> Self { - GpuDevConfig { - id: "".to_string(), - max_outputs: 1, - edid: true, - xres: 1024, - yres: 768, - max_hostmem: VIRTIO_GPU_MAX_HOSTMEM, - enable_bar0: false, - } - } -} - -impl ConfigCheck for GpuDevConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "id")?; - if self.max_outputs > VIRTIO_GPU_MAX_OUTPUTS as u32 || self.max_outputs == 0 { - return Err(anyhow!(ConfigError::IllegalValue( - "max_outputs".to_string(), - 0, - false, - VIRTIO_GPU_MAX_OUTPUTS as u64, - true - ))); - } - - if self.max_hostmem == 0 { - return Err(anyhow!(ConfigError::IllegalValueUnilateral( - "max_hostmem".to_string(), - true, - false, - 0 - ))); - } - - if self.max_hostmem < VIRTIO_GPU_MAX_HOSTMEM { - warn!( - "max_hostmem should >= {}, allocating less than it may cause \ - the GPU to fail to start or refresh.", - VIRTIO_GPU_MAX_HOSTMEM - ); - } - - Ok(()) - } -} - -pub fn parse_gpu(gpu_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("virtio-gpu-pci"); - cmd_parser - .push("") - .push("id") - .push("max_outputs") - .push("edid") - .push("xres") - .push("yres") - .push("max_hostmem") - .push("bus") - .push("addr") - .push("enable_bar0"); - cmd_parser.parse(gpu_config)?; - - let mut gpu_cfg: GpuDevConfig = GpuDevConfig::default(); - if let Some(id) = cmd_parser.get_value::("id")? { - gpu_cfg.id = id; - } - if let Some(max_outputs) = cmd_parser.get_value::("max_outputs")? { - gpu_cfg.max_outputs = max_outputs; - } - if let Some(edid) = cmd_parser.get_value::("edid")? { - gpu_cfg.edid = edid; - } - if let Some(xres) = cmd_parser.get_value::("xres")? { - gpu_cfg.xres = xres; - } - if let Some(yres) = cmd_parser.get_value::("yres")? { - gpu_cfg.yres = yres; - } - if let Some(max_hostmem) = cmd_parser.get_value::("max_hostmem")? { - gpu_cfg.max_hostmem = max_hostmem; - } - if let Some(enable_bar0) = cmd_parser.get_value::("enable_bar0")? { - gpu_cfg.enable_bar0 = enable_bar0; - } - gpu_cfg.check()?; - - Ok(gpu_cfg) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_pci_gpu_config_cmdline_parser() { - let max_hostmem = VIRTIO_GPU_MAX_HOSTMEM + 1; - let gpu_cfg_cmdline = format!( - "{}{}", - "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0,\ - max_outputs=1,edid=true,xres=1024,yres=768,max_hostmem=", - max_hostmem.to_string() - ); - let gpu_cfg_ = parse_gpu(&gpu_cfg_cmdline); - assert!(gpu_cfg_.is_ok()); - let gpu_cfg = gpu_cfg_.unwrap(); - assert_eq!(gpu_cfg.id, "gpu_1"); - assert_eq!(gpu_cfg.max_outputs, 1); - assert_eq!(gpu_cfg.edid, true); - assert_eq!(gpu_cfg.xres, 1024); - assert_eq!(gpu_cfg.yres, 768); - assert_eq!(gpu_cfg.max_hostmem, max_hostmem); - - // max_outputs is illegal - let gpu_cfg_cmdline = format!( - "{}{}", - "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0,\ - max_outputs=17,edid=true,xres=1024,yres=768,max_hostmem=", - max_hostmem.to_string() - ); - let gpu_cfg_ = parse_gpu(&gpu_cfg_cmdline); - assert!(gpu_cfg_.is_err()); - - let gpu_cfg_cmdline = format!( - "{}{}", - "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0,\ - max_outputs=0,edid=true,xres=1024,yres=768,max_hostmem=", - max_hostmem.to_string() - ); - let gpu_cfg_ = parse_gpu(&gpu_cfg_cmdline); - assert!(gpu_cfg_.is_err()); - - // max_hostmem is illegal - let gpu_cfg_cmdline = "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0,\ - max_outputs=1,edid=true,xres=1024,yres=768,max_hostmem=0"; - let gpu_cfg_ = parse_gpu(&gpu_cfg_cmdline); - assert!(gpu_cfg_.is_err()); - } -} diff --git a/machine_manager/src/config/iothread.rs b/machine_manager/src/config/iothread.rs index ac3a0a9e5df3445de314e4ce22beeb97fec901f7..029d6583cc6f5df53858898baa92297c3a6c44ce 100644 --- a/machine_manager/src/config/iothread.rs +++ b/machine_manager/src/config/iothread.rs @@ -11,37 +11,29 @@ // See the Mulan PSL v2 for more details. use anyhow::{anyhow, Result}; +use clap::Parser; use serde::{Deserialize, Serialize}; -use super::error::ConfigError; -use crate::config::{check_arg_too_long, CmdParser, ConfigCheck, VmConfig}; +use super::{error::ConfigError, str_slip_to_clap, valid_id}; +use crate::config::VmConfig; const MAX_IOTHREAD_NUM: usize = 8; /// Config structure for iothread. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct IothreadConfig { + #[arg(long, value_parser = ["iothread"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] pub id: String, } -impl ConfigCheck for IothreadConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "iothread id") - } -} - impl VmConfig { /// Add new iothread device to `VmConfig`. pub fn add_iothread(&mut self, iothread_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("iothread"); - cmd_parser.push("").push("id"); - cmd_parser.parse(iothread_config)?; - - let mut iothread = IothreadConfig::default(); - if let Some(id) = cmd_parser.get_value::("id")? { - iothread.id = id; - } - iothread.check()?; + let iothread = + IothreadConfig::try_parse_from(str_slip_to_clap(iothread_config, true, false))?; if self.iothreads.is_some() { if self.iothreads.as_ref().unwrap().len() >= MAX_IOTHREAD_NUM { diff --git a/machine_manager/src/config/machine_config.rs b/machine_manager/src/config/machine_config.rs index 490ee9070ee1a16a82406ad1a2114480d375bb4b..3d277ba155d3c22e708aaa0d44a60c05a19ce36a 100644 --- a/machine_manager/src/config/machine_config.rs +++ b/machine_manager/src/config/machine_config.rs @@ -13,13 +13,14 @@ use std::str::FromStr; use anyhow::{anyhow, bail, Context, Result}; +use clap::{ArgAction, Parser}; use serde::{Deserialize, Serialize}; use super::error::ConfigError; -use crate::config::{ - check_arg_too_long, check_path_too_long, CmdParser, ConfigCheck, ExBool, IntegerList, VmConfig, - MAX_NODES, +use super::{ + get_value_of_parameter, parse_bool, parse_size, str_slip_to_clap, valid_id, valid_path, }; +use crate::config::{ConfigCheck, IntegerList, VmConfig, MAX_NODES}; use crate::machine::HypervisorType; const DEFAULT_CPUS: u8 = 1; @@ -30,8 +31,8 @@ const DEFAULT_CLUSTERS: u8 = 1; const DEFAULT_SOCKETS: u8 = 1; const DEFAULT_MAX_CPUS: u8 = 1; const DEFAULT_MEMSIZE: u64 = 256; -const MAX_NR_CPUS: u64 = 254; -const MIN_NR_CPUS: u64 = 1; +const MAX_NR_CPUS: u8 = 254; +const MIN_NR_CPUS: u8 = 1; const MAX_MEMSIZE: u64 = 549_755_813_888; const MIN_MEMSIZE: u64 = 134_217_728; pub const K: u64 = 1024; @@ -46,7 +47,7 @@ pub enum MachineType { } impl FromStr for MachineType { - type Err = (); + type Err = anyhow::Error; fn from_str(s: &str) -> std::result::Result { match s.to_lowercase().as_str() { @@ -56,7 +57,7 @@ impl FromStr for MachineType { "q35" => Ok(MachineType::StandardVm), #[cfg(target_arch = "aarch64")] "virt" => Ok(MachineType::StandardVm), - _ => Err(()), + _ => Err(anyhow!("Invalid machine type.")), } } } @@ -83,32 +84,37 @@ impl From for HostMemPolicy { } } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Parser, Clone, Debug, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct MemZoneConfig { + #[arg(long, alias = "classtype", value_parser = ["memory-backend-ram", "memory-backend-file", "memory-backend-memfd"])] + pub mem_type: String, + #[arg(long, value_parser = valid_id)] pub id: String, + #[arg(long, value_parser = parse_size)] pub size: u64, - pub host_numa_nodes: Option>, + // Note: + // `Clap` will incorrectly assume that we're trying to get multiple arguments since we got + // a `Vec` from parser function `get_host_nodes`. Generally, we should use `Box` or a `new struct type` + // to encapsulate this `Vec`. And fortunately, there's a trick (using full qualified path of Vec) + // to avoid the new type wrapper. See: github.com/clap-rs/clap/issues/4626. + #[arg(long, alias = "host-nodes", value_parser = get_host_nodes)] + pub host_numa_nodes: Option<::std::vec::Vec>, + #[arg(long, default_value = "default", value_parser=["default", "preferred", "bind", "interleave"])] pub policy: String, + #[arg(long, value_parser = valid_path)] pub mem_path: Option, + #[arg(long, default_value = "true", value_parser = parse_bool, action = ArgAction::Append)] pub dump_guest_core: bool, + #[arg(long, default_value = "off", value_parser = parse_bool, action = ArgAction::Append)] pub share: bool, + #[arg(long, alias = "mem-prealloc", default_value = "false", value_parser = parse_bool, action = ArgAction::Append)] pub prealloc: bool, - pub memfd: bool, } -impl Default for MemZoneConfig { - fn default() -> Self { - MemZoneConfig { - id: String::new(), - size: 0, - host_numa_nodes: None, - policy: String::from("bind"), - mem_path: None, - dump_guest_core: true, - share: false, - prealloc: false, - memfd: false, - } +impl MemZoneConfig { + pub fn memfd(&self) -> bool { + self.mem_type.eq("memory-backend-memfd") } } @@ -136,9 +142,14 @@ impl Default for MachineMemConfig { } } -#[derive(Clone, Debug, Serialize, Deserialize, Default)] +#[derive(Parser, Clone, Debug, Serialize, Deserialize, Default)] +#[command(no_binary_name(true))] pub struct CpuConfig { + #[arg(long, alias = "classtype", value_parser = ["host"])] + pub family: String, + #[arg(long, default_value = "off")] pub pmu: PmuConfig, + #[arg(long, default_value = "off")] pub sve: SveConfig, } @@ -149,6 +160,20 @@ pub enum PmuConfig { Off, } +impl FromStr for PmuConfig { + type Err = anyhow::Error; + + fn from_str(s: &str) -> std::result::Result { + match s { + "on" => Ok(PmuConfig::On), + "off" => Ok(PmuConfig::Off), + _ => Err(anyhow!( + "Invalid PMU option,must be one of \'on\" or \"off\"." + )), + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Default)] pub enum SveConfig { On, @@ -156,6 +181,20 @@ pub enum SveConfig { Off, } +impl FromStr for SveConfig { + type Err = anyhow::Error; + + fn from_str(s: &str) -> std::result::Result { + match s { + "on" => Ok(SveConfig::On), + "off" => Ok(SveConfig::Off), + _ => Err(anyhow!( + "Invalid SVE option, must be one of \"on\" or \"off\"." + )), + } + } +} + #[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Default)] pub enum ShutdownAction { #[default] @@ -214,228 +253,203 @@ impl ConfigCheck for MachineConfig { } } -impl VmConfig { - /// Add argument `name` to `VmConfig`. - /// - /// # Arguments - /// - /// * `name` - The name `String` added to `VmConfig`. - pub fn add_machine(&mut self, mach_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("machine"); - cmd_parser - .push("") - .push("type") - .push("accel") - .push("usb") - .push("dump-guest-core") - .push("mem-share"); - #[cfg(target_arch = "aarch64")] - cmd_parser.push("gic-version"); - cmd_parser.parse(mach_config)?; +#[derive(Parser)] +#[command(no_binary_name(true))] +struct AccelConfig { + #[arg(long, alias = "classtype")] + hypervisor: HypervisorType, +} - #[cfg(target_arch = "aarch64")] - if let Some(gic_version) = cmd_parser.get_value::("gic-version")? { - if gic_version != 3 { - bail!("Unsupported gic version, only gicv3 is supported"); +#[derive(Parser)] +#[command(no_binary_name(true))] +struct MemSizeConfig { + #[arg(long, alias = "classtype", value_parser = parse_size)] + size: u64, +} + +#[derive(Parser)] +#[command(no_binary_name(true))] +struct MachineCmdConfig { + #[arg(long, aliases = ["classtype", "type"])] + mach_type: MachineType, + #[arg(long, default_value = "on", action = ArgAction::Append, value_parser = parse_bool)] + dump_guest_core: bool, + #[arg(long, default_value = "off", action = ArgAction::Append, value_parser = parse_bool)] + mem_share: bool, + #[arg(long, default_value = "kvm")] + accel: HypervisorType, + // The "usb" member is added for compatibility with libvirt and is currently not in use. + // It only supports configuration as "off". Currently, a `String` type is used to verify incoming values. + // When it will be used, it needs to be changed to a `bool` type. + #[arg(long, default_value = "off", value_parser = ["off"])] + usb: String, + #[cfg(target_arch = "aarch64")] + #[arg(long, default_value = "3", value_parser = clap::value_parser!(u8).range(3..=3))] + gic_version: u8, +} + +#[derive(Parser)] +#[command(no_binary_name(true))] +struct SmpConfig { + #[arg(long, alias = "classtype", value_parser = clap::value_parser!(u8).range(i64::from(MIN_NR_CPUS)..=i64::from(MAX_NR_CPUS)))] + cpus: u8, + #[arg(long, default_value = "0")] + maxcpus: u8, + #[arg(long, default_value = "0", value_parser = clap::value_parser!(u8).range(..i64::from(u8::MAX)))] + sockets: u8, + #[arg(long, default_value = "1", value_parser = clap::value_parser!(u8).range(1..i64::from(u8::MAX)))] + dies: u8, + #[arg(long, default_value = "1", value_parser = clap::value_parser!(u8).range(1..i64::from(u8::MAX)))] + clusters: u8, + #[arg(long, default_value = "0", value_parser = clap::value_parser!(u8).range(..i64::from(u8::MAX)))] + cores: u8, + #[arg(long, default_value = "0", value_parser = clap::value_parser!(u8).range(..i64::from(u8::MAX)))] + threads: u8, +} + +impl SmpConfig { + fn auto_adjust_topology(&mut self) -> Result<()> { + let mut max_cpus = self.maxcpus; + let mut sockets = self.sockets; + let mut cores = self.cores; + let mut threads = self.threads; + + if max_cpus == 0 { + let mut tmp_max = sockets + .checked_mul(self.dies) + .with_context(|| "Illegal smp config")?; + tmp_max = tmp_max + .checked_mul(self.clusters) + .with_context(|| "Illegal smp config")?; + tmp_max = tmp_max + .checked_mul(cores) + .with_context(|| "Illegal smp config")?; + tmp_max = tmp_max + .checked_mul(threads) + .with_context(|| "Illegal smp config")?; + + if tmp_max > 0 { + max_cpus = tmp_max; + } else { + max_cpus = self.cpus; } } - if let Some(accel) = cmd_parser.get_value::("accel")? { - // Libvirt checks the parameter types of 'kvm', 'kvm:tcg' and 'tcg'. - if accel.ne("kvm:tcg") && accel.ne("tcg") && accel.ne("kvm") && accel.ne("test") { - bail!("Only \'kvm\', \'kvm:tcg\', \'test\' and \'tcg\' are supported for \'accel\' of \'machine\'"); + if cores == 0 { + if sockets == 0 { + sockets = 1; } - - match accel.as_str() { - "test" => self.machine_config.hypervisor = HypervisorType::Test, - _ => self.machine_config.hypervisor = HypervisorType::Kvm, - }; - } - if let Some(usb) = cmd_parser.get_value::("usb")? { - if usb.into() { - bail!("Argument \'usb\' should be set to \'off\'"); + if threads == 0 { + threads = 1; + } + cores = max_cpus / (sockets * self.dies * self.clusters * threads); + } else if sockets == 0 { + if threads == 0 { + threads = 1; } + sockets = max_cpus / (self.dies * self.clusters * cores * threads); } - if let Some(mach_type) = cmd_parser - .get_value::("") - .with_context(|| "Unrecognized machine type")? - { - self.machine_config.mach_type = mach_type; + + if threads == 0 { + threads = max_cpus / (sockets * self.dies * self.clusters * cores); } - if let Some(mach_type) = cmd_parser - .get_value::("type") - .with_context(|| "Unrecognized machine type")? - { - self.machine_config.mach_type = mach_type; + + let min_max_cpus = std::cmp::max(self.cpus, MIN_NR_CPUS); + + if !(min_max_cpus..=MAX_NR_CPUS).contains(&max_cpus) { + return Err(anyhow!(ConfigError::IllegalValue( + "MAX CPU number".to_string(), + u64::from(min_max_cpus), + true, + u64::from(MAX_NR_CPUS), + true, + ))); } - if let Some(dump_guest) = cmd_parser.get_value::("dump-guest-core")? { - self.machine_config.mem_config.dump_guest_core = dump_guest.into(); + + if sockets * self.dies * self.clusters * cores * threads != max_cpus { + bail!("sockets * dies * clusters * cores * threads must be equal to max_cpus"); } - if let Some(mem_share) = cmd_parser.get_value::("mem-share")? { - self.machine_config.mem_config.mem_share = mem_share.into(); + + self.maxcpus = max_cpus; + self.sockets = sockets; + self.cores = cores; + self.threads = threads; + + Ok(()) + } +} + +impl VmConfig { + /// Add argument `name` to `VmConfig`. + /// + /// # Arguments + /// + /// * `name` - The name `String` added to `VmConfig`. + pub fn add_machine(&mut self, mach_config: &str) -> Result<()> { + let mut has_type_label = false; + if get_value_of_parameter("type", mach_config).is_ok() { + has_type_label = true; } + let mach_cfg = MachineCmdConfig::try_parse_from(str_slip_to_clap( + mach_config, + !has_type_label, + false, + ))?; + // TODO: The current "accel" configuration in "-machine" command line and "-accel" command line are not foolproof. + // Later parsing will overwrite first parsing. We will optimize this in the future. + self.machine_config.hypervisor = mach_cfg.accel; + self.machine_config.mach_type = mach_cfg.mach_type; + self.machine_config.mem_config.dump_guest_core = mach_cfg.dump_guest_core; + self.machine_config.mem_config.mem_share = mach_cfg.mem_share; Ok(()) } /// Add '-accel' accelerator config to `VmConfig`. pub fn add_accel(&mut self, accel_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("accel"); - cmd_parser.push(""); - cmd_parser.parse(accel_config)?; - - if let Some(accel) = cmd_parser - .get_value::("") - .with_context(|| "Only \'kvm\' and \'test\' is supported for \'accel\'")? - { - self.machine_config.hypervisor = accel; - } - + let accel_cfg = AccelConfig::try_parse_from(str_slip_to_clap(accel_config, true, false))?; + self.machine_config.hypervisor = accel_cfg.hypervisor; Ok(()) } /// Add '-m' memory config to `VmConfig`. pub fn add_memory(&mut self, mem_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("m"); - cmd_parser.push("").push("size"); - - cmd_parser.parse(mem_config)?; - - let mem = if let Some(mem_size) = cmd_parser.get_value::("")? { - memory_unit_conversion(&mem_size, M)? - } else if let Some(mem_size) = cmd_parser.get_value::("size")? { - memory_unit_conversion(&mem_size, M)? - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "size".to_string(), - "memory".to_string() - ))); - }; - - self.machine_config.mem_config.mem_size = mem; + // Is there a "size=" prefix tag in the command line. + let mut has_size_label = false; + if get_value_of_parameter("size", mem_config).is_ok() { + has_size_label = true; + } + let mem_cfg = + MemSizeConfig::try_parse_from(str_slip_to_clap(mem_config, !has_size_label, false))?; + self.machine_config.mem_config.mem_size = mem_cfg.size; Ok(()) } /// Add '-smp' cpu config to `VmConfig`. pub fn add_cpu(&mut self, cpu_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("smp"); - cmd_parser - .push("") - .push("maxcpus") - .push("sockets") - .push("dies") - .push("clusters") - .push("cores") - .push("threads") - .push("cpus"); - - cmd_parser.parse(cpu_config)?; - - let cpu = if let Some(cpu) = cmd_parser.get_value::("")? { - cpu - } else if let Some(cpu) = cmd_parser.get_value::("cpus")? { - if cpu == 0 { - return Err(anyhow!(ConfigError::IllegalValue( - "cpu".to_string(), - 1, - true, - MAX_NR_CPUS, - true - ))); - } - cpu - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "cpus".to_string(), - "smp".to_string() - ))); - }; - - let sockets = smp_read_and_check(&cmd_parser, "sockets", 0)?; - - let dies = smp_read_and_check(&cmd_parser, "dies", 1)?; - - let clusters = smp_read_and_check(&cmd_parser, "clusters", 1)?; - - let cores = smp_read_and_check(&cmd_parser, "cores", 0)?; - - let threads = smp_read_and_check(&cmd_parser, "threads", 0)?; - - let max_cpus = cmd_parser.get_value::("maxcpus")?.unwrap_or_default(); - - let (max_cpus, sockets, cores, threads) = - adjust_topology(cpu, max_cpus, sockets, dies, clusters, cores, threads); - - // limit cpu count - if !(MIN_NR_CPUS..=MAX_NR_CPUS).contains(&cpu) { - return Err(anyhow!(ConfigError::IllegalValue( - "CPU number".to_string(), - MIN_NR_CPUS, - true, - MAX_NR_CPUS, - true, - ))); - } - - if !(MIN_NR_CPUS..=MAX_NR_CPUS).contains(&max_cpus) { - return Err(anyhow!(ConfigError::IllegalValue( - "MAX CPU number".to_string(), - MIN_NR_CPUS, - true, - MAX_NR_CPUS, - true, - ))); + let mut has_cpus_label = false; + if get_value_of_parameter("cpus", cpu_config).is_ok() { + has_cpus_label = true; } - - if max_cpus < cpu { - return Err(anyhow!(ConfigError::IllegalValue( - "maxcpus".to_string(), - cpu, - true, - MAX_NR_CPUS, - true, - ))); - } - - if sockets * dies * clusters * cores * threads != max_cpus { - bail!("sockets * dies * clusters * cores * threads must be equal to max_cpus"); - } - - self.machine_config.nr_cpus = cpu as u8; - self.machine_config.nr_threads = threads as u8; - self.machine_config.nr_cores = cores as u8; - self.machine_config.nr_dies = dies as u8; - self.machine_config.nr_clusters = clusters as u8; - self.machine_config.nr_sockets = sockets as u8; - self.machine_config.max_cpus = max_cpus as u8; + let mut smp_cfg = + SmpConfig::try_parse_from(str_slip_to_clap(cpu_config, !has_cpus_label, false))?; + smp_cfg.auto_adjust_topology()?; + + self.machine_config.nr_cpus = smp_cfg.cpus; + self.machine_config.nr_threads = smp_cfg.threads; + self.machine_config.nr_cores = smp_cfg.cores; + self.machine_config.nr_dies = smp_cfg.dies; + self.machine_config.nr_clusters = smp_cfg.clusters; + self.machine_config.nr_sockets = smp_cfg.sockets; + self.machine_config.max_cpus = smp_cfg.maxcpus; Ok(()) } pub fn add_cpu_feature(&mut self, features: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("cpu"); - cmd_parser.push(""); - cmd_parser.push("pmu"); - cmd_parser.push("sve"); - cmd_parser.parse(features)?; - - // Check PMU when actually enabling PMU. - if let Some(k) = cmd_parser.get_value::("pmu")? { - self.machine_config.cpu_config.pmu = match k.as_ref() { - "on" => PmuConfig::On, - "off" => PmuConfig::Off, - _ => bail!("Invalid PMU option,must be one of \'on\" or \"off\"."), - } - } - - if let Some(k) = cmd_parser.get_value::("sve")? { - self.machine_config.cpu_config.sve = match k.as_ref() { - "on" => SveConfig::On, - "off" => SveConfig::Off, - _ => bail!("Invalid SVE option, must be one of \"on\" or \"off\"."), - } - } + let cpu_config = CpuConfig::try_parse_from(str_slip_to_clap(features, true, false))?; + self.machine_config.cpu_config = cpu_config; Ok(()) } @@ -458,152 +472,34 @@ impl VmConfig { self.machine_config.battery = true; true } -} - -impl VmConfig { - fn get_mem_zone_id(&self, cmd_parser: &CmdParser) -> Result { - if let Some(id) = cmd_parser.get_value::("id")? { - check_arg_too_long(&id, "id")?; - Ok(id) - } else { - Err(anyhow!(ConfigError::FieldIsMissing( - "id".to_string(), - "memory-backend-ram".to_string() - ))) - } - } - - fn get_mem_path(&self, cmd_parser: &CmdParser) -> Result> { - if let Some(path) = cmd_parser.get_value::("mem-path")? { - check_path_too_long(&path, "mem-path")?; - return Ok(Some(path)); - } - Ok(None) - } - fn get_mem_zone_size(&self, cmd_parser: &CmdParser) -> Result { - if let Some(mem) = cmd_parser.get_value::("size")? { - let size = memory_unit_conversion(&mem, M)?; - Ok(size) - } else { - Err(anyhow!(ConfigError::FieldIsMissing( - "size".to_string(), - "memory-backend-ram".to_string() - ))) - } - } - - fn get_mem_zone_host_nodes(&self, cmd_parser: &CmdParser) -> Result>> { - if let Some(mut host_nodes) = cmd_parser - .get_value::("host-nodes") - .with_context(|| { - ConfigError::ConvertValueFailed(String::from("u32"), "host-nodes".to_string()) - })? - .map(|v| v.0.iter().map(|e| *e as u32).collect::>()) - { - host_nodes.sort_unstable(); - if host_nodes[host_nodes.len() - 1] >= MAX_NODES { - return Err(anyhow!(ConfigError::IllegalValue( - "host_nodes".to_string(), - 0, - true, - MAX_NODES as u64, - false, - ))); - } - Ok(Some(host_nodes)) - } else { - Ok(None) - } - } - - fn get_mem_zone_policy(&self, cmd_parser: &CmdParser) -> Result { - let policy = cmd_parser - .get_value::("policy")? - .unwrap_or_else(|| "default".to_string()); - if HostMemPolicy::from(policy.clone()) == HostMemPolicy::NotSupported { - return Err(anyhow!(ConfigError::InvalidParam( - "policy".to_string(), - policy - ))); - } - Ok(policy) - } - - fn get_mem_share(&self, cmd_parser: &CmdParser) -> Result { - let share = cmd_parser - .get_value::("share")? - .unwrap_or_else(|| "off".to_string()); - - if share.eq("on") || share.eq("off") { - Ok(share.eq("on")) - } else { - Err(anyhow!(ConfigError::InvalidParam( - "share".to_string(), - share - ))) - } - } - - fn get_mem_dump(&self, cmd_parser: &CmdParser) -> Result { - if let Some(dump_guest) = cmd_parser.get_value::("dump-guest-core")? { - return Ok(dump_guest.into()); - } - Ok(true) - } - - fn get_mem_prealloc(&self, cmd_parser: &CmdParser) -> Result { - if let Some(mem_prealloc) = cmd_parser.get_value::("mem-prealloc")? { - return Ok(mem_prealloc.into()); - } - Ok(false) + pub fn add_hw_signature(&mut self, config: &str) -> Result<()> { + self.hardware_signature = Some(u32::from_str(config)?); + Ok(()) } +} +impl VmConfig { /// Convert memory zone cmdline to VM config /// /// # Arguments /// /// * `mem_zone` - The memory zone cmdline string. - /// * `mem_type` - The memory zone type - pub fn add_mem_zone(&mut self, mem_zone: &str, mem_type: String) -> Result { - let mut cmd_parser = CmdParser::new("mem_zone"); - cmd_parser - .push("") - .push("id") - .push("size") - .push("host-nodes") - .push("policy") - .push("share") - .push("mem-path") - .push("dump-guest-core") - .push("mem-prealloc"); - cmd_parser.parse(mem_zone)?; - - let zone_config = MemZoneConfig { - id: self.get_mem_zone_id(&cmd_parser)?, - size: self.get_mem_zone_size(&cmd_parser)?, - host_numa_nodes: self.get_mem_zone_host_nodes(&cmd_parser)?, - policy: self.get_mem_zone_policy(&cmd_parser)?, - dump_guest_core: self.get_mem_dump(&cmd_parser)?, - share: self.get_mem_share(&cmd_parser)?, - mem_path: self.get_mem_path(&cmd_parser)?, - prealloc: self.get_mem_prealloc(&cmd_parser)?, - memfd: mem_type.eq("memory-backend-memfd"), - }; + pub fn add_mem_zone(&mut self, mem_zone: &str) -> Result { + let zone_config = MemZoneConfig::try_parse_from(str_slip_to_clap(mem_zone, true, false))?; - if (zone_config.mem_path.is_none() && mem_type.eq("memory-backend-file")) - || (zone_config.mem_path.is_some() && mem_type.ne("memory-backend-file")) + if (zone_config.mem_path.is_none() && zone_config.mem_type.eq("memory-backend-file")) + || (zone_config.mem_path.is_some() && zone_config.mem_type.ne("memory-backend-file")) { - bail!("Object type: {} config path err", mem_type); + bail!("Object type: {} config path err", zone_config.mem_type); } - if self.object.mem_object.get(&zone_config.id).is_none() { - self.object - .mem_object - .insert(zone_config.id.clone(), zone_config.clone()); - } else { + if self.object.mem_object.contains_key(&zone_config.id) { bail!("Object: {} has been added", zone_config.id); } + self.object + .mem_object + .insert(zone_config.id.clone(), zone_config.clone()); if zone_config.host_numa_nodes.is_none() { return Ok(zone_config); @@ -624,62 +520,6 @@ impl VmConfig { } } -fn smp_read_and_check(cmd_parser: &CmdParser, name: &str, default_val: u64) -> Result { - if let Some(values) = cmd_parser.get_value::(name)? { - if values == 0 { - return Err(anyhow!(ConfigError::IllegalValue( - name.to_string(), - 1, - true, - u8::MAX as u64, - false - ))); - } - Ok(values) - } else { - Ok(default_val) - } -} - -fn adjust_topology( - cpu: u64, - mut max_cpus: u64, - mut sockets: u64, - dies: u64, - clusters: u64, - mut cores: u64, - mut threads: u64, -) -> (u64, u64, u64, u64) { - if max_cpus == 0 { - if sockets * dies * clusters * cores * threads > 0 { - max_cpus = sockets * dies * clusters * cores * threads; - } else { - max_cpus = cpu; - } - } - - if cores == 0 { - if sockets == 0 { - sockets = 1; - } - if threads == 0 { - threads = 1; - } - cores = max_cpus / (sockets * dies * clusters * threads); - } else if sockets == 0 { - if threads == 0 { - threads = 1; - } - sockets = max_cpus / (dies * clusters * cores * threads); - } - - if threads == 0 { - threads = max_cpus / (sockets * dies * clusters * cores); - } - - (max_cpus, sockets, cores, threads) -} - /// Convert memory units from GiB, Mib to Byte. /// /// # Arguments @@ -740,6 +580,34 @@ fn get_inner(outer: Option) -> Result { outer.with_context(|| ConfigError::IntegerOverflow("-m".to_string())) } +fn get_host_nodes(nodes: &str) -> Result> { + let mut host_nodes = IntegerList::from_str(nodes) + .with_context(|| { + ConfigError::ConvertValueFailed(String::from("u32"), "host-nodes".to_string()) + })? + .0 + .iter() + .map(|e| *e as u32) + .collect::>(); + + if host_nodes.is_empty() { + bail!("Got empty host nodes list!"); + } + + host_nodes.sort_unstable(); + if host_nodes[host_nodes.len() - 1] >= MAX_NODES { + return Err(anyhow!(ConfigError::IllegalValue( + "host_nodes".to_string(), + 0, + true, + u64::from(MAX_NODES), + false, + ))); + } + + Ok(host_nodes) +} + #[cfg(test)] mod tests { use super::*; @@ -777,9 +645,9 @@ mod tests { machine_config.nr_cpus = MIN_NR_CPUS as u8; machine_config.mem_config.mem_size = MIN_MEMSIZE - 1; - assert!(!machine_config.check().is_ok()); + assert!(machine_config.check().is_err()); machine_config.mem_config.mem_size = MAX_MEMSIZE + 1; - assert!(!machine_config.check().is_ok()); + assert!(machine_config.check().is_err()); machine_config.mem_config.mem_size = MIN_MEMSIZE; assert!(machine_config.check().is_ok()); @@ -976,17 +844,18 @@ mod tests { assert!(machine_cfg_ret.is_ok()); let machine_cfg = vm_config.machine_config; assert_eq!(machine_cfg.mach_type, MachineType::None); - assert_eq!(machine_cfg.mem_config.dump_guest_core, true); - assert_eq!(machine_cfg.mem_config.mem_share, true); + assert!(machine_cfg.mem_config.dump_guest_core); + assert!(machine_cfg.mem_config.mem_share); let mut vm_config = VmConfig::default(); - let memory_cfg_str = "type=none,dump-guest-core=off,mem-share=off,accel=kvm,usb=off"; + let memory_cfg_str = "none,dump-guest-core=off,mem-share=off,accel=kvm,usb=off"; let machine_cfg_ret = vm_config.add_machine(memory_cfg_str); assert!(machine_cfg_ret.is_ok()); let machine_cfg = vm_config.machine_config; assert_eq!(machine_cfg.mach_type, MachineType::None); - assert_eq!(machine_cfg.mem_config.dump_guest_core, false); - assert_eq!(machine_cfg.mem_config.mem_share, false); + assert_eq!(machine_cfg.hypervisor, HypervisorType::Kvm); + assert!(!machine_cfg.mem_config.dump_guest_core); + assert!(!machine_cfg.mem_config.mem_share); let mut vm_config = VmConfig::default(); let memory_cfg_str = "type=none,accel=kvm-tcg"; @@ -1047,10 +916,10 @@ mod tests { let mut vm_config = VmConfig::default(); let mem_prealloc = vm_config.machine_config.mem_config.mem_prealloc; // default value is false. - assert_eq!(mem_prealloc, false); + assert!(!mem_prealloc); vm_config.enable_mem_prealloc(); let mem_prealloc = vm_config.machine_config.mem_config.mem_prealloc; - assert_eq!(mem_prealloc, true); + assert!(mem_prealloc); } #[test] @@ -1089,10 +958,7 @@ mod tests { fn test_add_mem_zone() { let mut vm_config = VmConfig::default(); let zone_config_1 = vm_config - .add_mem_zone( - "-object memory-backend-ram,size=2G,id=mem1,host-nodes=1,policy=bind", - String::from("memory-backend-ram"), - ) + .add_mem_zone("memory-backend-ram,size=2G,id=mem1,host-nodes=1,policy=bind") .unwrap(); assert_eq!(zone_config_1.id, "mem1"); assert_eq!(zone_config_1.size, 2147483648); @@ -1100,38 +966,26 @@ mod tests { assert_eq!(zone_config_1.policy, "bind"); let zone_config_2 = vm_config - .add_mem_zone( - "-object memory-backend-ram,size=2G,id=mem2,host-nodes=1-2,policy=default", - String::from("memory-backend-ram"), - ) + .add_mem_zone("memory-backend-ram,size=2G,id=mem2,host-nodes=1-2,policy=default") .unwrap(); assert_eq!(zone_config_2.host_numa_nodes, Some(vec![1, 2])); let zone_config_3 = vm_config - .add_mem_zone( - "-object memory-backend-ram,size=2M,id=mem3,share=on", - String::from("memory-backend-ram"), - ) + .add_mem_zone("memory-backend-ram,size=2M,id=mem3,share=on") .unwrap(); assert_eq!(zone_config_3.size, 2 * 1024 * 1024); - assert_eq!(zone_config_3.share, true); + assert!(zone_config_3.share); let zone_config_4 = vm_config - .add_mem_zone( - "-object memory-backend-ram,size=2M,id=mem4", - String::from("memory-backend-ram"), - ) + .add_mem_zone("memory-backend-ram,size=2M,id=mem4") .unwrap(); - assert_eq!(zone_config_4.share, false); - assert_eq!(zone_config_4.memfd, false); + assert!(!zone_config_4.share); + assert!(!zone_config_4.memfd()); let zone_config_5 = vm_config - .add_mem_zone( - "-object memory-backend-memfd,size=2M,id=mem5", - String::from("memory-backend-memfd"), - ) + .add_mem_zone("memory-backend-memfd,size=2M,id=mem5") .unwrap(); - assert_eq!(zone_config_5.memfd, true); + assert!(zone_config_5.memfd()); } #[test] @@ -1155,15 +1009,44 @@ mod tests { assert!(vm_config.machine_config.cpu_config.pmu == PmuConfig::Off); vm_config.add_cpu_feature("host,pmu=off").unwrap(); assert!(vm_config.machine_config.cpu_config.pmu == PmuConfig::Off); - vm_config.add_cpu_feature("pmu=off").unwrap(); - assert!(vm_config.machine_config.cpu_config.pmu == PmuConfig::Off); vm_config.add_cpu_feature("host,pmu=on").unwrap(); assert!(vm_config.machine_config.cpu_config.pmu == PmuConfig::On); - vm_config.add_cpu_feature("pmu=on").unwrap(); - assert!(vm_config.machine_config.cpu_config.pmu == PmuConfig::On); - vm_config.add_cpu_feature("sve=on").unwrap(); + vm_config.add_cpu_feature("host,sve=on").unwrap(); assert!(vm_config.machine_config.cpu_config.sve == SveConfig::On); - vm_config.add_cpu_feature("sve=off").unwrap(); + vm_config.add_cpu_feature("host,sve=off").unwrap(); assert!(vm_config.machine_config.cpu_config.sve == SveConfig::Off); + + // Illegal cpu command lines: should set cpu family. + let result = vm_config.add_cpu_feature("pmu=off"); + assert!(result.is_err()); + let result = vm_config.add_cpu_feature("sve=on"); + assert!(result.is_err()); + + // Illegal parameters. + let result = vm_config.add_cpu_feature("host,sve1=on"); + assert!(result.is_err()); + + // Illegal values. + let result = vm_config.add_cpu_feature("host,sve=false"); + assert!(result.is_err()); + } + + #[test] + fn test_add_accel() { + let mut vm_config = VmConfig::default(); + let accel_cfg = "kvm"; + assert!(vm_config.add_accel(accel_cfg).is_ok()); + let machine_cfg = vm_config.machine_config; + assert_eq!(machine_cfg.hypervisor, HypervisorType::Kvm); + + let mut vm_config = VmConfig::default(); + let accel_cfg = "kvm:tcg"; + assert!(vm_config.add_accel(accel_cfg).is_ok()); + let machine_cfg = vm_config.machine_config; + assert_eq!(machine_cfg.hypervisor, HypervisorType::Kvm); + + let mut vm_config = VmConfig::default(); + let accel_cfg = "kvm1"; + assert!(vm_config.add_accel(accel_cfg).is_err()); } } diff --git a/machine_manager/src/config/mod.rs b/machine_manager/src/config/mod.rs index 8b2944d8cf8e2679453d2308b5aa7f1efe14a163..6f793d4e677a200d6b2cc60a0e1dfce5e200914a 100644 --- a/machine_manager/src/config/mod.rs +++ b/machine_manager/src/config/mod.rs @@ -20,79 +20,58 @@ pub mod vnc; mod boot_source; mod chardev; -#[cfg(feature = "demo_device")] -mod demo_dev; mod devices; mod drive; -mod fs; -#[cfg(feature = "virtio_gpu")] -mod gpu; mod incoming; mod iothread; mod machine_config; mod network; mod numa; mod pci; -#[cfg(feature = "pvpanic")] -mod pvpanic_pci; -#[cfg(all(feature = "ramfb", target_arch = "aarch64"))] -mod ramfb; mod rng; #[cfg(feature = "vnc_auth")] mod sasl_auth; -mod scsi; mod smbios; #[cfg(feature = "vnc_auth")] mod tls_creds; -mod usb; -mod vfio; pub use boot_source::*; #[cfg(feature = "usb_camera")] pub use camera::*; pub use chardev::*; -#[cfg(feature = "demo_device")] -pub use demo_dev::*; -pub use devices::*; #[cfg(any(feature = "gtk", feature = "ohui_srv"))] pub use display::*; pub use drive::*; pub use error::ConfigError; -pub use fs::*; -#[cfg(feature = "virtio_gpu")] -pub use gpu::*; pub use incoming::*; pub use iothread::*; pub use machine_config::*; pub use network::*; pub use numa::*; pub use pci::*; -#[cfg(feature = "pvpanic")] -pub use pvpanic_pci::*; -#[cfg(all(feature = "ramfb", target_arch = "aarch64"))] -pub use ramfb::*; pub use rng::*; #[cfg(feature = "vnc_auth")] pub use sasl_auth::*; -pub use scsi::*; pub use smbios::*; #[cfg(feature = "vnc_auth")] pub use tls_creds::*; -pub use usb::*; -pub use vfio::*; #[cfg(feature = "vnc")] pub use vnc::*; use std::collections::HashMap; -use std::fs::File; +use std::fs::{canonicalize, File}; use std::io::Read; +use std::os::unix::io::AsRawFd; +use std::path::Path; use std::str::FromStr; +use std::sync::Arc; use anyhow::{anyhow, bail, Context, Result}; -use log::error; +use clap::Parser; +use log::{error, info}; use serde::{Deserialize, Serialize}; -use trace::set_state_by_pattern; +use trace::{enable_state_by_type, set_state_by_pattern, TraceType}; #[cfg(target_arch = "aarch64")] use util::device_tree::{self, FdtBuilder}; use util::{ @@ -110,10 +89,22 @@ pub const MAX_SOCK_PATH_LENGTH: usize = 108; pub const MAX_VIRTIO_QUEUE: usize = 32; pub const FAST_UNPLUG_ON: &str = "1"; pub const FAST_UNPLUG_OFF: &str = "0"; -pub const MAX_TAG_LENGTH: usize = 36; pub const MAX_NODES: u32 = 128; /// Default virtqueue size for virtio devices excepts virtio-fs. pub const DEFAULT_VIRTQUEUE_SIZE: u16 = 256; +// Seg_max = queue_size - 2. So, size of each virtqueue for virtio-scsi/virtio-blk should be larger than 2. +pub const MIN_QUEUE_SIZE_BLOCK_DEVICE: u64 = 2; +// Max size of each virtqueue for virtio-scsi/virtio-blk. +pub const MAX_QUEUE_SIZE_BLOCK_DEVICE: u64 = 1024; +/// The bar0 size of enable_bar0 features +pub const VIRTIO_GPU_ENABLE_BAR0_SIZE: u64 = 64 * M; + +#[derive(Parser)] +#[command(no_binary_name(true))] +struct GlobalConfig { + #[arg(long, alias = "pcie-root-port.fast-unplug", value_parser = ["0", "1"])] + fast_unplug: Option, +} #[derive(Clone, Default, Debug, Serialize, Deserialize)] pub struct ObjectConfig { @@ -139,11 +130,12 @@ pub struct VmConfig { pub serial: Option, pub iothreads: Option>, pub object: ObjectConfig, - pub pflashs: Option>, + pub pflashs: Option>, pub dev_name: HashMap, pub global_config: HashMap, pub numa_nodes: Vec<(String, String)>, pub incoming: Option, + pub hardware_signature: Option, #[cfg(feature = "vnc")] pub vnc: Option, #[cfg(any(feature = "gtk", all(target_env = "ohos", feature = "ohui_srv")))] @@ -151,7 +143,7 @@ pub struct VmConfig { #[cfg(feature = "usb_camera")] pub camera_backend: HashMap, #[cfg(feature = "windows_emu_pid")] - pub windows_emu_pid: Option, + pub emulator_pid: Option, pub smbios: SmbiosConfig, } @@ -179,12 +171,12 @@ impl VmConfig { let mut stdio_count = 0; if let Some(serial) = self.serial.as_ref() { - if serial.chardev.backend == ChardevType::Stdio { + if let ChardevType::Stdio { .. } = serial.chardev.classtype { stdio_count += 1; } } for (_, char_dev) in self.chardev.clone() { - if char_dev.backend == ChardevType::Stdio { + if let ChardevType::Stdio { .. } = char_dev.classtype { stdio_count += 1; } } @@ -214,29 +206,24 @@ impl VmConfig { /// /// * `object_args` - The args of object. pub fn add_object(&mut self, object_args: &str) -> Result<()> { - let mut cmd_params = CmdParser::new("object"); - cmd_params.push(""); - - cmd_params.get_parameters(object_args)?; - let device_type = cmd_params - .get_value::("")? - .with_context(|| "Object type not specified")?; - match device_type.as_str() { + let object_type = + get_class_type(object_args).with_context(|| "Object type not specified")?; + match object_type.as_str() { "iothread" => { self.add_iothread(object_args) .with_context(|| "Failed to add iothread")?; } "rng-random" => { - let rng_cfg = parse_rng_obj(object_args)?; + let rng_cfg = + RngObjConfig::try_parse_from(str_slip_to_clap(object_args, true, false))?; let id = rng_cfg.id.clone(); - if self.object.rng_object.get(&id).is_none() { - self.object.rng_object.insert(id, rng_cfg); - } else { + if self.object.rng_object.contains_key(&id) { bail!("Object: {} has been added", id); } + self.object.rng_object.insert(id, rng_cfg); } "memory-backend-ram" | "memory-backend-file" | "memory-backend-memfd" => { - self.add_mem_zone(object_args, device_type)?; + self.add_mem_zone(object_args)?; } #[cfg(feature = "vnc_auth")] "tls-creds-x509" => { @@ -247,7 +234,7 @@ impl VmConfig { self.add_saslauth(object_args)?; } _ => { - bail!("Unknow object type: {:?}", &device_type); + bail!("Unknow object type: {:?}", &object_type); } } @@ -260,24 +247,18 @@ impl VmConfig { /// /// * `global_config` - The args of global config. pub fn add_global_config(&mut self, global_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("global"); - cmd_parser.push("pcie-root-port.fast-unplug"); - cmd_parser.parse(global_config)?; + let global_config = + GlobalConfig::try_parse_from(str_slip_to_clap(global_config, false, false))?; - if let Some(fast_unplug_value) = - cmd_parser.get_value::("pcie-root-port.fast-unplug")? - { - if fast_unplug_value != FAST_UNPLUG_ON && fast_unplug_value != FAST_UNPLUG_OFF { - bail!("The value of fast-unplug is invalid: {}", fast_unplug_value); - } + if let Some(fast_unplug_value) = global_config.fast_unplug { let fast_unplug_key = String::from("pcie-root-port.fast-unplug"); - if self.global_config.get(&fast_unplug_key).is_none() { - self.global_config - .insert(fast_unplug_key, fast_unplug_value); - } else { + if self.global_config.contains_key(&fast_unplug_key) { bail!("Global config {} has been added", fast_unplug_key); } + self.global_config + .insert(fast_unplug_key, fast_unplug_value); } + Ok(()) } @@ -289,9 +270,9 @@ impl VmConfig { #[cfg(feature = "windows_emu_pid")] pub fn add_windows_emu_pid(&mut self, windows_emu_pid: &str) -> Result<()> { if windows_emu_pid.is_empty() { - bail!("The arg of windows_emu_pid is empty!"); + bail!("The arg of emulator_pid is empty!"); } - self.windows_emu_pid = Some(windows_emu_pid.to_string()); + self.emulator_pid = Some(windows_emu_pid.to_string()); Ok(()) } @@ -326,7 +307,7 @@ impl VmConfig { } let drive_file = DriveFile { id: id.to_string(), - file, + file: Arc::new(file), count: 1, read_only, path: path.to_string(), @@ -334,6 +315,7 @@ impl VmConfig { req_align, buf_align, }; + info!("Open file {}, fd: {}", path, drive_file.file.as_raw_fd()); drive_files.insert(path.to_string(), drive_file); Ok(()) } @@ -358,12 +340,12 @@ impl VmConfig { } /// Get a file from drive file store. - pub fn fetch_drive_file(drive_files: &HashMap, path: &str) -> Result { + pub fn fetch_drive_file( + drive_files: &HashMap, + path: &str, + ) -> Result> { match drive_files.get(path) { - Some(drive_file) => drive_file - .file - .try_clone() - .with_context(|| format!("Failed to clone drive backend file {}", path)), + Some(drive_file) => Ok(drive_file.file.clone()), None => Err(anyhow!("The file {} is not in drive backend", path)), } } @@ -395,7 +377,7 @@ impl VmConfig { &mut drive_files, &drive.id, &drive.path_on_host, - drive.read_only, + drive.readonly, drive.direct, )?; } @@ -405,7 +387,7 @@ impl VmConfig { &mut drive_files, "", &pflash.path_on_host, - pflash.read_only, + pflash.readonly, false, )?; } @@ -436,163 +418,6 @@ pub trait ConfigCheck: AsAny + Send + Sync + std::fmt::Debug { fn check(&self) -> Result<()>; } -/// Struct `CmdParser` used to parse and check cmdline parameters to vm config. -pub struct CmdParser { - name: String, - params: HashMap>, -} - -impl CmdParser { - /// Allocates an empty `CmdParser`. - pub fn new(name: &str) -> Self { - CmdParser { - name: name.to_string(), - params: HashMap::>::new(), - } - } - - /// Push a new param field into `params`. - /// - /// # Arguments - /// - /// * `param_field`: The cmdline parameter field name. - pub fn push(&mut self, param_field: &str) -> &mut Self { - self.params.insert(param_field.to_string(), None); - - self - } - - /// Parse cmdline parameters string into `params`. - /// - /// # Arguments - /// - /// * `cmd_param`: The whole cmdline parameter string. - pub fn parse(&mut self, cmd_param: &str) -> Result<()> { - if cmd_param.starts_with(',') || cmd_param.ends_with(',') { - return Err(anyhow!(ConfigError::InvalidParam( - cmd_param.to_string(), - self.name.clone() - ))); - } - let param_items = cmd_param.split(',').collect::>(); - for (i, param_item) in param_items.iter().enumerate() { - if param_item.starts_with('=') || param_item.ends_with('=') { - return Err(anyhow!(ConfigError::InvalidParam( - param_item.to_string(), - self.name.clone() - ))); - } - let param = param_item.splitn(2, '=').collect::>(); - let (param_key, param_value) = match param.len() { - 1 => { - if i == 0 { - ("", param[0]) - } else { - (param[0], "") - } - } - 2 => (param[0], param[1]), - _ => { - return Err(anyhow!(ConfigError::InvalidParam( - param_item.to_string(), - self.name.clone() - ))); - } - }; - - if self.params.contains_key(param_key) { - let field_value = self.params.get_mut(param_key).unwrap(); - if field_value.is_none() { - *field_value = Some(String::from(param_value)); - } else { - return Err(anyhow!(ConfigError::FieldRepeat( - self.name.clone(), - param_key.to_string() - ))); - } - } else { - return Err(anyhow!(ConfigError::InvalidParam( - param[0].to_string(), - self.name.clone() - ))); - } - } - - Ok(()) - } - - /// Parse all cmdline parameters string into `params`. - /// - /// # Arguments - /// - /// * `cmd_param`: The whole cmdline parameter string. - fn get_parameters(&mut self, cmd_param: &str) -> Result<()> { - if cmd_param.starts_with(',') || cmd_param.ends_with(',') { - return Err(anyhow!(ConfigError::InvalidParam( - cmd_param.to_string(), - self.name.clone() - ))); - } - let param_items = cmd_param.split(',').collect::>(); - for param_item in param_items { - let param = param_item.splitn(2, '=').collect::>(); - let (param_key, param_value) = match param.len() { - 1 => ("", param[0]), - 2 => (param[0], param[1]), - _ => { - return Err(anyhow!(ConfigError::InvalidParam( - param_item.to_string(), - self.name.clone() - ))); - } - }; - - if self.params.contains_key(param_key) { - let field_value = self.params.get_mut(param_key).unwrap(); - if field_value.is_none() { - *field_value = Some(String::from(param_value)); - } else { - return Err(anyhow!(ConfigError::FieldRepeat( - self.name.clone(), - param_key.to_string() - ))); - } - } - } - - Ok(()) - } - - /// Get cmdline parameters value from param field name. - /// - /// # Arguments - /// - /// * `param_field`: The cmdline parameter field name. - pub fn get_value(&self, param_field: &str) -> Result> { - match self.params.get(param_field) { - Some(value) => { - let field_msg = if param_field.is_empty() { - &self.name - } else { - param_field - }; - - if let Some(raw_value) = value { - Ok(Some(raw_value.parse().map_err(|_| { - anyhow!(ConfigError::ConvertValueFailed( - field_msg.to_string(), - raw_value.clone() - )) - })?)) - } else { - Ok(None) - } - } - None => Ok(None), - } - } -} - /// This struct is a wrapper for `bool`. /// More switch string can be transferred to this structure. pub struct ExBool { @@ -619,13 +444,13 @@ impl From for bool { pub fn parse_bool(s: &str) -> Result { match s { - "on" => Ok(true), - "off" => Ok(false), + "true" | "on" | "yes" | "unmap" => Ok(true), + "false" | "off" | "no" | "ignore" => Ok(false), _ => Err(anyhow!("Unknow bool value {s}")), } } -fn enable_trace_state(path: &str) -> Result<()> { +fn enable_trace_state_from_file(path: &str) -> Result<()> { let mut file = File::open(path).with_context(|| format!("Failed to open {}", path))?; let mut buf = String::new(); file.read_to_string(&mut buf) @@ -644,15 +469,41 @@ fn enable_trace_state(path: &str) -> Result<()> { Ok(()) } -pub fn parse_trace_options(opt: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("trace"); - cmd_parser.push("file"); - cmd_parser.get_parameters(opt)?; +fn enable_trace_state_from_type(type_str: &str) -> Result<()> { + match type_str { + "events" => enable_state_by_type(TraceType::Event)?, + "scopes" => enable_state_by_type(TraceType::Scope)?, + "all" => { + enable_state_by_type(TraceType::Event)?; + enable_state_by_type(TraceType::Scope)?; + } + _ => bail!("Unknown trace type {}", type_str), + }; + + Ok(()) +} + +#[derive(Parser)] +#[command(no_binary_name(true))] +struct TraceConfig { + #[arg(long)] + file: Option, + #[arg(long, alias = "type")] + type_str: Option, +} - let path = cmd_parser - .get_value::("file")? - .with_context(|| "trace: trace file must be set.")?; - enable_trace_state(&path)?; +pub fn add_trace(opt: &str) -> Result<()> { + let trace_cfg = TraceConfig::try_parse_from(str_slip_to_clap(opt, false, false))?; + if trace_cfg.type_str.is_none() && trace_cfg.file.is_none() { + bail!("No type or file after -trace"); + } + + if let Some(type_str) = trace_cfg.type_str { + enable_trace_state_from_type(&type_str)?; + } + if let Some(file) = trace_cfg.file { + enable_trace_state_from_file(&file)?; + } Ok(()) } @@ -673,7 +524,7 @@ impl FromStr for UnsignedInteger { pub struct IntegerList(pub Vec); impl FromStr for IntegerList { - type Err = (); + type Err = anyhow::Error; fn from_str(s: &str) -> std::result::Result { let mut integer_list = Vec::new(); @@ -685,19 +536,22 @@ impl FromStr for IntegerList { for list in lists.iter() { let items: Vec<&str> = list.split('-').collect(); if items.len() > 2 { - return Err(()); + return Err(anyhow!( + "{} parameters connected by -, should be no more than 2.", + items.len() + )); } let start = items[0] .parse::() - .map_err(|e| error!("Invalid value {}, error is {:?}", items[0], e))?; + .map_err(|e| anyhow!("Invalid value {}, error is {:?}", items[0], e))?; integer_list.push(start); if items.len() == 2 { let end = items[1] .parse::() - .map_err(|e| error!("Invalid value {}, error is {:?}", items[1], e))?; + .map_err(|e| anyhow!("Invalid value {}, error is {:?}", items[1], e))?; if start >= end { - return Err(()); + return Err(anyhow!("start {} is bigger than end {}.", start, end)); } for i in start..end { @@ -730,128 +584,213 @@ pub fn check_path_too_long(arg: &str, name: &str) -> Result<()> { Ok(()) } -pub fn check_arg_nonexist(arg: Option, name: &str, device: &str) -> Result<()> { - arg.with_context(|| ConfigError::FieldIsMissing(name.to_string(), device.to_string()))?; +/// Make sure args are existed. +/// +/// arg_name: Name of arg. +/// arg_value: Value of arg. Should be Option<> class. +/// Eg: +/// check_arg_exist!(("id", id)); +/// check_arg_exist!(("bus", bus), ("addr", addr)); +#[macro_export] +macro_rules! check_arg_exist{ + ($(($arg_name:tt, $arg_value:expr)),*) => { + $($arg_value.clone().with_context(|| format!("Should set {}.", $arg_name))?;)* + } +} - Ok(()) +/// Make sure args are existed. +/// +/// arg_name: Name of arg. +/// arg_value: Value of arg. Should be Option<> class. +/// Eg: +/// check_arg_nonexist!(("id", id)); +/// check_arg_nonexist!(("bus", bus), ("addr", addr)); +#[macro_export] +macro_rules! check_arg_nonexist{ + ($(($arg_name:tt, $arg_value:expr)),*) => { + $($arg_value.clone().map_or(Some(0), |_| None).with_context(|| format!("Should not set {}", $arg_name))?;)* + } +} + +fn concat_classtype(args: &str, concat: bool) -> String { + if concat { + format!("classtype={}", args) + } else { + args.to_string() + } } /// Configure StratoVirt parameters in clap format. -pub fn str_slip_to_clap(args: &str) -> Vec { - let args_vecs = args.split([',', '=']).collect::>(); - let mut itr: Vec = Vec::with_capacity(args_vecs.len()); - for (cnt, param) in args_vecs.iter().enumerate() { - if cnt % 2 == 1 { - itr.push(format!("--{}", param)); - } else { - itr.push(param.to_string()); +/// +/// The first parameter will be parsed as the `binary name` unless Command::no_binary_name is used when using `clap`. +/// Stratovirt command line may use the first parameter as class type. +/// Eg: +/// 1. drive config: "-drive file=,if=pflash,unit=0" +/// This cmdline has no class type. +/// 2. device config: "-device virtio-balloon-pci,id=,bus=,addr=<0x4>" +/// This cmdline sets device type `virtio-balloon-pci` as the first parameter. +/// +/// Use first_pos_is_type to indicate whether the first parameter is a type class which needs a separate analysis. +/// Eg: +/// 1. drive config: "-drive file=,if=pflash,unit=0" +/// Set first_pos_is_type false for this cmdline has no class type. +/// 2. device config: "-device virtio-balloon-pci,id=,bus=,addr=<0x4>" +/// Set first_pos_is_type true for this cmdline has device type "virtio-balloon-pci" as the first parameter. +/// +/// Use first_pos_is_subcommand to indicate whether the first parameter is a subclass. +/// Eg: +/// Chardev has stdio/unix-socket/tcp-socket/pty/file classes. These classes have different configurations but will be stored +/// in the same `ChardevConfig` structure by using `enum`. So, we will use class type as a subcommand to indicate which subtype +/// will be used to store the configuration in enumeration type. Subcommand in `clap` doesn't need `--` in parameter. +/// 1. -serial file,path= +/// Set first_pos_is_subcommand true for first parameter `file` is the subclass type for chardev. +pub fn str_slip_to_clap( + args: &str, + first_pos_is_type: bool, + first_pos_is_subcommand: bool, +) -> Vec { + let mut subcommand = first_pos_is_subcommand; + let args_str = concat_classtype(args, first_pos_is_type && !subcommand); + let args_vecs = args_str.split([',']).collect::>(); + let mut itr: Vec = Vec::with_capacity(args_vecs.len() * 2); + for params in args_vecs { + let key_value = params.split(['=']).collect::>(); + // Command line like "key=value" will be converted to "--key value". + // Command line like "key" will be converted to "--key". + for (cnt, param) in key_value.iter().enumerate() { + if cnt % 2 == 0 { + if subcommand { + itr.push(param.to_string()); + subcommand = false; + } else { + itr.push(format!("--{}", param)); + } + } else { + itr.push(param.to_string()); + } } } itr } +/// Retrieve the value of the specified parameter from a string in the format "key=value". +pub fn get_value_of_parameter(parameter: &str, args_str: &str) -> Result { + let args_vecs = args_str.split([',']).collect::>(); + + for args in args_vecs { + let key_value = args.split(['=']).collect::>(); + if key_value.len() != 2 || key_value[0] != parameter { + continue; + } + if key_value[1].is_empty() { + bail!("Find empty arg {} in string {}.", key_value[0], args_str); + } + return Ok(key_value[1].to_string()); + } + + bail!("Cannot find {}'s value from string {}", parameter, args_str); +} + +pub fn get_class_type(args: &str) -> Result { + let args_str = concat_classtype(args, true); + get_value_of_parameter("classtype", &args_str) +} + pub fn valid_id(id: &str) -> Result { check_arg_too_long(id, "id")?; Ok(id.to_string()) } +// Virtio queue size must be power of 2 and in range [min_size, max_size]. +pub fn valid_virtqueue_size(size: u64, min_size: u64, max_size: u64) -> Result<()> { + if size < min_size || size > max_size { + return Err(anyhow!(ConfigError::IllegalValue( + "virtqueue size".to_string(), + min_size, + true, + max_size, + true + ))); + } + + if size & (size - 1) != 0 { + bail!("Virtqueue size should be power of 2!"); + } + + Ok(()) +} + +pub fn valid_path(path: &str) -> Result { + if path.len() > MAX_PATH_LENGTH { + return Err(anyhow!(ConfigError::StringLengthTooLong( + "path".to_string(), + MAX_PATH_LENGTH, + ))); + } + + let canonical_path = canonicalize(path).map_or(path.to_string(), |pathbuf| { + String::from(pathbuf.to_str().unwrap()) + }); + + Ok(canonical_path) +} + +pub fn valid_socket_path(sock_path: &str) -> Result { + if sock_path.len() > MAX_SOCK_PATH_LENGTH { + return Err(anyhow!(ConfigError::StringLengthTooLong( + "socket path".to_string(), + MAX_SOCK_PATH_LENGTH, + ))); + } + valid_path(sock_path) +} + +pub fn valid_dir(d: &str) -> Result { + let dir = String::from(d); + if !Path::new(&dir).is_dir() { + return Err(anyhow!(ConfigError::DirNotExist(dir))); + } + Ok(dir) +} + +pub fn valid_block_device_virtqueue_size(s: &str) -> Result { + let size: u64 = s.parse()?; + valid_virtqueue_size( + size, + MIN_QUEUE_SIZE_BLOCK_DEVICE + 1, + MAX_QUEUE_SIZE_BLOCK_DEVICE, + )?; + + Ok(size as u16) +} + +pub fn parse_size(s: &str) -> Result { + let size = memory_unit_conversion(s, M).with_context(|| format!("Invalid size: {}", s))?; + Ok(size) +} + #[cfg(test)] mod tests { use super::*; #[test] - fn test_cmd_parser() { - let mut cmd_parser = CmdParser::new("test"); - cmd_parser - .push("") - .push("id") - .push("path") - .push("num") - .push("test1") - .push("test2") - .push("test3") - .push("test4") - .push("test5") - .push("test6") - .push("test7"); - assert!(cmd_parser - .parse("socket,id=charconsole0,path=/tmp/console.sock,num=1,test1=true,test2=on,test3=yes,test4=false,test5=off,test6=no,test7=random") - .is_ok()); - assert_eq!( - cmd_parser.get_value::("").unwrap().unwrap(), - "socket".to_string() - ); - assert_eq!( - cmd_parser.get_value::("id").unwrap().unwrap(), - "charconsole0".to_string() - ); - assert_eq!( - cmd_parser.get_value::("path").unwrap().unwrap(), - "/tmp/console.sock".to_string() - ); - assert_eq!(cmd_parser.get_value::("num").unwrap().unwrap(), 1_u64); - assert_eq!(cmd_parser.get_value::("num").unwrap().unwrap(), 1_u32); - assert_eq!(cmd_parser.get_value::("num").unwrap().unwrap(), 1_u16); - assert_eq!(cmd_parser.get_value::("num").unwrap().unwrap(), 1_u8); - assert_eq!(cmd_parser.get_value::("num").unwrap().unwrap(), 1_i64); - assert_eq!(cmd_parser.get_value::("num").unwrap().unwrap(), 1_i32); - assert_eq!(cmd_parser.get_value::("num").unwrap().unwrap(), 1_i16); - assert_eq!(cmd_parser.get_value::("num").unwrap().unwrap(), 1_i8); - assert!(cmd_parser.get_value::("test1").unwrap().unwrap()); - assert!( - cmd_parser - .get_value::("test1") - .unwrap() - .unwrap() - .inner - ); - assert!( - cmd_parser - .get_value::("test2") - .unwrap() - .unwrap() - .inner - ); - assert!( - cmd_parser - .get_value::("test3") - .unwrap() - .unwrap() - .inner - ); - assert!(!cmd_parser.get_value::("test4").unwrap().unwrap()); - assert!( - !cmd_parser - .get_value::("test4") - .unwrap() - .unwrap() - .inner - ); - assert!( - !cmd_parser - .get_value::("test5") - .unwrap() - .unwrap() - .inner - ); - assert!( - !cmd_parser - .get_value::("test6") - .unwrap() - .unwrap() - .inner - ); - assert!(cmd_parser.get_value::("test7").is_err()); - assert!(cmd_parser.get_value::("test7").is_err()); - assert!(cmd_parser.get_value::("random").unwrap().is_none()); - assert!(cmd_parser.parse("random=false").is_err()); - } + fn test_add_trace() { + assert!(std::fs::File::create("/tmp/trace_file").is_ok()); - #[test] - fn test_parse_trace_options() { - assert!(parse_trace_options("fil=test_trace").is_err()); - assert!(parse_trace_options("file").is_err()); - assert!(parse_trace_options("file=test_trace").is_err()); + assert!(add_trace("file=/tmp/trace_file,type=all").is_ok()); + assert!(add_trace("fil=test_trace").is_err()); + assert!(add_trace("file").is_err()); + assert!(add_trace("file=test_trace").is_err()); + + assert!(add_trace("type=events").is_ok()); + assert!(add_trace("type=scopes").is_ok()); + assert!(add_trace("type=all").is_ok()); + assert!(add_trace("type=xxxxx").is_err()); + + assert!(add_trace("").is_err()); + assert!(add_trace("file=/tmp/trace_file,type=all").is_ok()); + + assert!(std::fs::remove_file("/tmp/trace_file").is_ok()); } #[test] @@ -886,4 +825,20 @@ mod tests { let res = vm_config.add_global_config("pcie-root-port.fast-unplug=1"); assert!(res.is_err()); } + + #[test] + fn test_get_value_of_parameter() { + let cmd = "scsi-hd,id=disk1,drive=scsi-drive-0"; + let id = get_value_of_parameter("id", cmd).unwrap(); + assert_eq!(id, "disk1"); + + let cmd = "id="; + assert!(get_value_of_parameter("id", cmd).is_err()); + + let cmd = "id"; + assert!(get_value_of_parameter("id", cmd).is_err()); + + let cmd = "scsi-hd,idxxx=disk1"; + assert!(get_value_of_parameter("id", cmd).is_err()); + } } diff --git a/machine_manager/src/config/network.rs b/machine_manager/src/config/network.rs index f7d79a13b83435f76684873b12af3d6f94f2a558..30f1ec807dd55b040b179368df9d941dbec059f2 100644 --- a/machine_manager/src/config/network.rs +++ b/machine_manager/src/config/network.rs @@ -13,40 +13,95 @@ use std::os::unix::io::RawFd; use anyhow::{anyhow, bail, Context, Result}; +use clap::{ArgAction, Parser}; use serde::{Deserialize, Serialize}; -use super::{error::ConfigError, pci_args_check}; -use crate::config::get_chardev_socket_path; -use crate::config::{ - check_arg_too_long, CmdParser, ConfigCheck, ExBool, VmConfig, DEFAULT_VIRTQUEUE_SIZE, - MAX_PATH_LENGTH, MAX_VIRTIO_QUEUE, -}; +use super::error::ConfigError; +use super::{get_pci_df, parse_bool, str_slip_to_clap, valid_id, valid_virtqueue_size}; +use crate::config::{ConfigCheck, VmConfig, DEFAULT_VIRTQUEUE_SIZE, MAX_VIRTIO_QUEUE}; use crate::qmp::{qmp_channel::QmpChannel, qmp_schema}; const MAC_ADDRESS_LENGTH: usize = 17; /// Max virtqueue size of each virtqueue. -pub const MAX_QUEUE_SIZE_NET: u16 = 4096; +const MAX_QUEUE_SIZE_NET: u64 = 4096; /// Max num of virtqueues. const MAX_QUEUE_PAIRS: usize = MAX_VIRTIO_QUEUE / 2; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Parser, Debug, Clone, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct NetDevcfg { + #[arg(long, alias="classtype", value_parser = ["tap", "vhost-user"])] + pub netdev_type: String, + #[arg(long, value_parser = valid_id)] pub id: String, + #[arg(long, aliases = ["fds", "fd"], use_value_delimiter = true, value_delimiter = ':')] pub tap_fds: Option>, - pub vhost_type: Option, + #[arg(long, alias = "vhost", default_value = "off", value_parser = parse_bool, action = ArgAction::Append)] + pub vhost_kernel: bool, + #[arg(long, aliases = ["vhostfds", "vhostfd"], use_value_delimiter = true, value_delimiter = ':')] pub vhost_fds: Option>, + #[arg(long, default_value = "", value_parser = valid_id)] pub ifname: String, + #[arg(long, default_value = "1", value_parser = parse_queues)] pub queues: u16, + #[arg(long)] pub chardev: Option, } +impl NetDevcfg { + pub fn vhost_type(&self) -> Option { + if self.vhost_kernel { + return Some("vhost-kernel".to_string()); + } + if self.netdev_type == "vhost-user" { + return Some("vhost-user".to_string()); + } + // Default: virtio net. + None + } + + fn auto_queues(&mut self) -> Result<()> { + if let Some(fds) = &self.tap_fds { + let fds_num = fds + .len() + .checked_mul(2) + .with_context(|| format!("Invalid fds number {}", fds.len()))? + as u16; + if fds_num > self.queues { + self.queues = fds_num; + } + } + if let Some(fds) = &self.vhost_fds { + let fds_num = fds + .len() + .checked_mul(2) + .with_context(|| format!("Invalid vhostfds number {}", fds.len()))? + as u16; + if fds_num > self.queues { + self.queues = fds_num; + } + } + Ok(()) + } +} + +fn parse_queues(q: &str) -> Result { + let queues = q + .parse::()? + .checked_mul(2) + .with_context(|| "Invalid 'queues' value")?; + is_netdev_queues_valid(queues)?; + Ok(queues) +} + impl Default for NetDevcfg { fn default() -> Self { NetDevcfg { + netdev_type: "".to_string(), id: "".to_string(), tap_fds: None, - vhost_type: None, + vhost_kernel: false, vhost_fds: None, ifname: "".to_string(), queues: 2, @@ -57,24 +112,18 @@ impl Default for NetDevcfg { impl ConfigCheck for NetDevcfg { fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "id")?; - check_arg_too_long(&self.ifname, "ifname")?; - - if let Some(vhost_type) = self.vhost_type.as_ref() { - if vhost_type != "vhost-kernel" && vhost_type != "vhost-user" { - return Err(anyhow!(ConfigError::UnknownVhostType)); - } + if self.vhost_kernel && self.netdev_type == "vhost-user" { + bail!("vhost-user netdev does not support 'vhost' option"); } - if !is_netdev_queues_valid(self.queues) { - return Err(anyhow!(ConfigError::IllegalValue( - "number queues of net device".to_string(), - 1, - true, - MAX_VIRTIO_QUEUE as u64 / 2, - true, - ))); + if self.vhost_fds.is_some() && self.vhost_type().is_none() { + bail!("Argument 'vhostfd' or 'vhostfds' are not needed for virtio-net device"); } + if self.tap_fds.is_none() && self.ifname.eq("") && self.netdev_type.ne("vhost-user") { + bail!("Tap device is missing, use \'ifname\' or \'fd\' to configure a tap device"); + } + + is_netdev_queues_valid(self.queues)?; Ok(()) } @@ -82,225 +131,89 @@ impl ConfigCheck for NetDevcfg { /// Config struct for network /// Contains network device config, such as `host_dev_name`, `mac`... -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Parser)] #[serde(deny_unknown_fields)] +#[command(no_binary_name(true))] pub struct NetworkInterfaceConfig { + #[arg(long, value_parser = ["virtio-net-pci", "virtio-net-device"])] + pub classtype: String, + #[arg(long, default_value = "", value_parser = valid_id)] pub id: String, - pub host_dev_name: String, + #[arg(long)] + pub netdev: String, + #[arg(long)] + pub bus: Option, + #[arg(long, value_parser = get_pci_df)] + pub addr: Option<(u8, u8)>, + #[arg(long, value_parser = parse_bool, action = ArgAction::Append)] + pub multifunction: Option, + #[arg(long, value_parser = valid_mac)] pub mac: Option, - pub tap_fds: Option>, - pub vhost_type: Option, - pub vhost_fds: Option>, + #[arg(long)] pub iothread: Option, - pub queues: u16, + #[arg(long)] + pub rx_iothread: Option, + #[arg(long)] + pub tx_iothread: Option, + #[arg(long, default_value="off", value_parser = parse_bool, action = ArgAction::Append)] pub mq: bool, - pub socket_path: Option, - /// All queues of a net device have the same queue size now. + // All queues of a net device have the same queue size now. + #[arg(long, default_value = "256", alias = "queue-size", value_parser = valid_network_queue_size)] pub queue_size: u16, + // MSI-X vectors the this network device has. This member isn't used now in stratovirt. + #[arg(long, default_value = "0")] + pub vectors: u16, } impl Default for NetworkInterfaceConfig { fn default() -> Self { NetworkInterfaceConfig { + classtype: "".to_string(), id: "".to_string(), - host_dev_name: "".to_string(), + netdev: "".to_string(), + bus: None, + addr: None, + multifunction: None, mac: None, - tap_fds: None, - vhost_type: None, - vhost_fds: None, iothread: None, - queues: 2, + rx_iothread: None, + tx_iothread: None, mq: false, - socket_path: None, queue_size: DEFAULT_VIRTQUEUE_SIZE, + vectors: 0, } } } -impl ConfigCheck for NetworkInterfaceConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "id")?; - check_arg_too_long(&self.host_dev_name, "host dev name")?; - - if self.mac.is_some() && !check_mac_address(self.mac.as_ref().unwrap()) { - return Err(anyhow!(ConfigError::MacFormatError)); +impl NetworkInterfaceConfig { + pub fn auto_iothread(&mut self) { + // If rx_iothread or tx_iothread is not configured, the default iothread will be used. + if self.rx_iothread.is_none() { + self.rx_iothread.clone_from(&self.iothread); } - - if self.iothread.is_some() { - check_arg_too_long(self.iothread.as_ref().unwrap(), "iothread name")?; + if self.tx_iothread.is_none() { + self.tx_iothread.clone_from(&self.iothread); } - - if self.socket_path.is_some() && self.socket_path.as_ref().unwrap().len() > MAX_PATH_LENGTH - { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "socket path".to_string(), - MAX_PATH_LENGTH - ))); - } - - if self.queue_size < DEFAULT_VIRTQUEUE_SIZE || self.queue_size > MAX_QUEUE_SIZE_NET { - return Err(anyhow!(ConfigError::IllegalValue( - "queue size of net device".to_string(), - DEFAULT_VIRTQUEUE_SIZE as u64, - true, - MAX_QUEUE_SIZE_NET as u64, - true - ))); - } - - if self.queue_size & (self.queue_size - 1) != 0 { - bail!("queue size of net device should be power of 2!"); - } - - Ok(()) - } -} - -fn parse_fds(cmd_parser: &CmdParser, name: &str) -> Result>> { - if let Some(fds) = cmd_parser.get_value::(name)? { - let mut raw_fds = Vec::new(); - for fd in fds.split(':').collect::>().iter() { - raw_fds.push( - (*fd) - .parse::() - .with_context(|| "Failed to parse fds")?, - ); - } - Ok(Some(raw_fds)) - } else { - Ok(None) } } -fn parse_netdev(cmd_parser: CmdParser) -> Result { - let mut net = NetDevcfg::default(); - let netdev_type = cmd_parser.get_value::("")?.unwrap_or_default(); - if netdev_type.ne("tap") && netdev_type.ne("vhost-user") { - bail!("Unsupported netdev type: {:?}", &netdev_type); - } - net.id = cmd_parser - .get_value::("id")? - .with_context(|| ConfigError::FieldIsMissing("id".to_string(), "netdev".to_string()))?; - if let Some(ifname) = cmd_parser.get_value::("ifname")? { - net.ifname = ifname; - } - if let Some(queue_pairs) = cmd_parser.get_value::("queues")? { - let queues = queue_pairs.checked_mul(2); - if queues.is_none() || !is_netdev_queues_valid(queues.unwrap()) { - return Err(anyhow!(ConfigError::IllegalValue( - "number queues of net device".to_string(), - 1, - true, - MAX_VIRTIO_QUEUE as u64 / 2, - true, - ))); - } - - net.queues = queues.unwrap(); - } +fn valid_network_queue_size(s: &str) -> Result { + let size: u64 = s.parse()?; + valid_virtqueue_size(size, u64::from(DEFAULT_VIRTQUEUE_SIZE), MAX_QUEUE_SIZE_NET)?; - if let Some(tap_fd) = parse_fds(&cmd_parser, "fd")? { - net.tap_fds = Some(tap_fd); - } else if let Some(tap_fds) = parse_fds(&cmd_parser, "fds")? { - net.tap_fds = Some(tap_fds); - } - if let Some(fds) = &net.tap_fds { - let fds_num = - fds.len() - .checked_mul(2) - .with_context(|| format!("Invalid fds number {}", fds.len()))? as u16; - if fds_num > net.queues { - net.queues = fds_num; - } - } + Ok(size as u16) +} - if let Some(vhost) = cmd_parser.get_value::("vhost")? { - if vhost.into() { - net.vhost_type = Some(String::from("vhost-kernel")); - } - } else if netdev_type.eq("vhost-user") { - net.vhost_type = Some(String::from("vhost-user")); - } - if let Some(chardev) = cmd_parser.get_value::("chardev")? { - net.chardev = Some(chardev); - } - if let Some(vhost_fd) = parse_fds(&cmd_parser, "vhostfd")? { - net.vhost_fds = Some(vhost_fd); - } else if let Some(vhost_fds) = parse_fds(&cmd_parser, "vhostfds")? { - net.vhost_fds = Some(vhost_fds); - } - if let Some(fds) = &net.vhost_fds { - let fds_num = fds - .len() - .checked_mul(2) - .with_context(|| format!("Invalid vhostfds number {}", fds.len()))? - as u16; - if fds_num > net.queues { - net.queues = fds_num; +impl ConfigCheck for NetworkInterfaceConfig { + fn check(&self) -> Result<()> { + if self.mac.is_some() && !check_mac_address(self.mac.as_ref().unwrap()) { + return Err(anyhow!(ConfigError::MacFormatError)); } - } - - if net.vhost_fds.is_some() && net.vhost_type.is_none() { - bail!("Argument \'vhostfd\' is not needed for virtio-net device"); - } - if net.tap_fds.is_none() && net.ifname.eq("") && netdev_type.ne("vhost-user") { - bail!("Tap device is missing, use \'ifname\' or \'fd\' to configure a tap device"); - } - net.check()?; + valid_network_queue_size(&self.queue_size.to_string())?; - Ok(net) -} - -pub fn parse_net(vm_config: &mut VmConfig, net_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("virtio-net"); - cmd_parser - .push("") - .push("id") - .push("netdev") - .push("mq") - .push("vectors") - .push("bus") - .push("addr") - .push("multifunction") - .push("mac") - .push("iothread") - .push("queue-size"); - - cmd_parser.parse(net_config)?; - pci_args_check(&cmd_parser)?; - let mut netdevinterfacecfg = NetworkInterfaceConfig::default(); - - let netdev = cmd_parser - .get_value::("netdev")? - .with_context(|| ConfigError::FieldIsMissing("netdev".to_string(), "net".to_string()))?; - let netid = cmd_parser.get_value::("id")?.unwrap_or_default(); - - if let Some(mq) = cmd_parser.get_value::("mq")? { - netdevinterfacecfg.mq = mq.inner; - } - netdevinterfacecfg.iothread = cmd_parser.get_value::("iothread")?; - netdevinterfacecfg.mac = cmd_parser.get_value::("mac")?; - if let Some(queue_size) = cmd_parser.get_value::("queue-size")? { - netdevinterfacecfg.queue_size = queue_size; - } - - let netcfg = &vm_config - .netdevs - .remove(&netdev) - .with_context(|| format!("Netdev: {:?} not found for net device", &netdev))?; - netdevinterfacecfg.id = netid; - netdevinterfacecfg.host_dev_name = netcfg.ifname.clone(); - netdevinterfacecfg.tap_fds = netcfg.tap_fds.clone(); - netdevinterfacecfg.vhost_fds = netcfg.vhost_fds.clone(); - netdevinterfacecfg.vhost_type = netcfg.vhost_type.clone(); - netdevinterfacecfg.queues = netcfg.queues; - if let Some(chardev) = &netcfg.chardev { - netdevinterfacecfg.socket_path = Some(get_chardev_socket_path(chardev, vm_config)?); + Ok(()) } - - netdevinterfacecfg.check()?; - Ok(netdevinterfacecfg) } fn get_netdev_fd(fd_name: &str) -> Result { @@ -337,17 +250,12 @@ pub fn get_netdev_config(args: Box) -> Result) -> Result Result<()> { - let mut cmd_parser = CmdParser::new("netdev"); - cmd_parser - .push("") - .push("id") - .push("fd") - .push("fds") - .push("vhost") - .push("ifname") - .push("vhostfd") - .push("vhostfds") - .push("queues") - .push("chardev"); - - cmd_parser.parse(netdev_config)?; - let drive_cfg = parse_netdev(cmd_parser)?; - self.add_netdev_with_config(drive_cfg) + let mut netdev_cfg = + NetDevcfg::try_parse_from(str_slip_to_clap(netdev_config, true, false))?; + netdev_cfg.auto_queues()?; + netdev_cfg.check()?; + self.add_netdev_with_config(netdev_cfg) } pub fn add_netdev_with_config(&mut self, conf: NetDevcfg) -> Result<()> { let netdev_id = conf.id.clone(); - if self.netdevs.get(&netdev_id).is_none() { - self.netdevs.insert(netdev_id, conf); - } else { + if self.netdevs.contains_key(&netdev_id) { bail!("Netdev {:?} has been added", netdev_id); } + self.netdevs.insert(netdev_id, conf); Ok(()) } pub fn del_netdev_by_id(&mut self, id: &str) -> Result<()> { - if self.netdevs.get(id).is_some() { - self.netdevs.remove(id); - } else { - bail!("Netdev {} not found", id); - } + self.netdevs + .remove(id) + .with_context(|| format!("Netdev {} not found", id))?; + Ok(()) } } +fn valid_mac(mac: &str) -> Result { + if !check_mac_address(mac) { + return Err(anyhow!(ConfigError::MacFormatError)); + } + Ok(mac.to_string()) +} + fn check_mac_address(mac: &str) -> bool { if mac.len() != MAC_ADDRESS_LENGTH { return false; @@ -485,207 +369,136 @@ fn check_mac_address(mac: &str) -> bool { true } -fn is_netdev_queues_valid(queues: u16) -> bool { - queues >= 1 && queues <= MAX_VIRTIO_QUEUE as u16 +fn is_netdev_queues_valid(queues: u16) -> Result<()> { + if !(queues >= 2 && queues <= MAX_VIRTIO_QUEUE as u16) { + return Err(anyhow!(ConfigError::IllegalValue( + "number queues of net device".to_string(), + 1, + true, + MAX_QUEUE_PAIRS as u64, + true, + ))); + } + + Ok(()) } #[cfg(test)] mod tests { use super::*; - use crate::config::{get_pci_bdf, MAX_STRING_LENGTH}; #[test] - fn test_network_config_cmdline_parser() { + fn test_netdev_config_cmdline_parser() { let mut vm_config = VmConfig::default(); + + // Test1: Right. assert!(vm_config.add_netdev("tap,id=eth0,ifname=tap0").is_ok()); - let net_cfg_res = parse_net( - &mut vm_config, - "virtio-net-device,id=net0,netdev=eth0,iothread=iothread0", - ); - assert!(net_cfg_res.is_ok()); - let network_configs = net_cfg_res.unwrap(); - assert_eq!(network_configs.id, "net0"); - assert_eq!(network_configs.host_dev_name, "tap0"); - assert_eq!(network_configs.iothread, Some("iothread0".to_string())); - assert!(network_configs.mac.is_none()); - assert!(network_configs.tap_fds.is_none()); - assert!(network_configs.vhost_type.is_none()); - assert!(network_configs.vhost_fds.is_none()); + assert!(vm_config.add_netdev("tap,id=eth0,ifname=tap0").is_err()); + let netdev_cfg = vm_config.netdevs.get("eth0").unwrap(); + assert_eq!(netdev_cfg.id, "eth0"); + assert_eq!(netdev_cfg.ifname, "tap0"); + assert!(netdev_cfg.tap_fds.is_none()); + assert!(!netdev_cfg.vhost_kernel); + assert!(netdev_cfg.vhost_fds.is_none()); + assert_eq!(netdev_cfg.queues, 2); + assert!(netdev_cfg.vhost_type().is_none()); - let mut vm_config = VmConfig::default(); assert!(vm_config .add_netdev("tap,id=eth1,ifname=tap1,vhost=on,vhostfd=4") .is_ok()); - let net_cfg_res = parse_net( - &mut vm_config, - "virtio-net-device,id=net1,netdev=eth1,mac=12:34:56:78:9A:BC", - ); - assert!(net_cfg_res.is_ok()); - let network_configs = net_cfg_res.unwrap(); - assert_eq!(network_configs.id, "net1"); - assert_eq!(network_configs.host_dev_name, "tap1"); - assert_eq!(network_configs.mac, Some(String::from("12:34:56:78:9A:BC"))); - assert!(network_configs.tap_fds.is_none()); - assert_eq!( - network_configs.vhost_type, - Some(String::from("vhost-kernel")) - ); - assert_eq!(network_configs.vhost_fds, Some(vec![4])); + let netdev_cfg = vm_config.netdevs.get("eth1").unwrap(); + assert_eq!(netdev_cfg.ifname, "tap1"); + assert_eq!(netdev_cfg.vhost_type().unwrap(), "vhost-kernel"); + assert_eq!(netdev_cfg.vhost_fds, Some(vec![4])); - let mut vm_config = VmConfig::default(); - assert!(vm_config.add_netdev("tap,id=eth1,fd=35").is_ok()); - let net_cfg_res = parse_net(&mut vm_config, "virtio-net-device,id=net1,netdev=eth1"); - assert!(net_cfg_res.is_ok()); - let network_configs = net_cfg_res.unwrap(); - assert_eq!(network_configs.id, "net1"); - assert_eq!(network_configs.host_dev_name, ""); - assert_eq!(network_configs.tap_fds, Some(vec![35])); + assert!(vm_config.add_netdev("tap,id=eth2,fd=35").is_ok()); + let netdev_cfg = vm_config.netdevs.get("eth2").unwrap(); + assert_eq!(netdev_cfg.tap_fds, Some(vec![35])); - let mut vm_config = VmConfig::default(); assert!(vm_config - .add_netdev("tap,id=eth1,ifname=tap1,vhost=on,vhostfd=4") + .add_netdev("tap,id=eth3,ifname=tap0,queues=4") .is_ok()); - let net_cfg_res = parse_net( - &mut vm_config, - "virtio-net-device,id=net1,netdev=eth2,mac=12:34:56:78:9A:BC", - ); - assert!(net_cfg_res.is_err()); + let netdev_cfg = vm_config.netdevs.get("eth3").unwrap(); + assert_eq!(netdev_cfg.queues, 8); - let mut vm_config = VmConfig::default(); - assert!(vm_config.add_netdev("tap,id=eth1,fd=35").is_ok()); - let net_cfg_res = parse_net(&mut vm_config, "virtio-net-device,id=net1,netdev=eth3"); - assert!(net_cfg_res.is_err()); - - // multi queue testcases - let mut vm_config = VmConfig::default(); assert!(vm_config - .add_netdev("tap,id=eth0,ifname=tap0,queues=4") + .add_netdev("tap,id=eth4,fds=34:35:36:37:38") .is_ok()); - let net_cfg_res = parse_net( - &mut vm_config, - "virtio-net-device,id=net0,netdev=eth0,iothread=iothread0,mq=on,vectors=6", - ); - assert!(net_cfg_res.is_ok()); - let network_configs = net_cfg_res.unwrap(); - assert_eq!(network_configs.queues, 8); - assert_eq!(network_configs.mq, true); + let netdev_cfg = vm_config.netdevs.get("eth4").unwrap(); + assert_eq!(netdev_cfg.queues, 10); + assert_eq!(netdev_cfg.tap_fds, Some(vec![34, 35, 36, 37, 38])); - let mut vm_config = VmConfig::default(); assert!(vm_config - .add_netdev("tap,id=eth0,fds=34:35:36:37:38") + .add_netdev("tap,id=eth5,fds=34:35:36:37:38,vhost=on,vhostfds=39:40:41:42:43") .is_ok()); - let net_cfg_res = parse_net( - &mut vm_config, - "virtio-net-device,id=net0,netdev=eth0,iothread=iothread0,mq=off,vectors=12", - ); - assert!(net_cfg_res.is_ok()); - let network_configs = net_cfg_res.unwrap(); - assert_eq!(network_configs.queues, 10); - assert_eq!(network_configs.tap_fds, Some(vec![34, 35, 36, 37, 38])); - assert_eq!(network_configs.mq, false); + let netdev_cfg = vm_config.netdevs.get("eth5").unwrap(); + assert_eq!(netdev_cfg.queues, 10); + assert_eq!(netdev_cfg.vhost_fds, Some(vec![39, 40, 41, 42, 43])); - let mut vm_config = VmConfig::default(); + // Test2: Missing values assert!(vm_config - .add_netdev("tap,id=eth0,fds=34:35:36:37:38,vhost=on,vhostfds=39:40:41:42:43") - .is_ok()); - let net_cfg_res = parse_net( - &mut vm_config, - "virtio-net-device,id=net0,netdev=eth0,iothread=iothread0,mq=off,vectors=12", - ); - assert!(net_cfg_res.is_ok()); - let network_configs = net_cfg_res.unwrap(); - assert_eq!(network_configs.queues, 10); - assert_eq!(network_configs.vhost_fds, Some(vec![39, 40, 41, 42, 43])); - assert_eq!(network_configs.mq, false); + .add_netdev("tap,fds=34:35:36:37:38,vhost=on") + .is_err()); + + // Test3: Illegal values. + assert!(vm_config + .add_netdev("tap,id=eth10,fds=34:35:36:37:38,vhost=on,vhostfds=39,40,41,42,43") + .is_err()); + assert!(vm_config.add_netdev("tap,id=eth10,queues=0").is_err()); + assert!(vm_config.add_netdev("tap,id=eth10,queues=17").is_err()); } #[test] - fn test_pci_network_config_cmdline_parser() { + fn test_networkinterface_config_cmdline_parser() { + // Test1: Right. let mut vm_config = VmConfig::default(); - assert!(vm_config .add_netdev("tap,id=eth1,ifname=tap1,vhost=on,vhostfd=4") .is_ok()); + let net_cmd = + "virtio-net-pci,id=net1,netdev=eth1,bus=pcie.0,addr=0x1.0x2,mac=12:34:56:78:9A:BC,mq=on,vectors=6,queue-size=2048,multifunction=on"; let net_cfg = - "virtio-net-pci,id=net1,netdev=eth1,bus=pcie.0,addr=0x1.0x2,mac=12:34:56:78:9A:BC"; - let net_cfg_res = parse_net(&mut vm_config, net_cfg); - assert!(net_cfg_res.is_ok()); - let network_configs = net_cfg_res.unwrap(); - assert_eq!(network_configs.id, "net1"); - assert_eq!(network_configs.host_dev_name, "tap1"); - assert_eq!(network_configs.mac, Some(String::from("12:34:56:78:9A:BC"))); - assert!(network_configs.tap_fds.is_none()); - assert_eq!( - network_configs.vhost_type, - Some(String::from("vhost-kernel")) - ); - assert_eq!(network_configs.vhost_fds.unwrap()[0], 4); - let pci_bdf = get_pci_bdf(net_cfg); - assert!(pci_bdf.is_ok()); - let pci = pci_bdf.unwrap(); - assert_eq!(pci.bus, "pcie.0".to_string()); - assert_eq!(pci.addr, (1, 2)); - - let net_cfg_res = parse_net(&mut vm_config, net_cfg); - assert!(net_cfg_res.is_err()); - + NetworkInterfaceConfig::try_parse_from(str_slip_to_clap(net_cmd, true, false)).unwrap(); + assert_eq!(net_cfg.id, "net1"); + assert_eq!(net_cfg.netdev, "eth1"); + assert_eq!(net_cfg.bus.unwrap(), "pcie.0"); + assert_eq!(net_cfg.addr.unwrap(), (1, 2)); + assert_eq!(net_cfg.mac.unwrap(), "12:34:56:78:9A:BC"); + assert_eq!(net_cfg.vectors, 6); + assert!(net_cfg.mq); + assert_eq!(net_cfg.queue_size, 2048); + assert_eq!(net_cfg.multifunction, Some(true)); + let netdev_cfg = vm_config.netdevs.get(&net_cfg.netdev).unwrap(); + assert_eq!(netdev_cfg.vhost_type().unwrap(), "vhost-kernel"); + + // Test2: Default values. let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_netdev("tap,id=eth1,ifname=tap1,vhost=on,vhostfd=4") - .is_ok()); - let net_cfg = - "virtio-net-pci,id=net1,netdev=eth1,bus=pcie.0,addr=0x1.0x2,mac=12:34:56:78:9A:BC,multifunction=on"; - assert!(parse_net(&mut vm_config, net_cfg).is_ok()); - - // For vhost-user net assert!(vm_config.add_netdev("vhost-user,id=netdevid").is_ok()); - let net_cfg = + let net_cmd = "virtio-net-pci,id=netid,netdev=netdevid,bus=pcie.0,addr=0x2.0x0,mac=12:34:56:78:9A:BC"; - let net_cfg_res = parse_net(&mut vm_config, net_cfg); - assert!(net_cfg_res.is_ok()); - let network_configs = net_cfg_res.unwrap(); - assert_eq!(network_configs.id, "netid"); - assert_eq!(network_configs.vhost_type, Some("vhost-user".to_string())); - assert_eq!(network_configs.mac, Some("12:34:56:78:9A:BC".to_string())); - - assert!(vm_config - .add_netdev("vhost-user,id=netdevid2,chardev=chardevid2") - .is_ok()); let net_cfg = - "virtio-net-pci,id=netid2,netdev=netdevid2,bus=pcie.0,addr=0x2.0x0,mac=12:34:56:78:9A:BC"; - let net_cfg_res = parse_net(&mut vm_config, net_cfg); - assert!(net_cfg_res.is_err()); - } - - #[test] - fn test_netdev_config_check() { - let mut netdev_conf = NetDevcfg::default(); - for _ in 0..MAX_STRING_LENGTH { - netdev_conf.id += "A"; - } - assert!(netdev_conf.check().is_ok()); - - // Overflow - netdev_conf.id += "A"; - assert!(netdev_conf.check().is_err()); - - let mut netdev_conf = NetDevcfg::default(); - for _ in 0..MAX_STRING_LENGTH { - netdev_conf.ifname += "A"; - } - assert!(netdev_conf.check().is_ok()); - - // Overflow - netdev_conf.ifname += "A"; - assert!(netdev_conf.check().is_err()); - - let mut netdev_conf = NetDevcfg::default(); - netdev_conf.vhost_type = None; - assert!(netdev_conf.check().is_ok()); - netdev_conf.vhost_type = Some(String::from("vhost-kernel")); - assert!(netdev_conf.check().is_ok()); - netdev_conf.vhost_type = Some(String::from("vhost-")); - assert!(netdev_conf.check().is_err()); + NetworkInterfaceConfig::try_parse_from(str_slip_to_clap(net_cmd, true, false)).unwrap(); + assert_eq!(net_cfg.queue_size, 256); + assert!(!net_cfg.mq); + assert_eq!(net_cfg.vectors, 0); + let netdev_cfg = vm_config.netdevs.get(&net_cfg.netdev).unwrap(); + assert_eq!(netdev_cfg.vhost_type().unwrap(), "vhost-user"); + + // Test3: Missing Parameters. + let net_cmd = "virtio-net-pci,id=netid"; + let result = NetworkInterfaceConfig::try_parse_from(str_slip_to_clap(net_cmd, true, false)); + assert!(result.is_err()); + + // Test4: Illegal Parameters. + let net_cmd = "virtio-net-pci,id=netid,netdev=netdevid,mac=1:1:1"; + let result = NetworkInterfaceConfig::try_parse_from(str_slip_to_clap(net_cmd, true, false)); + assert!(result.is_err()); + let net_cmd = "virtio-net-pci,id=netid,netdev=netdevid,queue-size=128"; + let result = NetworkInterfaceConfig::try_parse_from(str_slip_to_clap(net_cmd, true, false)); + assert!(result.is_err()); + let net_cmd = "virtio-net-pci,id=netid,netdev=netdevid,queue-size=10240"; + let result = NetworkInterfaceConfig::try_parse_from(str_slip_to_clap(net_cmd, true, false)); + assert!(result.is_err()); } #[test] @@ -748,7 +561,7 @@ mod tests { let mut net_conf = NetDevcfg::default(); net_conf.id = String::from(*id); assert!(vm_config.netdevs.get(*id).is_some()); - assert!(vm_config.del_netdev_by_id(*id).is_ok()); + assert!(vm_config.del_netdev_by_id(id).is_ok()); assert!(vm_config.netdevs.get(*id).is_none()); } } @@ -811,9 +624,9 @@ mod tests { ..qmp_schema::NetDevAddArgument::default() }); let net_cfg = get_netdev_config(netdev).unwrap(); + assert_eq!(net_cfg.vhost_type().unwrap(), "vhost-kernel"); assert_eq!(net_cfg.tap_fds.unwrap()[0], 11); assert_eq!(net_cfg.vhost_fds.unwrap()[0], 21); - assert_eq!(net_cfg.vhost_type.unwrap(), "vhost-kernel"); } // Normal test with 'vhostfds'. @@ -831,9 +644,9 @@ mod tests { ..qmp_schema::NetDevAddArgument::default() }); let net_cfg = get_netdev_config(netdev).unwrap(); - assert_eq!(net_cfg.tap_fds.unwrap(), [11, 12, 13, 14]); - assert_eq!(net_cfg.vhost_fds.unwrap(), [21, 22, 23, 24]); - assert_eq!(net_cfg.vhost_type.unwrap(), "vhost-kernel"); + assert_eq!(net_cfg.vhost_type().unwrap(), "vhost-kernel"); + assert_eq!(net_cfg.tap_fds.unwrap(), vec![11, 12, 13, 14]); + assert_eq!(net_cfg.vhost_fds.unwrap(), vec![21, 22, 23, 24]); } let err_msgs = [ @@ -851,7 +664,7 @@ mod tests { queues: Some(u16::MAX), ..qmp_schema::NetDevAddArgument::default() }); - check_err_msg(netdev, &err_msgs[0]); + check_err_msg(netdev, err_msgs[0]); // Abnornal test with invalid 'queues': MAX_QUEUE_PAIRS + 1. let netdev = Box::new(qmp_schema::NetDevAddArgument { @@ -859,8 +672,7 @@ mod tests { ..qmp_schema::NetDevAddArgument::default() }); let err_msg = format!( - "The 'queues' {} is bigger than max queue num {}", - MAX_QUEUE_PAIRS + 1, + "number queues of net device must >= 1 and <= {}.", MAX_QUEUE_PAIRS ); check_err_msg(netdev, &err_msg); @@ -871,7 +683,7 @@ mod tests { vhostfds: Some("21:22:23:24".to_string()), ..qmp_schema::NetDevAddArgument::default() }); - check_err_msg(netdev, &err_msgs[1]); + check_err_msg(netdev, err_msgs[1]); // Abnornal test with 'fds' and 'vhostfd'. let netdev = Box::new(qmp_schema::NetDevAddArgument { @@ -879,7 +691,7 @@ mod tests { vhostfd: Some("21".to_string()), ..qmp_schema::NetDevAddArgument::default() }); - check_err_msg(netdev, &err_msgs[2]); + check_err_msg(netdev, err_msgs[2]); // Abnornal test with different num of 'fds' and 'vhostfds'. let netdev = Box::new(qmp_schema::NetDevAddArgument { @@ -887,7 +699,7 @@ mod tests { vhostfds: Some("21:22:23".to_string()), ..qmp_schema::NetDevAddArgument::default() }); - check_err_msg(netdev, &err_msgs[3]); + check_err_msg(netdev, err_msgs[3]); // Abnornal test with 'net_type=vhost-user'. let netdev = Box::new(qmp_schema::NetDevAddArgument { @@ -897,7 +709,7 @@ mod tests { net_type: Some("vhost-user".to_string()), ..qmp_schema::NetDevAddArgument::default() }); - check_err_msg(netdev, &err_msgs[4]); + check_err_msg(netdev, err_msgs[4]); // Abnornal test with 'fds/vhostfds' and no 'vhost'. let netdev = Box::new(qmp_schema::NetDevAddArgument { @@ -905,13 +717,13 @@ mod tests { vhostfds: Some("21:22:23:24".to_string()), ..qmp_schema::NetDevAddArgument::default() }); - check_err_msg(netdev, &err_msgs[5]); + check_err_msg(netdev, err_msgs[5]); // Abnornal test with all default value. let netdev = Box::new(qmp_schema::NetDevAddArgument { ..qmp_schema::NetDevAddArgument::default() }); - check_err_msg(netdev, &err_msgs[6]); + check_err_msg(netdev, err_msgs[6]); // Abnornal test with invalid fd value. let netdev = Box::new(qmp_schema::NetDevAddArgument { @@ -929,6 +741,7 @@ mod tests { fds: Some(fds.to_string()), ..qmp_schema::NetDevAddArgument::default() }); + // number queues of net device let err_msg = format!( "The num of fd {} is bigger than max queue num {}", MAX_QUEUE_PAIRS + 1, diff --git a/machine_manager/src/config/numa.rs b/machine_manager/src/config/numa.rs index a9a0bfa3c2b39234d7d1e3b8e1fd09a2c3f4c2e0..f628fef7976e2c7225c789816933f1c8b5b11fc8 100644 --- a/machine_manager/src/config/numa.rs +++ b/machine_manager/src/config/numa.rs @@ -12,29 +12,17 @@ use std::cmp::max; use std::collections::{BTreeMap, HashSet}; +use std::str::FromStr; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{bail, Context, Result}; +use clap::Parser; use super::error::ConfigError; -use crate::config::{CmdParser, IntegerList, VmConfig, MAX_NODES}; +use super::{get_class_type, str_slip_to_clap}; +use crate::config::{IntegerList, VmConfig, MAX_NODES}; const MIN_NUMA_DISTANCE: u8 = 10; -#[derive(Default, Debug)] -pub struct NumaDistance { - pub destination: u32, - pub distance: u8, -} - -#[derive(Default, Debug)] -pub struct NumaConfig { - pub numa_id: u32, - pub cpus: Vec, - pub distances: Option>, - pub size: u64, - pub mem_dev: String, -} - #[derive(Default)] pub struct NumaNode { pub cpus: Vec, @@ -109,126 +97,83 @@ pub fn complete_numa_node(numa_nodes: &mut NumaNodes, nr_cpus: u8, mem_size: u64 Ok(()) } -/// Parse the NUMA node memory parameters. -/// -/// # Arguments -/// -/// * `numa_config` - The NUMA node configuration. -pub fn parse_numa_mem(numa_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("numa"); - cmd_parser - .push("") - .push("nodeid") - .push("cpus") - .push("memdev"); - cmd_parser.parse(numa_config)?; - - let mut config: NumaConfig = NumaConfig::default(); - if let Some(node_id) = cmd_parser.get_value::("nodeid")? { - if node_id >= MAX_NODES { - return Err(anyhow!(ConfigError::IllegalValue( - "nodeid".to_string(), - 0, - true, - MAX_NODES as u64, - false, - ))); - } - config.numa_id = node_id; - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "nodeid".to_string(), - "numa".to_string() - ))); - } - if let Some(mut cpus) = cmd_parser - .get_value::("cpus") +#[derive(Parser)] +#[command(no_binary_name(true))] +pub struct NumaNodeConfig { + #[arg(long, value_parser = ["node"])] + pub classtype: String, + #[arg(long, alias = "nodeid", value_parser = clap::value_parser!(u32).range(..MAX_NODES as i64))] + pub numa_id: u32, + #[arg(long, value_parser = get_cpus)] + pub cpus: ::std::vec::Vec, + #[arg(long, alias = "memdev")] + pub mem_dev: String, +} + +fn get_cpus(cpus_str: &str) -> Result> { + let mut cpus = IntegerList::from_str(cpus_str) .with_context(|| ConfigError::ConvertValueFailed(String::from("u8"), "cpus".to_string()))? - .map(|v| v.0.iter().map(|e| *e as u8).collect::>()) - { - cpus.sort_unstable(); - config.cpus = cpus; - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "cpus".to_string(), - "numa".to_string() - ))); + .0 + .iter() + .map(|e| *e as u8) + .collect::>(); + + if cpus.is_empty() { + bail!("Got empty cpus list!"); } - config.mem_dev = cmd_parser - .get_value::("memdev")? - .with_context(|| ConfigError::FieldIsMissing("memdev".to_string(), "numa".to_string()))?; - Ok(config) + cpus.sort_unstable(); + + Ok(cpus) } -/// Parse the NUMA node distance parameters. +/// Parse the NUMA node memory parameters. /// /// # Arguments /// -/// * `numa_dist` - The NUMA node distance configuration. -pub fn parse_numa_distance(numa_dist: &str) -> Result<(u32, NumaDistance)> { - let mut cmd_parser = CmdParser::new("numa"); - cmd_parser.push("").push("src").push("dst").push("val"); - cmd_parser.parse(numa_dist)?; - - let mut dist: NumaDistance = NumaDistance::default(); - let numa_id = if let Some(src) = cmd_parser.get_value::("src")? { - if src >= MAX_NODES { - return Err(anyhow!(ConfigError::IllegalValue( - "src".to_string(), - 0, - true, - MAX_NODES as u64, - false, - ))); - } - src - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "src".to_string(), - "numa".to_string() - ))); - }; - if let Some(dst) = cmd_parser.get_value::("dst")? { - if dst >= MAX_NODES { - return Err(anyhow!(ConfigError::IllegalValue( - "dst".to_string(), - 0, - true, - MAX_NODES as u64, - false, - ))); - } - dist.destination = dst; - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "dst".to_string(), - "numa".to_string() - ))); - } - if let Some(val) = cmd_parser.get_value::("val")? { - if val < MIN_NUMA_DISTANCE { - bail!("NUMA distance shouldn't be less than 10"); - } - if numa_id == dist.destination && val != MIN_NUMA_DISTANCE { - bail!("Local distance of node {} should be 10.", numa_id); +/// * `numa_config` - The NUMA node configuration. +pub fn parse_numa_mem(numa_config: &str) -> Result { + let config = NumaNodeConfig::try_parse_from(str_slip_to_clap(numa_config, true, false))?; + Ok(config) +} + +#[derive(Parser)] +#[command(no_binary_name(true))] +pub struct NumaDistConfig { + #[arg(long, value_parser = ["dist"])] + pub classtype: String, + #[arg(long, alias = "src", value_parser = clap::value_parser!(u32).range(..MAX_NODES as i64))] + pub numa_id: u32, + #[arg(long, alias = "dst", value_parser = clap::value_parser!(u32).range(..MAX_NODES as i64))] + pub destination: u32, + #[arg(long, alias = "val", value_parser = clap::value_parser!(u8).range(MIN_NUMA_DISTANCE as i64..))] + pub distance: u8, +} + +impl NumaDistConfig { + fn check(&self) -> Result<()> { + if self.numa_id == self.destination && self.distance != MIN_NUMA_DISTANCE { + bail!("Local distance of node {} should be 10.", self.numa_id); } - if numa_id != dist.destination && val == MIN_NUMA_DISTANCE { + if self.numa_id != self.destination && self.distance == MIN_NUMA_DISTANCE { bail!( "Remote distance of node {} should be more than 10.", - numa_id + self.numa_id ); } - - dist.distance = val; - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "val".to_string(), - "numa".to_string() - ))); + Ok(()) } +} - Ok((numa_id, dist)) +/// Parse the NUMA node distance parameters. +/// +/// # Arguments +/// +/// * `numa_dist` - The NUMA node distance configuration. +pub fn parse_numa_distance(numa_dist: &str) -> Result { + let dist_cfg = NumaDistConfig::try_parse_from(str_slip_to_clap(numa_dist, true, false))?; + dist_cfg.check()?; + Ok(dist_cfg) } impl VmConfig { @@ -238,13 +183,8 @@ impl VmConfig { /// /// * `numa_config` - The NUMA node configuration. pub fn add_numa(&mut self, numa_config: &str) -> Result<()> { - let mut cmd_params = CmdParser::new("numa"); - cmd_params.push(""); - - cmd_params.get_parameters(numa_config)?; - if let Some(numa_type) = cmd_params.get_value::("")? { - self.numa_nodes.push((numa_type, numa_config.to_string())); - } + let numa_type = get_class_type(numa_config).with_context(|| "Numa type not specified")?; + self.numa_nodes.push((numa_type, numa_config.to_string())); Ok(()) } @@ -258,20 +198,18 @@ mod tests { fn test_parse_numa_mem() { let mut vm_config = VmConfig::default(); assert!(vm_config - .add_numa("-numa node,nodeid=0,cpus=0-1,memdev=mem0") - .is_ok()); - assert!(vm_config - .add_numa("-numa node,nodeid=1,cpus=2-1,memdev=mem1") + .add_numa("node,nodeid=0,cpus=0-1,memdev=mem0") .is_ok()); assert!(vm_config - .add_numa("-numa node,nodeid=2,memdev=mem2") + .add_numa("node,nodeid=1,cpus=2-1,memdev=mem1") .is_ok()); - assert!(vm_config.add_numa("-numa node,nodeid=3,cpus=3-4").is_ok()); + assert!(vm_config.add_numa("node,nodeid=2,memdev=mem2").is_ok()); + assert!(vm_config.add_numa("node,nodeid=3,cpus=3-4").is_ok()); assert!(vm_config - .add_numa("-numa node,nodeid=0,cpus=[0-1:3-5],memdev=mem0") + .add_numa("node,nodeid=0,cpus=[0-1:3-5],memdev=mem0") .is_ok()); - let numa = vm_config.numa_nodes.get(0).unwrap(); + let numa = vm_config.numa_nodes.first().unwrap(); let numa_config = parse_numa_mem(numa.1.as_str()).unwrap(); assert_eq!(numa_config.cpus, vec![0, 1]); assert_eq!(numa_config.mem_dev, "mem0"); @@ -291,17 +229,17 @@ mod tests { #[test] fn test_parse_numa_distance() { let mut vm_config = VmConfig::default(); - assert!(vm_config.add_numa("-numa dist,src=0,dst=1,val=15").is_ok()); - assert!(vm_config.add_numa("-numa dist,dst=1,val=10").is_ok()); - assert!(vm_config.add_numa("-numa dist,src=0,val=10").is_ok()); - assert!(vm_config.add_numa("-numa dist,src=0,dst=1").is_ok()); - assert!(vm_config.add_numa("-numa dist,src=0,dst=1,val=10").is_ok()); - - let numa = vm_config.numa_nodes.get(0).unwrap(); - let dist = parse_numa_distance(numa.1.as_str()).unwrap(); - assert_eq!(dist.0, 0); - assert_eq!(dist.1.destination, 1); - assert_eq!(dist.1.distance, 15); + assert!(vm_config.add_numa("dist,src=0,dst=1,val=15").is_ok()); + assert!(vm_config.add_numa("dist,dst=1,val=10").is_ok()); + assert!(vm_config.add_numa("dist,src=0,val=10").is_ok()); + assert!(vm_config.add_numa("dist,src=0,dst=1").is_ok()); + assert!(vm_config.add_numa("dist,src=0,dst=1,val=10").is_ok()); + + let numa = vm_config.numa_nodes.first().unwrap(); + let dist_cfg = parse_numa_distance(numa.1.as_str()).unwrap(); + assert_eq!(dist_cfg.numa_id, 0); + assert_eq!(dist_cfg.destination, 1); + assert_eq!(dist_cfg.distance, 15); let numa = vm_config.numa_nodes.get(1).unwrap(); assert!(parse_numa_distance(numa.1.as_str()).is_err()); diff --git a/machine_manager/src/config/pci.rs b/machine_manager/src/config/pci.rs index e1ad4985a1143e9315c872b6a778d0d7d43eb6e5..6642f4f62bc4c6220e0986d22be5ecc4ddbc03be 100644 --- a/machine_manager/src/config/pci.rs +++ b/machine_manager/src/config/pci.rs @@ -13,9 +13,7 @@ use anyhow::{bail, Context, Result}; use serde::{Deserialize, Serialize}; -use super::error::ConfigError; -use super::{CmdParser, ConfigCheck, UnsignedInteger}; -use crate::config::{check_arg_too_long, ExBool}; +use super::get_value_of_parameter; use util::num_ops::str_to_num; /// Basic information of pci devices such as bus number, @@ -43,30 +41,6 @@ impl Default for PciBdf { } } -/// Basic information of RootPort like port number. -#[derive(Debug, Clone)] -pub struct RootPortConfig { - pub port: u8, - pub id: String, - pub multifunction: bool, -} - -impl ConfigCheck for RootPortConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "root_port id") - } -} - -impl Default for RootPortConfig { - fn default() -> Self { - RootPortConfig { - port: 0, - id: "".to_string(), - multifunction: false, - } - } -} - pub fn get_pci_df(addr: &str) -> Result<(u8, u8)> { let addr_vec: Vec<&str> = addr.split('.').collect(); if addr_vec.len() > 2 { @@ -96,89 +70,15 @@ pub fn get_pci_df(addr: &str) -> Result<(u8, u8)> { } pub fn get_pci_bdf(pci_cfg: &str) -> Result { - let mut cmd_parser = CmdParser::new("bdf"); - cmd_parser.push("").push("bus").push("addr"); - cmd_parser.get_parameters(pci_cfg)?; - - let mut pci_bdf = PciBdf { - bus: cmd_parser - .get_value::("bus")? - .with_context(|| "Bus not specified for pci device")?, - ..Default::default() - }; - if let Some(addr) = cmd_parser.get_value::("addr")? { - pci_bdf.addr = get_pci_df(&addr).with_context(|| "Failed to get addr")?; - } else { - bail!("No addr found for pci device"); + let bus = get_value_of_parameter("bus", pci_cfg)?; + let addr_str = get_value_of_parameter("addr", pci_cfg)?; + if addr_str.is_empty() { + bail!("Invalid addr."); } - Ok(pci_bdf) -} - -pub fn get_multi_function(pci_cfg: &str) -> Result { - let mut cmd_parser = CmdParser::new("multifunction"); - cmd_parser.push("").push("multifunction"); - cmd_parser.get_parameters(pci_cfg)?; - - if let Some(multi_func) = cmd_parser - .get_value::("multifunction") - .with_context(|| "Failed to get multifunction parameter, please set on or off (default).")? - { - return Ok(multi_func.inner); - } - - Ok(false) -} - -pub fn parse_root_port(rootport_cfg: &str) -> Result { - let mut cmd_parser = CmdParser::new("pcie-root-port"); - cmd_parser - .push("") - .push("bus") - .push("addr") - .push("port") - .push("chassis") - .push("multifunction") - .push("id"); - cmd_parser.parse(rootport_cfg)?; - - let root_port = RootPortConfig { - port: cmd_parser - .get_value::("port")? - .with_context(|| { - ConfigError::FieldIsMissing("port".to_string(), "rootport".to_string()) - })? - .0 as u8, - id: cmd_parser.get_value::("id")?.with_context(|| { - ConfigError::FieldIsMissing("id".to_string(), "rootport".to_string()) - })?, - multifunction: cmd_parser - .get_value::("multifunction")? - .map_or(false, bool::from), - }; - - let _ = cmd_parser.get_value::("chassis")?; + let addr = get_pci_df(&addr_str).with_context(|| "Failed to get addr")?; + let pci_bdf = PciBdf::new(bus, addr); - root_port.check()?; - Ok(root_port) -} - -pub fn pci_args_check(cmd_parser: &CmdParser) -> Result<()> { - let device_type = cmd_parser.get_value::("")?; - let dev_type = device_type.unwrap(); - // Safe, because this function only be called when certain - // devices type are added. - if dev_type.ends_with("-device") { - if cmd_parser.get_value::("bus")?.is_some() { - bail!("virtio mmio device does not support bus arguments"); - } - if cmd_parser.get_value::("addr")?.is_some() { - bail!("virtio mmio device does not support addr arguments"); - } - if cmd_parser.get_value::("multifunction")?.is_some() { - bail!("virtio mmio device does not support multifunction arguments"); - } - } - Ok(()) + Ok(pci_bdf) } #[cfg(test)] @@ -241,26 +141,4 @@ mod tests { let pci_bdf = get_pci_bdf("virtio-balloon-device,addr=0x1.0x2"); assert!(pci_bdf.is_err()); } - - #[test] - fn test_get_multi_function() { - assert_eq!( - get_multi_function("virtio-balloon-device,bus=pcie.0,addr=0x1.0x2").unwrap(), - false - ); - assert_eq!( - get_multi_function("virtio-balloon-device,bus=pcie.0,addr=0x1.0x2,multifunction=on") - .unwrap(), - true - ); - assert_eq!( - get_multi_function("virtio-balloon-device,bus=pcie.0,addr=0x1.0x2,multifunction=off") - .unwrap(), - false - ); - assert!(get_multi_function( - "virtio-balloon-device,bus=pcie.0,addr=0x1.0x2,multifunction=close" - ) - .is_err()); - } } diff --git a/machine_manager/src/config/rng.rs b/machine_manager/src/config/rng.rs index b153dcf2eaa5dc46e590cc2f9aec9becbc29be29..78c3ef79cc8139d662319057d37dba8decbfa5ac 100644 --- a/machine_manager/src/config/rng.rs +++ b/machine_manager/src/config/rng.rs @@ -10,243 +10,18 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use anyhow::{anyhow, bail, Context, Result}; +use clap::Parser; use serde::{Deserialize, Serialize}; -use super::error::ConfigError; -use super::pci_args_check; -use crate::config::{CmdParser, ConfigCheck, VmConfig, MAX_PATH_LENGTH}; +use crate::config::{valid_id, valid_path}; -const MIN_BYTES_PER_SEC: u64 = 64; -const MAX_BYTES_PER_SEC: u64 = 1_000_000_000; - -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct RngObjConfig { + #[arg(long, value_parser = ["rng-random"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] pub id: String, + #[arg(long, value_parser = valid_path)] pub filename: String, } - -/// Config structure for virtio-rng. -#[derive(Debug, Clone, Default)] -pub struct RngConfig { - pub id: String, - pub random_file: String, - pub bytes_per_sec: Option, -} - -impl ConfigCheck for RngConfig { - fn check(&self) -> Result<()> { - if self.id.len() > MAX_PATH_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "rng id".to_string(), - MAX_PATH_LENGTH - ))); - } - - if self.random_file.len() > MAX_PATH_LENGTH { - return Err(anyhow!(ConfigError::StringLengthTooLong( - "rng random file".to_string(), - MAX_PATH_LENGTH, - ))); - } - - if let Some(bytes_per_sec) = self.bytes_per_sec { - if !(MIN_BYTES_PER_SEC..=MAX_BYTES_PER_SEC).contains(&bytes_per_sec) { - return Err(anyhow!(ConfigError::IllegalValue( - "The bytes per second of rng device".to_string(), - MIN_BYTES_PER_SEC, - true, - MAX_BYTES_PER_SEC, - true, - ))); - } - } - - Ok(()) - } -} - -pub fn parse_rng_dev(vm_config: &mut VmConfig, rng_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("rng"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("addr") - .push("multifunction") - .push("max-bytes") - .push("period") - .push("rng"); - - cmd_parser.parse(rng_config)?; - pci_args_check(&cmd_parser)?; - let mut rng_cfg = RngConfig::default(); - let rng = cmd_parser - .get_value::("rng")? - .with_context(|| ConfigError::FieldIsMissing("rng".to_string(), "rng".to_string()))?; - - rng_cfg.id = cmd_parser.get_value::("id")?.unwrap_or_default(); - - if let Some(max) = cmd_parser.get_value::("max-bytes")? { - if let Some(peri) = cmd_parser.get_value::("period")? { - let mul = max - .checked_mul(1000) - .with_context(|| format!("Illegal max-bytes arguments: {:?}", max))?; - let div = mul - .checked_div(peri) - .with_context(|| format!("Illegal period arguments: {:?}", peri))?; - rng_cfg.bytes_per_sec = Some(div); - } else { - bail!("Argument 'period' is missing"); - } - } else if cmd_parser.get_value::("period")?.is_some() { - bail!("Argument 'max-bytes' is missing"); - } - - rng_cfg.random_file = vm_config - .object - .rng_object - .remove(&rng) - .map(|rng_object| rng_object.filename) - .with_context(|| "Object for rng-random device not found")?; - - rng_cfg.check()?; - Ok(rng_cfg) -} - -pub fn parse_rng_obj(object_args: &str) -> Result { - let mut cmd_params = CmdParser::new("rng-object"); - cmd_params.push("").push("id").push("filename"); - - cmd_params.parse(object_args)?; - let id = cmd_params - .get_value::("id")? - .with_context(|| ConfigError::FieldIsMissing("id".to_string(), "rng-object".to_string()))?; - let filename = cmd_params - .get_value::("filename")? - .with_context(|| { - ConfigError::FieldIsMissing("filename".to_string(), "rng-object".to_string()) - })?; - let rng_obj_cfg = RngObjConfig { id, filename }; - - Ok(rng_obj_cfg) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::config::get_pci_bdf; - - #[test] - fn test_rng_config_cmdline_parser_01() { - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_config = parse_rng_dev(&mut vm_config, "virtio-rng-device,rng=objrng0"); - assert!(rng_config.is_ok()); - let config = rng_config.unwrap(); - assert_eq!(config.random_file, "/path/to/random_file"); - assert_eq!(config.bytes_per_sec, None); - - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_config = parse_rng_dev( - &mut vm_config, - "virtio-rng-device,rng=objrng0,max-bytes=1234,period=1000", - ); - assert!(rng_config.is_ok()); - let config = rng_config.unwrap(); - assert_eq!(config.random_file, "/path/to/random_file"); - assert_eq!(config.bytes_per_sec, Some(1234)); - } - - #[test] - fn test_rng_config_cmdline_parser_02() { - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_config = parse_rng_dev( - &mut vm_config, - "virtio-rng-device,rng=objrng0,max-bytes=63,period=1000", - ); - assert!(rng_config.is_err()); - - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_config = parse_rng_dev( - &mut vm_config, - "virtio-rng-device,rng=objrng0,max-bytes=64,period=1000", - ); - assert!(rng_config.is_ok()); - let config = rng_config.unwrap(); - assert_eq!(config.random_file, "/path/to/random_file"); - assert_eq!(config.bytes_per_sec, Some(64)); - - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_config = parse_rng_dev( - &mut vm_config, - "virtio-rng-device,rng=objrng0,max-bytes=1000000000,period=1000", - ); - assert!(rng_config.is_ok()); - let config = rng_config.unwrap(); - assert_eq!(config.random_file, "/path/to/random_file"); - assert_eq!(config.bytes_per_sec, Some(1000000000)); - - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_config = parse_rng_dev( - &mut vm_config, - "virtio-rng-device,rng=objrng0,max-bytes=1000000001,period=1000", - ); - assert!(rng_config.is_err()); - } - - #[test] - fn test_pci_rng_config_cmdline_parser() { - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_cfg = "virtio-rng-pci,rng=objrng0,bus=pcie.0,addr=0x1.0x3"; - let rng_config = parse_rng_dev(&mut vm_config, rng_cfg); - assert!(rng_config.is_ok()); - let config = rng_config.unwrap(); - assert_eq!(config.random_file, "/path/to/random_file"); - assert_eq!(config.bytes_per_sec, None); - let pci_bdf = get_pci_bdf(rng_cfg); - assert!(pci_bdf.is_ok()); - let pci = pci_bdf.unwrap(); - assert_eq!(pci.bus, "pcie.0".to_string()); - assert_eq!(pci.addr, (1, 3)); - - // object "objrng0" has been removed. - let rng_config = parse_rng_dev(&mut vm_config, rng_cfg); - assert!(rng_config.is_err()); - - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_cfg = "virtio-rng-device,rng=objrng0,bus=pcie.0,addr=0x1.0x3"; - let rng_config = parse_rng_dev(&mut vm_config, rng_cfg); - assert!(rng_config.is_err()); - - let mut vm_config = VmConfig::default(); - assert!(vm_config - .add_object("rng-random,id=objrng0,filename=/path/to/random_file") - .is_ok()); - let rng_cfg = "virtio-rng-pci,rng=objrng0,bus=pcie.0,addr=0x1.0x3,multifunction=on"; - assert!(parse_rng_dev(&mut vm_config, rng_cfg).is_ok()); - } -} diff --git a/machine_manager/src/config/sasl_auth.rs b/machine_manager/src/config/sasl_auth.rs index 506763adc18f181bdab1e1d7f28ded404d0296aa..f01699a19bc6249f00481c7a8234ea8e54af4bcf 100644 --- a/machine_manager/src/config/sasl_auth.rs +++ b/machine_manager/src/config/sasl_auth.rs @@ -10,44 +10,33 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, Result}; +use clap::Parser; use serde::{Deserialize, Serialize}; -use crate::config::{ - ConfigError, {CmdParser, VmConfig}, -}; +use crate::config::{str_slip_to_clap, valid_id, ConfigError, VmConfig}; -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct SaslAuthObjConfig { - /// Object Id. + #[arg(long, value_parser = ["authz-simple"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] pub id: String, /// Authentication User Name. + #[arg(long, default_value = "")] pub identity: String, } impl VmConfig { pub fn add_saslauth(&mut self, saslauth_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("authz-simple"); - cmd_parser.push("").push("id").push("identity"); - cmd_parser.parse(saslauth_config)?; - - let mut saslauth = SaslAuthObjConfig { - id: cmd_parser.get_value::("id")?.with_context(|| { - ConfigError::FieldIsMissing("id".to_string(), "vnc sasl_auth".to_string()) - })?, - ..Default::default() - }; - - if let Some(identity) = cmd_parser.get_value::("identity")? { - saslauth.identity = identity; - } - + let saslauth = + SaslAuthObjConfig::try_parse_from(str_slip_to_clap(saslauth_config, true, false))?; let id = saslauth.id.clone(); - if self.object.sasl_object.get(&id).is_none() { - self.object.sasl_object.insert(id, saslauth); - } else { + if self.object.sasl_object.contains_key(&id) { return Err(anyhow!(ConfigError::IdRepeat("saslauth".to_string(), id))); } + self.object.sasl_object.insert(id, saslauth); Ok(()) } @@ -73,7 +62,7 @@ mod tests { assert!(vm_config.add_object("authz-simple,id=authz0").is_ok()); assert!(vm_config.object.sasl_object.get(&id).is_some()); if let Some(obj_cfg) = vm_config.object.sasl_object.get(&id) { - assert!(obj_cfg.identity == "".to_string()); + assert!(obj_cfg.identity == *""); } } } diff --git a/machine_manager/src/config/scsi.rs b/machine_manager/src/config/scsi.rs deleted file mode 100644 index b73833bcde4b6af82f14dc1eca0d7cd8c1c80a89..0000000000000000000000000000000000000000 --- a/machine_manager/src/config/scsi.rs +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright (c) 2022 Huawei Technologies Co.,Ltd. All rights reserved. -// -// StratoVirt is licensed under Mulan PSL v2. -// You can use this software according to the terms and conditions of the Mulan -// PSL v2. -// You may obtain a copy of Mulan PSL v2 at: -// http://license.coscl.org.cn/MulanPSL2 -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -// See the Mulan PSL v2 for more details. - -use anyhow::{anyhow, bail, Context, Result}; - -use super::{error::ConfigError, pci_args_check, DiskFormat}; -use crate::config::{ - check_arg_too_long, CmdParser, ConfigCheck, VmConfig, DEFAULT_VIRTQUEUE_SIZE, MAX_VIRTIO_QUEUE, -}; -use util::aio::AioEngine; - -/// According to Virtio Spec. -/// Max_channel should be 0. -/// Max_target should be less than or equal to 255. -pub const VIRTIO_SCSI_MAX_TARGET: u16 = 255; -/// Max_lun should be less than or equal to 16383 (2^14 - 1). -pub const VIRTIO_SCSI_MAX_LUN: u16 = 16383; - -/// Only support peripheral device addressing format(8 bits for lun) in stratovirt now. -/// So, max lun id supported is 255 (2^8 - 1). -const SUPPORT_SCSI_MAX_LUN: u16 = 255; - -// Seg_max = queue_size - 2. So, size of each virtqueue for virtio-scsi should be larger than 2. -const MIN_QUEUE_SIZE_SCSI: u16 = 2; -// Max size of each virtqueue for virtio-scsi. -const MAX_QUEUE_SIZE_SCSI: u16 = 1024; - -#[derive(Debug, Clone)] -pub struct ScsiCntlrConfig { - /// Virtio-scsi-pci device id. - pub id: String, - /// Thread name of io handler. - pub iothread: Option, - /// Number of scsi cmd queues. - pub queues: u32, - /// Boot path of this scsi controller. It's prefix of scsi device's boot path. - pub boot_prefix: Option, - /// Virtqueue size for all queues. - pub queue_size: u16, -} - -impl Default for ScsiCntlrConfig { - fn default() -> Self { - ScsiCntlrConfig { - id: "".to_string(), - iothread: None, - // At least 1 cmd queue. - queues: 1, - boot_prefix: None, - queue_size: DEFAULT_VIRTQUEUE_SIZE, - } - } -} - -impl ConfigCheck for ScsiCntlrConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.id, "virtio-scsi-pci device id")?; - - if self.iothread.is_some() { - check_arg_too_long(self.iothread.as_ref().unwrap(), "iothread name")?; - } - - if self.queues < 1 || self.queues > MAX_VIRTIO_QUEUE as u32 { - return Err(anyhow!(ConfigError::IllegalValue( - "queues number of scsi controller".to_string(), - 1, - true, - MAX_VIRTIO_QUEUE as u64, - true, - ))); - } - - if self.queue_size <= MIN_QUEUE_SIZE_SCSI || self.queue_size > MAX_QUEUE_SIZE_SCSI { - return Err(anyhow!(ConfigError::IllegalValue( - "virtqueue size of scsi controller".to_string(), - MIN_QUEUE_SIZE_SCSI as u64, - false, - MAX_QUEUE_SIZE_SCSI as u64, - true - ))); - } - - if self.queue_size & (self.queue_size - 1) != 0 { - bail!("Virtqueue size should be power of 2!"); - } - - Ok(()) - } -} - -pub fn parse_scsi_controller( - drive_config: &str, - queues_auto: Option, -) -> Result { - let mut cmd_parser = CmdParser::new("virtio-scsi-pci"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("addr") - .push("multifunction") - .push("iothread") - .push("num-queues") - .push("queue-size"); - - cmd_parser.parse(drive_config)?; - - pci_args_check(&cmd_parser)?; - - let mut cntlr_cfg = ScsiCntlrConfig::default(); - - if let Some(iothread) = cmd_parser.get_value::("iothread")? { - cntlr_cfg.iothread = Some(iothread); - } - - cntlr_cfg.id = cmd_parser.get_value::("id")?.with_context(|| { - ConfigError::FieldIsMissing("id".to_string(), "virtio scsi pci".to_string()) - })?; - - if let Some(queues) = cmd_parser.get_value::("num-queues")? { - cntlr_cfg.queues = queues; - } else if let Some(queues) = queues_auto { - cntlr_cfg.queues = queues as u32; - } - - if let Some(size) = cmd_parser.get_value::("queue-size")? { - cntlr_cfg.queue_size = size; - } - - cntlr_cfg.check()?; - Ok(cntlr_cfg) -} - -#[derive(Clone, Debug)] -pub struct ScsiDevConfig { - /// Scsi Device id. - pub id: String, - /// The image file path. - pub path_on_host: String, - /// Serial number of the scsi device. - pub serial: Option, - /// Scsi controller which the scsi device attaches to. - pub cntlr: String, - /// Scsi device can not do write operation. - pub read_only: bool, - /// If true, use direct access io. - pub direct: bool, - /// Async IO type. - pub aio_type: AioEngine, - /// Boot order. - pub boot_index: Option, - /// Scsi four level hierarchical address(host, channel, target, lun). - pub channel: u8, - pub target: u8, - pub lun: u16, - pub format: DiskFormat, - pub l2_cache_size: Option, - pub refcount_cache_size: Option, -} - -impl Default for ScsiDevConfig { - fn default() -> Self { - ScsiDevConfig { - id: "".to_string(), - path_on_host: "".to_string(), - serial: None, - cntlr: "".to_string(), - read_only: false, - direct: true, - aio_type: AioEngine::Native, - boot_index: None, - channel: 0, - target: 0, - lun: 0, - format: DiskFormat::Raw, - l2_cache_size: None, - refcount_cache_size: None, - } - } -} - -pub fn parse_scsi_device(vm_config: &mut VmConfig, drive_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("scsi-device"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("scsi-id") - .push("lun") - .push("serial") - .push("bootindex") - .push("drive"); - - cmd_parser.parse(drive_config)?; - - let mut scsi_dev_cfg = ScsiDevConfig::default(); - - let scsi_drive = cmd_parser.get_value::("drive")?.with_context(|| { - ConfigError::FieldIsMissing("drive".to_string(), "scsi device".to_string()) - })?; - - if let Some(boot_index) = cmd_parser.get_value::("bootindex")? { - scsi_dev_cfg.boot_index = Some(boot_index); - } - - if let Some(serial) = cmd_parser.get_value::("serial")? { - scsi_dev_cfg.serial = Some(serial); - } - - scsi_dev_cfg.id = cmd_parser.get_value::("id")?.with_context(|| { - ConfigError::FieldIsMissing("id".to_string(), "scsi device".to_string()) - })?; - - if let Some(bus) = cmd_parser.get_value::("bus")? { - // Format "$parent_cntlr_name.0" is required by scsi bus. - let strs = bus.split('.').collect::>(); - if strs.len() != 2 || strs[1] != "0" { - bail!("Invalid scsi bus {}", bus); - } - scsi_dev_cfg.cntlr = strs[0].to_string(); - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "bus".to_string(), - "scsi device".to_string() - ))); - } - - if let Some(target) = cmd_parser.get_value::("scsi-id")? { - if target > VIRTIO_SCSI_MAX_TARGET as u8 { - return Err(anyhow!(ConfigError::IllegalValue( - "scsi-id of scsi device".to_string(), - 0, - true, - VIRTIO_SCSI_MAX_TARGET as u64, - true, - ))); - } - scsi_dev_cfg.target = target; - } - - if let Some(lun) = cmd_parser.get_value::("lun")? { - // Do not support Flat space addressing format(14 bits for lun) in stratovirt now. - // We now support peripheral device addressing format(8 bits for lun). - // So, MAX_LUN should be less than 255(2^8 - 1) temporarily. - if lun > SUPPORT_SCSI_MAX_LUN { - return Err(anyhow!(ConfigError::IllegalValue( - "lun of scsi device".to_string(), - 0, - true, - SUPPORT_SCSI_MAX_LUN as u64, - true, - ))); - } - scsi_dev_cfg.lun = lun; - } - - let drive_arg = &vm_config - .drives - .remove(&scsi_drive) - .with_context(|| "No drive configured matched for scsi device")?; - scsi_dev_cfg.path_on_host = drive_arg.path_on_host.clone(); - scsi_dev_cfg.read_only = drive_arg.read_only; - scsi_dev_cfg.direct = drive_arg.direct; - scsi_dev_cfg.aio_type = drive_arg.aio; - scsi_dev_cfg.format = drive_arg.format; - scsi_dev_cfg.l2_cache_size = drive_arg.l2_cache_size; - scsi_dev_cfg.refcount_cache_size = drive_arg.refcount_cache_size; - - Ok(scsi_dev_cfg) -} diff --git a/machine_manager/src/config/smbios.rs b/machine_manager/src/config/smbios.rs index 2c8f0d95d1a88cfa4488350994d1292ef9c42ff7..75220f456fbe1427d89b463d0974f00135c77e02 100644 --- a/machine_manager/src/config/smbios.rs +++ b/machine_manager/src/config/smbios.rs @@ -12,74 +12,138 @@ use std::str::FromStr; -use anyhow::{bail, Context, Result}; +use anyhow::{anyhow, bail, Result}; +use clap::Parser; use serde::{Deserialize, Serialize}; -use crate::config::{CmdParser, VmConfig}; +use super::{get_value_of_parameter, str_slip_to_clap}; +use crate::config::VmConfig; -#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[derive(Parser, Clone, Default, Debug, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct SmbiosType0Config { - pub vender: Option, + #[arg(long, alias = "type", value_parser = ["0"])] + pub smbios_type: String, + #[arg(long)] + pub vendor: Option, + #[arg(long)] pub version: Option, + #[arg(long)] pub date: Option, + // Note: we don't set `ArgAction::Append` for `added`, so it cannot be specified + // from the command line, as command line will parse errors. + #[arg(long, default_value = "true")] pub added: bool, } -#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[derive(Parser, Clone, Default, Debug, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct SmbiosType1Config { + #[arg(long, alias = "type", value_parser = ["1"])] + pub smbios_type: String, + #[arg(long)] pub manufacturer: Option, + #[arg(long)] pub product: Option, + #[arg(long)] pub version: Option, + #[arg(long)] pub serial: Option, + #[arg(long)] pub sku: Option, + #[arg(long)] pub family: Option, + #[arg(long, value_parser = get_uuid)] pub uuid: Option, + #[arg(long, default_value = "true")] pub added: bool, } -#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[derive(Parser, Clone, Default, Debug, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct SmbiosType2Config { + #[arg(long, alias = "type", value_parser = ["2"])] + pub smbios_type: String, + #[arg(long)] pub manufacturer: Option, + #[arg(long)] pub product: Option, + #[arg(long)] pub version: Option, + #[arg(long)] pub serial: Option, + #[arg(long)] pub asset: Option, + #[arg(long)] pub location: Option, + #[arg(long, default_value = "true")] pub added: bool, } -#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[derive(Parser, Clone, Default, Debug, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct SmbiosType3Config { + #[arg(long, alias = "type", value_parser = ["3"])] + pub smbios_type: String, + #[arg(long)] pub manufacturer: Option, + #[arg(long)] pub version: Option, + #[arg(long)] pub serial: Option, + #[arg(long)] pub sku: Option, + #[arg(long)] pub asset: Option, + #[arg(long, default_value = "true")] pub added: bool, } -#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[derive(Parser, Clone, Default, Debug, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct SmbiosType4Config { + #[arg(long, alias = "type", value_parser = ["4"])] + pub smbios_type: String, + #[arg(long)] pub manufacturer: Option, + #[arg(long)] pub version: Option, + #[arg(long)] pub serial: Option, + #[arg(long)] pub asset: Option, + #[arg(long, alias = "sock_pfx")] pub sock_pfx: Option, + #[arg(long)] pub part: Option, + #[arg(long)] pub max_speed: Option, + #[arg(long)] pub current_speed: Option, + #[arg(long, default_value = "true")] pub added: bool, } -#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[derive(Parser, Clone, Default, Debug, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct SmbiosType17Config { + #[arg(long, alias = "type", value_parser = ["17"])] + pub smbios_type: String, + #[arg(long)] pub manufacturer: Option, + #[arg(long)] pub serial: Option, + #[arg(long)] pub asset: Option, + #[arg(long, alias = "loc_pfx")] pub loc_pfx: Option, + #[arg(long)] pub part: Option, + #[arg(long, default_value = "0")] pub speed: u16, + #[arg(long)] pub bank: Option, + #[arg(long, default_value = "true")] pub added: bool, } @@ -124,13 +188,13 @@ pub struct Uuid { } impl FromStr for Uuid { - type Err = (); + type Err = anyhow::Error; fn from_str(str: &str) -> std::result::Result { let name = str.to_string(); if !check_valid_uuid(&name) { - return Err(()); + return Err(anyhow!("Invalid uuid {}", name)); } let mut uuid_bytes = Vec::new(); @@ -149,6 +213,11 @@ impl FromStr for Uuid { } } +fn get_uuid(s: &str) -> Result { + let uuid = Uuid::from_str(s)?; + Ok(uuid) +} + impl VmConfig { /// # Arguments /// @@ -158,19 +227,8 @@ impl VmConfig { bail!("smbios type0 has been added"); } - let mut cmd_parser = CmdParser::new("smbios"); - cmd_parser - .push("") - .push("type") - .push("vendor") - .push("version") - .push("date"); - cmd_parser.parse(type0)?; - - self.smbios.type0.vender = cmd_parser.get_value::("vendor")?; - self.smbios.type0.version = cmd_parser.get_value::("version")?; - self.smbios.type0.date = cmd_parser.get_value::("date")?; - self.smbios.type0.added = true; + let type0_cfg = SmbiosType0Config::try_parse_from(str_slip_to_clap(type0, false, false))?; + self.smbios.type0 = type0_cfg; Ok(()) } @@ -183,27 +241,8 @@ impl VmConfig { bail!("smbios type1 has been added"); } - let mut cmd_parser = CmdParser::new("smbios"); - cmd_parser - .push("") - .push("type") - .push("manufacturer") - .push("product") - .push("version") - .push("serial") - .push("sku") - .push("uuid") - .push("family"); - cmd_parser.parse(type1)?; - - self.smbios.type1.manufacturer = cmd_parser.get_value::("manufacturer")?; - self.smbios.type1.product = cmd_parser.get_value::("product")?; - self.smbios.type1.version = cmd_parser.get_value::("version")?; - self.smbios.type1.serial = cmd_parser.get_value::("serial")?; - self.smbios.type1.sku = cmd_parser.get_value::("sku")?; - self.smbios.type1.family = cmd_parser.get_value::("family")?; - self.smbios.type1.uuid = cmd_parser.get_value::("uuid")?; - self.smbios.type1.added = true; + let type1_cfg = SmbiosType1Config::try_parse_from(str_slip_to_clap(type1, false, false))?; + self.smbios.type1 = type1_cfg; Ok(()) } @@ -215,26 +254,8 @@ impl VmConfig { if self.smbios.type2.added { bail!("smbios type2 has been added"); } - - let mut cmd_parser = CmdParser::new("smbios"); - cmd_parser - .push("") - .push("type") - .push("manufacturer") - .push("product") - .push("version") - .push("serial") - .push("asset") - .push("location"); - cmd_parser.parse(type2)?; - - self.smbios.type2.manufacturer = cmd_parser.get_value::("manufacturer")?; - self.smbios.type2.product = cmd_parser.get_value::("product")?; - self.smbios.type2.version = cmd_parser.get_value::("version")?; - self.smbios.type2.serial = cmd_parser.get_value::("serial")?; - self.smbios.type2.asset = cmd_parser.get_value::("asset")?; - self.smbios.type2.location = cmd_parser.get_value::("location")?; - self.smbios.type2.added = true; + let type2_cfg = SmbiosType2Config::try_parse_from(str_slip_to_clap(type2, false, false))?; + self.smbios.type2 = type2_cfg; Ok(()) } @@ -247,23 +268,8 @@ impl VmConfig { bail!("smbios type3 has been added"); } - let mut cmd_parser = CmdParser::new("smbios"); - cmd_parser - .push("") - .push("type") - .push("manufacturer") - .push("version") - .push("serial") - .push("sku") - .push("asset"); - cmd_parser.parse(type3)?; - - self.smbios.type3.manufacturer = cmd_parser.get_value::("manufacturer")?; - self.smbios.type3.version = cmd_parser.get_value::("version")?; - self.smbios.type3.serial = cmd_parser.get_value::("serial")?; - self.smbios.type3.sku = cmd_parser.get_value::("sku")?; - self.smbios.type3.asset = cmd_parser.get_value::("asset")?; - self.smbios.type3.added = true; + let type3_cfg = SmbiosType3Config::try_parse_from(str_slip_to_clap(type3, false, false))?; + self.smbios.type3 = type3_cfg; Ok(()) } @@ -276,29 +282,8 @@ impl VmConfig { bail!("smbios type4 has been added"); } - let mut cmd_parser = CmdParser::new("smbios"); - cmd_parser - .push("") - .push("type") - .push("manufacturer") - .push("version") - .push("serial") - .push("sock_pfx") - .push("max-speed") - .push("current-speed") - .push("part") - .push("asset"); - cmd_parser.parse(type4)?; - - self.smbios.type4.manufacturer = cmd_parser.get_value::("manufacturer")?; - self.smbios.type4.version = cmd_parser.get_value::("version")?; - self.smbios.type4.serial = cmd_parser.get_value::("serial")?; - self.smbios.type4.asset = cmd_parser.get_value::("asset")?; - self.smbios.type4.part = cmd_parser.get_value::("part")?; - self.smbios.type4.sock_pfx = cmd_parser.get_value::("sock_pfx")?; - self.smbios.type4.max_speed = cmd_parser.get_value::("max-speed")?; - self.smbios.type4.current_speed = cmd_parser.get_value::("current-speed")?; - self.smbios.type4.added = true; + let type4_cfg = SmbiosType4Config::try_parse_from(str_slip_to_clap(type4, false, false))?; + self.smbios.type4 = type4_cfg; Ok(()) } @@ -311,31 +296,9 @@ impl VmConfig { bail!("smbios type17 has been added"); } - let mut cmd_parser = CmdParser::new("smbios"); - cmd_parser - .push("") - .push("type") - .push("loc_pfx") - .push("bank") - .push("manufacturer") - .push("serial") - .push("speed") - .push("part") - .push("asset"); - cmd_parser.parse(type17)?; - - self.smbios.type17.manufacturer = cmd_parser.get_value::("manufacturer")?; - self.smbios.type17.loc_pfx = cmd_parser.get_value::("loc_pfx")?; - self.smbios.type17.serial = cmd_parser.get_value::("serial")?; - self.smbios.type17.asset = cmd_parser.get_value::("asset")?; - self.smbios.type17.part = cmd_parser.get_value::("part")?; - self.smbios.type17.speed = if let Some(speed) = cmd_parser.get_value::("speed")? { - speed - } else { - 0 - }; - self.smbios.type17.bank = cmd_parser.get_value::("bank")?; - self.smbios.type17.added = true; + let type17_cfg = + SmbiosType17Config::try_parse_from(str_slip_to_clap(type17, false, false))?; + self.smbios.type17 = type17_cfg; Ok(()) } @@ -346,13 +309,7 @@ impl VmConfig { /// /// * `smbios_args` - The args of object. pub fn add_smbios(&mut self, smbios_args: &str) -> Result<()> { - let mut cmd_params = CmdParser::new("smbios"); - cmd_params.push("").push("type"); - - cmd_params.get_parameters(smbios_args)?; - let smbios_type = cmd_params - .get_value::("type")? - .with_context(|| "smbios type not specified")?; + let smbios_type = get_value_of_parameter("type", smbios_args)?; match smbios_type.as_str() { "0" => { self.add_smbios_type0(smbios_args)?; @@ -397,4 +354,24 @@ mod test { ] ); } + + #[test] + fn test_add_smbios() { + let mut vm_config = VmConfig::default(); + + let smbios0 = "type=0,vendor=fake,version=fake,date=fake"; + let smbios1 = "type=1,manufacturer=fake,version=fake,product=fake,serial=fake,uuid=33DB4D5E-1FF7-401C-9657-7441C03DD766,sku=fake,family=fake"; + let smbios2 = "type=2,manufacturer=fake,product=fake,version=fake,serial=fake,asset=fake,location=fake"; + let smbios3 = "type=3,manufacturer=fake,version=fake,serial=fake,asset=fake,sku=fake"; + let smbios4 = "type=4,sock_pfx=fake,manufacturer=fake,version=fake,serial=fake,asset=fake,part=fake,max-speed=1,current-speed=1"; + let smbios17 = "type=17,loc_pfx=fake,bank=fake,manufacturer=fake,serial=fake,asset=fake,part=fake,speed=1"; + + assert!(vm_config.add_smbios(smbios0).is_ok()); + assert!(vm_config.add_smbios(smbios1).is_ok()); + assert!(vm_config.add_smbios(smbios2).is_ok()); + assert!(vm_config.add_smbios(smbios3).is_ok()); + assert!(vm_config.add_smbios(smbios4).is_ok()); + assert!(vm_config.add_smbios(smbios17).is_ok()); + assert!(vm_config.add_smbios(smbios0).is_err()); + } } diff --git a/machine_manager/src/config/tls_creds.rs b/machine_manager/src/config/tls_creds.rs index 8803ea42893d31782d77494c30458967f0a7c649..a3b7396c07d0ed789c5f2e5ec7392585d4af0e6f 100644 --- a/machine_manager/src/config/tls_creds.rs +++ b/machine_manager/src/config/tls_creds.rs @@ -10,64 +10,36 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::path::Path; - -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, Result}; +use clap::{ArgAction, Parser}; use serde::{Deserialize, Serialize}; -use crate::config::{ - ConfigError, {CmdParser, VmConfig}, -}; +use crate::config::{str_slip_to_clap, valid_dir, valid_id, ConfigError, VmConfig}; -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct TlsCredObjConfig { + #[arg(long)] + pub classtype: String, + #[arg(long, value_parser = valid_id)] pub id: String, + #[arg(long, value_parser = valid_dir)] pub dir: String, - pub cred_type: String, + #[arg(long)] pub endpoint: Option, + #[arg(long, alias = "verify-peer", default_value= "false", action = ArgAction::Append)] pub verifypeer: bool, } impl VmConfig { pub fn add_tlscred(&mut self, tlscred_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("tls-creds-x509"); - cmd_parser - .push("") - .push("id") - .push("dir") - .push("endpoint") - .push("verify-peer"); - cmd_parser.parse(tlscred_config)?; - - let mut tlscred = TlsCredObjConfig { - id: cmd_parser.get_value::("id")?.with_context(|| { - ConfigError::FieldIsMissing("id".to_string(), "vnc tls_creds".to_string()) - })?, - ..Default::default() - }; - - if let Some(dir) = cmd_parser.get_value::("dir")? { - if Path::new(&dir).is_dir() { - tlscred.dir = dir; - } else { - return Err(anyhow!(ConfigError::DirNotExist(dir))); - } - } - if let Some(endpoint) = cmd_parser.get_value::("endpoint")? { - tlscred.endpoint = Some(endpoint); - } - if let Some(verifypeer) = cmd_parser.get_value::("verify-peer")? { - tlscred.verifypeer = verifypeer == *"true"; - } - tlscred.cred_type = "x509".to_string(); - + let tlscred = + TlsCredObjConfig::try_parse_from(str_slip_to_clap(tlscred_config, true, false))?; let id = tlscred.id.clone(); - if self.object.tls_object.get(&id).is_none() { - self.object.tls_object.insert(id, tlscred); - } else { + if self.object.tls_object.contains_key(&id) { return Err(anyhow!(ConfigError::IdRepeat("tlscred".to_string(), id))); } - + self.object.tls_object.insert(id, tlscred); Ok(()) } } @@ -86,7 +58,7 @@ mod tests { if !dir.is_dir() { fs::create_dir(dir.clone()).unwrap(); } - assert_eq!(dir.is_dir(), true); + assert!(dir.is_dir()); // Certificate directory is exist. let tls_config: String = format!( @@ -100,12 +72,12 @@ mod tests { if let Some(tls_cred_cfg) = vm_config.object.tls_object.get(&id) { assert_eq!(tls_cred_cfg.dir, dir.to_str().unwrap()); assert_eq!(tls_cred_cfg.endpoint, Some("server".to_string())); - assert_eq!(tls_cred_cfg.verifypeer, false); + assert!(!tls_cred_cfg.verifypeer); } // Delete file. fs::remove_dir(dir.clone()).unwrap(); - assert_eq!(dir.is_dir(), false); + assert!(!dir.is_dir()); // Certificate directory does not exist. let mut vm_config = VmConfig::default(); assert!(vm_config.add_object(tls_config.as_str()).is_err()); diff --git a/machine_manager/src/config/usb.rs b/machine_manager/src/config/usb.rs deleted file mode 100644 index da363b7e43a835ef3de092292fe03196bc90c2f8..0000000000000000000000000000000000000000 --- a/machine_manager/src/config/usb.rs +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2022 Huawei Technologies Co.,Ltd. All rights reserved. -// -// StratoVirt is licensed under Mulan PSL v2. -// You can use this software according to the terms and conditions of the Mulan -// PSL v2. -// You may obtain a copy of Mulan PSL v2 at: -// http://license.coscl.org.cn/MulanPSL2 -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -// See the Mulan PSL v2 for more details. - -use anyhow::{bail, Context, Result}; - -use super::error::ConfigError; -use crate::config::{ - check_arg_nonexist, check_arg_too_long, CmdParser, ConfigCheck, ScsiDevConfig, VmConfig, -}; -use util::aio::AioEngine; - -pub fn check_id(id: Option, device: &str) -> Result<()> { - check_arg_nonexist(id.clone(), "id", device)?; - check_arg_too_long(&id.unwrap(), "id")?; - - Ok(()) -} - -#[derive(Clone, Debug)] -pub struct UsbStorageConfig { - /// USB Storage device id. - pub id: Option, - /// The scsi backend config. - pub scsi_cfg: ScsiDevConfig, - /// The backend scsi device type(Disk or CD-ROM). - pub media: String, -} - -impl UsbStorageConfig { - fn new() -> Self { - Self { - id: None, - scsi_cfg: ScsiDevConfig::default(), - media: "".to_string(), - } - } -} - -impl Default for UsbStorageConfig { - fn default() -> Self { - Self::new() - } -} - -impl ConfigCheck for UsbStorageConfig { - fn check(&self) -> Result<()> { - check_id(self.id.clone(), "usb-storage")?; - - if self.scsi_cfg.aio_type != AioEngine::Off || self.scsi_cfg.direct { - bail!("USB-storage: \"aio=off,direct=false\" must be configured."); - } - - Ok(()) - } -} - -pub fn parse_usb_storage(vm_config: &mut VmConfig, drive_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("usb-storage"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("port") - .push("drive"); - - cmd_parser.parse(drive_config)?; - - let mut dev = UsbStorageConfig::new(); - dev.id = cmd_parser.get_value::("id")?; - - let storage_drive = cmd_parser.get_value::("drive")?.with_context(|| { - ConfigError::FieldIsMissing("drive".to_string(), "usb storage device".to_string()) - })?; - - let drive_arg = &vm_config - .drives - .remove(&storage_drive) - .with_context(|| "No drive configured matched for usb storage device.")?; - dev.scsi_cfg.path_on_host = drive_arg.path_on_host.clone(); - dev.scsi_cfg.read_only = drive_arg.read_only; - dev.scsi_cfg.aio_type = drive_arg.aio; - dev.scsi_cfg.direct = drive_arg.direct; - dev.scsi_cfg.format = drive_arg.format; - dev.scsi_cfg.l2_cache_size = drive_arg.l2_cache_size; - dev.scsi_cfg.refcount_cache_size = drive_arg.refcount_cache_size; - dev.media = drive_arg.media.clone(); - - dev.check()?; - Ok(dev) -} diff --git a/machine_manager/src/config/vfio.rs b/machine_manager/src/config/vfio.rs deleted file mode 100644 index dddebde74ed9e358f3444b7bc0c44e663b411ee4..0000000000000000000000000000000000000000 --- a/machine_manager/src/config/vfio.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co.,Ltd. All rights reserved. -// -// StratoVirt is licensed under Mulan PSL v2. -// You can use this software according to the terms and conditions of the Mulan -// PSL v2. -// You may obtain a copy of Mulan PSL v2 at: -// http://license.coscl.org.cn/MulanPSL2 -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -// See the Mulan PSL v2 for more details. - -use anyhow::{anyhow, Result}; - -use super::error::ConfigError; -use crate::config::{check_arg_too_long, CmdParser, ConfigCheck}; - -#[derive(Default, Debug)] -pub struct VfioConfig { - pub sysfsdev: String, - pub host: String, - pub id: String, -} - -impl ConfigCheck for VfioConfig { - fn check(&self) -> Result<()> { - check_arg_too_long(&self.host, "host")?; - check_arg_too_long(&self.id, "id")?; - - Ok(()) - } -} - -pub fn parse_vfio(vfio_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("vfio-pci"); - cmd_parser - .push("") - .push("host") - .push("sysfsdev") - .push("id") - .push("bus") - .push("addr") - .push("multifunction"); - cmd_parser.parse(vfio_config)?; - - let mut vfio: VfioConfig = VfioConfig::default(); - if let Some(host) = cmd_parser.get_value::("host")? { - vfio.host = host; - } - - if let Some(sysfsdev) = cmd_parser.get_value::("sysfsdev")? { - vfio.sysfsdev = sysfsdev; - } - - if vfio.host.is_empty() && vfio.sysfsdev.is_empty() { - return Err(anyhow!(ConfigError::FieldIsMissing( - "host nor sysfsdev".to_string(), - "vfio".to_string() - ))); - } - - if !vfio.host.is_empty() && !vfio.sysfsdev.is_empty() { - return Err(anyhow!(ConfigError::InvalidParam( - "host and sysfsdev".to_string(), - "vfio".to_string() - ))); - } - - if let Some(id) = cmd_parser.get_value::("id")? { - vfio.id = id; - } - vfio.check()?; - - Ok(vfio) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::config::get_pci_bdf; - - #[test] - fn test_check_vfio_config() { - let mut vfio_config = - parse_vfio("vfio-pci,host=0000:1a:00.3,id=net,bus=pcie.0,addr=0x1.0x2").unwrap(); - assert!(vfio_config.check().is_ok()); - - vfio_config.host = "IYqUdAMXggoUMU28eBJCxQGUirYYSyW1cfGJI3ZpZAzMFCKnVPA5e7gnurLtXjCm\ - YoG5pfqRDbN7M2dpSd8fzSbufAJaor8UY9xbH7BybZ7WDEFmkxgCQp6PWgaBSmLOCe1tEMs4RQ938ZLnh8ej\ - Q81VovbrU7ecafacCn9AJQoidN3Seab3QOEd4SJbtd4hAPeYvsXLVa6xOZxtVjqjRxk9b36feF0C5JrucVcs\ - QsusZZtVfUFUZxOoV8JltVsBmdasnic" - .to_string(); - assert!(vfio_config.check().is_err()); - - vfio_config.id = "LPwM1h4QUTCjL4fX2gFdCdPrF9S0kGHf0onpU6E4fyI6Jmzg0DCM9sffvEVjaVu1ilp\ - 2OrgCWzvNBflYvUUihPj3ePPYs3erSHmSOmQZbnGEFsiBSTJHfPAsRtWJoipeIh9cgIR1tnU3OjwPPli4gmb6\ - E6GgSyMd0oQtUGFyNf5pRHlYqlx3s7PMPVUtRJP0bBnNd5eDwWAotInu33h6UI0zfKgckAxeVdEROKAExx5xWK\ - V3AgPhvvPzFx3chYymy" - .to_string(); - assert!(vfio_config.check().is_err()); - } - - #[test] - fn test_vfio_config_cmdline_parser() { - let vfio_cfg = parse_vfio("vfio-pci,host=0000:1a:00.3,id=net"); - assert!(vfio_cfg.is_ok()); - let vfio_config = vfio_cfg.unwrap(); - assert_eq!(vfio_config.host, "0000:1a:00.3"); - assert_eq!(vfio_config.id, "net"); - } - - #[test] - fn test_pci_vfio_config_cmdline_parser() { - let vfio_cfg1 = "vfio-pci,host=0000:1a:00.3,id=net,bus=pcie.0,addr=0x1.0x2"; - let config1 = parse_vfio(vfio_cfg1); - assert!(config1.is_ok()); - let vfio_cfg2 = "vfio-pci,host=0000:1a:00.3,bus=pcie.0,addr=0x1.0x2"; - let config2 = parse_vfio(vfio_cfg2); - assert!(config2.is_ok()); - let vfio_cfg3 = "vfio-pci,id=net,bus=pcie.0,addr=0x1.0x2"; - let config3 = parse_vfio(vfio_cfg3); - assert!(config3.is_err()); - - let pci_bdf = get_pci_bdf(vfio_cfg1); - assert!(pci_bdf.is_ok()); - let pci = pci_bdf.unwrap(); - assert_eq!(pci.bus, "pcie.0".to_string()); - assert_eq!(pci.addr, (1, 2)); - - let vfio_cfg1 = - "vfio-pci,host=0000:1a:00.3,id=net,bus=pcie.0,addr=0x1.0x2,multifunction=on"; - assert!(parse_vfio(vfio_cfg1).is_ok()); - } -} diff --git a/machine_manager/src/config/vnc.rs b/machine_manager/src/config/vnc.rs index b243d945464b7752a85c01e5906847238ae766bb..f257ae07c56315a420a40fd53526224b4d4f9a4e 100644 --- a/machine_manager/src/config/vnc.rs +++ b/machine_manager/src/config/vnc.rs @@ -13,22 +13,26 @@ use std::net::Ipv4Addr; use anyhow::{anyhow, Context, Result}; +use clap::{ArgAction, Parser}; use serde::{Deserialize, Serialize}; -use crate::config::{CmdParser, ConfigError, VmConfig}; +use crate::config::{str_slip_to_clap, ConfigError, VmConfig}; /// Configuration of vnc. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +#[command(no_binary_name(true))] pub struct VncConfig { - /// Listening ip. - pub ip: String, - /// Listening port. - pub port: String, + /// Vnc listening addr (ip, port). + #[arg(long, alias = "classtype", value_parser = parse_ip_port)] + pub addr: (String, u16), /// Configuration of encryption. + #[arg(long, alias = "tls-creds", default_value = "")] pub tls_creds: String, /// Authentication switch. + #[arg(long, default_value = "false", action = ArgAction::SetTrue)] pub sasl: bool, /// Configuration of authentication. + #[arg(long, alias = "sasl-authz", default_value = "")] pub sasl_authz: String, } @@ -38,45 +42,13 @@ const VNC_PORT_OFFSET: i32 = 5900; impl VmConfig { /// Make configuration for vnc: "chardev" -> "vnc". pub fn add_vnc(&mut self, vnc_config: &str) -> Result<()> { - let mut cmd_parser = CmdParser::new("vnc"); - cmd_parser - .push("") - .push("tls-creds") - .push("sasl") - .push("sasl-authz"); - cmd_parser.parse(vnc_config)?; - - let mut vnc_config = VncConfig::default(); - // Parse Ip:Port. - if let Some(addr) = cmd_parser.get_value::("")? { - parse_port(&mut vnc_config, addr)?; - } else { - return Err(anyhow!(ConfigError::FieldIsMissing( - "ip".to_string(), - "port".to_string() - ))); - } - - // VNC Security Type. - if let Some(tls_creds) = cmd_parser.get_value::("tls-creds")? { - vnc_config.tls_creds = tls_creds - } - if let Some(_sasl) = cmd_parser.get_value::("sasl")? { - vnc_config.sasl = true - } else { - vnc_config.sasl = false - } - if let Some(sasl_authz) = cmd_parser.get_value::("sasl-authz")? { - vnc_config.sasl_authz = sasl_authz; - } - + let vnc_config = VncConfig::try_parse_from(str_slip_to_clap(vnc_config, true, false))?; self.vnc = Some(vnc_config); Ok(()) } } -/// Parse Ip:port. -fn parse_port(vnc_config: &mut VncConfig, addr: String) -> Result<()> { +fn parse_ip_port(addr: &str) -> Result<(String, u16)> { let v: Vec<&str> = addr.split(':').collect(); if v.len() != 2 { return Err(anyhow!(ConfigError::FieldIsMissing( @@ -97,10 +69,8 @@ fn parse_port(vnc_config: &mut VncConfig, addr: String) -> Result<()> { "port".to_string() ))); } - vnc_config.ip = ip.to_string(); - vnc_config.port = ((base_port + VNC_PORT_OFFSET) as u16).to_string(); - Ok(()) + Ok((ip.to_string(), (base_port + VNC_PORT_OFFSET) as u16)) } #[cfg(test)] @@ -113,18 +83,18 @@ mod tests { let config_line = "0.0.0.0:1,tls-creds=vnc-tls-creds0,sasl,sasl-authz=authz0"; assert!(vm_config.add_vnc(config_line).is_ok()); let vnc_config = vm_config.vnc.unwrap(); - assert_eq!(vnc_config.ip, String::from("0.0.0.0")); - assert_eq!(vnc_config.port, String::from("5901")); + assert_eq!(vnc_config.addr.0, String::from("0.0.0.0")); + assert_eq!(vnc_config.addr.1, 5901); assert_eq!(vnc_config.tls_creds, String::from("vnc-tls-creds0")); - assert_eq!(vnc_config.sasl, true); + assert!(vnc_config.sasl); assert_eq!(vnc_config.sasl_authz, String::from("authz0")); let mut vm_config = VmConfig::default(); let config_line = "0.0.0.0:5900,tls-creds=vnc-tls-creds0"; assert!(vm_config.add_vnc(config_line).is_ok()); let vnc_config = vm_config.vnc.unwrap(); - assert_eq!(vnc_config.sasl, false); - assert_eq!(vnc_config.port, String::from("11800")); + assert!(!vnc_config.sasl); + assert_eq!(vnc_config.addr.1, 11800); let mut vm_config = VmConfig::default(); let config_line = "0.0.0.0:1,sasl,sasl-authz=authz0"; diff --git a/machine_manager/src/event_loop.rs b/machine_manager/src/event_loop.rs index 7acaca84a91433f6ac736a887b979db2482aecb9..c3eb2999c254ab051d10372621148fe8a82e07b3 100644 --- a/machine_manager/src/event_loop.rs +++ b/machine_manager/src/event_loop.rs @@ -12,16 +12,17 @@ use std::collections::HashMap; use std::os::unix::prelude::RawFd; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Barrier, Mutex}; use std::{process, thread}; use anyhow::{bail, Result}; -use log::info; +use log::{error, info}; use super::config::IothreadConfig; use crate::machine::IOTHREADS; use crate::qmp::qmp_schema::IothreadInfo; use crate::signal_handler::get_signal; +use crate::temp_cleaner::TempCleaner; use util::loop_context::{ gen_delete_notifiers, get_notifiers_fds, EventLoopContext, EventLoopManager, EventNotifier, }; @@ -49,9 +50,18 @@ impl EventLoop { /// * `iothreads` - refer to `-iothread` params pub fn object_init(iothreads: &Option>) -> Result<()> { let mut io_threads = HashMap::new(); + let cnt = match iothreads { + Some(thrs) => thrs.len(), + None => 0, + }; + let thread_exit_barrier = Arc::new(Barrier::new(cnt + 1)); + if let Some(thrs) = iothreads { for thr in thrs { - io_threads.insert(thr.id.clone(), EventLoopContext::new()); + io_threads.insert( + thr.id.clone(), + EventLoopContext::new(thread_exit_barrier.clone()), + ); } } @@ -60,7 +70,7 @@ impl EventLoop { unsafe { if GLOBAL_EVENT_LOOP.is_none() { GLOBAL_EVENT_LOOP = Some(EventLoop { - main_loop: EventLoopContext::new(), + main_loop: EventLoopContext::new(thread_exit_barrier), io_threads, }); @@ -75,11 +85,17 @@ impl EventLoop { id: id.to_string(), }; IOTHREADS.lock().unwrap().push(iothread_info); - while let Ok(ret) = ctx.iothread_run() { - if !ret { + while ctx.iothread_run().is_ok() { + // If is_cleaned() is true, it means the main thread will exit. + // So, exit the iothread. + if TempCleaner::is_cleaned() { break; } } + if let Err(e) = ctx.clean_event_loop() { + error!("Failed to clean event loop {:?}", e); + } + ctx.thread_exit_barrier.wait(); })?; } } else { @@ -115,11 +131,16 @@ impl EventLoop { /// /// # Arguments /// - /// * `manager` - The main part to manager the event loop specified by name. - /// * `name` - specify which event loop to manage - pub fn set_manager(manager: Arc>, name: Option<&String>) { - if let Some(ctx) = Self::get_ctx(name) { - ctx.set_manager(manager) + /// * `manager` - The main part to manager the event loop. + pub fn set_manager(manager: Arc>) { + // SAFETY: All concurrently accessed data of EventLoopContext is protected. + unsafe { + if let Some(event_loop) = GLOBAL_EVENT_LOOP.as_mut() { + event_loop.main_loop.set_manager(manager.clone()); + for (_name, io_thread) in event_loop.io_threads.iter_mut() { + io_thread.set_manager(manager.clone()); + } + } } } @@ -166,12 +187,27 @@ impl EventLoop { } pub fn loop_clean() { + EventLoop::kick_iothreads(); // SAFETY: the main_loop ctx is dedicated for main thread, thus no concurrent // accessing. unsafe { + if let Some(event_loop) = GLOBAL_EVENT_LOOP.as_mut() { + event_loop.main_loop.thread_exit_barrier.wait(); + } GLOBAL_EVENT_LOOP = None; } } + + pub fn kick_iothreads() { + // SAFETY: All concurrently accessed data of EventLoopContext is protected. + unsafe { + if let Some(event_loop) = GLOBAL_EVENT_LOOP.as_mut() { + for (_name, io_thread) in event_loop.io_threads.iter_mut() { + io_thread.kick(); + } + } + } + } } pub fn register_event_helper( diff --git a/machine_manager/src/lib.rs b/machine_manager/src/lib.rs index 78c6a77e9d293a803f167c2a918fc6cbcde79dd0..cab6b880623bf93669ad786f016f0764dd7231fb 100644 --- a/machine_manager/src/lib.rs +++ b/machine_manager/src/lib.rs @@ -26,9 +26,11 @@ pub mod config; pub mod error; pub mod event_loop; pub mod machine; +pub mod notifier; pub mod qmp; pub mod signal_handler; pub mod socket; +pub mod state_query; pub mod temp_cleaner; pub mod test_server; diff --git a/machine_manager/src/machine.rs b/machine_manager/src/machine.rs index ff4d858633af21cc82bf4809ff4e149c21e15520..bc4a38ec9ecc012d1745503eb590d55ae3a8bf71 100644 --- a/machine_manager/src/machine.rs +++ b/machine_manager/src/machine.rs @@ -14,6 +14,7 @@ use std::os::unix::io::RawFd; use std::str::FromStr; use std::sync::Mutex; +use anyhow::anyhow; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use strum::VariantNames; @@ -54,13 +55,14 @@ pub enum HypervisorType { } impl FromStr for HypervisorType { - type Err = (); + type Err = anyhow::Error; fn from_str(s: &str) -> std::result::Result { match s { - "kvm" => Ok(HypervisorType::Kvm), + // Note: "kvm:tcg" is a configuration compatible with libvirt. + "kvm" | "kvm:tcg" => Ok(HypervisorType::Kvm), "test" => Ok(HypervisorType::Test), - _ => Err(()), + _ => Err(anyhow!("Not supported or invalid hypervisor type {}.", s)), } } } @@ -235,6 +237,14 @@ pub trait DeviceInterface { /// Query display of stratovirt. fn query_display_image(&self) -> Response; + /// Query state. + fn query_workloads(&self) -> Response { + Response::create_error_response( + QmpErrorClass::GenericError("query_workloads not supported for VM".to_string()), + None, + ) + } + /// Set balloon's size. fn balloon(&self, size: u64) -> Response; diff --git a/machine_manager/src/notifier.rs b/machine_manager/src/notifier.rs new file mode 100644 index 0000000000000000000000000000000000000000..36285f36c5337d981a83458e925cb997e2abbb13 --- /dev/null +++ b/machine_manager/src/notifier.rs @@ -0,0 +1,73 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +use log::error; +use once_cell::sync::Lazy; + +static NOTIFIER_MANAGER: Lazy> = + Lazy::new(|| RwLock::new(NotifierManager::new())); + +pub type PauseNOtifyCallback = dyn Fn(bool) + Send + Sync; + +struct NotifierManager { + pause_notifiers: HashMap>, + next_id: u64, +} + +impl NotifierManager { + fn new() -> Self { + Self { + pause_notifiers: HashMap::new(), + next_id: 1, + } + } + + fn register_pause_notifier(&mut self, notifier: Arc) -> u64 { + let id = self.next_id; + self.pause_notifiers.insert(id, notifier); + self.next_id += 1; + id + } + + fn unregister_pause_notifier(&mut self, id: u64) { + if self.pause_notifiers.remove(&id).is_none() { + error!("There is no pause notifier with id {}", id); + } + } + + fn pause_notify(&self, paused: bool) { + for (_, notify) in self.pause_notifiers.iter() { + notify(paused); + } + } +} + +pub fn register_vm_pause_notifier(notifier: Arc) -> u64 { + NOTIFIER_MANAGER + .write() + .unwrap() + .register_pause_notifier(notifier) +} + +pub fn unregister_vm_pause_notifier(id: u64) { + NOTIFIER_MANAGER + .write() + .unwrap() + .unregister_pause_notifier(id) +} + +pub fn pause_notify(paused: bool) { + NOTIFIER_MANAGER.read().unwrap().pause_notify(paused); +} diff --git a/machine_manager/src/qmp/qmp_channel.rs b/machine_manager/src/qmp/qmp_channel.rs index a0839501b622ee6ba31b55397fff8c39c9fe2844..51063331fa5e940dda2bc50c2018b5cb099eddb1 100644 --- a/machine_manager/src/qmp/qmp_channel.rs +++ b/machine_manager/src/qmp/qmp_channel.rs @@ -73,7 +73,7 @@ pub fn create_timestamp() -> TimeStamp { .expect("Time went backwards"); let seconds = u128::from(since_the_epoch.as_secs()); let microseconds = - (since_the_epoch.as_nanos() - seconds * (NANOSECONDS_PER_SECOND as u128)) / (1_000_u128); + (since_the_epoch.as_nanos() - seconds * u128::from(NANOSECONDS_PER_SECOND)) / (1_000_u128); TimeStamp { seconds: seconds as u64, microseconds: microseconds as u64, diff --git a/machine_manager/src/qmp/qmp_response.rs b/machine_manager/src/qmp/qmp_response.rs index 54fe538a3298ae2cc6fae55453b17a41ee6548a3..bbf2d41ed45652075da7044825b0c28d376fae3a 100644 --- a/machine_manager/src/qmp/qmp_response.rs +++ b/machine_manager/src/qmp/qmp_response.rs @@ -226,7 +226,7 @@ mod tests { running: true, status: qmp_schema::RunState::running, }; - let resp = Response::create_response(serde_json::to_value(&resp_value).unwrap(), None); + let resp = Response::create_response(serde_json::to_value(resp_value).unwrap(), None); let json_msg = r#"{"return":{"running":true,"singlestep":false,"status":"running"}}"#; assert_eq!(serde_json::to_string(&resp).unwrap(), json_msg); @@ -273,7 +273,7 @@ mod tests { let msg = ErrorMessage::new(&err_cls); assert_eq!(msg.desc, strange_msg); assert_eq!(msg.errorkind, "KVMMissingCap".to_string()); - let qmp_err = qmp_schema::QmpErrorClass::KVMMissingCap(strange_msg.clone()); + let qmp_err = qmp_schema::QmpErrorClass::KVMMissingCap(strange_msg); let resp = Response::create_error_response(qmp_err, None); assert_eq!(resp.error, Some(msg)); } diff --git a/machine_manager/src/qmp/qmp_schema.rs b/machine_manager/src/qmp/qmp_schema.rs index 4281624fc906da3747d275689fa78bafaf3bc464..5d5558d985ad2754a68a25e930cc879739159182 100644 --- a/machine_manager/src/qmp/qmp_schema.rs +++ b/machine_manager/src/qmp/qmp_schema.rs @@ -137,7 +137,8 @@ define_qmp_command_enum!( blockdev_snapshot_delete_internal_sync("blockdev-snapshot-delete-internal-sync", blockdev_snapshot_internal, FALSE), query_vcpu_reg("query-vcpu-reg", query_vcpu_reg, FALSE), trace_get_state("trace-get-state", trace_get_state, FALSE), - trace_set_state("trace-set-state", trace_set_state, FALSE) + trace_set_state("trace-set-state", trace_set_state, FALSE), + query_workloads("query-workloads", query_workloads, FALSE) ); /// Command trait for Deserialize and find back Response. @@ -1797,7 +1798,8 @@ define_qmp_event_enum!( Powerdown("POWERDOWN", Powerdown, default), CpuResize("CPU_RESIZE", CpuResize, default), DeviceDeleted("DEVICE_DELETED", DeviceDeleted), - BalloonChanged("BALLOON_CHANGED", BalloonInfo) + BalloonChanged("BALLOON_CHANGED", BalloonInfo), + UsbHostAddRes("USB_HOST_ADD_RES", UsbHostAddRes) ); /// Shutdown @@ -1909,6 +1911,28 @@ pub struct Powerdown {} #[serde(deny_unknown_fields)] pub struct CpuResize {} +/// UsbHostAddRes +/// +/// Emitted whenever the usb host device add completion is acknowledged. +/// At this point, it's safe to reuse the specified device ID. +/// +/// # Examples +/// +/// ```text +/// <- { "event": "USB_HOST_ADD_RES", +/// "data": { "device": "hw_vid_pid" }, +/// "timestamp": { "seconds": 1265044230, "microseconds": 450486 } } +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(deny_unknown_fields)] +pub struct UsbHostAddRes { + /// Device name. + #[serde(rename = "device", default, skip_serializing_if = "Option::is_none")] + pub device: Option, + #[serde(rename = "state_msg", default, skip_serializing_if = "Option::is_none")] + pub state_msg: Option, +} + /// DeviceDeleted /// /// Emitted whenever the device removal completion is acknowledged by the guest. @@ -1986,6 +2010,21 @@ pub struct trace_set_state { } pub type TraceSetArgument = trace_set_state; +/// query_workloads +/// +/// Query the current workloads of the running VM. +/// +/// # Examples +/// +/// ```text +/// -> {"execute": "query-workloads", "arguments": {}} +/// <- {"return":[{"module":"scream-play","state":"Off"},{"module":"tap-0","state":"upload: 0 download: 0"}]} +/// ``` +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct query_workloads {} +generate_command_impl!(query_workloads, Empty); + #[cfg(test)] mod tests { use super::*; @@ -1994,7 +2033,7 @@ mod tests { fn test_qmp_event_msg() { let event_json = r#"{"event":"STOP","data":{},"timestamp":{"seconds":1575531524,"microseconds":91519}}"#; - let qmp_event: QmpEvent = serde_json::from_str(&event_json).unwrap(); + let qmp_event: QmpEvent = serde_json::from_str(event_json).unwrap(); match qmp_event { QmpEvent::Stop { data: _, diff --git a/machine_manager/src/qmp/qmp_socket.rs b/machine_manager/src/qmp/qmp_socket.rs index 35ad238f41e6be4ff4e4e76700d2258102320896..aa6ee53f961064ed79a6163ccaa8459bd0e2c3dc 100644 --- a/machine_manager/src/qmp/qmp_socket.rs +++ b/machine_manager/src/qmp/qmp_socket.rs @@ -10,11 +10,14 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::fmt::Display; use std::net::IpAddr; use std::os::unix::io::{AsRawFd, RawFd}; use std::rc::Rc; use std::str::FromStr; use std::sync::{Arc, Mutex, RwLock}; +use std::thread; +use std::time::{Duration, Instant}; use anyhow::{bail, Context, Result}; use log::{error, info, warn}; @@ -25,16 +28,14 @@ use super::qmp_schema::QmpCommand; use super::{qmp_channel::QmpChannel, qmp_response::QmpGreeting, qmp_response::Response}; use crate::event; use crate::event_loop::EventLoop; -use crate::machine::MachineExternalInterface; +use crate::machine::{MachineExternalInterface, VmState}; use crate::socket::SocketHandler; use crate::socket::SocketRWHandler; -use crate::temp_cleaner::TempCleaner; use util::leak_bucket::LeakBucket; use util::loop_context::{ gen_delete_notifiers, read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, }; -use util::set_termi_canon_mode; use util::socket::{SocketListener, SocketStream}; use util::unix::parse_unix_uri; @@ -60,13 +61,13 @@ impl QmpSocketPath { } } -impl ToString for QmpSocketPath { - fn to_string(&self) -> String { +impl Display for QmpSocketPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { QmpSocketPath::Tcp { host, port } => { - format!("{}:{}", &host, &port) + write!(f, "{}:{}", &host, &port) } - QmpSocketPath::Unix { path } => path.clone(), + QmpSocketPath::Unix { path } => write!(f, "{}", path), } } } @@ -372,7 +373,7 @@ fn handle_qmp( // If flow over `LEAK_BUCKET_LIMIT` per seconds, discard the request and return // a `OperationThrottled` error. - if leak_bucket.throttled(EventLoop::get_ctx(None).unwrap(), 1_u64) { + if leak_bucket.throttled(EventLoop::get_ctx(None).unwrap(), 1_u32) { qmp_service.discard()?; let err_resp = qmp_schema::QmpErrorClass::OperationThrottled(LEAK_BUCKET_LIMIT); qmp_service @@ -399,10 +400,6 @@ fn handle_qmp( reason: "host-qmp-quit".to_string(), }; event!(Shutdown; shutdown_msg); - TempCleaner::clean(); - set_termi_canon_mode().expect("Failed to set terminal to canonical mode."); - - std::process::exit(0); } Ok(()) @@ -431,7 +428,6 @@ fn qmp_command_exec( // Use macro create match to cover most Qmp command let mut id = create_command_matches!( qmp_command.clone(); controller.lock().unwrap(); qmp_response; - (stop, pause), (cont, resume), (system_powerdown, powerdown), (system_reset, reset), @@ -465,7 +461,8 @@ fn qmp_command_exec( (query_vnc, query_vnc), (query_display_image, query_display_image), (list_type, list_type), - (query_hotpluggable_cpus, query_hotpluggable_cpus); + (query_hotpluggable_cpus, query_hotpluggable_cpus), + (query_workloads, query_workloads); (input_event, input_event, key, value), (device_list_properties, device_list_properties, typename), (device_del, device_del, id), @@ -492,6 +489,27 @@ fn qmp_command_exec( // Handle the Qmp command which macro can't cover if id.is_none() { id = match qmp_command { + QmpCommand::stop { arguments: _, id } => { + let now = Instant::now(); + while !controller.lock().unwrap().pause() { + thread::sleep(Duration::from_millis(5)); + if now.elapsed() > Duration::from_secs(2) { + // Not use resume() to avoid unnecessary qmp event. + controller + .lock() + .unwrap() + .notify_lifecycle(VmState::Paused, VmState::Running); + qmp_response = Response::create_error_response( + qmp_schema::QmpErrorClass::GenericError( + "Failed to pause VM".to_string(), + ), + None, + ); + break; + } + } + id + } QmpCommand::quit { id, .. } => { controller.lock().unwrap().destroy(); shutdown_flag = true; @@ -568,7 +586,7 @@ mod tests { // Environment Recovery for UnixSocket fn recover_unix_socket_environment(socket_id: &str) { let socket_name: String = format!("test_{}.sock", socket_id); - std::fs::remove_file(&socket_name).unwrap(); + std::fs::remove_file(socket_name).unwrap(); } #[test] @@ -579,20 +597,20 @@ mod tests { // life cycle test // 1.Unconnected - assert_eq!(socket.is_connected(), false); + assert!(!socket.is_connected()); // 2.Connected socket.bind_stream(server); - assert_eq!(socket.is_connected(), true); + assert!(socket.is_connected()); // 3.Unbind SocketStream, reset state socket.drop_stream(); - assert_eq!(socket.is_connected(), false); + assert!(!socket.is_connected()); // 4.Accept and reconnect a new UnixStream let _new_client = UnixStream::connect("test_04.sock"); socket.accept(); - assert_eq!(socket.is_connected(), true); + assert!(socket.is_connected()); // After test. Environment Recover recover_unix_socket_environment("04"); @@ -640,7 +658,7 @@ mod tests { serde_json::from_str(&(String::from_utf8_lossy(&buffer[..length]))).unwrap(); match qmp_event { qmp_schema::QmpEvent::Shutdown { data, timestamp: _ } => { - assert_eq!(data.guest, true); + assert!(data.guest); assert_eq!(data.reason, "guest-shutdown".to_string()); } _ => assert!(false), @@ -669,7 +687,7 @@ mod tests { serde_json::from_str(&(String::from_utf8_lossy(&buffer[..length]))).unwrap(); let qmp_greeting = QmpGreeting::create_greeting(1, 0, 5); assert_eq!(qmp_greeting, qmp_response); - assert_eq!(res.is_err(), false); + assert!(res.is_ok()); // 2.send empty response let res = socket.send_response(false); @@ -678,7 +696,7 @@ mod tests { serde_json::from_str(&(String::from_utf8_lossy(&buffer[..length]))).unwrap(); let qmp_empty_response = Response::create_empty_response(); assert_eq!(qmp_empty_response, qmp_response); - assert_eq!(res.is_err(), false); + assert!(res.is_ok()); // After test. Environment Recover recover_unix_socket_environment("07"); diff --git a/machine_manager/src/socket.rs b/machine_manager/src/socket.rs index 6bca5b216fe79a5571067141731ea43ed6bbaa0a..a2c0f3369bf588006a0e6c401b1c6729962494f8 100644 --- a/machine_manager/src/socket.rs +++ b/machine_manager/src/socket.rs @@ -96,7 +96,7 @@ impl SocketRWHandler { fn parse_fd(&mut self, mhdr: &msghdr) { // At least it should has one RawFd. // SAFETY: The input parameter is constant. - let min_cmsg_len = unsafe { CMSG_LEN(size_of::() as u32) as u64 }; + let min_cmsg_len = unsafe { u64::from(CMSG_LEN(size_of::() as u32)) }; if (mhdr.msg_controllen as u64) < min_cmsg_len { return; } @@ -111,8 +111,8 @@ impl SocketRWHandler { { // SAFETY: The pointer of scm can be guaranteed not null. let fds = unsafe { - let fd_num = - (scm.cmsg_len as u64 - CMSG_LEN(0) as u64) as usize / size_of::(); + let fd_num = (scm.cmsg_len as u64 - u64::from(CMSG_LEN(0))) as usize + / size_of::(); std::slice::from_raw_parts(CMSG_DATA(scm) as *const RawFd, fd_num) }; self.scm_fd.append(&mut fds.to_vec()); @@ -406,7 +406,7 @@ mod tests { // Environment Recovery for UnixSocket fn recover_unix_socket_environment(socket_id: &str) { let socket_name: String = format!("test_{}.sock", socket_id); - std::fs::remove_file(&socket_name).unwrap(); + std::fs::remove_file(socket_name).unwrap(); } fn socket_basic_rw(client_fd: RawFd, server_fd: RawFd) -> bool { diff --git a/machine_manager/src/state_query.rs b/machine_manager/src/state_query.rs new file mode 100644 index 0000000000000000000000000000000000000000..e5c0dd9f4f354657430f76dde98d3af7fd8dd0b7 --- /dev/null +++ b/machine_manager/src/state_query.rs @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +use log::error; +use once_cell::sync::Lazy; + +static STATE_QUERY_MANAGER: Lazy> = + Lazy::new(|| RwLock::new(StateQueryManager::new())); + +pub type StateQueryCallback = dyn Fn() -> String + Send + Sync; + +struct StateQueryManager { + query_callbacks: HashMap>, +} + +impl StateQueryManager { + fn new() -> Self { + Self { + query_callbacks: HashMap::new(), + } + } + + fn register_query_callback(&mut self, key: String, callback: Arc) { + self.query_callbacks.insert(key, callback); + } + + fn unregister_query_callback(&mut self, key: &str) { + if self.query_callbacks.remove(key).is_none() { + error!("There is no query callback with key {}", key); + } + } + + fn query_workloads(&self) -> Vec<(String, String)> { + self.query_callbacks + .iter() + .map(|(module, query)| (module.clone(), query())) + .collect() + } +} + +pub fn register_state_query_callback(key: String, callback: Arc) { + STATE_QUERY_MANAGER + .write() + .unwrap() + .register_query_callback(key, callback); +} + +pub fn unregister_state_query_callback(key: &str) { + STATE_QUERY_MANAGER + .write() + .unwrap() + .unregister_query_callback(key); +} + +pub fn query_workloads() -> Vec<(String, String)> { + STATE_QUERY_MANAGER.read().unwrap().query_workloads() +} diff --git a/machine_manager/src/temp_cleaner.rs b/machine_manager/src/temp_cleaner.rs index cd5b2f42ddb58c2dd69f1e3eeba51752e641ac27..ba92cd78e09baa390f962c7266a2f0d31aeb439e 100644 --- a/machine_manager/src/temp_cleaner.rs +++ b/machine_manager/src/temp_cleaner.rs @@ -13,7 +13,10 @@ use std::collections::HashMap; use std::fs; use std::path::Path; -use std::sync::Arc; +use std::sync::{ + atomic::{fence, Ordering}, + Arc, +}; use log::{error, info}; @@ -105,7 +108,19 @@ impl TempCleaner { if let Some(tmp) = GLOBAL_TEMP_CLEANER.as_mut() { tmp.clean_files(); tmp.exit_notifier(); + fence(Ordering::SeqCst); + GLOBAL_TEMP_CLEANER = None; } } } + + pub fn is_cleaned() -> bool { + // SAFETY: This global variable is read but not modified by iothread. + // so there is not need to add lock to it. + unsafe { + let ret = GLOBAL_TEMP_CLEANER.is_none(); + fence(Ordering::SeqCst); + ret + } + } } diff --git a/migration/Cargo.toml b/migration/Cargo.toml index e1469d9ffdd540a359b1e829a009ee5493dacea8..f0d10681b29e2ebc82811d1e8daa7e1d9378185a 100644 --- a/migration/Cargo.toml +++ b/migration/Cargo.toml @@ -5,7 +5,7 @@ authors = ["Huawei StratoVirt Team"] edition = "2021" [dependencies] -kvm-ioctls = "0.15.0" +kvm-ioctls = "0.16.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" once_cell = "1.18.0" diff --git a/migration/migration_derive/src/attr_parser.rs b/migration/migration_derive/src/attr_parser.rs index ef1269f2736cbadaf888ea26ce5586309e63ec1c..237949842e893cd2d9943ee1984e8b150cea220d 100644 --- a/migration/migration_derive/src/attr_parser.rs +++ b/migration/migration_derive/src/attr_parser.rs @@ -129,7 +129,7 @@ fn version_to_u32(version_str: &str) -> u32 { panic!("Version str is illegal."); } - (version_vec[2] as u32) + ((version_vec[1] as u32) << 8) + ((version_vec[0] as u32) << 16) + u32::from(version_vec[2]) + (u32::from(version_vec[1]) << 8) + (u32::from(version_vec[0]) << 16) } #[cfg(test)] diff --git a/migration/src/general.rs b/migration/src/general.rs index 0777e0de142d78d26fe351f654a1d8564a0cb510..d0c7c65dc053721871fb6bdbb51a2fe24e4b4ebc 100644 --- a/migration/src/general.rs +++ b/migration/src/general.rs @@ -14,6 +14,8 @@ use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; use std::io::{Read, Write}; use std::mem::size_of; +use std::thread; +use std::time::{Duration, Instant}; use anyhow::{anyhow, bail, Context, Result}; @@ -22,6 +24,7 @@ use crate::protocol::{ DeviceStateDesc, FileFormat, MigrationHeader, MigrationStatus, VersionCheck, HEADER_LENGTH, }; use crate::{MigrationError, MigrationManager}; +use machine_manager::machine::VmState; use util::unix::host_page_size; impl MigrationManager { @@ -260,7 +263,18 @@ pub trait Lifecycle { /// Pause VM during migration. fn pause() -> Result<()> { if let Some(locked_vm) = &MIGRATION_MANAGER.vmm.read().unwrap().vm { - locked_vm.lock().unwrap().pause(); + let now = Instant::now(); + while !locked_vm.lock().unwrap().pause() { + thread::sleep(Duration::from_millis(5)); + if now.elapsed() > Duration::from_secs(2) { + // Not use resume() to avoid unnecessary qmp event. + locked_vm + .lock() + .unwrap() + .notify_lifecycle(VmState::Paused, VmState::Running); + bail!("Failed to pause VM"); + } + } } Ok(()) diff --git a/migration/src/manager.rs b/migration/src/manager.rs index d381ae3125a46d0b8eb03bbca1d8d6c0c5b9a991..3b081e97161cf9425d8118255fb881bcff57c6b2 100644 --- a/migration/src/manager.rs +++ b/migration/src/manager.rs @@ -266,7 +266,7 @@ impl MigrationManager { let name = cpu_desc.name.clone() + "/" + &id.to_string(); let mut copied_cpu_desc = cpu_desc.clone(); copied_cpu_desc.name = name.clone(); - copied_cpu_desc.alias = cpu_desc.alias + id as u64; + copied_cpu_desc.alias = cpu_desc.alias + u64::from(id); Self::register_device_desc(copied_cpu_desc); let mut locked_vmm = MIGRATION_MANAGER.vmm.write().unwrap(); diff --git a/migration/src/protocol.rs b/migration/src/protocol.rs index 0d57220f00b7372687d310eaf68808b23d1c9a62..ca8e23c296ec325827f478286399ccd1391570b5 100644 --- a/migration/src/protocol.rs +++ b/migration/src/protocol.rs @@ -920,12 +920,9 @@ pub mod tests { ); let mut current_slice = device_v1.get_state_vec().unwrap(); - assert_eq!( - state_2_desc - .add_padding(&state_1_desc, &mut current_slice) - .is_ok(), - true - ); + assert!(state_2_desc + .add_padding(&state_1_desc, &mut current_slice) + .is_ok()); let mut device_v2 = DeviceV2 { state: DeviceV2State::default(), @@ -964,12 +961,9 @@ pub mod tests { ); let mut current_slice = device_v2.get_state_vec().unwrap(); - assert_eq!( - state_3_desc - .add_padding(&state_2_desc, &mut current_slice) - .is_ok(), - true - ); + assert!(state_3_desc + .add_padding(&state_2_desc, &mut current_slice) + .is_ok()); let mut device_v3 = DeviceV3 { state: DeviceV3State::default(), @@ -977,10 +971,10 @@ pub mod tests { device_v3.set_state_mut(¤t_slice).unwrap(); assert!(state_3_desc.current_version > state_2_desc.current_version); - assert_eq!(device_v3.state.ier, device_v2.state.ier as u64); - assert_eq!(device_v3.state.iir, device_v2.state.iir as u64); - assert_eq!(device_v3.state.lcr, device_v2.state.lcr as u64); - assert_eq!(device_v3.state.mcr, device_v2.state.mcr as u64); + assert_eq!(device_v3.state.ier, u64::from(device_v2.state.ier)); + assert_eq!(device_v3.state.iir, u64::from(device_v2.state.iir)); + assert_eq!(device_v3.state.lcr, u64::from(device_v2.state.lcr)); + assert_eq!(device_v3.state.mcr, u64::from(device_v2.state.mcr)); } #[test] @@ -1007,12 +1001,9 @@ pub mod tests { ); let mut current_slice = device_v3.get_state_vec().unwrap(); - assert_eq!( - state_4_desc - .add_padding(&state_3_desc, &mut current_slice) - .is_ok(), - true - ); + assert!(state_4_desc + .add_padding(&state_3_desc, &mut current_slice) + .is_ok()); let mut device_v4 = DeviceV4 { state: DeviceV4State::default(), @@ -1050,12 +1041,9 @@ pub mod tests { ); let mut current_slice = device_v4.get_state_vec().unwrap(); - assert_eq!( - state_5_desc - .add_padding(&state_4_desc, &mut current_slice) - .is_ok(), - true - ); + assert!(state_5_desc + .add_padding(&state_4_desc, &mut current_slice) + .is_ok()); let mut device_v5 = DeviceV5 { state: DeviceV5State::default(), @@ -1089,12 +1077,9 @@ pub mod tests { ); let mut current_slice = device_v2.get_state_vec().unwrap(); - assert_eq!( - state_5_desc - .add_padding(&state_2_desc, &mut current_slice) - .is_ok(), - true - ); + assert!(state_5_desc + .add_padding(&state_2_desc, &mut current_slice) + .is_ok()); let mut device_v5 = DeviceV5 { state: DeviceV5State::default(), @@ -1102,16 +1087,16 @@ pub mod tests { device_v5.set_state_mut(¤t_slice).unwrap(); assert!(state_5_desc.current_version > state_2_desc.current_version); - assert_eq!(device_v5.state.rii, device_v2.state.iir as u64); + assert_eq!(device_v5.state.rii, u64::from(device_v2.state.iir)); } #[test] fn test_check_header() { - if !Kvm::new().is_ok() { + if Kvm::new().is_err() { return; } let header = MigrationHeader::default(); - assert_eq!(header.check_header().is_ok(), true); + assert!(header.check_header().is_ok()); } } diff --git a/ozonec/Cargo.lock b/ozonec/Cargo.lock new file mode 100644 index 0000000000000000000000000000000000000000..b67a79bbbdd38f05c0def30a76251a40073ad7c4 --- /dev/null +++ b/ozonec/Cargo.lock @@ -0,0 +1,906 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anyhow" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "caps" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190baaad529bcfbde9e1a19022c42781bdb6ff9de25721abdb8fd98c0807730b" +dependencies = [ + "libc", + "thiserror", +] + +[[package]] +name = "cc" +version = "1.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72db2f7947ecee9b03b510377e8bb9077afa27176fdbff55c51027e976fdcc48" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "serde", + "windows-targets 0.52.6", +] + +[[package]] +name = "clap" +version = "4.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f13b9c79b5d1dd500d20ef541215a6423c75829ef43117e1b4d17fd8af0b5d76" +dependencies = [ + "bitflags 1.3.2", + "clap_derive", + "clap_lex", + "once_cell", +] + +[[package]] +name = "clap_derive" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "684a277d672e91966334af371f1a7b5833f9aa00b07c84e92fbce95e00208ce8" +dependencies = [ + "heck", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "clap_lex" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "033f6b7a4acb1f358c742aaca805c939ee73b4c6209ae4318ec7aca81c42e646" +dependencies = [ + "os_str_bytes", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "fastrand" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" + +[[package]] +name = "flate2" +version = "1.0.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "io-lifetimes" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "js-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.146" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" + +[[package]] +name = "libseccomp" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21c57fd8981a80019807b7b68118618d29a87177c63d704fc96e6ecd003ae5b3" +dependencies = [ + "bitflags 1.3.2", + "libc", + "libseccomp-sys", + "pkg-config", +] + +[[package]] +name = "libseccomp-sys" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a7cbbd4ad467251987c6e5b47d53b11a5a05add08f2447a9e2d70aef1e0d138" + +[[package]] +name = "linux-raw-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" + +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + +[[package]] +name = "log" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "518ef76f2f87365916b142844c16d8fefd85039bc5699050210a7778ee1cd1de" + +[[package]] +name = "memoffset" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +dependencies = [ + "autocfg", +] + +[[package]] +name = "miniz_oxide" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +dependencies = [ + "adler", +] + +[[package]] +name = "nix" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +dependencies = [ + "bitflags 1.3.2", + "cfg-if", + "libc", + "memoffset", + "pin-utils", + "static_assertions", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "oci_spec" +version = "0.1.0" +dependencies = [ + "anyhow", + "libc", + "nix", + "serde", + "serde_json", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "os_str_bytes" +version = "6.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" + +[[package]] +name = "ozonec" +version = "0.1.0" +dependencies = [ + "anyhow", + "bitflags 1.3.2", + "caps", + "chrono", + "clap", + "libc", + "libseccomp", + "log", + "nix", + "oci_spec", + "procfs", + "rlimit", + "rusty-fork", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "procfs" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1de8dacb0873f77e6aefc6d71e044761fcc68060290f5b1089fcdf84626bb69" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "chrono", + "flate2", + "hex", + "lazy_static", + "rustix 0.36.17", +] + +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "rlimit" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81a9ed03edbed449d6897c2092c71ab5f7b5fb80f6f0b1a3ed6d40a6f9fc0720" +dependencies = [ + "libc", +] + +[[package]] +name = "rustix" +version = "0.36.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "305efbd14fde4139eb501df5f136994bb520b033fa9fbdce287507dc23b8c7ed" +dependencies = [ + "bitflags 1.3.2", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys 0.1.4", + "windows-sys 0.45.0", +] + +[[package]] +name = "rustix" +version = "0.38.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac5ffa1efe7548069688cd7028f32591853cd7b5b756d41bcffd2353e4fc75b4" +dependencies = [ + "bitflags 2.6.0", + "errno", + "libc", + "linux-raw-sys 0.4.14", + "windows-sys 0.48.0", +] + +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "serde" +version = "1.0.163" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.163" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.74", +] + +[[package]] +name = "serde_json" +version = "1.0.96" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fceb41e3d546d0bd83421d3409b1460cc7444cd389341a4c880fe7a042cb3d7" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall", + "rustix 0.38.3", + "windows-sys 0.48.0", +] + +[[package]] +name = "thiserror" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.74", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.74", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.74", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/ozonec/Cargo.toml b/ozonec/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..a617a89a40f8257b51a6539ae4011533ff76e296 --- /dev/null +++ b/ozonec/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "ozonec" +version = "0.1.0" +authors = ["Huawei StratoVirt Team"] +edition = "2021" +license = "Mulan PSL v2" +description = "An OCI runtime implemented by Rust" + +[dependencies] +anyhow = "= 1.0.71" +bitflags = "= 1.3.2" +caps = "0.5.5" +chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"] } +clap = { version = "= 4.1.4", default-features = false, features = ["derive", "cargo", "std", "help", "usage"] } +libc = "= 0.2.146" +libseccomp = "0.3.0" +log = { version = "= 0.4.18", features = ["std"]} +nix = "= 0.26.2" +oci_spec = { path = "oci_spec" } +procfs = "0.14.0" +rlimit = "0.5.3" +rusty-fork = "0.3.0" +serde = { version = "= 1.0.163", features = ["derive"] } +serde_json = "= 1.0.96" +thiserror = "= 1.0.40" + +[workspace] + +[profile.dev] +panic = "unwind" + +[profile.release] +lto = true +strip = true +opt-level = 'z' +codegen-units = 1 +panic = "abort" diff --git a/ozonec/oci_spec/Cargo.toml b/ozonec/oci_spec/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..e5923ad52ead9f6f2528cf0d35eae5c28a9d0405 --- /dev/null +++ b/ozonec/oci_spec/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "oci_spec" +version = "0.1.0" +authors = ["Huawei StratoVirt Team"] +edition = "2021" +license = "Mulan PSL v2" +description = "Open Container Initiative (OCI) Specifications in Rust" + +[dependencies] +anyhow = "= 1.0.71" +libc = "= 0.2.146" +nix = "= 0.26.2" +serde = { version = "= 1.0.163", features = ["derive"] } +serde_json = "= 1.0.96" + +[profile.dev] +panic = "unwind" + +[profile.release] +lto = true +strip = true +opt-level = 'z' +codegen-units = 1 +panic = "abort" \ No newline at end of file diff --git a/ozonec/oci_spec/src/lib.rs b/ozonec/oci_spec/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..f0dd3fe5fcbf13d380e3b3c773f7bd29b24508fc --- /dev/null +++ b/ozonec/oci_spec/src/lib.rs @@ -0,0 +1,20 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +#[cfg(target_os = "linux")] +pub mod linux; +#[cfg(target_family = "unix")] +pub mod posix; +pub mod process; +pub mod runtime; +pub mod state; +pub mod vm; diff --git a/ozonec/oci_spec/src/linux.rs b/ozonec/oci_spec/src/linux.rs new file mode 100644 index 0000000000000000000000000000000000000000..6dafcd5a7a66001454e1ea34d42830dc06149b4a --- /dev/null +++ b/ozonec/oci_spec/src/linux.rs @@ -0,0 +1,1232 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{collections::HashMap, path::PathBuf}; + +use anyhow::{anyhow, Result}; +use nix::sched::CloneFlags; +use serde::{Deserialize, Serialize}; + +/// Available Linux namespaces. +#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize, Hash)] +#[serde(rename_all = "snake_case")] +pub enum NamespaceType { + Cgroup = 0x0200_0000, + Ipc = 0x0800_0000, + Network = 0x4000_0000, + Mount = 0x0002_0000, + Pid = 0x2000_0000, + Time = 0x0000_0080, + User = 0x1000_0000, + Uts = 0x0400_0000, +} + +impl TryInto for NamespaceType { + type Error = anyhow::Error; + + fn try_into(self) -> Result { + match self { + NamespaceType::Cgroup => Ok(CloneFlags::CLONE_NEWCGROUP), + NamespaceType::Ipc => Ok(CloneFlags::CLONE_NEWIPC), + NamespaceType::Network => Ok(CloneFlags::CLONE_NEWNET), + NamespaceType::Mount => Ok(CloneFlags::CLONE_NEWNS), + NamespaceType::Pid => Ok(CloneFlags::CLONE_NEWPID), + NamespaceType::Time => Err(anyhow!("Time namespace not supported with clone")), + NamespaceType::User => Ok(CloneFlags::CLONE_NEWUSER), + NamespaceType::Uts => Ok(CloneFlags::CLONE_NEWUTS), + } + } +} + +impl From for String { + fn from(ns_type: NamespaceType) -> Self { + match ns_type { + NamespaceType::Cgroup => String::from("cgroup"), + NamespaceType::Ipc => String::from("ipc"), + NamespaceType::Network => String::from("net"), + NamespaceType::Mount => String::from("mnt"), + NamespaceType::Pid => String::from("pid"), + NamespaceType::Time => String::from("time"), + NamespaceType::User => String::from("user"), + NamespaceType::Uts => String::from("uts"), + } + } +} + +/// Namespaces. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Namespace { + /// Namespace type. + #[serde(rename = "type")] + pub ns_type: NamespaceType, + /// Namespace file. If path is not specified, a new namespace is created. + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +/// UID/GID mapping. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct IdMapping { + /// Starting uid/gid in the container. + pub containerID: u32, + /// Starting uid/gid on the host to be mapped to containerID. + pub hostID: u32, + /// Number of ids to be mapped. + pub size: u32, +} + +/// Offset for Time Namespace. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct TimeOffsets { + /// Offset of clock (in seconds) in the container. + #[serde(skip_serializing_if = "Option::is_none")] + pub secs: Option, + /// Offset of clock (in nanoseconds) in the container. + #[serde(skip_serializing_if = "Option::is_none")] + pub nanosecs: Option, +} + +/// Devices available in the container. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Device { + /// Type of device. + #[serde(rename = "type")] + pub dev_type: String, + /// Full path to device inside container. + pub path: String, + /// Major number for the device. + #[serde(skip_serializing_if = "Option::is_none")] + pub major: Option, + /// Minor number for the device. + #[serde(skip_serializing_if = "Option::is_none")] + pub minor: Option, + /// File mode for the device. + #[serde(skip_serializing_if = "Option::is_none")] + pub fileMode: Option, + /// Id of device owner. + #[serde(skip_serializing_if = "Option::is_none")] + pub uid: Option, + /// Id of device group. + #[serde(skip_serializing_if = "Option::is_none")] + pub gid: Option, +} + +fn default_device_type() -> String { + "a".to_string() +} + +/// Allowed device in Device Cgroup. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct CgroupDevice { + /// Whether the entry is allowed or denied. + #[serde(default)] + pub allow: bool, + /// Type of device. + #[serde(default = "default_device_type", rename = "type")] + pub dev_type: String, + /// Major number for the device. + #[serde(skip_serializing_if = "Option::is_none")] + pub major: Option, + /// Minor number for the device. + #[serde(skip_serializing_if = "Option::is_none")] + pub minor: Option, + /// Cgroup permissions for device. + #[serde(skip_serializing_if = "Option::is_none")] + pub access: Option, +} + +/// Cgroup subsystem to set limits on the container's memory usage. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MemoryCgroup { + /// Limit of memory usage. + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option, + /// Soft limit of memory usage. + #[serde(skip_serializing_if = "Option::is_none")] + pub reservation: Option, + /// Limits of memory +Swap usage. + #[serde(skip_serializing_if = "Option::is_none")] + pub swap: Option, + /// Hard limit for kernel memory. + #[serde(skip_serializing_if = "Option::is_none")] + pub kernel: Option, + /// Hard limit for kernel TCP buffer memory. + #[serde(skip_serializing_if = "Option::is_none")] + pub kernelTCP: Option, + /// Swappiness parameter of vmscan. + #[serde(skip_serializing_if = "Option::is_none")] + pub swappiness: Option, + /// Enable or disable the OOM killer. + #[serde(skip_serializing_if = "Option::is_none")] + pub disableOOMKiller: Option, + /// Enable or disable hierarchical memory accounting. + #[serde(skip_serializing_if = "Option::is_none")] + pub useHierarchy: Option, + /// Enable container memory usage check before setting a new limit. + #[serde(skip_serializing_if = "Option::is_none")] + pub checkBeforeUpdate: Option, +} + +/// Cgroup subsystems cpu and cpusets. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct CpuCgroup { + /// Relative share of CPU time available to the tasks in a cgroup. + #[serde(skip_serializing_if = "Option::is_none")] + pub shares: Option, + /// Total amount of time in microseconds for which all tasks in a + /// cgroup can run during one period. + #[serde(skip_serializing_if = "Option::is_none")] + pub quota: Option, + /// Maximum amount of accumulated time in microseconds for which + /// all tasks in a cgroup can run additionally for burst during + /// one period. + #[serde(skip_serializing_if = "Option::is_none")] + pub burst: Option, + /// Period of time in microseconds for how regularly a cgroup's access + /// to CPU resources should be reallocated (CFS scheduler only) + #[serde(skip_serializing_if = "Option::is_none")] + pub period: Option, + /// Period of time in microseconds for the longest continuous period + /// in which the tasks in a cgrouop have access to CPU resources. + #[serde(skip_serializing_if = "Option::is_none")] + pub realtimeRuntime: Option, + /// Same as period but applies to realtime scheduler only. + #[serde(skip_serializing_if = "Option::is_none")] + pub realtimePeriod: Option, + /// List of CPUs the container will run on. + #[serde(skip_serializing_if = "Option::is_none")] + pub cpus: Option, + /// List of memory nodes the container will run on. + #[serde(skip_serializing_if = "Option::is_none")] + pub mems: Option, + /// Cgroups are configured with minimum weight. + #[serde(skip_serializing_if = "Option::is_none")] + pub idle: Option, +} + +/// Per-device bandwidth weights. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct WeightDevice { + /// Major number for device. + pub major: i64, + /// Minor number for device. + pub minor: i64, + /// Bandwidth weight for the device. + #[serde(skip_serializing_if = "Option::is_none")] + pub weight: Option, + /// Bandwidth weight for the device while competing with the cgroup's + /// child cgroups (CFS scheduler only) + #[serde(skip_serializing_if = "Option::is_none")] + pub leafWeight: Option, +} + +/// Per-device bandwidth rate limits. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct ThrottleDevice { + /// Major number for device. + pub major: i64, + /// Minor number for device. + pub minor: i64, + /// Bandwidth rate limit in bytes per second or IO rate limit for + /// the device. + pub rate: u64, +} + +/// Cgroup subsystem blkio which implements the block IO controller. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct BlockIoCgroup { + /// Per-cgroup weight. + #[serde(skip_serializing_if = "Option::is_none")] + pub weight: Option, + /// Equivalents of weight for the purpose of deciding how much + /// weight tasks in the given cgroup has while competing with + /// the cgroup's child cgroups. + #[serde(skip_serializing_if = "Option::is_none")] + pub leafWeight: Option, + /// Array of per-device bandwidth weights. + #[serde(skip_serializing_if = "Option::is_none")] + pub weightDevice: Option>, + /// Array of per-device read bandwidth rate limits. + #[serde(skip_serializing_if = "Option::is_none")] + pub throttleReadBpsDevice: Option>, + /// Array of per-device write bandwidth rate limits. + #[serde(skip_serializing_if = "Option::is_none")] + pub throttleWriteBpsDevice: Option>, + /// Array of per-device read IO rate limits. + #[serde(skip_serializing_if = "Option::is_none")] + pub throttleReadIOPSDevice: Option>, + /// Array of per-device write IO rate limits. + #[serde(skip_serializing_if = "Option::is_none")] + pub throttleWriteIOPSDevice: Option>, +} + +/// hugetlb controller which allows to limit the HugeTLB reservations +/// (if supported) or usage (page fault). +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HugetlbCgroup { + /// Hugepage size + pub pageSize: String, + /// Limit in bytes of hugepagesize HugeTLB reservations + /// (if supported) or usage. + pub limit: u64, +} + +/// Priority assigned to traffic originating from processes in the +/// group and egressing the system on various interfaces. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct NetPriority { + /// Interface name. + pub name: String, + /// Priority applied to the interface. + pub priority: u32, +} + +/// Cgroup subsystems net_cls and net_prio. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct NetworkCgroup { + /// Network class identifier the cgroup's network packets will + /// be tagged with. + #[serde(skip_serializing_if = "Option::is_none")] + pub classID: Option, + /// List of objects of the priorities assigned to traffic + /// originating from processes in the group and egressing the + /// system on various interfaces. + #[serde(skip_serializing_if = "Option::is_none")] + pub priorities: Option>, +} + +/// Cgroup subsystem pids. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PidsCgroup { + /// Maximum number of tasks in the cgroup. + pub limit: i64, +} + +/// Per-device rdma limit. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RdmaLimit { + /// Maximum number of hca_handles in the cgroup. + pub hcaHandles: Option, + /// Maximum number of hca_objects in the cgroup. + pub hcaObjects: Option, +} + +/// Cgroup subsystem rdma. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RdmaCgroup { + /// Rdma limit for mlx5_1. + #[serde(skip_serializing_if = "Option::is_none")] + pub mlx5_1: Option, + /// Rdma limit for mlx4_0. + #[serde(skip_serializing_if = "Option::is_none")] + pub mlx4_0: Option, + /// Rdma limit for rxe3. + #[serde(skip_serializing_if = "Option::is_none")] + pub rxe3: Option, +} + +/// Cgroups to restrict resource usage for a container and +/// handle device access. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Cgroups { + /// Device cgroup settings. + #[serde(skip_serializing_if = "Option::is_none")] + pub devices: Option>, + /// Memory cgroup settings. + #[serde(skip_serializing_if = "Option::is_none")] + pub memory: Option, + /// Cpu and Cpuset cgroup settings. + #[serde(skip_serializing_if = "Option::is_none")] + pub cpu: Option, + /// Blkio cgroup settings. + #[serde(skip_serializing_if = "Option::is_none")] + pub blockIO: Option, + /// Hugetlb cgroup settings. + #[serde(skip_serializing_if = "Option::is_none")] + pub hugepageLimits: Option>, + /// Network cgroup settings. + #[serde(skip_serializing_if = "Option::is_none")] + pub network: Option, + /// Pids cgroup settings. + #[serde(skip_serializing_if = "Option::is_none")] + pub pids: Option, +} + +#[cfg(target_arch = "x86_64")] +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +/// Intel Resource Director Technology +pub struct IntelRdt { + #[serde(skip_serializing_if = "Option::is_none")] + /// Identity for RDT Class of Service (CLOS). + pub closID: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Schema for L3 cache id and capacity bitmask (CBM). + pub l3CacheSchema: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Schema of memory bandwidth per L3 cache id. + pub memBwSchema: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// If Intel RDT CMT should be enabled. + pub enableCMT: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// If Intel RDT MBM should be enabled. + pub enableMBM: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[repr(u32)] +/// Action for seccomp rules. +pub enum SeccompAction { + ScmpActKill = 0x0000_0000, + ScmpActKillProcess = 0x8000_0000, + ScmpActTrap = 0x0003_0000, + ScmpActErrno = 0x0005_0001, + ScmpActNotify = 0x7fc0_0000, + ScmpActTrace = 0x7ff0_0001, + ScmpActLog = 0x7ffc_0000, + ScmpActAllow = 0x7fff_0000, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[repr(u32)] +/// Operator for syscall arguments in seccomp. +pub enum SeccompOp { + ScmpCmpNe = 1, + ScmpCmpLt = 2, + ScmpCmpLe = 3, + #[default] + ScmpCmpEq = 4, + ScmpCmpGe = 5, + ScmpCmpGt = 6, + ScmpCmpMaskedEq = 7, +} + +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +/// The specific syscall in seccomp. +pub struct SeccompSyscallArg { + /// Index for syscall arguments. + #[serde(default)] + pub index: usize, + /// Value for syscall arguments. + #[serde(default)] + pub value: u64, + #[serde(skip_serializing_if = "Option::is_none")] + /// Value for syscall arguments. + pub valueTwo: Option, + /// Operator for syscall arguments. + pub op: SeccompOp, +} + +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +/// Match a syscall in seccomp. +pub struct SeccompSyscall { + /// Names of the syscalls. + pub names: Vec, + /// Action for seccomp rules. + pub action: SeccompAction, + #[serde(skip_serializing_if = "Option::is_none")] + /// Errno return code to use. + pub errnoRet: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Specific syscall in seccomp. + pub args: Option>, +} + +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +/// Seccomp provides application sandboxing mechanism in the Linux kernel. +pub struct Seccomp { + /// Default action for seccomp. + pub defaultAction: SeccompAction, + #[serde(skip_serializing_if = "Option::is_none")] + /// Errno return code to use. + pub defaultErrnoRet: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Architecture used for system calls. + pub architectures: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// List of flags to use with seccomp. + pub flags: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// Path of UNIX domain socket over which the runtime will send the + /// container process state data structure when the SCMP_ACT_NOTIFY + /// action is used. + pub listennerPath: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Seccomp file descriptor returned by the seccomp syscall. + pub seccompFd: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Opaque data to pass to the seccomp agent. + pub listenerMetadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Match a syscall in seccomp. + pub syscalls: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +/// Linux execution personality. +pub struct Personality { + /// Execution domain. + pub domain: String, + /// Additional flags to apply. + pub flags: Option>, +} + +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +/// Linux-specific configuration. +pub struct LinuxPlatform { + /// A namespace wraps a global system resource in an abstraction. + pub namespaces: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + /// User namespace uid mappings from the host to the container. + pub uidMappings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// User namespace gid mappings from the host to the container. + pub gidMappings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// Offset for Time Namespace. + pub timeOffsets: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Lists devices that MUST be available in the container. + pub devices: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// Path to the cgroups. + pub cgroupsPath: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Rootfs's mount propagation. + pub rootfsPropagation: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Mask over the provided paths inside the container so + /// that they cannot be read. + pub maskedPaths: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// Set the provided paths as readonly inside the container. + pub readonlyPaths: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// Selinux context for the mounts in the container. + pub mountLabel: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Linux execution personality. + pub personality: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Configure a container's cgroups. + pub resources: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// The cgroup subsystem rdma. + pub rdma: Option, + #[serde(skip_serializing_if = "Option::is_none")] + /// Allows cgroup v2 parameters to be to be set and modified + /// for the container. + pub unified: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// Kernel parameters to be modified at runtime for the + /// container. + pub sysctl: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + /// Seccomp provides application sandboxing mechanism in + /// the Linux kernel. + pub seccomp: Option, + #[cfg(target_arch = "x86_64")] + #[serde(skip_serializing_if = "Option::is_none")] + /// Intel Resource Director Technology. + pub intelRdt: Option, +} + +/// Arrays that specifies the sets of capabilities for the process. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Capbilities { + /// Array of effective capabilities that are kept for the process. + #[serde(skip_serializing_if = "Option::is_none")] + pub effective: Option>, + /// Array of bounding capabilities that are kept for the process. + #[serde(skip_serializing_if = "Option::is_none")] + pub bounding: Option>, + /// Array of inheritable capabilities that are kept for the process. + #[serde(skip_serializing_if = "Option::is_none")] + pub inheritable: Option>, + /// Array of permitted capabilities that are kept for the process. + #[serde(skip_serializing_if = "Option::is_none")] + pub permitted: Option>, + /// Array of ambient capabilities that are kept for the process. + #[serde(skip_serializing_if = "Option::is_none")] + pub ambient: Option>, +} + +/// Scheduling policy. +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum SchedPolicy { + SchedOther, + SchedFifo, + SchedRr, + SchedBatch, + SchedIdle, +} + +impl From for libc::c_int { + fn from(value: SchedPolicy) -> Self { + match value { + SchedPolicy::SchedOther => libc::SCHED_OTHER, + SchedPolicy::SchedFifo => libc::SCHED_FIFO, + SchedPolicy::SchedRr => libc::SCHED_RR, + SchedPolicy::SchedBatch => libc::SCHED_BATCH, + SchedPolicy::SchedIdle => libc::SCHED_IDLE, + } + } +} + +/// Scheduler properties for the process. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Scheduler { + /// Scheduling policy. + pub policy: SchedPolicy, + /// Nice value for the process, affecting its priority. + #[serde(skip_serializing_if = "Option::is_none")] + pub nice: Option, + /// Static priority of the process. + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + /// Array of strings representing scheduling flags. + #[serde(skip_serializing_if = "Option::is_none")] + pub flags: Option>, + /// Amount of time in nanoseconds during which the process is + /// allowed to run in a given period, used by the deadline + /// scheduler. + #[serde(skip_serializing_if = "Option::is_none")] + pub runtime: Option, + /// Absolute deadline for the process to complete its execution, + /// used by the deadline scheduler. + #[serde(skip_serializing_if = "Option::is_none")] + pub deadline: Option, + /// Length of the period in nanoseconds used for determining the + /// process runtime, used by the deadline scheduler. + #[serde(skip_serializing_if = "Option::is_none")] + pub period: Option, +} + +/// I/O scheduling class. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum IoPriClass { + IoprioClassRt, + IoprioClassBe, + IoprioClassIdle, +} + +/// I/O priority settings for the container's processes within the +/// process group. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct IoPriority { + /// I/O scheduling class. + pub class: IoPriClass, + /// Priority level within the class. + pub priority: i64, +} + +/// CPU affinity used to execute the process. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ExecCpuAffinity { + /// List of CPUs a runtime parent process to be run on initially, + /// before the transition to container's cgroup. + #[serde(skip_serializing_if = "Option::is_none")] + pub initial: Option, + /// List of CPUs the process will be run on after the transition + /// to container's cgroup. + #[serde(skip_serializing_if = "Option::is_none", rename = "final")] + pub final_cpus: Option, +} + +#[cfg(test)] +mod tests { + use serde_json; + + use super::*; + + #[test] + fn test_namespaces() { + let json = r#"{ + "namespaces": [ + { + "type": "pid", + "path": "/proc/1234/ns/pid" + }, + { + "type": "network", + "path": "/var/run/netns/neta" + }, + { + "type": "mount" + }, + { + "type": "ipc" + }, + { + "type": "uts" + }, + { + "type": "user" + }, + { + "type": "cgroup" + }, + { + "type": "time" + } + ] + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + namespaces: Vec, + } + + let ns: Section = serde_json::from_str(json).unwrap(); + assert_eq!(ns.namespaces.len(), 8); + assert_eq!(ns.namespaces[0].ns_type, NamespaceType::Pid); + assert_eq!(ns.namespaces[1].ns_type, NamespaceType::Network); + assert_eq!(ns.namespaces[2].ns_type, NamespaceType::Mount); + assert_eq!(ns.namespaces[3].ns_type, NamespaceType::Ipc); + assert_eq!(ns.namespaces[4].ns_type, NamespaceType::Uts); + assert_eq!(ns.namespaces[5].ns_type, NamespaceType::User); + assert_eq!(ns.namespaces[6].ns_type, NamespaceType::Cgroup); + assert_eq!(ns.namespaces[7].ns_type, NamespaceType::Time); + } + + #[test] + fn test_ids_mapping() { + let json = r#"{ + "uidMappings": [ + { + "containerID": 0, + "hostID": 1000, + "size": 32000 + } + ], + "gidMappings": [ + { + "containerID": 0, + "hostID": 1000, + "size": 32000 + } + ] + }"#; + + #[allow(non_snake_case)] + #[derive(Serialize, Deserialize)] + struct Section { + uidMappings: Vec, + gidMappings: Vec, + } + + let ids_mapping: Section = serde_json::from_str(json).unwrap(); + assert_eq!(ids_mapping.uidMappings.len(), 1); + assert_eq!(ids_mapping.uidMappings[0].size, 32000 as u32); + assert_eq!(ids_mapping.gidMappings.len(), 1); + assert_eq!(ids_mapping.gidMappings[0].size, 32000 as u32); + } + + #[test] + fn test_time_offsets() { + let json = r#"{ + "timeOffsets": { + "secs": 100 + } + }"#; + + #[allow(non_snake_case)] + #[derive(Serialize, Deserialize)] + struct Section { + timeOffsets: TimeOffsets, + } + + let time_offsets: Section = serde_json::from_str(json).unwrap(); + assert_eq!(time_offsets.timeOffsets.secs, Some(100)); + assert_eq!(time_offsets.timeOffsets.nanosecs, None); + } + + #[test] + fn test_devices() { + let json = r#"{ + "devices": [ + { + "path": "/dev/fuse", + "type": "c", + "major": 10, + "minor": 229, + "fileMode": 438, + "uid": 0, + "gid": 0 + }, + { + "path": "/dev/sda", + "type": "b", + "major": 8, + "minor": 0 + } + ] + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + devices: Vec, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.devices.len(), 2); + assert_eq!(section.devices[1].path, "/dev/sda"); + assert_eq!(section.devices[1].dev_type, "b"); + assert_eq!(section.devices[1].major, Some(8)); + assert_eq!(section.devices[1].minor, Some(0)); + assert_eq!(section.devices[1].fileMode, None); + assert_eq!(section.devices[1].uid, None); + assert_eq!(section.devices[1].gid, None); + } + + #[test] + fn test_cgroup_devices() { + let json = r#"{ + "devices": [ + { + "allow": false + }, + { + "allow": true, + "type": "c", + "major": 10, + "minor": 229, + "access": "rw" + } + ] + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + devices: Vec, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.devices.len(), 2); + assert_eq!(section.devices[0].allow, false); + assert_eq!(section.devices[0].dev_type, "a"); + assert_eq!(section.devices[0].major, None); + assert_eq!(section.devices[0].minor, None); + assert_eq!(section.devices[0].access, None); + assert_eq!(section.devices[1].allow, true); + assert_eq!(section.devices[1].dev_type, "c"); + assert_eq!(section.devices[1].major, Some(10)); + assert_eq!(section.devices[1].minor, Some(229)); + assert_eq!(section.devices[1].access, Some("rw".to_string())); + } + + #[test] + fn test_cgroup_memory_01() { + let json = r#"{ + "memory": { + "limit": 536870912, + "reservation": 536870912, + "swap": 536870912, + "kernel": -1, + "kernelTCP": -1, + "swappiness": 0, + "disableOOMKiller": false + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + memory: MemoryCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.memory.limit, Some(536870912)); + assert_eq!(section.memory.reservation, Some(536870912)); + assert_eq!(section.memory.swap, Some(536870912)); + assert_eq!(section.memory.kernel, Some(-1)); + assert_eq!(section.memory.kernelTCP, Some(-1)); + assert_eq!(section.memory.swappiness, Some(0)); + assert_eq!(section.memory.disableOOMKiller, Some(false)); + assert_eq!(section.memory.useHierarchy, None); + assert_eq!(section.memory.checkBeforeUpdate, None); + } + + #[test] + fn test_cgroup_memory_02() { + let json = r#"{ + "memory": { + "useHierarchy": true, + "checkBeforeUpdate": true + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + memory: MemoryCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.memory.limit, None); + assert_eq!(section.memory.reservation, None); + assert_eq!(section.memory.swap, None); + assert_eq!(section.memory.kernel, None); + assert_eq!(section.memory.kernelTCP, None); + assert_eq!(section.memory.swappiness, None); + assert_eq!(section.memory.disableOOMKiller, None); + assert_eq!(section.memory.useHierarchy, Some(true)); + assert_eq!(section.memory.checkBeforeUpdate, Some(true)); + } + + #[test] + fn test_cgroup_cpu_01() { + let json = r#"{ + "cpu": { + "shares": 1024, + "quota": 1000000, + "burst": 1000000, + "period": 500000, + "realtimeRuntime": 950000, + "realtimePeriod": 1000000, + "cpus": "2-3", + "mems": "0-7", + "idle": 0 + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + cpu: CpuCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.cpu.shares, Some(1024)); + assert_eq!(section.cpu.quota, Some(1000000)); + assert_eq!(section.cpu.burst, Some(1000000)); + assert_eq!(section.cpu.period, Some(500000)); + assert_eq!(section.cpu.realtimeRuntime, Some(950000)); + assert_eq!(section.cpu.realtimePeriod, Some(1000000)); + assert_eq!(section.cpu.cpus, Some("2-3".to_string())); + assert_eq!(section.cpu.mems, Some("0-7".to_string())); + assert_eq!(section.cpu.idle, Some(0)); + } + + #[test] + fn test_cgroup_cpu_02() { + let json = r#"{ + "cpu": {} + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + cpu: CpuCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.cpu.shares, None); + assert_eq!(section.cpu.quota, None); + assert_eq!(section.cpu.burst, None); + assert_eq!(section.cpu.period, None); + assert_eq!(section.cpu.realtimeRuntime, None); + assert_eq!(section.cpu.realtimePeriod, None); + assert_eq!(section.cpu.cpus, None); + assert_eq!(section.cpu.mems, None); + assert_eq!(section.cpu.idle, None); + } + + #[test] + fn test_cgroup_blkio() { + let json = r#"{ + "blockIO": { + "weight": 10, + "leafWeight": 10, + "weightDevice": [ + { + "major": 8, + "minor": 0, + "weight": 500, + "leafWeight": 300 + }, + { + "major": 8, + "minor": 16 + } + ], + "throttleReadBpsDevice": [ + { + "major": 8, + "minor": 0, + "rate": 600 + }, + { + "major": 8, + "minor": 16, + "rate": 300 + } + ] + } + }"#; + + #[allow(non_snake_case)] + #[derive(Serialize, Deserialize)] + struct Section { + blockIO: BlockIoCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.blockIO.weight, Some(10)); + assert_eq!(section.blockIO.leafWeight, Some(10)); + assert_eq!(section.blockIO.throttleReadIOPSDevice, None); + assert_eq!(section.blockIO.throttleWriteBpsDevice, None); + assert_eq!(section.blockIO.throttleWriteIOPSDevice, None); + + let weight_device = section.blockIO.weightDevice.as_ref().unwrap(); + assert_eq!(weight_device.len(), 2); + assert_eq!(weight_device[0].major, 8); + assert_eq!(weight_device[0].minor, 0); + assert_eq!(weight_device[0].weight, Some(500)); + assert_eq!(weight_device[0].leafWeight, Some(300)); + assert_eq!(weight_device[1].major, 8); + assert_eq!(weight_device[1].minor, 16); + assert_eq!(weight_device[1].weight, None); + assert_eq!(weight_device[1].leafWeight, None); + + let throttle = section.blockIO.throttleReadBpsDevice.as_ref().unwrap(); + assert_eq!(throttle.len(), 2); + assert_eq!(throttle[1].major, 8); + assert_eq!(throttle[1].minor, 16); + assert_eq!(throttle[1].rate, 300); + } + + #[test] + fn test_cgroup_hugetlb() { + let json = r#"{ + "hugepageLimits": [ + { + "pageSize": "2MB", + "limit": 209715200 + } + ] + }"#; + + #[allow(non_snake_case)] + #[derive(Serialize, Deserialize)] + struct Section { + hugepageLimits: Vec, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.hugepageLimits[0].pageSize, "2MB"); + assert_eq!(section.hugepageLimits[0].limit, 209715200); + } + + #[test] + fn test_cgroup_network_01() { + let json = r#"{ + "network": { + "classID": 1048577, + "priorities": [ + { + "name": "eth0", + "priority": 500 + } + ] + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + network: NetworkCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.network.classID, Some(1048577)); + let priorities = section.network.priorities.as_ref().unwrap(); + assert_eq!(priorities[0].name, "eth0"); + assert_eq!(priorities[0].priority, 500); + } + + #[test] + fn test_cgroup_network_02() { + let json = r#"{ + "network": {} + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + network: NetworkCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.network.classID, None); + assert_eq!(section.network.priorities, None); + } + + #[test] + fn test_cgroup_pid() { + let json = r#"{ + "pids": { + "limit": 32771 + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + pids: PidsCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.pids.limit, 32771); + } + + #[test] + fn test_cgroup_rdma() { + let json = r#"{ + "rdma": { + "mlx5_1": { + "hcaHandles": 3, + "hcaObjects": 10000 + }, + "mlx4_0": { + "hcaObjects": 1000 + }, + "rxe3": { + "hcaHandles": 10000 + } + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + rdma: RdmaCgroup, + } + + let section: Section = serde_json::from_str(json).unwrap(); + let rdma_limit = section.rdma.mlx5_1.as_ref().unwrap(); + assert_eq!(rdma_limit.hcaHandles, Some(3)); + assert_eq!(rdma_limit.hcaObjects, Some(10000)); + let rdma_limit = section.rdma.mlx4_0.as_ref().unwrap(); + assert_eq!(rdma_limit.hcaHandles, None); + assert_eq!(rdma_limit.hcaObjects, Some(1000)); + let rdma_limit = section.rdma.rxe3.as_ref().unwrap(); + assert_eq!(rdma_limit.hcaHandles, Some(10000)); + assert_eq!(rdma_limit.hcaObjects, None); + } + + #[cfg(target_arch = "x86_64")] + #[test] + fn test_intel_rdt() { + let json = r#"{ + "intelRdt": { + "closID": "guaranteed_group", + "l3CacheSchema": "L3:0=7f0;1=1f", + "memBwSchema": "MB:0=20;1=70", + "enableCMT": true, + "enableMBM": true + } + }"#; + + #[allow(non_snake_case)] + #[derive(Serialize, Deserialize)] + struct Section { + intelRdt: IntelRdt, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!( + section.intelRdt.closID, + Some("guaranteed_group".to_string()) + ); + assert_eq!( + section.intelRdt.l3CacheSchema, + Some("L3:0=7f0;1=1f".to_string()) + ); + assert_eq!( + section.intelRdt.memBwSchema, + Some("MB:0=20;1=70".to_string()) + ); + assert_eq!(section.intelRdt.enableCMT, Some(true)); + assert_eq!(section.intelRdt.enableMBM, Some(true)); + } + + #[test] + fn test_seccomp() { + let json = r#"{ + "seccomp": { + "defaultAction": "SCMP_ACT_ALLOW", + "architectures": [ + "SCMP_ARCH_X86", + "SCMP_ARCH_X32" + ], + "syscalls": [ + { + "names": [ + "getcwd", + "chmod" + ], + "action": "SCMP_ACT_ERRNO" + } + ] + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + seccomp: Seccomp, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.seccomp.defaultAction, SeccompAction::ScmpActAllow); + let architectures = section.seccomp.architectures.as_ref().unwrap(); + assert_eq!(architectures.len(), 2); + assert_eq!(architectures[0], "SCMP_ARCH_X86"); + assert_eq!(architectures[1], "SCMP_ARCH_X32"); + let syscall_names = section.seccomp.syscalls.as_ref().unwrap(); + assert_eq!(syscall_names[0].names.len(), 2); + assert_eq!(syscall_names[0].names[0], "getcwd"); + assert_eq!(syscall_names[0].names[1], "chmod"); + assert_eq!(syscall_names[0].action, SeccompAction::ScmpActErrno); + } + + #[test] + fn test_personality() { + let json = r#"{ + "personality": { + "domain": "LINUX" + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + personality: Personality, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.personality.domain, "LINUX"); + assert_eq!(section.personality.flags, None); + } +} diff --git a/ozonec/oci_spec/src/posix.rs b/ozonec/oci_spec/src/posix.rs new file mode 100644 index 0000000000000000000000000000000000000000..b284d5d363829f1e74baed8037c318be79fed7bc --- /dev/null +++ b/ozonec/oci_spec/src/posix.rs @@ -0,0 +1,249 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use serde::{Deserialize, Serialize}; + +/// Container's root filesystem. +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct Root { + /// Path to the root filesystem for the container. + pub path: String, + #[serde(default)] + /// If true then the root filesystem MUST be read-only inside the container. + pub readonly: bool, +} + +/// Resource limits for the process. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Rlimits { + /// The platform resource being limited. + #[serde(rename = "type")] + pub rlimit_type: String, + /// Value of the limit enforced for the corresponding resource. + pub soft: u64, + /// Ceiling for the soft limit that could be set by an + /// unprivileged process. + pub hard: u64, +} + +/// The user for the process that allows specific control over which user +/// the process runs as. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct User { + /// User ID in the container namespace. + #[serde(default)] + pub uid: u32, + /// Group ID in the container namespace. + #[serde(default)] + pub gid: u32, + /// [umask][umask_2] of the user. + #[serde(skip_serializing_if = "Option::is_none")] + pub umask: Option, + /// Additional group IDs in the container namespace to be added + /// to the process. + #[serde(skip_serializing_if = "Option::is_none")] + pub additionalGids: Option>, +} + +/// Hook Entry. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HookEntry { + /// Similar semantics to IEEE Std 1003.1-2008 execv's path. + pub path: String, + /// Same semantics as IEEE Std 1003.1-2008 execv's argv. + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option>, + /// Same semantics as IEEE Std 1003.1-2008's environ. + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, + /// Number of seconds before aborting the hook. + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Hooks { + /// Array of prestart hooks. + #[serde(skip_serializing_if = "Option::is_none")] + prestart: Option>, + /// Array of createRuntime hooks. + #[serde(skip_serializing_if = "Option::is_none")] + createRuntime: Option>, + /// Array of createContainer hooks. + #[serde(skip_serializing_if = "Option::is_none")] + createContainer: Option>, + /// Array of startContainer hooks. + #[serde(skip_serializing_if = "Option::is_none")] + startContainer: Option>, + /// Array of poststart hooks. + #[serde(skip_serializing_if = "Option::is_none")] + poststart: Option>, + /// Array of poststop hooks. + #[serde(skip_serializing_if = "Option::is_none")] + poststop: Option>, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + #[test] + fn test_root() { + let json = r#"{ + "root": { + "path": "rootfs", + "readonly": true + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + root: Root, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.root.path, "rootfs"); + assert_eq!(section.root.readonly, true); + } + + #[test] + fn test_hooks() { + let json = r#"{ + "hooks": { + "prestart": [ + { + "path": "/usr/bin/fix-mounts", + "args": ["fix-mounts", "arg1", "arg2"], + "env": [ "key1=value1"] + }, + { + "path": "/usr/bin/setup-network" + } + ], + "createRuntime": [ + { + "path": "/usr/bin/fix-mounts", + "args": ["fix-mounts", "arg1", "arg2"], + "env": [ "key1=value1"] + }, + { + "path": "/usr/bin/setup-network" + } + ], + "createContainer": [ + { + "path": "/usr/bin/mount-hook", + "args": ["-mount", "arg1", "arg2"], + "env": [ "key1=value1"] + } + ], + "startContainer": [ + { + "path": "/usr/bin/refresh-ldcache" + } + ], + "poststart": [ + { + "path": "/usr/bin/notify-start", + "timeout": 5 + } + ], + "poststop": [ + { + "path": "/usr/sbin/cleanup.sh", + "args": ["cleanup.sh", "-f"] + } + ] + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + hooks: Hooks, + } + + let section: Section = serde_json::from_str(json).unwrap(); + let prestart = section.hooks.prestart.as_ref().unwrap(); + assert_eq!(prestart.len(), 2); + assert_eq!(prestart[0].path, "/usr/bin/fix-mounts"); + let args = prestart[0].args.as_ref().unwrap(); + assert_eq!(args.len(), 3); + assert_eq!(args[0], "fix-mounts"); + assert_eq!(args[1], "arg1"); + assert_eq!(args[2], "arg2"); + let env = prestart[0].env.as_ref().unwrap(); + assert_eq!(env.len(), 1); + assert_eq!(env[0], "key1=value1"); + assert_eq!(prestart[0].timeout, None); + assert_eq!(prestart[1].path, "/usr/bin/setup-network"); + assert_eq!(prestart[1].args, None); + assert_eq!(prestart[1].env, None); + assert_eq!(prestart[1].timeout, None); + + let create_runtime = section.hooks.createRuntime.as_ref().unwrap(); + assert_eq!(create_runtime.len(), 2); + assert_eq!(create_runtime[0].path, "/usr/bin/fix-mounts"); + let args = create_runtime[0].args.as_ref().unwrap(); + assert_eq!(args.len(), 3); + assert_eq!(args[0], "fix-mounts"); + assert_eq!(args[1], "arg1"); + assert_eq!(args[2], "arg2"); + let env = create_runtime[0].env.as_ref().unwrap(); + assert_eq!(env.len(), 1); + assert_eq!(env[0], "key1=value1"); + assert_eq!(create_runtime[0].timeout, None); + assert_eq!(create_runtime[1].path, "/usr/bin/setup-network"); + assert_eq!(create_runtime[1].args, None); + assert_eq!(create_runtime[1].env, None); + assert_eq!(create_runtime[1].timeout, None); + + let create_container = section.hooks.createContainer.as_ref().unwrap(); + assert_eq!(create_container.len(), 1); + assert_eq!(create_container[0].path, "/usr/bin/mount-hook"); + let args = create_container[0].args.as_ref().unwrap(); + assert_eq!(args.len(), 3); + assert_eq!(args[0], "-mount"); + assert_eq!(args[1], "arg1"); + assert_eq!(args[2], "arg2"); + let env = create_container[0].env.as_ref().unwrap(); + assert_eq!(env.len(), 1); + assert_eq!(env[0], "key1=value1"); + assert_eq!(create_container[0].timeout, None); + + let start_container = section.hooks.startContainer.as_ref().unwrap(); + assert_eq!(start_container.len(), 1); + assert_eq!(start_container[0].path, "/usr/bin/refresh-ldcache"); + assert_eq!(start_container[0].args, None); + assert_eq!(start_container[0].env, None); + assert_eq!(start_container[0].timeout, None); + + let poststart = section.hooks.poststart.as_ref().unwrap(); + assert_eq!(poststart.len(), 1); + assert_eq!(poststart[0].path, "/usr/bin/notify-start"); + assert_eq!(poststart[0].args, None); + assert_eq!(poststart[0].env, None); + assert_eq!(poststart[0].timeout, Some(5)); + + let poststop = section.hooks.poststop.as_ref().unwrap(); + assert_eq!(poststop.len(), 1); + assert_eq!(poststop[0].path, "/usr/sbin/cleanup.sh"); + let args = poststop[0].args.as_ref().unwrap(); + assert_eq!(args.len(), 2); + assert_eq!(args[0], "cleanup.sh"); + assert_eq!(args[1], "-f"); + assert_eq!(poststop[0].env, None); + assert_eq!(poststop[0].timeout, None); + } +} diff --git a/ozonec/oci_spec/src/process.rs b/ozonec/oci_spec/src/process.rs new file mode 100644 index 0000000000000000000000000000000000000000..a558d78b0282e7134d81d8a46f822f055eb91656 --- /dev/null +++ b/ozonec/oci_spec/src/process.rs @@ -0,0 +1,240 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use serde::{Deserialize, Serialize}; + +#[cfg(target_os = "linux")] +use crate::linux::{Capbilities, ExecCpuAffinity, IoPriority, Scheduler}; +#[cfg(target_family = "unix")] +use crate::posix::{Rlimits, User}; + +/// Console size in characters of the terminal. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConsoleSize { + /// Height size in characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub height: Option, + /// Width size in characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub width: Option, +} + +/// Container process. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Process { + /// Working directory that will be set for the executable. + pub cwd: String, + /// Similar semantics to IEEE Std 1003.1-2008 execvp's argv. + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option>, + /// Same semantics as IEEE Std 1003.1-2008's environ. + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, + /// Whether a terminal is attached to the process. + #[serde(default)] + pub terminal: bool, + /// Console size in characters of the terminal. + #[serde(skip_serializing_if = "Option::is_none")] + pub consoleSize: Option, + /// Full command line to be executed on Windows. + #[cfg(target_os = "windows")] + pub commandLine: Option, + /// Resource limits for the process. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub rlimits: Option>, + /// Name of the AppArmor profile for the process. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub apparmorProfile: Option, + /// Arrays that specifies the sets of capabilities for the process. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub capabilities: Option, + /// Setting noNewPrivileges to true prevents the process from + /// gaining additional privileges. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub noNewPrivileges: Option, + /// Oom-killer score in [pid]/oom_score_adj for the process's + /// [pid] in a proc pseudo-filesystem. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub oomScoreAdj: Option, + /// Scheduler properties for the process. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub scheduler: Option, + /// SELinux label for the process. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub selinuxLabel: Option, + /// I/O priority settings for the container's processes within + /// the process group. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub ioPriority: Option, + /// CPU affinity used to execute the process. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub execCPUAffinity: Option, + /// The user for the process that allows specific control over + /// which user the process runs as. + pub user: User, +} + +#[cfg(test)] +mod tests { + use crate::linux::IoPriClass; + + use super::*; + use serde_json; + + #[test] + fn test_process() { + let json = r#"{ + "process": { + "terminal": true, + "consoleSize": { + "height": 25, + "width": 80 + }, + "user": { + "uid": 1, + "gid": 1, + "umask": 63, + "additionalGids": [5, 6] + }, + "env": [ + "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", + "TERM=xterm" + ], + "cwd": "/root", + "args": [ + "sh" + ], + "apparmorProfile": "acme_secure_profile", + "selinuxLabel": "system_u:system_r:svirt_lxc_net_t:s0:c124,c675", + "ioPriority": { + "class": "IOPRIO_CLASS_IDLE", + "priority": 4 + }, + "noNewPrivileges": true, + "capabilities": { + "bounding": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "permitted": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "inheritable": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "effective": [ + "CAP_AUDIT_WRITE", + "CAP_KILL" + ], + "ambient": [ + "CAP_NET_BIND_SERVICE" + ] + }, + "rlimits": [ + { + "type": "RLIMIT_NOFILE", + "hard": 1024, + "soft": 1024 + } + ], + "execCPUAffinity": { + "initial": "7", + "final": "0-3,7" + } + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + process: Process, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.process.terminal, true); + let console_size = section.process.consoleSize.as_ref().unwrap(); + assert_eq!(console_size.height, Some(25)); + assert_eq!(console_size.width, Some(80)); + assert_eq!(section.process.user.uid, 1); + assert_eq!(section.process.user.gid, 1); + assert_eq!(section.process.user.umask, Some(63)); + assert_eq!(section.process.user.additionalGids, Some(vec![5, 6])); + let env = section.process.env.as_ref().unwrap(); + assert_eq!(env.len(), 2); + assert_eq!( + env[0], + "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" + ); + assert_eq!(env[1], "TERM=xterm"); + assert_eq!(section.process.cwd, "/root"); + let args = section.process.args.as_ref().unwrap(); + assert_eq!(args.len(), 1); + assert_eq!(args[0], "sh"); + assert_eq!( + section.process.apparmorProfile, + Some("acme_secure_profile".to_string()) + ); + assert_eq!( + section.process.selinuxLabel, + Some("system_u:system_r:svirt_lxc_net_t:s0:c124,c675".to_string()) + ); + let io_pri = section.process.ioPriority.as_ref().unwrap(); + assert_eq!(io_pri.class, IoPriClass::IoprioClassIdle); + assert_eq!(io_pri.priority, 4); + assert_eq!(section.process.noNewPrivileges, Some(true)); + let caps = section.process.capabilities.as_ref().unwrap(); + let bonding_caps = caps.bounding.as_ref().unwrap(); + assert_eq!(bonding_caps.len(), 3); + assert_eq!(bonding_caps[0], "CAP_AUDIT_WRITE"); + assert_eq!(bonding_caps[1], "CAP_KILL"); + assert_eq!(bonding_caps[2], "CAP_NET_BIND_SERVICE"); + let permitted_caps = caps.permitted.as_ref().unwrap(); + assert_eq!(permitted_caps.len(), 3); + assert_eq!(permitted_caps[0], "CAP_AUDIT_WRITE"); + assert_eq!(permitted_caps[1], "CAP_KILL"); + assert_eq!(permitted_caps[2], "CAP_NET_BIND_SERVICE"); + let inheritable_caps = caps.inheritable.as_ref().unwrap(); + assert_eq!(inheritable_caps.len(), 3); + assert_eq!(inheritable_caps[0], "CAP_AUDIT_WRITE"); + assert_eq!(inheritable_caps[1], "CAP_KILL"); + assert_eq!(inheritable_caps[2], "CAP_NET_BIND_SERVICE"); + let effective_caps = caps.effective.as_ref().unwrap(); + assert_eq!(effective_caps.len(), 2); + assert_eq!(effective_caps[0], "CAP_AUDIT_WRITE"); + assert_eq!(effective_caps[1], "CAP_KILL"); + let ambient_caps = caps.ambient.as_ref().unwrap(); + assert_eq!(ambient_caps.len(), 1); + assert_eq!(ambient_caps[0], "CAP_NET_BIND_SERVICE"); + let rlimits = section.process.rlimits.as_ref().unwrap(); + assert_eq!(rlimits.len(), 1); + assert_eq!(rlimits[0].rlimit_type, "RLIMIT_NOFILE"); + assert_eq!(rlimits[0].hard, 1024); + assert_eq!(rlimits[0].soft, 1024); + let exec_cpu_affinity = section.process.execCPUAffinity.as_ref().unwrap(); + assert_eq!(exec_cpu_affinity.initial, Some("7".to_string())); + assert_eq!(exec_cpu_affinity.final_cpus, Some("0-3,7".to_string())); + } +} diff --git a/ozonec/oci_spec/src/runtime.rs b/ozonec/oci_spec/src/runtime.rs new file mode 100644 index 0000000000000000000000000000000000000000..d68dea0bd7bbee0e50913704401e6a95de754e00 --- /dev/null +++ b/ozonec/oci_spec/src/runtime.rs @@ -0,0 +1,140 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{collections::HashMap, fs::File, io::BufReader, path::Path}; + +use anyhow::{anyhow, Context, Result}; +use serde::{Deserialize, Serialize}; + +#[cfg(target_os = "linux")] +use crate::linux::IdMapping; +#[cfg(target_family = "unix")] +use crate::posix::Root; +use crate::{linux::LinuxPlatform, posix::Hooks, process::Process, vm::VmPlatform}; + +/// Additional mounts beyond root. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Mount { + /// Destination of mount point: path inside container. + pub destination: String, + /// A device name, but can also be a file or directory name for bind mounts + /// or a dummy. + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + /// Mount options of the filesystem to be used. + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option>, + /// The type of the filesystem to be mounted. + #[serde(skip_serializing_if = "Option::is_none", rename = "type")] + pub fs_type: Option, + /// The mapping to convert UIDs from the source file system to the + /// destination mount point. + #[serde(skip_serializing_if = "Option::is_none")] + pub uidMappings: Option, + /// The mapping to convert GIDs from the source file system to the + /// destination mount point. + #[serde(skip_serializing_if = "Option::is_none")] + pub gidMappings: Option, +} + +/// Metadata necessary to implement standard operations against the container. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RuntimeConfig { + /// Version of the Open Container Initiative Runtime Specification + /// with which the bundle complies. + pub ociVersion: String, + /// Container's root filesystem. + pub root: Root, + /// Additional mounts beyond root. + pub mounts: Vec, + /// Container process. + pub process: Process, + /// Container's hostname as seen by processes running inside the container. + #[serde(skip_serializing_if = "Option::is_none")] + pub hostname: Option, + /// Container's domainname as seen by processes running inside the + /// container. + #[serde(skip_serializing_if = "Option::is_none")] + pub domainname: Option, + /// Linux-specific section of the container configuration. + #[cfg(target_os = "linux")] + #[serde(skip_serializing_if = "Option::is_none")] + pub linux: Option, + /// Vm-specific section of the container configuration. + #[serde(skip_serializing_if = "Option::is_none")] + pub vm: Option, + /// Custom actions related to the lifecycle of the container. + #[cfg(target_family = "unix")] + #[serde(skip_serializing_if = "Option::is_none")] + pub hooks: Option, + /// Arbitrary metadata for the container. + #[serde(skip_serializing_if = "Option::is_none")] + pub annotations: Option>, +} + +impl RuntimeConfig { + pub fn from_file(path: &String) -> Result { + let file = File::open(Path::new(path)).with_context(|| "Failed to open config.json")?; + let reader = BufReader::new(file); + serde_json::from_reader(reader).map_err(|e| anyhow!("Failed to load config.json: {:?}", e)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + #[test] + fn test_mounts() { + let json = r#"{ + "mounts": [ + { + "destination": "/proc", + "type": "proc", + "source": "proc" + }, + { + "destination": "/dev", + "type": "tmpfs", + "source": "tmpfs", + "options": [ + "nosuid", + "strictatime", + "mode=755", + "size=65536k" + ] + } + ] + }"#; + + #[allow(non_snake_case)] + #[derive(Serialize, Deserialize)] + struct Section { + mounts: Vec, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.mounts.len(), 2); + assert_eq!(section.mounts[0].destination, "/proc"); + assert_eq!(section.mounts[0].fs_type, Some("proc".to_string())); + assert_eq!(section.mounts[0].source, Some("proc".to_string())); + let options = section.mounts[1].options.as_ref().unwrap(); + assert_eq!(options.len(), 4); + assert_eq!(options[0], "nosuid"); + assert_eq!(options[1], "strictatime"); + assert_eq!(options[2], "mode=755"); + assert_eq!(options[3], "size=65536k"); + } +} diff --git a/ozonec/oci_spec/src/state.rs b/ozonec/oci_spec/src/state.rs new file mode 100644 index 0000000000000000000000000000000000000000..105f128ef92f442cda2b2bfae8b20245ca06a40d --- /dev/null +++ b/ozonec/oci_spec/src/state.rs @@ -0,0 +1,110 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +/// Runtime state of the container. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Copy, Default, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ContainerStatus { + Creating, + Created, + Running, + #[default] + Stopped, +} + +impl ToString for ContainerStatus { + fn to_string(&self) -> String { + match *self { + ContainerStatus::Creating => String::from("creating"), + ContainerStatus::Created => String::from("created"), + ContainerStatus::Running => String::from("running"), + ContainerStatus::Stopped => String::from("stopped"), + } + } +} + +/// The state of a container. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct State { + /// Version of the Open Container Initiative Runtime Specification + /// with which the state complies. + #[serde(default, skip_serializing_if = "String::is_empty")] + pub ociVersion: String, + /// Container's ID. + #[serde(default, skip_serializing_if = "String::is_empty")] + pub id: String, + /// Runtime state of the container. + pub status: ContainerStatus, + /// ID of the container process. + #[serde(default)] + pub pid: i32, + /// Absolute path to the container's bundle directory. + #[serde(default, skip_serializing_if = "String::is_empty")] + pub bundle: String, + /// List of annotations associated with the container. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub annotations: HashMap, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + #[test] + fn test_state() { + let json = r#"{ + "ociVersion": "0.2.0", + "id": "oci-container1", + "status": "running", + "pid": 4422, + "bundle": "/containers/redis", + "annotations": { + "myKey": "myValue" + } + }"#; + + let state: State = serde_json::from_str(json).unwrap(); + assert_eq!(state.ociVersion, "0.2.0"); + assert_eq!(state.id, "oci-container1"); + assert_eq!(state.status, ContainerStatus::Running); + assert_eq!(state.pid, 4422); + assert_eq!(state.bundle, "/containers/redis"); + assert!(state.annotations.contains_key("myKey")); + assert_eq!(state.annotations.get("myKey"), Some(&"myValue".to_string())); + } + + #[test] + fn test_container_status_to_string() { + assert_eq!( + ContainerStatus::Creating.to_string(), + String::from("creating") + ); + assert_eq!( + ContainerStatus::Created.to_string(), + String::from("created") + ); + assert_eq!( + ContainerStatus::Running.to_string(), + String::from("running") + ); + assert_eq!( + ContainerStatus::Stopped.to_string(), + String::from("stopped") + ); + } +} diff --git a/ozonec/oci_spec/src/vm.rs b/ozonec/oci_spec/src/vm.rs new file mode 100644 index 0000000000000000000000000000000000000000..08d1a3548e08afc958bc92f27fa403026a2ebd5f --- /dev/null +++ b/ozonec/oci_spec/src/vm.rs @@ -0,0 +1,136 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use serde::{Deserialize, Serialize}; + +/// Hypervisor that manages the container virtual machine. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Hypervisor { + /// Path to the hypervisor binary that manages the container + /// virtual machine. + pub path: String, + /// Array of parameters to pass to the hypervisor. + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option>, +} + +/// Kernel to boot the container virtual machine with. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Kernel { + /// Path to the kernel used to boot the container virtual machine. + pub path: String, + #[serde(skip_serializing_if = "Option::is_none")] + /// Array of parameters to pass to the kernel. + pub parameters: Option>, + /// Path to an initial ramdisk to be used by the container + /// virtual machine. + #[serde(skip_serializing_if = "Option::is_none")] + pub initrd: Option, +} + +/// Image that contains the root filesystem for the container +/// virtual machine. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Image { + /// Path to the container virtual machine root image. + pub path: String, + /// Format of the container virtual machine root image. + pub format: String, +} + +/// Configuration for the hypervisor, kernel, and image. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct VmPlatform { + /// Hypervisor that manages the container virtual machine. + #[serde(skip_serializing_if = "Option::is_none")] + pub hypervisor: Option, + /// Kernel to boot the container virtual machine with. + pub kernel: Kernel, + /// Image that contains the root filesystem for the container + /// virtual machine. + #[serde(skip_serializing_if = "Option::is_none")] + pub image: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + #[test] + fn test_hypervisor() { + let json = r#"{ + "hypervisor": { + "path": "/path/to/vmm", + "parameters": ["opts1=foo", "opts2=bar"] + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + hypervisor: Hypervisor, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.hypervisor.path, "/path/to/vmm"); + let parameters = section.hypervisor.parameters.as_ref().unwrap(); + assert_eq!(parameters.len(), 2); + assert_eq!(parameters[0], "opts1=foo"); + assert_eq!(parameters[1], "opts2=bar"); + } + + #[test] + fn test_kernel() { + let json = r#"{ + "kernel": { + "path": "/path/to/vmlinuz", + "parameters": ["foo=bar", "hello world"], + "initrd": "/path/to/initrd.img" + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + kernel: Kernel, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.kernel.path, "/path/to/vmlinuz"); + let parameters = section.kernel.parameters.as_ref().unwrap(); + assert_eq!(parameters.len(), 2); + assert_eq!(parameters[0], "foo=bar"); + assert_eq!(parameters[1], "hello world"); + assert_eq!( + section.kernel.initrd, + Some("/path/to/initrd.img".to_string()) + ); + } + + #[test] + fn test_image() { + let json = r#"{ + "image": { + "path": "/path/to/vm/rootfs.img", + "format": "raw" + } + }"#; + + #[derive(Serialize, Deserialize)] + struct Section { + image: Image, + } + + let section: Section = serde_json::from_str(json).unwrap(); + assert_eq!(section.image.path, "/path/to/vm/rootfs.img"); + assert_eq!(section.image.format, "raw"); + } +} diff --git a/ozonec/src/commands/create.rs b/ozonec/src/commands/create.rs new file mode 100644 index 0000000000000000000000000000000000000000..e4c802b000043df8ec061cfb514b9c00b778b1bc --- /dev/null +++ b/ozonec/src/commands/create.rs @@ -0,0 +1,79 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Ok, Result}; +use clap::{builder::NonEmptyStringValueParser, Parser}; + +use crate::container::{Action, Container, Launcher}; +use crate::linux::LinuxContainer; +use oci_spec::runtime::RuntimeConfig; + +/// Create a container from a bundle directory +#[derive(Parser, Debug)] +pub struct Create { + /// File to write the container PID to + #[arg(short, long)] + pub pid_file: Option, + /// Path to the bundle directory, defaults to the current working directory. + #[arg(short, long, default_value = ".")] + pub bundle: PathBuf, + /// Path to an AF_UNIX socket which will receive the pseudoterminal master + /// at a file descriptor. + #[arg(short, long)] + pub console_socket: Option, + /// Container ID to create. + #[arg(value_parser = NonEmptyStringValueParser::new(), required = true)] + pub container_id: String, +} + +impl Create { + fn launcher(&self, root: &Path, exist: &mut bool) -> Result { + let bundle_path = self + .bundle + .canonicalize() + .with_context(|| "Failed to canonicalize bundle path")?; + let config_path = bundle_path + .join("config.json") + .to_string_lossy() + .to_string(); + let mut config = RuntimeConfig::from_file(&config_path)?; + let mut rootfs_path = PathBuf::from(config.root.path); + + if !rootfs_path.is_absolute() { + rootfs_path = bundle_path.join(rootfs_path); + } + config.root.path = rootfs_path.to_string_lossy().to_string(); + + let container: Box = Box::new(LinuxContainer::new( + &self.container_id, + &root.to_string_lossy().to_string(), + &config, + &self.console_socket, + exist, + )?); + Ok(Launcher::new( + &bundle_path, + root, + true, + container, + self.pid_file.clone(), + )) + } + + pub fn run(&self, root: &Path, exist: &mut bool) -> Result<()> { + let mut launcher = self.launcher(root, exist)?; + launcher.launch(Action::Create)?; + Ok(()) + } +} diff --git a/util/src/syscall.rs b/ozonec/src/commands/delete.rs similarity index 31% rename from util/src/syscall.rs rename to ozonec/src/commands/delete.rs index f6088a72f3bdc30b50fc6d45a7cf7ffc55b0ae07..67f712e4dbe23dccde5d5b7433320b5ea5ff0f69 100644 --- a/util/src/syscall.rs +++ b/ozonec/src/commands/delete.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Huawei Technologies Co.,Ltd. All rights reserved. +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. // // StratoVirt is licensed under Mulan PSL v2. // You can use this software according to the terms and conditions of the Mulan @@ -10,48 +10,43 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::{fs, path::Path}; + use anyhow::{bail, Result}; -use libc::{c_void, syscall, SYS_mbind}; +use clap::{builder::NonEmptyStringValueParser, Parser}; -/// This function set memory policy for host NUMA node memory range. -/// -/// * Arguments -/// -/// * `addr` - The memory range starting with addr. -/// * `len` - Length of the memory range. -/// * `mode` - Memory policy mode. -/// * `node_mask` - node_mask specifies physical node ID. -/// * `max_node` - The max node. -/// * `flags` - Mode flags. -pub fn mbind( - addr: u64, - len: u64, - mode: u32, - node_mask: Vec, - max_node: u64, - flags: u32, -) -> Result<()> { - // SAFETY: - // 1. addr is managed by memory mapping, it can be guaranteed legal. - // 2. node_mask was created in function of set_host_memory_policy. - // 3. Upper limit of max_node is MAX_NODES. - let res = unsafe { - syscall( - SYS_mbind, - addr as *mut c_void, - len, - mode, - node_mask.as_ptr(), - max_node + 1, - flags, - ) - }; - if res < 0 { - bail!( - "Failed to apply host numa node policy, error is {}", - std::io::Error::last_os_error() - ); - } +use crate::{ + container::{Container, State}, + linux::LinuxContainer, +}; + +/// Release container resources after the container process has exited +#[derive(Debug, Parser)] +pub struct Delete { + /// Specify the container id + #[arg(value_parser = NonEmptyStringValueParser::new(), required = true)] + pub container_id: String, + /// Force to delete the container (kill the container using SIGKILL) + #[arg(short, long)] + pub force: bool, +} - Ok(()) +impl Delete { + pub fn run(&self, root: &Path) -> Result<()> { + let state_dir = root.join(&self.container_id); + if !state_dir.exists() { + bail!("{} doesn't exist", state_dir.display()); + } + + let state = if let Ok(s) = State::load(root, &self.container_id) { + s + } else { + fs::remove_dir_all(state_dir)?; + return Ok(()); + }; + + let container = LinuxContainer::load_from_state(&state, &None)?; + container.delete(&state, self.force)?; + Ok(()) + } } diff --git a/ozonec/src/commands/exec.rs b/ozonec/src/commands/exec.rs new file mode 100644 index 0000000000000000000000000000000000000000..9ee4e52110ebdfff34bb2c9bf9c26b13e80fc519 --- /dev/null +++ b/ozonec/src/commands/exec.rs @@ -0,0 +1,128 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use core::str; +use std::path::{Path, PathBuf}; + +use anyhow::{anyhow, bail, Context, Result}; +use clap::{builder::NonEmptyStringValueParser, Parser}; +use oci_spec::state::ContainerStatus; + +use crate::{ + container::{Action, Launcher, State}, + linux::LinuxContainer, + utils::OzonecErr, +}; + +/// Execute a new process inside the container +#[derive(Debug, Parser)] +pub struct Exec { + /// Path to an AF_UNIX socket which will receive a file descriptor of the master end + /// of the console's pseudoterminal + #[arg(long)] + pub console_socket: Option, + /// Allocate a pseudio-TTY + #[arg(short, long)] + pub tty: bool, + /// Current working directory in the container + #[arg(long)] + pub cwd: Option, + /// Specify the file to write the process pid to + #[arg(long)] + pub pid_file: Option, + /// Specify environment variables + #[arg(short, long, value_parser = parse_key_val::, number_of_values = 1)] + pub env: Vec<(String, String)>, + /// Prevent the process from gaining additional privileges + #[arg(long)] + pub no_new_privs: bool, + /// Specify the container id + #[arg(value_parser = NonEmptyStringValueParser::new(), required = true)] + pub container_id: String, + /// Specify the command to execute in the container + #[arg(required = false)] + pub command: Vec, +} + +fn parse_key_val(s: &str) -> Result<(T, U)> +where + T: str::FromStr, + T::Err: std::error::Error + Send + Sync + 'static, + U: str::FromStr, + U::Err: std::error::Error + Send + Sync + 'static, +{ + let pos = s + .find('=') + .ok_or_else(|| anyhow!("Invalid KEY=value: no '=' found in '{}'", s))?; + Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) +} + +impl Exec { + fn launcher(&self, root: &Path) -> Result { + let mut container_state = + State::load(root, &self.container_id).with_context(|| OzonecErr::LoadConState)?; + + if let Some(config) = container_state.config.as_mut() { + config.process.terminal = self.tty; + config.process.cwd = if let Some(cwd) = &self.cwd { + cwd.to_string_lossy().to_string() + } else { + String::from("/") + }; + + for (env_name, env_value) in &self.env { + config + .process + .env + .as_mut() + .unwrap() + .push(format!("{}={}", env_name, env_value)); + } + config.process.noNewPrivileges = Some(self.no_new_privs); + config.process.args = Some(self.command.clone()); + } + + let container = LinuxContainer::load_from_state(&container_state, &self.console_socket)?; + let status = container.status()?; + if status != ContainerStatus::Created && status != ContainerStatus::Running { + bail!("Can't exec in container with {:?} state", status); + } + + Ok(Launcher::new( + &container_state.bundle, + root, + false, + Box::new(container), + self.pid_file.clone(), + )) + } + + pub fn run(&self, root: &Path) -> Result<()> { + let mut launcher = self.launcher(root)?; + launcher.launch(Action::Exec)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_key_val() { + let (key, value): (String, String) = parse_key_val("OZONEC_LOG_LEVEL=info").unwrap(); + assert_eq!(key, "OZONEC_LOG_LEVEL"); + assert_eq!(value, "info"); + + assert!(parse_key_val::("OZONEC_LOG_LEVEL").is_err()); + } +} diff --git a/ozonec/src/commands/kill.rs b/ozonec/src/commands/kill.rs new file mode 100644 index 0000000000000000000000000000000000000000..e9ab6350bcbcc463e80f5e5b05e5e562f89c70ba --- /dev/null +++ b/ozonec/src/commands/kill.rs @@ -0,0 +1,71 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{path::Path, str::FromStr}; + +use anyhow::{Context, Result}; +use clap::{builder::NonEmptyStringValueParser, Parser}; +use nix::sys::signal::Signal; + +use crate::{ + container::{Container, State}, + linux::LinuxContainer, +}; + +/// Send a signal to the container process +#[derive(Parser, Debug)] +pub struct Kill { + /// Specify the container id + #[arg(value_parser = NonEmptyStringValueParser::new(), required = true)] + pub container_id: String, + /// The signal to send to the container process + pub signal: String, +} + +impl Kill { + pub fn run(&self, root: &Path) -> Result<()> { + let container_state = State::load(root, &self.container_id)?; + let signal = parse_signal(&self.signal).with_context(|| "Invalid signal")?; + let container = LinuxContainer::load_from_state(&container_state, &None)?; + + container.kill(signal)?; + Ok(()) + } +} + +fn parse_signal(signal: &str) -> Result { + if let Ok(num) = signal.parse::() { + return Ok(Signal::try_from(num)?); + } + + let mut uppercase_sig = signal.to_uppercase(); + if !uppercase_sig.starts_with("SIG") { + uppercase_sig = format!("SIG{}", &uppercase_sig); + } + Ok(Signal::from_str(&uppercase_sig)?) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_signal() { + assert_eq!(parse_signal("9").unwrap(), Signal::SIGKILL); + assert_eq!(parse_signal("sigterm").unwrap(), Signal::SIGTERM); + assert_eq!(parse_signal("SIGBUS").unwrap(), Signal::SIGBUS); + assert_eq!(parse_signal("hup").unwrap(), Signal::SIGHUP); + assert_eq!(parse_signal("ABRT").unwrap(), Signal::SIGABRT); + assert!(parse_signal("100").is_err()); + assert!(parse_signal("ERROR").is_err()); + } +} diff --git a/ozonec/src/commands/mod.rs b/ozonec/src/commands/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..f8096f3f655dda7bab90238a4124aa632f2e762c --- /dev/null +++ b/ozonec/src/commands/mod.rs @@ -0,0 +1,25 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +mod create; +mod delete; +mod exec; +mod kill; +mod start; +mod state; + +pub use create::Create; +pub use delete::Delete; +pub use exec::Exec; +pub use kill::Kill; +pub use start::Start; +pub use state::State; diff --git a/ozonec/src/commands/start.rs b/ozonec/src/commands/start.rs new file mode 100644 index 0000000000000000000000000000000000000000..33ce7dd6c8c55e7389858844d3f4449f17efb076 --- /dev/null +++ b/ozonec/src/commands/start.rs @@ -0,0 +1,56 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::path::Path; + +use anyhow::{bail, Context, Result}; +use clap::Parser; +use oci_spec::state::ContainerStatus; + +use crate::{ + container::{Action, Launcher, State}, + linux::LinuxContainer, + utils::OzonecErr, +}; + +/// Start the user-specified code from process +#[derive(Parser, Debug)] +pub struct Start { + pub container_id: String, +} + +impl Start { + fn launcher(&self, root: &Path) -> Result { + let container_state = + State::load(root, &self.container_id).with_context(|| OzonecErr::LoadConState)?; + let container = LinuxContainer::load_from_state(&container_state, &None)?; + let oci_status = container.status()?; + + if oci_status != ContainerStatus::Created { + bail!("Can't start a container with {:?} status", oci_status); + } + + Ok(Launcher::new( + &container_state.bundle, + root, + false, + Box::new(container), + None, + )) + } + + pub fn run(&self, root: &Path) -> Result<()> { + let mut launcher = self.launcher(root)?; + launcher.launch(Action::Start)?; + Ok(()) + } +} diff --git a/machine_manager/src/config/pvpanic_pci.rs b/ozonec/src/commands/state.rs similarity index 31% rename from machine_manager/src/config/pvpanic_pci.rs rename to ozonec/src/commands/state.rs index d0c3b8723a299639d88433ed1cee2651fd940419..d667694ff0d72595148bee28f5f0e7e272872962 100644 --- a/machine_manager/src/config/pvpanic_pci.rs +++ b/ozonec/src/commands/state.rs @@ -10,57 +10,46 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use crate::config::{CmdParser, ConfigCheck}; -use anyhow::{bail, Context, Result}; +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Result}; +use clap::{builder::NonEmptyStringValueParser, Parser}; use serde::{Deserialize, Serialize}; -pub const PVPANIC_PANICKED: u32 = 1 << 0; -pub const PVPANIC_CRASHLOADED: u32 = 1 << 1; +use crate::{container::State as ContainerState, linux::LinuxContainer}; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PvpanicDevConfig { - pub id: String, - pub supported_features: u32, +/// Request the container state +#[derive(Debug, Parser)] +pub struct State { + /// Specify the container id + #[arg(value_parser = NonEmptyStringValueParser::new(), required = true)] + pub container_id: String, } -impl Default for PvpanicDevConfig { - fn default() -> Self { - PvpanicDevConfig { - id: "".to_string(), - supported_features: PVPANIC_PANICKED | PVPANIC_CRASHLOADED, - } - } +#[derive(Serialize, Deserialize, Debug)] +pub struct RuntimeState { + pub oci_version: String, + pub id: String, + pub status: String, + pub pid: i32, + pub bundle: PathBuf, } -impl ConfigCheck for PvpanicDevConfig { - fn check(&self) -> Result<()> { +impl State { + pub fn run(&self, root: &Path) -> Result<()> { + let state = ContainerState::load(root, &self.container_id)?; + let container = LinuxContainer::load_from_state(&state, &None)?; + let runtime_state = RuntimeState { + oci_version: state.oci_version, + id: state.id, + pid: state.pid, + status: container.status()?.to_string(), + bundle: state.bundle, + }; + let json_data = &serde_json::to_string_pretty(&runtime_state) + .with_context(|| "Failed to get json data of container state")?; + + println!("{}", json_data); Ok(()) } } - -pub fn parse_pvpanic(args_config: &str) -> Result { - let mut cmd_parser = CmdParser::new("pvpanic"); - cmd_parser - .push("") - .push("id") - .push("bus") - .push("addr") - .push("supported-features"); - cmd_parser.parse(args_config)?; - - let mut pvpanicdevcfg = PvpanicDevConfig::default(); - - if let Some(features) = cmd_parser.get_value::("supported-features")? { - pvpanicdevcfg.supported_features = - match features & !(PVPANIC_PANICKED | PVPANIC_CRASHLOADED) { - 0 => features, - _ => bail!("Unsupported pvpanic device features {}", features), - } - } - - pvpanicdevcfg.id = cmd_parser - .get_value::("id")? - .with_context(|| "No id configured for pvpanic device")?; - - Ok(pvpanicdevcfg) -} diff --git a/ozonec/src/container/launcher.rs b/ozonec/src/container/launcher.rs new file mode 100644 index 0000000000000000000000000000000000000000..b9dca85d7fe68c9f57398151cddbe033c8a973da --- /dev/null +++ b/ozonec/src/container/launcher.rs @@ -0,0 +1,132 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +// Linux container create flow: +// ozonec create | State 1 process | Stage 2 process | ozonec start +// | | | +// -> clone3 -> | | | +// <- mapping request <- | | | +// write uid/gid mappings | | | +// -> send mapping done -> | | | +// | set uid/gid | | +// | set pid namespace | | +// <- send stage 2 pid | | -> clone3 -> | +// | exit | set rest namespaces | +// | | pivot_root/chroot | +// | | set capabilities | +// | | set seccomp | +// < send ready <- | | +// | | wait for start signal | +// update pid file | | | ozonec start $id +// exit | | | <- send start signal +// | | execvp cmd | exit + +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Result}; + +use super::{state::State, Container}; +use crate::{linux::Process, utils::OzonecErr}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Action { + Create, + Start, + Exec, +} + +pub struct Launcher { + pub bundle: PathBuf, + pub root: PathBuf, + /// init is set to true when creating a container. + pub init: bool, + pub runner: Box, + pub pid_file: Option, +} + +impl Launcher { + pub fn new( + bundle: &Path, + root: &Path, + init: bool, + runner: Box, + pid_file: Option, + ) -> Self { + Self { + bundle: bundle.to_path_buf(), + root: root.to_path_buf(), + init, + runner, + pid_file, + } + } + + pub fn launch(&mut self, action: Action) -> Result<()> { + if self.init { + self.spawn_container()?; + } else { + self.spawn_process(action)?; + } + + if let Some(pid_file) = self.pid_file.as_ref() { + let pid = self.runner.get_pid(); + std::fs::write(pid_file, format!("{}", pid)).with_context(|| "Failed to write pid")?; + } + + Ok(()) + } + + fn spawn_container(&mut self) -> Result<()> { + self.spawn_process(Action::Create)?; + + let mut state = self + .get_state() + .with_context(|| "Failed to get container state")?; + state.update(); + state.save().with_context(|| "Failed to save state")?; + Ok(()) + } + + fn spawn_process(&mut self, action: Action) -> Result<()> { + let mut process = self.get_process(); + match action { + Action::Create => self.runner.create(&mut process), + Action::Start => self.runner.start(), + Action::Exec => self.runner.exec(&mut process), + } + } + + fn get_process(&self) -> Process { + let config = self.runner.get_config(); + Process::new(&config.process, self.init) + } + + fn get_state(&self) -> Result { + let state = self.runner.get_oci_state()?; + let pid = self.runner.get_pid(); + let proc = + procfs::process::Process::new(pid).with_context(|| OzonecErr::ReadProcPid(pid))?; + let start_time = proc + .stat() + .with_context(|| OzonecErr::ReadProcStat(pid))? + .starttime; + + Ok(State::new( + &self.root, + &self.bundle, + state, + start_time, + *self.runner.created_time(), + self.runner.get_config(), + )) + } +} diff --git a/ozonec/src/container/mod.rs b/ozonec/src/container/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..761e2517ec5380aa6563ffffb1c5737a28fb3324 --- /dev/null +++ b/ozonec/src/container/mod.rs @@ -0,0 +1,47 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +mod launcher; +mod state; + +pub use launcher::{Action, Launcher}; +pub use state::State; + +use std::time::SystemTime; + +use anyhow::Result; +use libc::pid_t; +use nix::sys::signal::Signal; + +use oci_spec::{runtime::RuntimeConfig, state::State as OciState}; + +use crate::linux::Process; + +pub trait Container { + fn get_config(&self) -> &RuntimeConfig; + + fn get_oci_state(&self) -> Result; + + fn get_pid(&self) -> pid_t; + + fn created_time(&self) -> &SystemTime; + + fn create(&mut self, process: &mut Process) -> Result<()>; + + fn start(&mut self) -> Result<()>; + + fn exec(&mut self, process: &mut Process) -> Result<()>; + + fn kill(&self, sig: Signal) -> Result<()>; + + fn delete(&self, state: &State, force: bool) -> Result<()>; +} diff --git a/ozonec/src/container/state.rs b/ozonec/src/container/state.rs new file mode 100644 index 0000000000000000000000000000000000000000..659752a098f2441f377203380128e095c6bb3d80 --- /dev/null +++ b/ozonec/src/container/state.rs @@ -0,0 +1,204 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + fs::{self, DirBuilder, File, OpenOptions}, + os::unix::fs::DirBuilderExt, + path::{Path, PathBuf}, + time::SystemTime, +}; + +use anyhow::{bail, Context, Result}; +use chrono::{DateTime, Utc}; +use libc::pid_t; +use nix::sys::stat::Mode; +use serde::{Deserialize, Serialize}; + +use oci_spec::{runtime::RuntimeConfig, state::State as OciState}; + +use crate::utils::OzonecErr; + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "camelCase")] +pub struct State { + pub oci_version: String, + pub id: String, + pub pid: pid_t, + pub root: PathBuf, + pub bundle: PathBuf, + pub rootfs: String, + pub start_time: u64, + pub created_time: DateTime, + pub config: Option, +} + +impl State { + pub fn new( + root: &Path, + bundle: &Path, + oci_state: OciState, + start_time: u64, + created_time: SystemTime, + config: &RuntimeConfig, + ) -> Self { + Self { + oci_version: oci_state.ociVersion, + id: oci_state.id, + pid: oci_state.pid, + root: root.to_path_buf(), + bundle: bundle.to_path_buf(), + rootfs: config.root.path.clone(), + start_time, + created_time: DateTime::from(created_time), + config: Some(config.clone()), + } + } + + pub fn save(&self) -> Result<()> { + if !&self.root.exists() { + DirBuilder::new() + .recursive(true) + .mode(Mode::S_IRWXU.bits()) + .create(&self.root) + .with_context(|| "Failed to create root directory")?; + } + + let path = Self::file_path(&self.root, &self.id); + let state_file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .with_context(|| OzonecErr::OpenFile(path.to_string_lossy().to_string()))?; + serde_json::to_writer(&state_file, self)?; + Ok(()) + } + + pub fn update(&mut self) { + let linux = self.config.as_mut().unwrap().linux.as_mut(); + if let Some(config) = linux { + for ns in &mut config.namespaces { + if ns.path.is_none() { + let ns_name: String = ns.ns_type.into(); + ns.path = Some(PathBuf::from(format!("/proc/{}/ns/{}", self.pid, ns_name))) + } + } + } + } + + pub fn load(root: &Path, id: &str) -> Result { + let path = Self::file_path(root, id); + if !path.exists() { + bail!("Container {} doesn't exist", id); + } + + let state_file = File::open(&path) + .with_context(|| OzonecErr::OpenFile(path.to_string_lossy().to_string()))?; + let state = serde_json::from_reader(&state_file)?; + Ok(state) + } + + pub fn remove_dir(&self) -> Result<()> { + let state_dir = &self.root.join(&self.id); + fs::remove_dir_all(state_dir).with_context(|| "Failed to remove state directory")?; + Ok(()) + } + + fn file_path(root: &Path, id: &str) -> PathBuf { + root.join(id).join("state.json") + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use fs::{create_dir_all, remove_dir_all}; + use nix::unistd::getpid; + + use crate::linux::container::tests::init_config; + use oci_spec::{ + linux::{Namespace, NamespaceType}, + state::ContainerStatus, + }; + + use super::*; + + fn init_state(root: &Path, id: &str) -> State { + let oci_state = OciState { + ociVersion: String::from("1.2"), + id: String::from(id), + status: ContainerStatus::Created, + pid: 100, + bundle: root.to_string_lossy().to_string(), + annotations: HashMap::new(), + }; + State::new(root, root, oci_state, 0, SystemTime::now(), &init_config()) + } + + #[test] + fn test_state_update() { + let root = "/tmp/ozonec"; + remove_dir_all(root).unwrap_or_default(); + let mut state = init_state(Path::new(root), "test_state_update"); + state + .config + .as_mut() + .unwrap() + .linux + .as_mut() + .unwrap() + .namespaces + .push(Namespace { + ns_type: NamespaceType::Mount, + path: None, + }); + state.pid = getpid().as_raw(); + state.update(); + + for ns in &state + .config + .as_ref() + .unwrap() + .linux + .as_ref() + .unwrap() + .namespaces + { + assert_eq!( + ns.path.as_ref().unwrap().to_str().unwrap(), + format!( + "/proc/{}/ns/{}", + state.pid, + >::into(ns.ns_type) + ) + ); + } + } + + #[test] + fn test_state_load() { + let root = "/tmp/ozonec"; + remove_dir_all(root).unwrap_or_default(); + + let state = init_state(Path::new(root), "test_state_load"); + let dir = PathBuf::from(String::from(root)).join("test_state_load"); + create_dir_all(&dir).unwrap(); + + assert!(state.save().is_ok()); + assert!(dir.join("state.json").exists()); + let loaded_state = State::load(Path::new(root), "test_state_load").unwrap(); + assert_eq!(loaded_state.id, state.id); + assert!(state.remove_dir().is_ok()); + assert!(State::load(Path::new(root), "test_state_load").is_err()); + } +} diff --git a/ozonec/src/linux/apparmor.rs b/ozonec/src/linux/apparmor.rs new file mode 100644 index 0000000000000000000000000000000000000000..1f91c59c48317a46a109d2c8355d9fe65b294a7b --- /dev/null +++ b/ozonec/src/linux/apparmor.rs @@ -0,0 +1,44 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{fs, path::Path}; + +use anyhow::{Context, Result}; + +const APPARMOR_ENABLED_PATH: &str = "/sys/module/apparmor/parameters/enabled"; +const APPARMOR_INTERFACE: &str = "/proc/self/attr/apparmor/exec"; +const APPARMOR_LEGACY_INTERFACE: &str = "/proc/self/attr/exec"; + +pub fn is_enabled() -> Result { + let enabled = fs::read_to_string(APPARMOR_ENABLED_PATH) + .with_context(|| format!("Failed to read {}", APPARMOR_ENABLED_PATH))?; + Ok(enabled.starts_with('Y')) +} + +pub fn apply_profile(profile: &str) -> Result<()> { + if profile.is_empty() { + return Ok(()); + } + + // Try the module specific subdirectory. This is recommended to configure LSMs + // since Linux kernel 5.1. AppArmor has such a directory since Linux kernel 5.8. + match activate_profile(Path::new(APPARMOR_INTERFACE), profile) { + Ok(_) => Ok(()), + Err(_) => activate_profile(Path::new(APPARMOR_LEGACY_INTERFACE), profile) + .with_context(|| "Failed to apply apparmor profile"), + } +} + +fn activate_profile(path: &Path, profile: &str) -> Result<()> { + fs::write(path, format!("exec {}", profile))?; + Ok(()) +} diff --git a/ozonec/src/linux/container.rs b/ozonec/src/linux/container.rs new file mode 100644 index 0000000000000000000000000000000000000000..776a65e7a14a6a274187636074fc6692be7853a0 --- /dev/null +++ b/ozonec/src/linux/container.rs @@ -0,0 +1,1258 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + collections::HashMap, + fs::{self, canonicalize, create_dir_all, OpenOptions}, + io::Write, + os::unix::net::UnixStream, + path::{Path, PathBuf}, + thread::sleep, + time::{Duration, SystemTime}, +}; + +use anyhow::{anyhow, bail, Context, Result}; +use libc::{c_char, pid_t, setdomainname}; +use log::{debug, info}; +use nix::{ + errno::Errno, + mount::MsFlags, + sys::{ + signal::{kill, Signal}, + statfs::statfs, + wait::{waitpid, WaitStatus}, + }, + unistd::{self, chown, getegid, geteuid, sethostname, unlink, Gid, Pid, Uid}, +}; +use procfs::process::ProcState; + +use super::{ + namespace::NsController, + notify_socket::{NotifySocket, NOTIFY_SOCKET}, + process::clone_process, + NotifyListener, Process, +}; +use crate::{ + container::{Container, State}, + linux::{rootfs::Rootfs, seccomp::set_seccomp}, + utils::{prctl, Channel, Message, OzonecErr}, +}; +use oci_spec::{ + linux::{Device as OciDevice, IdMapping, NamespaceType}, + runtime::RuntimeConfig, + state::{ContainerStatus, State as OciState}, +}; + +pub struct LinuxContainer { + pub id: String, + pub root: String, + pub config: RuntimeConfig, + pub pid: pid_t, + pub start_time: u64, + pub created_time: SystemTime, + pub console_socket: Option, +} + +impl LinuxContainer { + pub fn new( + id: &String, + root: &String, + config: &RuntimeConfig, + console_socket: &Option, + exist: &mut bool, + ) -> Result { + let container_dir = format!("{}/{}", root, id); + + Self::validate_config(config)?; + + if Path::new(container_dir.as_str()).exists() { + *exist = true; + bail!("Container {} already exists", id); + } + create_dir_all(container_dir.as_str()) + .with_context(|| OzonecErr::CreateDir(container_dir.clone()))?; + chown(container_dir.as_str(), Some(geteuid()), Some(getegid())) + .with_context(|| "Failed to chown container directory")?; + + Ok(Self { + id: id.clone(), + root: container_dir, + config: config.clone(), + pid: -1, + start_time: 0, + created_time: SystemTime::now(), + console_socket: console_socket.clone(), + }) + } + + pub fn load_from_state(state: &State, console_socket: &Option) -> Result { + let root_path = format!("{}/{}", state.root.to_string_lossy(), &state.id); + let config = state + .config + .clone() + .ok_or_else(|| anyhow!("Can't find config in state"))?; + + Ok(Self { + id: state.id.clone(), + root: root_path, + config, + pid: state.pid, + start_time: state.start_time, + created_time: state.created_time.into(), + console_socket: console_socket.clone(), + }) + } + + fn validate_config(config: &RuntimeConfig) -> Result<()> { + if config.linux.is_none() { + bail!("There is no linux specific configuration in config.json for Linux container"); + } + if config.process.args.is_none() { + bail!("args in process is not set in config.json."); + } + Ok(()) + } + + fn do_first_stage( + &mut self, + process: &mut Process, + parent_channel: &Channel, + fst_stage_channel: &Channel, + notify_listener: &Option, + ) -> Result<()> { + debug!("First stage process start"); + + fst_stage_channel + .receiver + .close() + .with_context(|| "Failed to close receiver end of first stage channel")?; + + process + .set_rlimits() + .with_context(|| "Failed to set rlimit")?; + self.set_user_namespace(parent_channel, fst_stage_channel, process)?; + // New pid namespace goes intto effect in cloned child processes. + self.set_pid_namespace()?; + + // Spawn a child process to perform the second stage to initialize container. + let init_pid = clone_process("ozonec:[2:INIT]", || { + self.do_second_stage(process, parent_channel, notify_listener) + .with_context(|| "Second stage process encounters errors")?; + Ok(0) + })?; + + // Send the final container pid to the parent process. + parent_channel.send_init_pid(init_pid)?; + + debug!("First stage process exit"); + Ok(()) + } + + fn do_second_stage( + &mut self, + process: &mut Process, + parent_channel: &Channel, + notify_listener: &Option, + ) -> Result<()> { + debug!("Second stage process start"); + + unistd::setsid().with_context(|| "Failed to setsid")?; + process + .set_io_priority() + .with_context(|| "Failed to set io priority")?; + process + .set_scheduler() + .with_context(|| "Failed to set scheduler")?; + + let console_stream = match &self.console_socket { + Some(cs) => { + Some(UnixStream::connect(cs).with_context(|| "Failed to connect console socket")?) + } + None => None, + }; + self.set_rest_namespaces()?; + process.set_no_new_privileges()?; + + if process.init { + let propagation = self + .config + .linux + .as_ref() + .unwrap() + .rootfsPropagation + .clone(); + // Container running in a user namespace is not allowed to do mknod. + let mknod_device = !self.is_namespace_set(NamespaceType::User)?; + let mut devices: Vec = Vec::new(); + if let Some(devs) = self.config.linux.as_ref().unwrap().devices.as_ref() { + devices = devs.clone() + }; + let rootfs = Rootfs::new( + self.config.root.path.clone().into(), + propagation, + self.config.mounts.clone(), + mknod_device, + devices, + )?; + rootfs.prepare_rootfs(&self.config)?; + + // Entering into rootfs jail. If mount namespace is specified, use pivot_root. + // Otherwise use chroot. + if self.is_namespace_set(NamespaceType::Mount)? { + Rootfs::pivot_root(&rootfs.path).with_context(|| "Failed to pivot_root")?; + } else { + Rootfs::chroot(&rootfs.path).with_context(|| "Failed to chroot")?; + } + + self.set_sysctl_parameters()?; + } else if !self.is_namespace_set(NamespaceType::Mount)? { + Rootfs::chroot(&PathBuf::from(self.config.root.path.clone())) + .with_context(|| "Failed to chroot")?; + } + + process + .set_tty(console_stream, process.init) + .with_context(|| "Failed to set tty")?; + process.set_apparmor()?; + if self.config.root.readonly { + LinuxContainer::mount_rootfs_readonly()?; + } + self.set_readonly_paths()?; + self.set_masked_paths()?; + + let chdir_cwd_ret = process.chdir_cwd().is_err(); + process.set_additional_gids()?; + process.set_process_id()?; + + // Without setting no new privileges, setting seccomp is a privileged operation. + if !process.no_new_privileges() { + if let Some(seccomp) = &self.config.linux.as_ref().unwrap().seccomp { + set_seccomp(seccomp).with_context(|| "Failed to set seccomp")?; + } + } + process + .reset_capabilities() + .with_context(|| "Failed to reset capabilities")?; + process + .drop_capabilities() + .with_context(|| "Failed to drop capabilities")?; + if chdir_cwd_ret { + process.chdir_cwd()?; + } + // Ensure that the current working directory is inside the mount namespace root + // of the current container process. + Process::getcwd()?; + process.clean_envs(); + process.set_envs(); + if process.no_new_privileges() { + if let Some(seccomp) = &self.config.linux.as_ref().unwrap().seccomp { + set_seccomp(seccomp).with_context(|| "Failed to set seccomp")?; + } + } + + // Tell the parent process that the init process has been cloned. + parent_channel.send_container_created()?; + parent_channel + .sender + .close() + .with_context(|| "Failed to close sender of parent channel")?; + + // Listening on the notify socket to start container. + if let Some(listener) = notify_listener { + listener.wait_for_start_container()?; + listener + .close() + .with_context(|| "Failed to close notify socket")?; + } + process.exec_program(); + } + + fn mount_rootfs_readonly() -> Result<()> { + let ms_flags = MsFlags::MS_RDONLY | MsFlags::MS_REMOUNT | MsFlags::MS_BIND; + let root_path = Path::new("/"); + let fs_flags = statfs(root_path) + .with_context(|| "Statfs root directory error")? + .flags() + .bits(); + + nix::mount::mount( + None::<&str>, + root_path, + None::<&str>, + ms_flags | MsFlags::from_bits_truncate(fs_flags), + None::<&str>, + ) + .with_context(|| "Failed to remount rootfs readonly")?; + Ok(()) + } + + fn get_container_status(&self) -> Result { + if self.pid == -1 { + return Ok(ContainerStatus::Creating); + } + + let proc = procfs::process::Process::new(self.pid); + // If error occurs when accessing /proc/, the process most likely has stopped. + if proc.is_err() { + return Ok(ContainerStatus::Stopped); + } + let proc_stat = proc + .unwrap() + .stat() + .with_context(|| OzonecErr::ReadProcStat(self.pid))?; + // If starttime is not the same, then pid is reused, and the original process has stopped. + if proc_stat.starttime != self.start_time { + return Ok(ContainerStatus::Stopped); + } + + match proc_stat.state()? { + ProcState::Zombie | ProcState::Dead => Ok(ContainerStatus::Stopped), + _ => { + let notify_socket = PathBuf::from(&self.root).join(NOTIFY_SOCKET); + if notify_socket.exists() { + return Ok(ContainerStatus::Created); + } + Ok(ContainerStatus::Running) + } + } + } + + pub fn status(&self) -> Result { + Ok(self + .get_oci_state() + .with_context(|| OzonecErr::GetOciState)? + .status) + } + + fn ns_controller(&self) -> Result { + self.config + .linux + .as_ref() + .unwrap() + .namespaces + .clone() + .try_into() + } + + fn set_user_namespace( + &self, + parent_channel: &Channel, + fst_stage_channel: &Channel, + process: &Process, + ) -> Result<()> { + let ns_controller: NsController = self.ns_controller()?; + + if let Some(ns) = ns_controller.get(NamespaceType::User)? { + ns_controller + .set_namespace(NamespaceType::User) + .with_context(|| "Failed to set user namespace")?; + + if ns.path.is_none() { + // Child process needs to be dumpable, otherwise the parent process is not + // allowed to write the uid/gid mappings. + prctl::set_dumpable(true) + .map_err(|e| anyhow!("Failed to set process dumpable: {e}"))?; + parent_channel + .send_id_mappings() + .with_context(|| "Failed to send id mappings")?; + fst_stage_channel + .recv_id_mappings_done() + .with_context(|| "Failed to receive id mappings done")?; + prctl::set_dumpable(false) + .map_err(|e| anyhow!("Failed to set process undumpable: {e}"))?; + } + + // After UID/GID mappings are configured, ozonec wants to make sure continue as + // the root user inside the new user namespace. This is required because the + // process of configuring the container process will require root, even though + // the root in the user namespace is likely mapped to an non-privileged user. + process.set_id(Gid::from_raw(0), Uid::from_raw(0))?; + } + Ok(()) + } + + fn is_namespace_set(&self, ns_type: NamespaceType) -> Result { + let ns_controller: NsController = self.ns_controller()?; + Ok(ns_controller.get(ns_type)?.is_some()) + } + + fn set_pid_namespace(&self) -> Result<()> { + let ns_controller = self.ns_controller()?; + + if ns_controller.get(NamespaceType::Pid)?.is_some() { + ns_controller + .set_namespace(NamespaceType::Pid) + .with_context(|| "Failed to set pid namespace")?; + } + Ok(()) + } + + fn set_readonly_paths(&self) -> Result<()> { + if let Some(readonly_paths) = self.config.linux.as_ref().unwrap().readonlyPaths.clone() { + for p in readonly_paths { + let path = Path::new(&p); + if let Err(e) = nix::mount::mount( + Some(path), + path, + None::<&str>, + MsFlags::MS_BIND | MsFlags::MS_REC, + None::<&str>, + ) { + if matches!(e, Errno::ENOENT) { + return Ok(()); + } + bail!("Failed to make {} as recursive bind mount", path.display()); + } + + nix::mount::mount( + Some(path), + path, + None::<&str>, + MsFlags::MS_NOSUID + | MsFlags::MS_NODEV + | MsFlags::MS_NOEXEC + | MsFlags::MS_BIND + | MsFlags::MS_REMOUNT + | MsFlags::MS_RDONLY, + None::<&str>, + ) + .with_context(|| format!("Failed to remount {} readonly", path.display()))?; + } + } + Ok(()) + } + + fn set_masked_paths(&self) -> Result<()> { + let linux = self.config.linux.as_ref().unwrap(); + if let Some(masked_paths) = linux.maskedPaths.clone() { + for p in masked_paths { + let path = Path::new(&p); + if let Err(e) = nix::mount::mount( + Some(Path::new("/dev/null")), + path, + None::<&str>, + MsFlags::MS_BIND, + None::<&str>, + ) { + match e { + // Ignore if path doesn't exists. + Errno::ENOENT => (), + Errno::ENOTDIR => { + let label = match linux.mountLabel.clone() { + Some(l) => format!("context=\"{}\"", l), + None => "".to_string(), + }; + nix::mount::mount( + Some(Path::new("tmpfs")), + path, + Some("tmpfs"), + MsFlags::MS_RDONLY, + Some(label.as_str()), + ) + .with_context(|| { + format!( + "Failed to make {} as masked mount by tmpfs", + path.display() + ) + })?; + } + _ => bail!( + "Failed to make {} as masked mount by /dev/null", + path.display() + ), + } + } + } + } + Ok(()) + } + + fn set_rest_namespaces(&self) -> Result<()> { + let ns_config = &self.config.linux.as_ref().unwrap().namespaces; + let ns_controller: NsController = ns_config.clone().try_into()?; + let mut mnt_ns = false; + + for ns in ns_config { + match ns.ns_type { + // User namespace and pid namespace have been set in the first stage. + // Mount namespace is going to be set later to avoid failure with + // existed namespaces. + NamespaceType::User | NamespaceType::Pid => (), + NamespaceType::Mount => mnt_ns = true, + _ => ns_controller.set_namespace(ns.ns_type).with_context(|| { + format!( + "Failed to set {} namespace", + >::into(ns.ns_type) + ) + })?, + } + + if ns.ns_type == NamespaceType::Uts && ns.path.is_none() { + if let Some(hostname) = &self.config.hostname { + sethostname(hostname).with_context(|| "Failed to set hostname")?; + } + if let Some(domainname) = &self.config.domainname { + // SAFETY: FFI call with valid arguments. + let errno = match unsafe { + setdomainname( + domainname.as_bytes().as_ptr() as *const c_char, + domainname.len(), + ) + } { + 0 => return Ok(()), + -1 => nix::Error::last(), + _ => nix::Error::UnknownErrno, + }; + bail!("Failed to set domainname: {}", errno); + } + } + } + + if mnt_ns { + ns_controller + .set_namespace(NamespaceType::Mount) + .with_context(|| "Failed to set mount namespace")?; + } + Ok(()) + } + + fn set_id_mappings( + &self, + parent_channel: &Channel, + fst_stage_channel: &Channel, + fst_stage_pid: &Pid, + ) -> Result<()> { + parent_channel + .recv_id_mappings() + .with_context(|| "Failed to receive id mappings")?; + LinuxContainer::set_groups(fst_stage_pid, false) + .with_context(|| "Failed to disable setting groups")?; + + if let Some(linux) = self.config.linux.as_ref() { + if let Some(uid_mappings) = linux.uidMappings.as_ref() { + self.write_id_mapping(uid_mappings, fst_stage_pid, "uid_map")?; + } + if let Some(gid_mappings) = linux.gidMappings.as_ref() { + self.write_id_mapping(gid_mappings, fst_stage_pid, "gid_map")?; + } + } + + fst_stage_channel + .send_id_mappings_done() + .with_context(|| "Failed to send id mapping done")?; + fst_stage_channel + .sender + .close() + .with_context(|| "Failed to close fst_stage_channel sender")?; + Ok(()) + } + + fn write_id_mapping(&self, mappings: &Vec, pid: &Pid, file: &str) -> Result<()> { + let path = format!("/proc/{}/{}", pid.as_raw(), file); + let mut opened_file = OpenOptions::new() + .write(true) + .open(&path) + .with_context(|| OzonecErr::OpenFile(path))?; + let mut id_mappings = String::from(""); + + for m in mappings { + let mapping = format!("{} {} {}\n", m.containerID, m.hostID, m.size); + id_mappings = id_mappings + &mapping; + } + opened_file + .write_all(id_mappings.as_bytes()) + .with_context(|| "Failed to write id mappings")?; + Ok(()) + } + + fn set_groups(pid: &Pid, allow: bool) -> Result<()> { + let path = format!("/proc/{}/setgroups", pid.as_raw()); + if allow { + std::fs::write(&path, "allow")? + } else { + std::fs::write(&path, "deny")? + } + Ok(()) + } + + fn set_sysctl_parameters(&self) -> Result<()> { + if let Some(sysctl_params) = self.config.linux.as_ref().unwrap().sysctl.clone() { + let sys_path = PathBuf::from("/proc/sys"); + for (param, value) in sysctl_params { + let path = sys_path.join(param.replace('.', "/")); + fs::write(&path, value.as_bytes()) + .with_context(|| format!("Failed to set {} to {}", path.display(), value))?; + } + } + Ok(()) + } +} + +impl Container for LinuxContainer { + fn get_config(&self) -> &RuntimeConfig { + &self.config + } + + fn get_pid(&self) -> pid_t { + self.pid + } + + fn created_time(&self) -> &SystemTime { + &self.created_time + } + + fn get_oci_state(&self) -> Result { + let status = self.get_container_status()?; + let pid = if status != ContainerStatus::Stopped { + self.pid + } else { + 0 + }; + + let rootfs = canonicalize(self.config.root.path.clone()) + .with_context(|| "Failed to canonicalize root path")?; + let bundle = match rootfs.parent() { + Some(p) => p + .to_str() + .ok_or_else(|| anyhow!("root path is not valid unicode"))? + .to_string(), + None => bail!("Failed to get bundle directory"), + }; + let annotations = if let Some(a) = self.config.annotations.clone() { + a + } else { + HashMap::new() + }; + Ok(OciState { + ociVersion: self.config.ociVersion.clone(), + id: self.id.clone(), + status, + pid, + bundle, + annotations, + }) + } + + fn create(&mut self, process: &mut Process) -> Result<()> { + // Create notify socket to notify the container process to start. + let notify_listener = if process.init { + Some(NotifyListener::new(PathBuf::from(&self.root))?) + } else { + None + }; + + // As /proc/self/oom_score_adj is not allowed to write unless privileged, + // set oom_score_adj before setting process undumpable. + process + .set_oom_score_adj() + .with_context(|| "Failed to set oom_score_adj")?; + + // Make the process undumpable to avoid various race conditions that could cause + // processes in namespaces to join to access host resources (or execute code). + if !self.config.linux.as_ref().unwrap().namespaces.is_empty() { + prctl::set_dumpable(false) + .map_err(|e| anyhow!("Failed to set process undumpable: errno {}", e))?; + } + + // Create channels to communicate with child processes. + let parent_channel = Channel::::new() + .with_context(|| "Failed to create message channel for parent process")?; + let fst_stage_channel = Channel::::new()?; + // Set receivers timeout: 50ms. + parent_channel.receiver.set_timeout(50000)?; + fst_stage_channel.receiver.set_timeout(50000)?; + + // Spawn a child process to perform Stage 1. + let fst_stage_pid = clone_process("ozonec:[1:CHILD]", || { + self.do_first_stage( + process, + &parent_channel, + &fst_stage_channel, + ¬ify_listener, + ) + .with_context(|| "First stage process encounters errors")?; + Ok(0) + })?; + + if self.is_namespace_set(NamespaceType::User)? { + self.set_id_mappings(&parent_channel, &fst_stage_channel, &fst_stage_pid)?; + } + + let init_pid = parent_channel + .recv_init_pid() + .with_context(|| "Failed to receive init pid")?; + parent_channel.recv_container_created()?; + parent_channel + .receiver + .close() + .with_context(|| "Failed to close receiver end of parent channel")?; + + self.pid = init_pid.as_raw(); + self.start_time = procfs::process::Process::new(self.pid) + .with_context(|| OzonecErr::ReadProcPid(self.pid))? + .stat() + .with_context(|| OzonecErr::ReadProcStat(self.pid))? + .starttime; + + match waitpid(fst_stage_pid, None) { + Ok(WaitStatus::Exited(_, 0)) => (), + Ok(WaitStatus::Exited(_, s)) => { + info!("First stage process exits with status: {}", s); + } + Ok(WaitStatus::Signaled(_, sig, _)) => { + info!("First stage process killed by signal: {}", sig) + } + Ok(_) => (), + Err(Errno::ECHILD) => { + info!("First stage process has already been reaped"); + } + Err(e) => { + bail!("Failed to waitpid for first stage process: {e}"); + } + } + Ok(()) + } + + fn start(&mut self) -> Result<()> { + let path = PathBuf::from(&self.root).join(NOTIFY_SOCKET); + let mut notify_socket = NotifySocket::new(&path); + + notify_socket.notify_container_start()?; + unlink(&path).with_context(|| "Failed to delete notify.sock")?; + self.start_time = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .with_context(|| "Failed to get start time")? + .as_secs(); + Ok(()) + } + + fn exec(&mut self, process: &mut Process) -> Result<()> { + // process.init is false. + self.create(process)?; + Ok(()) + } + + fn kill(&self, sig: Signal) -> Result<()> { + let mut status = self.status()?; + if status == ContainerStatus::Stopped { + bail!("The container is already stopped"); + } + if status == ContainerStatus::Creating { + bail!("The container has not been created"); + } + + let pid = Pid::from_raw(self.pid); + match kill(pid, None) { + Err(errno) => { + if errno != Errno::ESRCH { + bail!("Failed to kill process {}: {:?}", pid, errno); + } + } + Ok(_) => kill(pid, sig)?, + } + + let mut _retry = 0; + status = self.status()?; + while status != ContainerStatus::Stopped { + sleep(Duration::from_millis(1)); + if _retry > 3 { + bail!("The container is still not stopped."); + } + status = self.status()?; + _retry += 1; + } + Ok(()) + } + + fn delete(&self, state: &State, force: bool) -> Result<()> { + match self.status()? { + ContainerStatus::Stopped => state.remove_dir()?, + _ => { + if force { + self.kill(Signal::SIGKILL) + .with_context(|| "Failed to kill the container by force")?; + state.remove_dir()?; + } else { + bail!( + "Failed to delete container {} which is not stopped", + &state.id + ); + } + } + } + Ok(()) + } +} + +#[cfg(test)] +pub mod tests { + use std::ffi::CStr; + + use chrono::DateTime; + use fs::{read_to_string, remove_dir_all, File}; + use libc::getdomainname; + use nix::sys::stat::stat; + use rusty_fork::rusty_fork_test; + use unistd::{gethostname, getpid}; + + use crate::linux::{ + mount::Mount, namespace::tests::set_namespace, process::tests::init_oci_process, + }; + use oci_spec::{ + linux::{LinuxPlatform, Namespace}, + posix::{Root, User}, + process::Process as OciProcess, + runtime::Mount as OciMount, + }; + + use super::*; + + pub fn init_config() -> RuntimeConfig { + let root = Root { + path: String::from("/tmp/ozonec/bundle/rootfs"), + readonly: true, + }; + let user = User { + uid: 0, + gid: 0, + umask: None, + additionalGids: None, + }; + let process = OciProcess { + cwd: String::from("/"), + args: Some(vec![String::from("bash")]), + env: None, + terminal: false, + consoleSize: None, + rlimits: None, + apparmorProfile: None, + capabilities: None, + noNewPrivileges: None, + oomScoreAdj: None, + scheduler: None, + selinuxLabel: None, + ioPriority: None, + execCPUAffinity: None, + user, + }; + let linux = LinuxPlatform { + namespaces: Vec::new(), + uidMappings: None, + gidMappings: None, + timeOffsets: None, + devices: None, + cgroupsPath: None, + rootfsPropagation: None, + maskedPaths: None, + readonlyPaths: None, + mountLabel: None, + personality: None, + resources: None, + rdma: None, + unified: None, + sysctl: None, + seccomp: None, + #[cfg(target_arch = "x86_64")] + intelRdt: None, + }; + RuntimeConfig { + ociVersion: String::from("1.2"), + root, + mounts: Vec::new(), + process, + hostname: None, + domainname: None, + linux: Some(linux), + vm: None, + hooks: None, + annotations: None, + } + } + + #[test] + fn test_linux_container_new() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + let config = init_config(); + let mut exist: bool = false; + let container = LinuxContainer::new( + &String::from("LinuxContainer_new"), + &String::from("/tmp/ozonec"), + &config, + &None, + &mut exist, + ) + .unwrap(); + + let root = Path::new(&container.root); + assert!(root.exists()); + let root_stat = stat(root).unwrap(); + assert_eq!(root_stat.st_uid, geteuid().as_raw()); + assert_eq!(root_stat.st_gid, getegid().as_raw()); + + assert!(LinuxContainer::new( + &String::from("LinuxContainer_new"), + &String::from("/tmp/ozonec"), + &config, + &None, + &mut exist, + ) + .is_err()); + assert_eq!(exist, true); + } + + #[test] + fn test_validate_config() { + let mut config = init_config(); + config.linux = None; + assert!(LinuxContainer::validate_config(&config).is_err()); + + let linux = LinuxPlatform { + namespaces: Vec::new(), + uidMappings: None, + gidMappings: None, + timeOffsets: None, + devices: None, + cgroupsPath: None, + rootfsPropagation: None, + maskedPaths: None, + readonlyPaths: None, + mountLabel: None, + personality: None, + resources: None, + rdma: None, + unified: None, + sysctl: None, + seccomp: None, + #[cfg(target_arch = "x86_64")] + intelRdt: None, + }; + config.process.args = None; + config.linux = Some(linux); + assert!(LinuxContainer::validate_config(&config).is_err()); + } + + #[test] + fn test_load_from_state() { + let mut state = State { + oci_version: String::from("1.2"), + id: String::from("load_from_state"), + pid: 0, + root: PathBuf::from("/tmp/ozonec/root"), + bundle: PathBuf::from("/tmp/ozonec/bundle"), + rootfs: String::from("/tmp/ozonec/bundle/rootfs"), + start_time: 0, + created_time: DateTime::from(SystemTime::now()), + config: None, + }; + assert!(LinuxContainer::load_from_state(&state, &None).is_err()); + + let config = init_config(); + state.config = Some(config); + assert!(LinuxContainer::load_from_state(&state, &None).is_ok()); + } + + #[test] + fn test_status() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + let config = init_config(); + create_dir_all(&config.root.path).unwrap(); + let mut exist: bool = false; + let mut container = LinuxContainer::new( + &String::from("get_container_status"), + &String::from("/tmp/ozonec"), + &config, + &None, + &mut exist, + ) + .unwrap(); + container.pid = -1; + + assert_eq!(container.status().unwrap(), ContainerStatus::Creating); + + container.pid = 0; + assert_eq!(container.status().unwrap(), ContainerStatus::Stopped); + + container.pid = getpid().as_raw(); + assert_eq!(container.status().unwrap(), ContainerStatus::Stopped); + + let proc_stat = procfs::process::Process::new(container.pid) + .unwrap() + .stat() + .unwrap(); + container.start_time = proc_stat.starttime; + assert_eq!(container.status().unwrap(), ContainerStatus::Running); + + let notify_socket = PathBuf::from(&container.root).join(NOTIFY_SOCKET); + File::create(¬ify_socket).unwrap(); + assert_eq!(container.status().unwrap(), ContainerStatus::Created); + } + + #[test] + fn test_is_namespace_set() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + let mut config = init_config(); + config.linux.as_mut().unwrap().namespaces.push(Namespace { + ns_type: NamespaceType::Mount, + path: None, + }); + let mut exist = false; + let container = LinuxContainer::new( + &String::from("test_is_namespace_set"), + &String::from("/tmp/ozonec/test_is_namespace_set"), + &config, + &None, + &mut exist, + ) + .unwrap(); + + assert!(container.is_namespace_set(NamespaceType::Mount).unwrap()); + assert!(!container.is_namespace_set(NamespaceType::User).unwrap()); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_pid_namespace() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + let mut config = init_config(); + config.linux.as_mut().unwrap().namespaces.push(Namespace { + ns_type: NamespaceType::Pid, + path: None, + }); + let mut exist = false; + let container = LinuxContainer::new( + &String::from("test_set_pid_namespace"), + &String::from("/tmp/ozonec/test_set_pid_namespace"), + &config, + &None, + &mut exist, + ) + .unwrap(); + + assert!(container.set_pid_namespace().is_ok()); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_id_mappings() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + let mut config = init_config(); + let linux = config.linux.as_mut().unwrap(); + linux.namespaces = vec![Namespace { + ns_type: NamespaceType::User, + path: None, + }]; + linux.uidMappings = Some(vec![IdMapping { + containerID: 0, + hostID: 0, + size: 1000, + }]); + linux.gidMappings = Some(vec![IdMapping { + containerID: 0, + hostID: 0, + size: 1000, + }]); + let mut exist = false; + let container = LinuxContainer::new( + &String::from("test_set_id_mappings"), + &String::from("/tmp/ozonec/test_set_id_mappings"), + &config, + &None, + &mut exist, + ) + .unwrap(); + + let fst_channel = Channel::::new().unwrap(); + let sec_channel = Channel::::new().unwrap(); + let child = clone_process("test_set_id_mappings", || { + let process = Process::new(&init_oci_process(), false); + assert!(container + .set_user_namespace(&fst_channel, &sec_channel, &process) + .is_ok()); + Ok(1) + }) + .unwrap(); + + assert!(container + .set_id_mappings(&fst_channel, &sec_channel, &child) + .is_ok()); + let path = format!("/proc/{}/setgroups", child.as_raw().to_string()); + let setgroups = fs::read_to_string(path).unwrap(); + assert_eq!(setgroups.trim(), "deny"); + let path = format!("/proc/{}/uid_map", child.as_raw().to_string()); + let uid_map = fs::read_to_string(path).unwrap(); + let mut iter = uid_map.split_ascii_whitespace(); + assert_eq!(iter.next(), Some("0")); + assert_eq!(iter.next(), Some("0")); + assert_eq!(iter.next(), Some("1000")); + assert_eq!(iter.next(), None); + let path = format!("/proc/{}/gid_map", child.as_raw().to_string()); + let gid_map = fs::read_to_string(path).unwrap(); + let mut iter = gid_map.split_ascii_whitespace(); + assert_eq!(iter.next(), Some("0")); + assert_eq!(iter.next(), Some("0")); + assert_eq!(iter.next(), Some("1000")); + assert_eq!(iter.next(), None); + + match waitpid(child, None) { + Ok(WaitStatus::Exited(_, s)) => { + assert_eq!(s, 1); + } + Ok(_) => (), + Err(e) => { + panic!("Failed to waitpid for child process: {e}"); + } + } + } + + rusty_fork_test! { + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_readonly_paths() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + let root = PathBuf::from("/tmp/ozonec/test_set_readonly_paths"); + let mut config = init_config(); + let path = root.to_string_lossy().to_string(); + config.linux.as_mut().unwrap().readonlyPaths = Some(vec![path.clone()]); + let mut exist = false; + let container = LinuxContainer::new( + &String::from("test_set_readonly_paths"), + &root.to_string_lossy().to_string(), + &config, + &None, + &mut exist, + ) + .unwrap(); + File::create(root.join("test")).unwrap(); + + assert!(container.set_readonly_paths().is_ok()); + let path = PathBuf::from(path).join("test"); + assert!(File::create(&path).is_err()); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_masked_paths() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + let root = PathBuf::from("/tmp/ozonec/test_set_masked_paths"); + let mut config = init_config(); + config.linux.as_mut().unwrap().maskedPaths = Some(vec![root.to_string_lossy().to_string()]); + let mut exist = false; + let container = LinuxContainer::new( + &String::from("test_set_masked_paths"), + &root.to_string_lossy().to_string(), + &config, + &None, + &mut exist, + ) + .unwrap(); + + File::create(root.join("test")).unwrap(); + assert!(container.set_masked_paths().is_ok()); + assert!(!root.join("test").exists()); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_rest_namespaces() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + let root = PathBuf::from("/tmp/ozonec/test_set_rest_namespaces"); + let mut config = init_config(); + config.linux.as_mut().unwrap().namespaces = vec![ + Namespace { + ns_type: NamespaceType::User, + path: None, + }, + Namespace { + ns_type: NamespaceType::Uts, + path: None, + }, + ]; + config.hostname = Some(String::from("test_set_rest_namespaces")); + config.domainname = Some(String::from("test_set_rest_namespaces")); + let mut exist = false; + let container = LinuxContainer::new( + &String::from("test_set_rest_namespaces"), + &root.to_string_lossy().to_string(), + &config, + &None, + &mut exist, + ) + .unwrap(); + + assert!(container.set_rest_namespaces().is_ok()); + assert_eq!( + gethostname().unwrap().to_str().unwrap(), + "test_set_rest_namespaces" + ); + let len = 100; + let mut domain: Vec = Vec::with_capacity(len); + unsafe { + getdomainname(domain.as_mut_ptr().cast(), len); + // Ensure always null-terminated. + domain.as_mut_ptr().wrapping_add(len - 1).write(0); + let len = CStr::from_ptr(domain.as_ptr().cast()).to_bytes().len(); + domain.set_len(len); + } + assert_eq!(String::from_utf8_lossy(&domain), "test_set_rest_namespaces"); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_sysctl_parameters() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + let root = PathBuf::from("/tmp/ozonec/test_set_sysctl_parameters"); + let mut config = init_config(); + config.linux.as_mut().unwrap().sysctl = Some(HashMap::new()); + let sysctl = &mut config.linux.as_mut().unwrap().sysctl; + sysctl + .as_mut() + .unwrap() + .insert(String::from("vm.oom_dump_tasks"), String::from("0")); + + let mut exist = false; + let container = LinuxContainer::new( + &String::from("test_set_sysctl_parameters"), + &root.to_string_lossy().to_string(), + &config, + &None, + &mut exist, + ) + .unwrap(); + + let mounts = vec![OciMount { + destination: String::from("/proc"), + source: Some(String::from("proc")), + options: None, + fs_type: Some(String::from("proc")), + uidMappings: None, + gidMappings: None, + }]; + let mnt = Mount::new(&root); + mnt.do_mounts(&mounts, &None).unwrap(); + + assert!(container.set_sysctl_parameters().is_ok()); + assert_eq!(read_to_string("/proc/sys/vm/oom_dump_tasks").unwrap().trim(), "0"); + } + } +} diff --git a/ozonec/src/linux/device.rs b/ozonec/src/linux/device.rs new file mode 100644 index 0000000000000000000000000000000000000000..f48d87f89e7553472e6656da126efaf01db17273 --- /dev/null +++ b/ozonec/src/linux/device.rs @@ -0,0 +1,438 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + fs::{create_dir_all, remove_file, File}, + path::{Path, PathBuf}, +}; + +use anyhow::{anyhow, bail, Context, Result}; +use nix::{ + mount::MsFlags, + sys::stat::{makedev, mknod, Mode, SFlag}, + unistd::{chown, Gid, Uid}, +}; +use oci_spec::linux::Device as OciDevice; + +use crate::utils::OzonecErr; + +pub struct Device { + rootfs: PathBuf, +} + +impl Device { + pub fn new(rootfs: PathBuf) -> Self { + Self { rootfs } + } + + pub fn default_devices(&self) -> Vec { + vec![ + DeviceInfo { + path: self.rootfs.join("dev/null"), + dev_type: "c".to_string(), + major: 1, + minor: 3, + file_mode: Some(0o666u32), + uid: None, + gid: None, + }, + DeviceInfo { + path: self.rootfs.join("dev/zero"), + dev_type: "c".to_string(), + major: 1, + minor: 5, + file_mode: Some(0o666u32), + uid: None, + gid: None, + }, + DeviceInfo { + path: self.rootfs.join("dev/full"), + dev_type: "c".to_string(), + major: 1, + minor: 7, + file_mode: Some(0o666u32), + uid: None, + gid: None, + }, + DeviceInfo { + path: self.rootfs.join("dev/random"), + dev_type: "c".to_string(), + major: 1, + minor: 8, + file_mode: Some(0o666u32), + uid: None, + gid: None, + }, + DeviceInfo { + path: self.rootfs.join("dev/urandom"), + dev_type: "c".to_string(), + major: 1, + minor: 9, + file_mode: Some(0o666u32), + uid: None, + gid: None, + }, + DeviceInfo { + path: self.rootfs.join("dev/tty"), + dev_type: "c".to_string(), + major: 5, + minor: 0, + file_mode: Some(0o666u32), + uid: None, + gid: None, + }, + ] + } + + fn create_device_dir(&self, path: &PathBuf) -> Result<()> { + let dir = Path::new(path) + .parent() + .ok_or_else(|| anyhow!("Failed to get parent directory: {}", path.display()))?; + if !dir.exists() { + create_dir_all(dir) + .with_context(|| OzonecErr::CreateDir(dir.to_string_lossy().to_string()))?; + } + Ok(()) + } + + fn get_sflag(&self, dev_type: &str) -> Result { + let sflag = match dev_type { + "c" => SFlag::S_IFCHR, + "b" => SFlag::S_IFBLK, + "u" => SFlag::S_IFCHR, + "p" => SFlag::S_IFIFO, + _ => bail!("Not supported device type: {}", dev_type), + }; + Ok(sflag) + } + + fn bind_device(&self, dev: &DeviceInfo) -> Result<()> { + self.create_device_dir(&dev.path)?; + + let binding = dev.path.to_string_lossy().to_string(); + let stripped_path = binding + .strip_prefix(&self.rootfs.to_string_lossy().to_string()) + .ok_or_else(|| anyhow!("Invalid device path"))?; + let src_path = PathBuf::from(stripped_path); + + if !dev.path.exists() { + File::create(&dev.path) + .with_context(|| format!("Failed to create {}", dev.path.display()))?; + } + nix::mount::mount( + Some(&src_path), + &dev.path, + Some("bind"), + MsFlags::MS_BIND, + None::<&str>, + ) + .with_context(|| OzonecErr::Mount(stripped_path.to_string()))?; + + Ok(()) + } + + fn mknod_device(&self, dev: &DeviceInfo) -> Result<()> { + self.create_device_dir(&dev.path)?; + + let sflag = self.get_sflag(&dev.dev_type)?; + let device = makedev(dev.major as u64, dev.minor as u64); + mknod( + &dev.path, + sflag, + Mode::from_bits_truncate(dev.file_mode.unwrap_or(0)), + device, + )?; + chown( + &dev.path, + dev.uid.map(Uid::from_raw), + dev.gid.map(Gid::from_raw), + ) + .with_context(|| "Failed to chown")?; + + Ok(()) + } + + pub fn create_default_devices(&self, mknod: bool) -> Result<()> { + let default_devs = self.default_devices(); + for dev in default_devs { + if mknod { + if self.mknod_device(&dev).is_err() { + self.bind_device(&dev).with_context(|| { + OzonecErr::BindDev(dev.path.to_string_lossy().to_string()) + })?; + } + } else { + self.bind_device(&dev) + .with_context(|| OzonecErr::BindDev(dev.path.to_string_lossy().to_string()))?; + } + } + Ok(()) + } + + pub fn is_default_device(&self, dev: &OciDevice) -> bool { + for d in &self.default_devices() { + let path = self.rootfs.join(&dev.path.clone()[1..]); + if path == d.path { + return true; + } + } + false + } + + pub fn delete_device(&self, dev: &OciDevice) -> Result<()> { + let path = self.rootfs.join(&dev.path.clone()[1..]); + remove_file(&path).with_context(|| format!("Failed to delete {}", path.display()))?; + Ok(()) + } + + pub fn create_device(&self, dev: &OciDevice, mknod: bool) -> Result<()> { + let path = self.rootfs.join(&dev.path.clone()[1..]); + let major = dev + .major + .ok_or_else(|| anyhow!("major not set for device {}", dev.path))?; + let minor = dev + .minor + .ok_or_else(|| anyhow!("minor not set for device {}", dev.path))?; + let dev_info = DeviceInfo { + path, + dev_type: dev.dev_type.clone(), + major, + minor, + file_mode: dev.fileMode, + uid: dev.uid, + gid: dev.gid, + }; + + if mknod { + if self.mknod_device(&dev_info).is_err() { + self.bind_device(&dev_info).with_context(|| { + OzonecErr::BindDev(dev_info.path.to_string_lossy().to_string()) + })?; + } + } else { + self.bind_device(&dev_info) + .with_context(|| OzonecErr::BindDev(dev_info.path.to_string_lossy().to_string()))?; + } + Ok(()) + } +} + +pub struct DeviceInfo { + pub path: PathBuf, + dev_type: String, + major: i64, + minor: i64, + file_mode: Option, + uid: Option, + gid: Option, +} + +#[cfg(test)] +mod tests { + use std::{ + fs, + os::unix::fs::{FileTypeExt, MetadataExt, PermissionsExt}, + }; + + use nix::mount::umount; + + use super::*; + + #[test] + #[ignore = "mount may not be permitted"] + fn test_mknod_dev() { + let rootfs = PathBuf::from("/tmp/ozonec/mknod_dev"); + create_dir_all(&rootfs).unwrap(); + let dev = Device::new(rootfs.clone()); + let path = rootfs.join("mknod_dev"); + if path.exists() { + remove_file(&path).unwrap(); + } + let dev_info = DeviceInfo { + path: path.clone(), + dev_type: "c".to_string(), + major: 1, + minor: 3, + file_mode: Some(0o644u32), + uid: Some(1000u32), + gid: Some(1000u32), + }; + + assert!(dev.mknod_device(&dev_info).is_ok()); + assert!(path.exists()); + + let metadata = fs::metadata(&path).unwrap(); + assert!(metadata.file_type().is_char_device()); + let major = (metadata.rdev() >> 8) as u32; + let minor = (metadata.rdev() & 0xff) as u32; + assert_eq!(major, 1); + assert_eq!(minor, 3); + let file_mode = metadata.permissions().mode(); + assert_eq!(file_mode & 0o777, 0o644u32); + assert_eq!(metadata.uid(), 1000); + assert_eq!(metadata.gid(), 1000); + + fs::remove_dir_all("/tmp/ozonec").unwrap(); + } + + #[test] + #[ignore = "mount may not be permitted"] + fn test_bind_dev() { + let rootfs = PathBuf::from("/tmp/ozonec/bind_dev"); + create_dir_all(&rootfs).unwrap(); + let dev_path = PathBuf::from("/mknod_dev"); + if dev_path.exists() { + remove_file(&dev_path).unwrap(); + } + let dev = makedev(1, 3); + mknod( + &dev_path, + SFlag::S_IFCHR, + Mode::from_bits_truncate(0o644u32), + dev, + ) + .unwrap(); + let dev_to_bind = Device::new(rootfs.clone()); + let binded_path = rootfs.join("mknod_dev"); + if binded_path.exists() { + umount(&binded_path).unwrap(); + remove_file(&binded_path).unwrap(); + } + let dev_info = DeviceInfo { + path: binded_path.clone(), + dev_type: "c".to_string(), + major: 1, + minor: 3, + file_mode: Some(0o644u32), + uid: Some(1000u32), + gid: Some(1000u32), + }; + + assert!(dev_to_bind.bind_device(&dev_info).is_ok()); + + let metadata = fs::metadata(&dev_path).unwrap(); + let binded_metadata = fs::metadata(&binded_path).unwrap(); + assert_eq!(binded_metadata.file_type(), metadata.file_type()); + assert_eq!(binded_metadata.rdev(), metadata.rdev()); + assert_eq!(binded_metadata.permissions(), metadata.permissions()); + assert_eq!(binded_metadata.uid(), metadata.uid()); + assert_eq!(binded_metadata.gid(), metadata.gid()); + + umount(&binded_path).unwrap(); + fs::remove_dir_all("/tmp/ozonec").unwrap(); + fs::remove_file(dev_path).unwrap(); + } + + #[test] + #[ignore = "mknod may not be permitted"] + fn test_create_device() { + let oci_dev = OciDevice { + dev_type: "c".to_string(), + path: "/mknod_dev".to_string(), + major: Some(1), + minor: Some(3), + fileMode: Some(0o644u32), + uid: Some(1000), + gid: Some(1000), + }; + let rootfs = PathBuf::from("/tmp/ozonec/create_device"); + create_dir_all(&rootfs).unwrap(); + let path = rootfs.join("mknod_dev"); + if path.exists() { + remove_file(&path).unwrap(); + } + let dev = Device::new(rootfs.clone()); + + assert!(dev.create_device(&oci_dev, true).is_ok()); + assert!(path.exists()); + + let metadata = fs::metadata(&path).unwrap(); + assert!(metadata.file_type().is_char_device()); + let major = (metadata.rdev() >> 8) as u32; + let minor = (metadata.rdev() & 0xff) as u32; + assert_eq!(major, 1); + assert_eq!(minor, 3); + let file_mode = metadata.permissions().mode(); + assert_eq!(file_mode & 0o777, 0o644u32); + assert_eq!(metadata.uid(), 1000); + assert_eq!(metadata.gid(), 1000); + + fs::remove_dir_all("/tmp/ozonec").unwrap(); + } + + #[test] + #[ignore = "mount may not be permitted"] + fn test_delete_device() { + let oci_dev = OciDevice { + dev_type: "c".to_string(), + path: "/mknod_dev".to_string(), + major: Some(1), + minor: Some(3), + fileMode: Some(0o644u32), + uid: Some(1000), + gid: Some(1000), + }; + let rootfs = PathBuf::from("/tmp/ozonec/delete_device"); + create_dir_all(&rootfs).unwrap(); + let path = rootfs.join("mknod_dev"); + if path.exists() { + remove_file(&path).unwrap(); + } + let dev = Device::new(rootfs.clone()); + dev.create_device(&oci_dev, true).unwrap(); + + assert!(dev.delete_device(&oci_dev).is_ok()); + assert!(!path.exists()); + + fs::remove_dir_all("/tmp/ozonec").unwrap(); + } + + #[test] + fn test_default_device() { + let rootfs = PathBuf::from("/tmp/ozonec/default_device"); + let dev = Device::new(rootfs.clone()); + + let mut oci_dev = OciDevice { + dev_type: "c".to_string(), + path: "mknod_dev".to_string(), + major: Some(1), + minor: Some(3), + fileMode: Some(0o644u32), + uid: Some(1000), + gid: Some(1000), + }; + assert!(!dev.is_default_device(&oci_dev)); + oci_dev.path = "/dev/null".to_string(); + assert!(dev.is_default_device(&oci_dev)); + oci_dev.path = "/dev/zero".to_string(); + assert!(dev.is_default_device(&oci_dev)); + oci_dev.path = "/dev/full".to_string(); + assert!(dev.is_default_device(&oci_dev)); + oci_dev.path = "/dev/random".to_string(); + assert!(dev.is_default_device(&oci_dev)); + oci_dev.path = "/dev/urandom".to_string(); + assert!(dev.is_default_device(&oci_dev)); + oci_dev.path = "/dev/tty".to_string(); + assert!(dev.is_default_device(&oci_dev)); + } + + #[test] + fn test_get_sflag() { + let rootfs = PathBuf::from("/tmp/ozonec/test_get_sflag"); + let dev = Device::new(rootfs.clone()); + + assert_eq!(dev.get_sflag("c").unwrap(), SFlag::S_IFCHR); + assert_eq!(dev.get_sflag("b").unwrap(), SFlag::S_IFBLK); + assert_eq!(dev.get_sflag("p").unwrap(), SFlag::S_IFIFO); + assert_eq!(dev.get_sflag("u").unwrap(), SFlag::S_IFCHR); + } +} diff --git a/ozonec/src/linux/mod.rs b/ozonec/src/linux/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..658f50c707d3170f476497750686f6be443f53ed --- /dev/null +++ b/ozonec/src/linux/mod.rs @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +pub mod container; + +mod apparmor; +mod device; +mod mount; +mod namespace; +mod notify_socket; +mod process; +mod rootfs; +mod seccomp; +mod terminal; + +pub use container::LinuxContainer; +pub use notify_socket::NotifyListener; +#[allow(unused_imports)] +pub use process::clone_process; +pub use process::Process; diff --git a/ozonec/src/linux/mount.rs b/ozonec/src/linux/mount.rs new file mode 100644 index 0000000000000000000000000000000000000000..7e9d51fdaf93f48c6e465184ba3c845226a2ed33 --- /dev/null +++ b/ozonec/src/linux/mount.rs @@ -0,0 +1,453 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + collections::HashMap, + fs::{self, canonicalize, create_dir_all, read_to_string}, + path::{Path, PathBuf}, +}; + +use anyhow::{anyhow, bail, Context, Result}; +use nix::{ + mount::MsFlags, + sys::statfs::{statfs, CGROUP2_SUPER_MAGIC}, + unistd::close, +}; +use procfs::process::{MountInfo, Process}; + +use crate::utils::{openat2_in_root, proc_fd_path, OzonecErr}; +use oci_spec::runtime::Mount as OciMount; + +#[derive(PartialEq, Debug)] +enum CgroupType { + CgroupV1, + CgroupV2, +} + +pub struct Mount { + rootfs: PathBuf, +} + +impl Mount { + pub fn new(rootfs: &Path) -> Self { + Self { + rootfs: rootfs.to_path_buf(), + } + } + + fn get_mount_flag_data(&self, mount: &OciMount) -> (MsFlags, String) { + let mut ms_flags = MsFlags::empty(); + let mut data = Vec::new(); + + if let Some(options) = &mount.options { + for option in options { + if let Some((clear, flag)) = match option.as_str() { + "defaults" => Some((false, MsFlags::empty())), + "ro" => Some((false, MsFlags::MS_RDONLY)), + "rw" => Some((true, MsFlags::MS_RDONLY)), + "suid" => Some((true, MsFlags::MS_NOSUID)), + "nosuid" => Some((false, MsFlags::MS_NOSUID)), + "dev" => Some((true, MsFlags::MS_NODEV)), + "nodev" => Some((false, MsFlags::MS_NODEV)), + "exec" => Some((true, MsFlags::MS_NOEXEC)), + "noexec" => Some((false, MsFlags::MS_NOEXEC)), + "sync" => Some((false, MsFlags::MS_SYNCHRONOUS)), + "async" => Some((true, MsFlags::MS_SYNCHRONOUS)), + "dirsync" => Some((false, MsFlags::MS_DIRSYNC)), + "remount" => Some((false, MsFlags::MS_REMOUNT)), + "mand" => Some((false, MsFlags::MS_MANDLOCK)), + "nomand" => Some((true, MsFlags::MS_MANDLOCK)), + "atime" => Some((true, MsFlags::MS_NOATIME)), + "noatime" => Some((false, MsFlags::MS_NOATIME)), + "diratime" => Some((true, MsFlags::MS_NODIRATIME)), + "nodiratime" => Some((false, MsFlags::MS_NODIRATIME)), + "bind" => Some((false, MsFlags::MS_BIND)), + "rbind" => Some((false, MsFlags::MS_BIND | MsFlags::MS_REC)), + "unbindable" => Some((false, MsFlags::MS_UNBINDABLE)), + "runbindable" => Some((false, MsFlags::MS_UNBINDABLE | MsFlags::MS_REC)), + "private" => Some((false, MsFlags::MS_PRIVATE)), + "rprivate" => Some((false, MsFlags::MS_PRIVATE | MsFlags::MS_REC)), + "shared" => Some((false, MsFlags::MS_SHARED)), + "rshared" => Some((false, MsFlags::MS_SHARED | MsFlags::MS_REC)), + "slave" => Some((false, MsFlags::MS_SLAVE)), + "rslave" => Some((false, MsFlags::MS_SLAVE | MsFlags::MS_REC)), + "relatime" => Some((false, MsFlags::MS_RELATIME)), + "norelatime" => Some((true, MsFlags::MS_RELATIME)), + "strictatime" => Some((false, MsFlags::MS_STRICTATIME)), + "nostrictatime" => Some((true, MsFlags::MS_STRICTATIME)), + _ => None, + } { + if clear { + ms_flags &= !flag; + } else { + ms_flags |= flag; + } + continue; + } + data.push(option.as_str()); + } + } + (ms_flags, data.join(",")) + } + + fn do_one_mount(&self, mount: &OciMount, label: &Option) -> Result<()> { + let mut fs_type = mount.fs_type.as_deref(); + let (mnt_flags, mut data) = self.get_mount_flag_data(mount); + if let Some(label) = label { + if fs_type != Some("proc") && fs_type != Some("sysfs") { + match data.is_empty() { + true => data = format!("context=\"{}\"", label), + false => data = format!("{},context=\"{}\"", data, label), + } + } + } + + let src_binding = mount + .source + .clone() + .ok_or_else(|| anyhow!("Mount source not set"))?; + let mut source = Path::new(&src_binding); + let canonicalized; + // Strip the first "/". + let target_binding = self.rootfs.join(&mount.destination[1..]); + let target = Path::new(&target_binding); + + if !(mnt_flags & MsFlags::MS_BIND).is_empty() { + canonicalized = canonicalize(source) + .with_context(|| format!("Failed to canonicalize {}", source.display()))?; + source = canonicalized.as_path(); + let dir = if source.is_file() { + target.parent().ok_or_else(|| { + anyhow!("Failed to get parent directory: {}", target.display()) + })? + } else { + target + }; + create_dir_all(dir) + .with_context(|| OzonecErr::CreateDir(dir.to_string_lossy().to_string()))?; + // Actually when MS_BIND is set, filesystemtype is ignored by mount syscall. + fs_type = Some("bind"); + } else { + // Sysfs doesn't support duplicate mounting to one directory. + if self.is_mounted_sysfs_dir(&target.to_string_lossy()) { + nix::mount::umount(target) + .with_context(|| format!("Failed to umount {}", target.display()))?; + } + } + + let target_fd = openat2_in_root( + Path::new(&self.rootfs), + Path::new(&mount.destination[1..]), + !source.is_file(), + )?; + nix::mount::mount( + Some(source), + &proc_fd_path(target_fd), + fs_type, + mnt_flags, + Some(data.as_str()), + ) + .with_context(|| OzonecErr::Mount(source.to_string_lossy().to_string()))?; + close(target_fd).with_context(|| OzonecErr::CloseFd)?; + Ok(()) + } + + fn is_mounted_sysfs_dir(&self, path: &str) -> bool { + if let Ok(metadata) = fs::metadata(path) { + if metadata.file_type().is_dir() { + if let Ok(mounts) = read_to_string("/proc/mounts") { + for line in mounts.lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 3 && parts[1] == path && parts[2] == "sysfs" { + return true; + } + } + } + } + } + false + } + + pub fn do_mounts(&self, mounts: &Vec, label: &Option) -> Result<()> { + for mount in mounts { + match mount.fs_type.as_deref() { + Some("cgroup") => match self.cgroup_type()? { + CgroupType::CgroupV1 => self + .do_cgroup_mount(mount) + .with_context(|| "Failed to do cgroup mount")?, + CgroupType::CgroupV2 => bail!("Cgroup V2 is not supported now"), + }, + _ => self.do_one_mount(mount, label)?, + } + } + Ok(()) + } + + fn do_cgroup_mount(&self, mount: &OciMount) -> Result<()> { + // Strip the first "/". + let rel_target = Path::new(&mount.destination[1..]); + let target_fd = openat2_in_root(Path::new(&self.rootfs), rel_target, true)?; + nix::mount::mount( + Some("tmpfs"), + &proc_fd_path(target_fd), + Some("tmpfs"), + MsFlags::MS_NOEXEC | MsFlags::MS_NOSUID | MsFlags::MS_NODEV, + None::<&str>, + ) + .with_context(|| OzonecErr::Mount(String::from("tmpfs")))?; + close(target_fd).with_context(|| OzonecErr::CloseFd)?; + + let process = Process::myself().with_context(|| OzonecErr::AccessProcSelf)?; + let mnt_info: Vec = + process.mountinfo().with_context(|| OzonecErr::GetMntInfo)?; + let proc_cgroups: HashMap = process + .cgroups() + .with_context(|| "Failed to get cgroups belong to")? + .into_iter() + .map(|cgroup| (cgroup.controllers.join(","), cgroup.pathname)) + .collect(); + // Get all of available cgroup mount points. + let host_cgroups: Vec = mnt_info + .into_iter() + .filter(|m| m.fs_type == "cgroup") + .map(|m| m.mount_point) + .collect(); + for cg_path in host_cgroups { + let cg = cg_path + .file_name() + .ok_or_else(|| anyhow!("Failed to get controller file"))? + .to_str() + .ok_or_else(|| { + anyhow!("Convert {:?} to string error", cg_path.file_name().unwrap()) + })?; + let proc_cg_key = if cg == "systemd" { + String::from("systemd") + } else { + cg.to_string() + }; + + if let Some(src) = proc_cgroups.get(&proc_cg_key) { + let source = cg_path.join(&src[1..]); + let rel_target = cg_path + .strip_prefix("/") + .with_context(|| format!("{} doesn't start with '/'", cg_path.display()))?; + let target_fd = openat2_in_root(Path::new(&self.rootfs), rel_target, true)?; + + nix::mount::mount( + Some(&source), + &proc_fd_path(target_fd), + Some("bind"), + MsFlags::MS_BIND | MsFlags::MS_REC, + None::<&str>, + ) + .with_context(|| OzonecErr::Mount(source.to_string_lossy().to_string()))?; + close(target_fd).with_context(|| OzonecErr::CloseFd)?; + } + } + + Ok(()) + } + + fn cgroup_type(&self) -> Result { + let cgroup_path = Path::new("/sys/fs/cgroup"); + if !cgroup_path.exists() { + bail!("/sys/fs/cgroup doesn't exist."); + } + + let st = statfs(cgroup_path).with_context(|| "statfs /sys/fs/cgroup error")?; + if st.filesystem_type() == CGROUP2_SUPER_MAGIC { + return Ok(CgroupType::CgroupV2); + } + Ok(CgroupType::CgroupV1) + } +} + +#[cfg(test)] +mod tests { + use rusty_fork::rusty_fork_test; + + use crate::linux::namespace::tests::set_namespace; + use oci_spec::linux::NamespaceType; + + use super::*; + + fn init_mount(rootfs: &str) -> Mount { + let path = PathBuf::from(rootfs); + create_dir_all(&path).unwrap(); + Mount::new(&path) + } + + #[test] + fn test_is_mounted_sysfs_dir() { + let mut path = PathBuf::from("/test"); + let mut mnt = Mount::new(&path); + assert!(!mnt.is_mounted_sysfs_dir(path.to_str().unwrap())); + + path = PathBuf::from("/sys"); + mnt = Mount::new(&path); + assert!(mnt.is_mounted_sysfs_dir(path.to_str().unwrap())); + } + + #[test] + #[ignore = "mount may not be permitted"] + fn test_cgroup_type() { + let rootfs = PathBuf::from("/tmp/ozonec/test_cgroup_type"); + let mnt = Mount::new(&rootfs); + let cgroup_path = Path::new("/sys/fs/cgroup"); + + if !cgroup_path.exists() { + assert!(mnt.cgroup_type().is_err()); + } else { + let st = statfs(cgroup_path).unwrap(); + if st.filesystem_type() == CGROUP2_SUPER_MAGIC { + assert_eq!(mnt.cgroup_type().unwrap(), CgroupType::CgroupV2); + } else { + assert_eq!(mnt.cgroup_type().unwrap(), CgroupType::CgroupV1); + } + } + } + + #[test] + fn test_get_mount_flag_data() { + let rootfs = PathBuf::from("/test_get_mount_flag_data"); + let mnt = Mount::new(&rootfs); + let mut oci_mnt = OciMount { + destination: String::new(), + source: None, + options: Some(vec![ + String::from("defaults"), + String::from("rw"), + String::from("suid"), + String::from("dev"), + String::from("exec"), + String::from("async"), + String::from("nomand"), + String::from("atime"), + String::from("diratime"), + String::from("norelatime"), + String::from("nostrictatime"), + ]), + fs_type: None, + uidMappings: None, + gidMappings: None, + }; + + let (flags, _data) = mnt.get_mount_flag_data(&oci_mnt); + assert_eq!(flags, MsFlags::empty()); + + oci_mnt.options = Some(vec![ + String::from("ro"), + String::from("nosuid"), + String::from("nodev"), + String::from("noexec"), + String::from("sync"), + String::from("dirsync"), + String::from("remount"), + String::from("mand"), + String::from("noatime"), + String::from("nodiratime"), + String::from("bind"), + String::from("unbindable"), + String::from("private"), + String::from("shared"), + String::from("slave"), + String::from("relatime"), + String::from("strictatime"), + ]); + let (flags, _data) = mnt.get_mount_flag_data(&oci_mnt); + assert_eq!( + flags, + MsFlags::MS_RDONLY + | MsFlags::MS_NOSUID + | MsFlags::MS_NODEV + | MsFlags::MS_NOEXEC + | MsFlags::MS_SYNCHRONOUS + | MsFlags::MS_DIRSYNC + | MsFlags::MS_REMOUNT + | MsFlags::MS_MANDLOCK + | MsFlags::MS_NOATIME + | MsFlags::MS_NODIRATIME + | MsFlags::MS_BIND + | MsFlags::MS_UNBINDABLE + | MsFlags::MS_PRIVATE + | MsFlags::MS_SHARED + | MsFlags::MS_SLAVE + | MsFlags::MS_RELATIME + | MsFlags::MS_STRICTATIME + ); + + oci_mnt.options = Some(vec![String::from("rbind")]); + let (flags, _data) = mnt.get_mount_flag_data(&oci_mnt); + assert_eq!(flags, MsFlags::MS_BIND | MsFlags::MS_REC); + oci_mnt.options = Some(vec![String::from("runbindable")]); + let (flags, _data) = mnt.get_mount_flag_data(&oci_mnt); + assert_eq!(flags, MsFlags::MS_UNBINDABLE | MsFlags::MS_REC); + oci_mnt.options = Some(vec![String::from("rprivate")]); + let (flags, _data) = mnt.get_mount_flag_data(&oci_mnt); + assert_eq!(flags, MsFlags::MS_PRIVATE | MsFlags::MS_REC); + oci_mnt.options = Some(vec![String::from("rshared")]); + let (flags, _data) = mnt.get_mount_flag_data(&oci_mnt); + assert_eq!(flags, MsFlags::MS_SHARED | MsFlags::MS_REC); + oci_mnt.options = Some(vec![String::from("rslave")]); + let (flags, _data) = mnt.get_mount_flag_data(&oci_mnt); + assert_eq!(flags, MsFlags::MS_SLAVE | MsFlags::MS_REC); + } + + rusty_fork_test! { + #[test] + #[ignore = "unshare may not be permitted"] + fn test_do_mounts_cgroup() { + set_namespace(NamespaceType::Mount); + + let mounts = vec![OciMount { + destination: String::from("/sys/fs/cgroup"), + source: Some(String::from("cgroup")), + options: Some(vec![ + String::from("nosuid"), + String::from("noexec"), + String::from("nodev"), + String::from("relatime"), + String::from("ro"), + ]), + fs_type: Some(String::from("cgroup")), + uidMappings: None, + gidMappings: None, + }]; + let mnt = init_mount("/tmp/ozonec/test_do_mounts_cgroup"); + + assert!(mnt.do_mounts(&mounts, &None).is_ok()); + assert!(mnt.rootfs.join("sys/fs/cgroup").exists()); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_do_mounts_bind() { + set_namespace(NamespaceType::Mount); + + let mounts = vec![OciMount { + destination: String::from("/dest"), + source: Some(String::from("/tmp/ozonec/test_do_mounts_bind/source")), + options: Some(vec![ + String::from("rbind") + ]), + fs_type: None, + uidMappings: None, + gidMappings: None, + }]; + let mnt = init_mount("/tmp/ozonec/test_do_mounts_bind"); + create_dir_all(&mnt.rootfs.join("source")).unwrap(); + + assert!(mnt.do_mounts(&mounts, &None).is_ok()); + assert!(mnt.rootfs.join("dest").exists()); + } + } +} diff --git a/ozonec/src/linux/namespace.rs b/ozonec/src/linux/namespace.rs new file mode 100644 index 0000000000000000000000000000000000000000..819bd99ef87ebb9ed143a0f201e07313abb9e52e --- /dev/null +++ b/ozonec/src/linux/namespace.rs @@ -0,0 +1,141 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::collections::HashMap; + +use anyhow::{Context, Result}; +use nix::{ + fcntl::{self, OFlag}, + sched::{setns, unshare, CloneFlags}, + sys::stat::Mode, + unistd, +}; +use oci_spec::linux::{Namespace, NamespaceType}; + +pub struct NsController { + pub namespaces: HashMap, +} + +impl TryFrom> for NsController { + type Error = anyhow::Error; + + fn try_from(namespaces: Vec) -> Result { + Ok(NsController { + namespaces: namespaces + .iter() + .map(|ns| match ns.ns_type.try_into() { + Ok(flag) => Ok((flag, ns.clone())), + Err(e) => Err(e), + }) + .collect::>>()? + .into_iter() + .collect(), + }) + } +} + +impl NsController { + pub fn set_namespace(&self, ns_type: NamespaceType) -> Result<()> { + if let Some(ns) = self.get(ns_type)? { + match ns.path.clone() { + Some(path) => { + let fd = fcntl::open(&path, OFlag::empty(), Mode::empty()) + .with_context(|| format!("fcntl error at opening {}", path.display()))?; + setns(fd, ns_type.try_into()?).with_context(|| "Failed to setns")?; + unistd::close(fd).with_context(|| "Close fcntl fd error")?; + } + None => unshare(ns_type.try_into()?).with_context(|| "Failed to unshare")?, + } + } + Ok(()) + } + + pub fn get(&self, ns_type: NamespaceType) -> Result> { + let clone_flags: CloneFlags = ns_type.try_into()?; + Ok(self.namespaces.get(&clone_flags)) + } +} + +#[cfg(test)] +pub mod tests { + use std::{path::PathBuf, thread::sleep, time::Duration}; + + use nix::sys::{ + signal::{self, Signal}, + wait::{waitpid, WaitStatus}, + }; + + use crate::linux::process::clone_process; + + use super::*; + + fn init_ns_controller(ns_type: NamespaceType) -> NsController { + let mut ns_ctrl = NsController { + namespaces: HashMap::new(), + }; + let ns = Namespace { + ns_type, + path: None, + }; + ns_ctrl.namespaces.insert(ns_type.try_into().unwrap(), ns); + ns_ctrl + } + + pub fn set_namespace(ns_type: NamespaceType) { + let ns_ctrl = init_ns_controller(ns_type); + ns_ctrl.set_namespace(ns_type).unwrap(); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_namespace() { + let mut ns_ctrl = init_ns_controller(NamespaceType::Mount); + let fst_child = clone_process("test_set_namespace_with_unshare", || { + assert!(ns_ctrl.set_namespace(NamespaceType::Mount).is_ok()); + sleep(Duration::from_secs(10)); + Ok(1) + }) + .unwrap(); + + let ns_path = PathBuf::from(format!("/proc/{}/ns/mnt", fst_child.as_raw())); + ns_ctrl + .namespaces + .get_mut(&CloneFlags::CLONE_NEWNS) + .unwrap() + .path = Some(ns_path); + let sec_child = clone_process("test_set_namespace_with_setns", || { + assert!(ns_ctrl.set_namespace(NamespaceType::Mount).is_ok()); + Ok(1) + }) + .unwrap(); + + match waitpid(sec_child, None) { + Ok(WaitStatus::Exited(_, s)) => { + assert_eq!(s, 1); + } + Ok(_) => (), + Err(e) => { + panic!("Failed to waitpid for unshare process: {e}"); + } + } + signal::kill(fst_child.clone(), Signal::SIGKILL).unwrap(); + match waitpid(fst_child, None) { + Ok(WaitStatus::Exited(_, s)) => { + assert_eq!(s, 1); + } + Ok(_) => (), + Err(e) => { + panic!("Failed to waitpid for setns process: {e}"); + } + } + } +} diff --git a/ozonec/src/linux/notify_socket.rs b/ozonec/src/linux/notify_socket.rs new file mode 100644 index 0000000000000000000000000000000000000000..5db9c57db6315b74bfef4ff8b4c6cbff60319954 --- /dev/null +++ b/ozonec/src/linux/notify_socket.rs @@ -0,0 +1,129 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + env, + io::{Read, Write}, + os::unix::{ + io::AsRawFd, + net::{UnixListener, UnixStream}, + }, + path::PathBuf, +}; + +use anyhow::{anyhow, bail, Context, Result}; +use nix::unistd::{self, chdir}; + +use crate::utils::OzonecErr; + +pub const NOTIFY_SOCKET: &str = "notify.sock"; + +pub struct NotifyListener { + socket: UnixListener, +} + +impl NotifyListener { + pub fn new(root: PathBuf) -> Result { + // The length of path of Unix domain socket has the limit 108, which is smaller then + // the maximum length of file on Linux (255). + let cwd = env::current_dir().with_context(|| OzonecErr::GetCurDir)?; + chdir(&root).with_context(|| "Failed to chdir to root directory")?; + let listener = + UnixListener::bind(NOTIFY_SOCKET).with_context(|| "Failed to bind notify socket")?; + chdir(&cwd).with_context(|| "Failed to chdir to previous working directory")?; + Ok(Self { socket: listener }) + } + + pub fn wait_for_start_container(&self) -> Result<()> { + match self.socket.accept() { + Ok((mut socket, _)) => { + let mut response = String::new(); + socket + .read_to_string(&mut response) + .with_context(|| "Invalid response from notify socket")?; + } + Err(e) => { + bail!("Failed to accept on notify socket: {}", e); + } + } + Ok(()) + } + + pub fn close(&self) -> Result<()> { + Ok(unistd::close(self.socket.as_raw_fd())?) + } +} + +pub struct NotifySocket { + path: PathBuf, +} + +impl NotifySocket { + pub fn new(path: &PathBuf) -> Self { + Self { path: path.into() } + } + + pub fn notify_container_start(&mut self) -> Result<()> { + let cwd = env::current_dir().with_context(|| OzonecErr::GetCurDir)?; + let root_path = self + .path + .parent() + .ok_or_else(|| anyhow!("Invalid notify socket path"))?; + chdir(root_path).with_context(|| "Failed to chdir to root directory")?; + + let mut stream = + UnixStream::connect(NOTIFY_SOCKET).with_context(|| "Failed to connect notify.sock")?; + stream.write_all(b"start container")?; + chdir(&cwd).with_context(|| "Failed to chdir to previous working directory")?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::fs::{create_dir_all, remove_dir_all}; + + use nix::sys::wait::{waitpid, WaitStatus}; + + use crate::linux::process::clone_process; + + use super::*; + + #[test] + fn test_notify_socket() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + let root = PathBuf::from("/tmp/ozonec/notify_socket"); + create_dir_all(&root).unwrap(); + + let socket_path = root.join(NOTIFY_SOCKET); + let mut socket = NotifySocket::new(&socket_path); + let listener = NotifyListener::new(root.clone()).unwrap(); + let child = clone_process("notify_socket", || { + listener.wait_for_start_container().unwrap(); + Ok(1) + }) + .unwrap(); + socket.notify_container_start().unwrap(); + + match waitpid(child, None) { + Ok(WaitStatus::Exited(_, s)) => { + assert_eq!(s, 1); + } + Ok(_) => (), + Err(e) => { + panic!("Failed to waitpid for child process: {e}"); + } + } + } +} diff --git a/ozonec/src/linux/process.rs b/ozonec/src/linux/process.rs new file mode 100644 index 0000000000000000000000000000000000000000..20bd35d6a195ef4df121bfb5b88f438ce0653e9b --- /dev/null +++ b/ozonec/src/linux/process.rs @@ -0,0 +1,766 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + env, + ffi::CString, + fs::{self, read_to_string}, + io::{stderr, stdin, stdout}, + mem, + os::unix::{ + io::{AsRawFd, RawFd}, + net::UnixStream, + }, + path::PathBuf, + str::FromStr, +}; + +use anyhow::{anyhow, bail, Context, Result}; +use caps::{self, CapSet, Capability, CapsHashSet}; +use libc::SIGCHLD; +use nix::{ + errno::Errno, + sched::{clone, CloneFlags}, + unistd::{self, chdir, setresgid, setresuid, Gid, Pid, Uid}, +}; +use rlimit::{setrlimit, Resource, Rlim}; + +use super::{apparmor, terminal::setup_console}; +use crate::utils::{prctl, Clone3, OzonecErr}; +use oci_spec::{linux::IoPriClass, process::Process as OciProcess}; + +pub struct Process { + pub stdin: Option, + pub stdout: Option, + pub stderr: Option, + pub init: bool, + pub tty: bool, + pub oci: OciProcess, +} + +impl Process { + pub fn new(oci: &OciProcess, init: bool) -> Self { + let mut p = Process { + stdin: None, + stdout: None, + stderr: None, + tty: oci.terminal, + init, + oci: oci.clone(), + }; + + if !p.tty { + p.stdin = Some(stdin().as_raw_fd()); + p.stdout = Some(stdout().as_raw_fd()); + p.stderr = Some(stderr().as_raw_fd()); + } + p + } + + pub fn set_tty(&self, console_fd: Option, mount: bool) -> Result<()> { + if self.tty { + if console_fd.is_none() { + bail!("Terminal is specified, but no console socket set"); + } + setup_console(&console_fd.unwrap().as_raw_fd(), mount) + .with_context(|| "Failed to setup console")?; + } + Ok(()) + } + + pub fn set_oom_score_adj(&self) -> Result<()> { + if let Some(score) = self.oci.oomScoreAdj { + fs::write("/proc/self/oom_score_adj", score.to_string().as_bytes())?; + } + Ok(()) + } + + pub fn set_rlimits(&self) -> Result<()> { + if let Some(rlimits) = self.oci.rlimits.as_ref() { + for rlimit in rlimits { + setrlimit( + Resource::from_str(&rlimit.rlimit_type) + .with_context(|| "rlimit type is ill-formatted")?, + Rlim::from_raw(rlimit.soft), + Rlim::from_raw(rlimit.hard), + )?; + } + } + Ok(()) + } + + pub fn set_io_priority(&self) -> Result<()> { + if let Some(io_prio) = &self.oci.ioPriority { + let class = match io_prio.class { + IoPriClass::IoprioClassRt => 1i64, + IoPriClass::IoprioClassBe => 2i64, + IoPriClass::IoprioClassIdle => 3i64, + }; + // Who is a process id or thread id identifying a single process or + // thread. If who is 0, then operate on the calling process or thread. + let io_prio_who_process: libc::c_int = 1; + let io_prio_who_pid = 0; + // SAFETY: FFI call with valid arguments. + match unsafe { + libc::syscall( + libc::SYS_ioprio_set, + io_prio_who_process, + io_prio_who_pid, + (class << 13) | io_prio.priority, + ) + } { + 0 => Ok(()), + -1 => Err(nix::Error::last()), + _ => Err(nix::Error::UnknownErrno), + }?; + } + Ok(()) + } + + pub fn set_scheduler(&self) -> Result<()> { + if let Some(scheduler) = &self.oci.scheduler { + // SAFETY: FFI call with valid arguments. + let mut param: libc::sched_param = unsafe { mem::zeroed() }; + param.sched_priority = scheduler.priority.unwrap_or_default(); + // SAFETY: FFI call with valid arguments. + match unsafe { libc::sched_setscheduler(0, scheduler.policy.into(), ¶m) } { + 0 => Ok(()), + -1 => Err(nix::Error::last()), + _ => Err(nix::Error::UnknownErrno), + }?; + } + Ok(()) + } + + pub fn no_new_privileges(&self) -> bool { + self.oci.noNewPrivileges.is_some() + } + + pub fn set_no_new_privileges(&self) -> Result<()> { + if let Some(no_new_privileges) = self.oci.noNewPrivileges { + if no_new_privileges { + prctl::set_no_new_privileges(true) + .map_err(|e| anyhow!("Failed to set no new privileges: {}", e))?; + } + } + Ok(()) + } + + pub fn chdir_cwd(&self) -> Result<()> { + if !self.oci.cwd.is_empty() { + chdir(&PathBuf::from(&self.oci.cwd)) + .with_context(|| format!("Failed to chdir to {}", &self.oci.cwd))?; + } + Ok(()) + } + + pub fn drop_capabilities(&self) -> Result<()> { + if let Some(caps) = self.oci.capabilities.as_ref() { + if let Some(bounding) = caps.bounding.as_ref() { + let all_caps = caps::read(None, CapSet::Bounding) + .with_context(|| OzonecErr::GetAllCaps("Bounding".to_string()))?; + let caps_hash_set = to_cap_set(bounding)?; + for cap in all_caps.difference(&caps_hash_set) { + caps::drop(None, CapSet::Bounding, *cap) + .with_context(|| format!("Failed to drop {} from bonding set", cap))?; + } + } + if let Some(effective) = caps.effective.as_ref() { + caps::set(None, CapSet::Effective, &to_cap_set(effective)?) + .with_context(|| OzonecErr::SetCaps("Effective".to_string()))?; + } + if let Some(permitted) = caps.permitted.as_ref() { + caps::set(None, CapSet::Permitted, &to_cap_set(permitted)?) + .with_context(|| OzonecErr::SetCaps("Permitted".to_string()))?; + } + if let Some(inheritable) = caps.inheritable.as_ref() { + caps::set(None, CapSet::Inheritable, &to_cap_set(inheritable)?) + .with_context(|| OzonecErr::SetCaps("Inheritable".to_string()))?; + } + if let Some(ambient) = caps.ambient.as_ref() { + caps::set(None, CapSet::Ambient, &to_cap_set(ambient)?) + .with_context(|| OzonecErr::SetCaps("Ambient".to_string()))?; + } + } + Ok(()) + } + + pub fn set_apparmor(&self) -> Result<()> { + if let Some(profile) = &self.oci.apparmorProfile { + if !apparmor::is_enabled()? { + bail!("Apparmor is disabled."); + } + apparmor::apply_profile(profile)?; + } + Ok(()) + } + + pub fn reset_capabilities(&self) -> Result<()> { + let permitted = caps::read(None, CapSet::Permitted) + .with_context(|| OzonecErr::GetAllCaps("Permitted".to_string()))?; + caps::set(None, CapSet::Effective, &permitted)?; + Ok(()) + } + + pub fn set_additional_gids(&self) -> Result<()> { + if let Some(additional_gids) = &self.oci.user.additionalGids { + let setgroups = read_to_string("proc/self/setgroups") + .with_context(|| "Failed to read setgroups")?; + if setgroups.trim() == "deny" { + bail!("Cannot set additional gids as setgroup is desabled"); + } + + let gids: Vec = additional_gids + .iter() + .map(|gid| Gid::from_raw(*gid)) + .collect(); + unistd::setgroups(&gids).with_context(|| "Failed to set additional gids")?; + } + Ok(()) + } + + pub fn set_process_id(&self) -> Result<()> { + let gid = Gid::from(self.oci.user.gid); + let uid = Uid::from(self.oci.user.uid); + self.set_id(gid, uid)?; + Ok(()) + } + + pub fn set_id(&self, gid: Gid, uid: Uid) -> Result<()> { + prctl::set_keep_capabilities(true) + .map_err(|e| anyhow!("Failed to enable keeping capabilities: {}", e))?; + setresgid(gid, gid, gid).with_context(|| "Failed to setresgid")?; + setresuid(uid, uid, uid).with_context(|| "Failed to setresuid")?; + + let permitted = caps::read(None, CapSet::Permitted) + .with_context(|| OzonecErr::GetAllCaps("Permitted".to_string()))?; + caps::set(None, CapSet::Effective, &permitted) + .with_context(|| OzonecErr::SetCaps("Effective".to_string()))?; + prctl::set_keep_capabilities(false) + .map_err(|e| anyhow!("Failed to disable keeping capabilities: {}", e))?; + Ok(()) + } + + // Check and reserve valid environment variables. + // Invalid env vars may cause panic, refer to https://doc.rust-lang.org/std/env/fn.set_var.html#panics + // Key should not : + // * contain NULL character '\0' + // * contain ASCII character '=' + // * be empty + // Value should not: + // * contain NULL character '\0' + fn is_env_valid(env: &str) -> Option<(&str, &str)> { + // Split the env var by '=' to ensure there is no '=' in key, and there is only one '=' + // in the whole env var. + if let Some((key, value)) = env.split_once('=') { + if !key.is_empty() + && !key.as_bytes().contains(&b'\0') + && !value.as_bytes().contains(&b'\0') + { + return Some((key.trim(), value.trim())); + } + } + None + } + + pub fn set_envs(&self) { + if let Some(envs) = &self.oci.env { + for env in envs { + if let Some((key, value)) = Self::is_env_valid(env) { + env::set_var(key, value); + } + } + } + } + + pub fn clean_envs(&self) { + env::vars().for_each(|(key, _value)| env::remove_var(key)); + } + + pub fn exec_program(&self) -> ! { + // It has been make sure that args is not None in validate_config(). + let args = &self.oci.args.as_ref().unwrap(); + // args don't have 0 byte in the middle such as "hello\0world". + let exec_bin = CString::new(args[0].as_str().as_bytes()).unwrap(); + let args: Vec = args + .iter() + .map(|s| CString::new(s.as_bytes()).unwrap_or_default()) + .collect(); + + let _ = unistd::execvp(&exec_bin, &args).map_err(|e| match e { + nix::Error::UnknownErrno => std::process::exit(-2), + _ => std::process::exit(e as i32), + }); + + unreachable!() + } + + pub fn getcwd() -> Result<()> { + unistd::getcwd().map_err(|e| match e { + Errno::ENOENT => anyhow!("Current working directory is out of container rootfs"), + _ => anyhow!("Failed to getcwd"), + })?; + Ok(()) + } +} + +// Clone a new child process. +pub fn clone_process Result>(child_name: &str, mut cb: F) -> Result { + let mut clone3 = Clone3::default(); + clone3.exit_signal(SIGCHLD as u64); + + let mut ret = clone3.call(); + if ret.is_err() { + // clone3() may not be supported in the kernel, fallback to clone(); + let mut stack = [0; 1024 * 1024]; + ret = clone( + Box::new(|| match cb() { + Ok(r) => r as isize, + Err(e) => { + eprintln!("{}", e); + -1 + } + }), + &mut stack, + CloneFlags::empty(), + Some(SIGCHLD), + ) + .map_err(|e| anyhow!("Clone error: errno {}", e)); + } + + match ret { + Ok(pid) => { + if pid.as_raw() != 0 { + return Ok(pid); + } + + prctl::set_name(child_name) + .map_err(|e| anyhow!("Failed to set process name: errno {}", e))?; + let ret = match cb() { + Err(e) => { + eprintln!("Child process exit with errors: {:?}", e); + -1 + } + Ok(exit_code) => exit_code, + }; + std::process::exit(ret); + } + Err(e) => bail!(e), + } +} + +fn to_cap_set(caps: &Vec) -> Result { + let mut caps_hash_set = CapsHashSet::new(); + + for c in caps { + let cap = to_cap(c)?; + caps_hash_set.insert(cap); + } + Ok(caps_hash_set) +} + +fn to_cap(value: &str) -> Result { + let binding = value.to_uppercase(); + let stripped = binding.strip_prefix("CAP_").unwrap_or(&binding); + + match stripped { + "AUDIT_CONTROL" => Ok(Capability::CAP_AUDIT_CONTROL), + "AUDIT_READ" => Ok(Capability::CAP_AUDIT_READ), + "AUDIT_WRITE" => Ok(Capability::CAP_AUDIT_WRITE), + "BLOCK_SUSPEND" => Ok(Capability::CAP_BLOCK_SUSPEND), + "BPF" => Ok(Capability::CAP_BPF), + "CHECKPOINT_RESTORE" => Ok(Capability::CAP_CHECKPOINT_RESTORE), + "CHOWN" => Ok(Capability::CAP_CHOWN), + "DAC_OVERRIDE" => Ok(Capability::CAP_DAC_OVERRIDE), + "DAC_READ_SEARCH" => Ok(Capability::CAP_DAC_READ_SEARCH), + "FOWNER" => Ok(Capability::CAP_FOWNER), + "FSETID" => Ok(Capability::CAP_FSETID), + "IPC_LOCK" => Ok(Capability::CAP_IPC_LOCK), + "IPC_OWNER" => Ok(Capability::CAP_IPC_OWNER), + "KILL" => Ok(Capability::CAP_KILL), + "LEASE" => Ok(Capability::CAP_LEASE), + "LINUX_IMMUTABLE" => Ok(Capability::CAP_LINUX_IMMUTABLE), + "MAC_ADMIN" => Ok(Capability::CAP_MAC_ADMIN), + "MAC_OVERRIDE" => Ok(Capability::CAP_MAC_OVERRIDE), + "MKNOD" => Ok(Capability::CAP_MKNOD), + "NET_ADMIN" => Ok(Capability::CAP_NET_ADMIN), + "NET_BIND_SERVICE" => Ok(Capability::CAP_NET_BIND_SERVICE), + "NET_BROADCAST" => Ok(Capability::CAP_NET_BROADCAST), + "NET_RAW" => Ok(Capability::CAP_NET_RAW), + "PERFMON" => Ok(Capability::CAP_PERFMON), + "SETGID" => Ok(Capability::CAP_SETGID), + "SETFCAP" => Ok(Capability::CAP_SETFCAP), + "SETPCAP" => Ok(Capability::CAP_SETPCAP), + "SETUID" => Ok(Capability::CAP_SETUID), + "SYS_ADMIN" => Ok(Capability::CAP_SYS_ADMIN), + "SYS_BOOT" => Ok(Capability::CAP_SYS_BOOT), + "SYS_CHROOT" => Ok(Capability::CAP_SYS_CHROOT), + "SYS_MODULE" => Ok(Capability::CAP_SYS_MODULE), + "SYS_NICE" => Ok(Capability::CAP_SYS_NICE), + "SYS_PACCT" => Ok(Capability::CAP_SYS_PACCT), + "SYS_PTRACE" => Ok(Capability::CAP_SYS_PTRACE), + "SYS_RAWIO" => Ok(Capability::CAP_SYS_RAWIO), + "SYS_RESOURCE" => Ok(Capability::CAP_SYS_RESOURCE), + "SYS_TIME" => Ok(Capability::CAP_SYS_TIME), + "SYS_TTY_CONFIG" => Ok(Capability::CAP_SYS_TTY_CONFIG), + "SYSLOG" => Ok(Capability::CAP_SYSLOG), + "WAKE_ALARM" => Ok(Capability::CAP_WAKE_ALARM), + _ => bail!("Invalid capability: {}", value), + } +} + +#[cfg(test)] +pub mod tests { + use std::path::Path; + + use nix::sys::resource::{getrlimit, Resource}; + use rusty_fork::rusty_fork_test; + use unistd::getcwd; + + use oci_spec::{ + linux::{Capbilities, IoPriority, SchedPolicy, Scheduler}, + posix::{Rlimits, User}, + }; + + use super::*; + + pub fn init_oci_process() -> OciProcess { + let user = User { + uid: 0, + gid: 0, + umask: None, + additionalGids: None, + }; + OciProcess { + cwd: String::from("/"), + args: Some(vec![String::from("bash")]), + env: None, + terminal: false, + consoleSize: None, + rlimits: None, + apparmorProfile: None, + capabilities: None, + noNewPrivileges: None, + oomScoreAdj: None, + scheduler: None, + selinuxLabel: None, + ioPriority: None, + execCPUAffinity: None, + user, + } + } + + #[test] + fn test_process_new() { + let mut oci_process = init_oci_process(); + + let process = Process::new(&oci_process, false); + assert_eq!(process.stdin.unwrap(), stdin().as_raw_fd()); + assert_eq!(process.stdout.unwrap(), stdout().as_raw_fd()); + assert_eq!(process.stderr.unwrap(), stderr().as_raw_fd()); + + oci_process.terminal = true; + let process = Process::new(&oci_process, false); + assert!(process.stdin.is_none()); + assert!(process.stdout.is_none()); + assert!(process.stderr.is_none()); + } + + #[test] + fn test_set_tty() { + let mut oci_process = init_oci_process(); + + let process = Process::new(&oci_process, false); + assert!(process.set_tty(None, false).is_ok()); + + oci_process.terminal = true; + let process = Process::new(&oci_process, false); + assert!(process.set_tty(None, false).is_err()); + } + + #[test] + fn test_chdir_cwd() { + let oci_process = init_oci_process(); + let process = Process::new(&oci_process, false); + + assert!(process.chdir_cwd().is_ok()); + assert_eq!(getcwd().unwrap().to_str().unwrap(), "/"); + } + + #[test] + fn test_set_envs() { + let mut oci_process = init_oci_process(); + oci_process.env = Some(vec![ + String::from("OZONEC_ENV_1=1"), + String::from("=OZONEC_ENV_2"), + String::from("OZONEC_ENV"), + ]); + let process = Process::new(&oci_process, false); + + process.set_envs(); + for (key, value) in env::vars() { + if key == "OZONEC_ENV_1" { + assert_eq!(value, "1"); + continue; + } + assert_ne!(value, "OZONEC_ENV_2"); + assert_ne!(key, "OZONEC_ENV"); + assert_ne!(value, "OZONEC_ENV"); + } + + env::remove_var("OZONEC_ENV_1"); + } + + #[test] + fn test_to_cap() { + assert_eq!( + to_cap("CAP_AUDIT_CONTROL").unwrap(), + Capability::CAP_AUDIT_CONTROL + ); + assert_eq!( + to_cap("CAP_AUDIT_READ").unwrap(), + Capability::CAP_AUDIT_READ + ); + assert_eq!( + to_cap("CAP_AUDIT_WRITE").unwrap(), + Capability::CAP_AUDIT_WRITE + ); + assert_eq!( + to_cap("CAP_BLOCK_SUSPEND").unwrap(), + Capability::CAP_BLOCK_SUSPEND + ); + assert_eq!(to_cap("CAP_BPF").unwrap(), Capability::CAP_BPF); + assert_eq!( + to_cap("CAP_CHECKPOINT_RESTORE").unwrap(), + Capability::CAP_CHECKPOINT_RESTORE + ); + assert_eq!(to_cap("CAP_CHOWN").unwrap(), Capability::CAP_CHOWN); + assert_eq!( + to_cap("CAP_DAC_OVERRIDE").unwrap(), + Capability::CAP_DAC_OVERRIDE + ); + assert_eq!( + to_cap("CAP_DAC_READ_SEARCH").unwrap(), + Capability::CAP_DAC_READ_SEARCH + ); + assert_eq!(to_cap("CAP_FOWNER").unwrap(), Capability::CAP_FOWNER); + assert_eq!(to_cap("CAP_FSETID").unwrap(), Capability::CAP_FSETID); + assert_eq!(to_cap("CAP_IPC_LOCK").unwrap(), Capability::CAP_IPC_LOCK); + assert_eq!(to_cap("CAP_IPC_OWNER").unwrap(), Capability::CAP_IPC_OWNER); + assert_eq!(to_cap("CAP_KILL").unwrap(), Capability::CAP_KILL); + assert_eq!(to_cap("CAP_LEASE").unwrap(), Capability::CAP_LEASE); + assert_eq!( + to_cap("CAP_LINUX_IMMUTABLE").unwrap(), + Capability::CAP_LINUX_IMMUTABLE + ); + assert_eq!(to_cap("CAP_MAC_ADMIN").unwrap(), Capability::CAP_MAC_ADMIN); + assert_eq!( + to_cap("CAP_MAC_OVERRIDE").unwrap(), + Capability::CAP_MAC_OVERRIDE + ); + assert_eq!(to_cap("CAP_MKNOD").unwrap(), Capability::CAP_MKNOD); + assert_eq!(to_cap("CAP_NET_ADMIN").unwrap(), Capability::CAP_NET_ADMIN); + assert_eq!( + to_cap("CAP_NET_BIND_SERVICE").unwrap(), + Capability::CAP_NET_BIND_SERVICE + ); + assert_eq!( + to_cap("CAP_NET_BROADCAST").unwrap(), + Capability::CAP_NET_BROADCAST + ); + assert_eq!(to_cap("CAP_NET_RAW").unwrap(), Capability::CAP_NET_RAW); + assert_eq!(to_cap("CAP_PERFMON").unwrap(), Capability::CAP_PERFMON); + assert_eq!(to_cap("CAP_SETGID").unwrap(), Capability::CAP_SETGID); + assert_eq!(to_cap("CAP_SETFCAP").unwrap(), Capability::CAP_SETFCAP); + assert_eq!(to_cap("CAP_SETPCAP").unwrap(), Capability::CAP_SETPCAP); + assert_eq!(to_cap("CAP_SETUID").unwrap(), Capability::CAP_SETUID); + assert_eq!(to_cap("CAP_SYS_ADMIN").unwrap(), Capability::CAP_SYS_ADMIN); + assert_eq!(to_cap("CAP_SYS_BOOT").unwrap(), Capability::CAP_SYS_BOOT); + assert_eq!( + to_cap("CAP_SYS_CHROOT").unwrap(), + Capability::CAP_SYS_CHROOT + ); + assert_eq!( + to_cap("CAP_SYS_MODULE").unwrap(), + Capability::CAP_SYS_MODULE + ); + assert_eq!(to_cap("CAP_SYS_NICE").unwrap(), Capability::CAP_SYS_NICE); + assert_eq!(to_cap("CAP_SYS_PACCT").unwrap(), Capability::CAP_SYS_PACCT); + assert_eq!( + to_cap("CAP_SYS_PTRACE").unwrap(), + Capability::CAP_SYS_PTRACE + ); + assert_eq!(to_cap("CAP_SYS_RAWIO").unwrap(), Capability::CAP_SYS_RAWIO); + assert_eq!( + to_cap("CAP_SYS_RESOURCE").unwrap(), + Capability::CAP_SYS_RESOURCE + ); + assert_eq!(to_cap("CAP_SYS_TIME").unwrap(), Capability::CAP_SYS_TIME); + assert_eq!( + to_cap("CAP_SYS_TTY_CONFIG").unwrap(), + Capability::CAP_SYS_TTY_CONFIG + ); + assert_eq!(to_cap("CAP_SYSLOG").unwrap(), Capability::CAP_SYSLOG); + assert_eq!( + to_cap("CAP_WAKE_ALARM").unwrap(), + Capability::CAP_WAKE_ALARM + ); + assert!(to_cap("CAP_TO_CAP").is_err()); + } + + rusty_fork_test! { + #[test] + #[ignore = "oom_score_adj may not be permitted to set"] + fn test_set_oom_score_adj() { + let mut oci_process = init_oci_process(); + oci_process.oomScoreAdj = Some(100); + let process = Process::new(&oci_process, false); + + assert!(process.set_oom_score_adj().is_ok()); + assert_eq!( + read_to_string(Path::new("/proc/self/oom_score_adj")).unwrap(), + String::from("100\n") + ); + } + + #[test] + #[ignore = "setrlimit may not be permitted"] + fn test_set_rlimits() { + let mut oci_process = init_oci_process(); + let rlimits = Rlimits { + rlimit_type: String::from("RLIMIT_CORE"), + soft: 10, + hard: 20, + }; + oci_process.rlimits = Some(vec![rlimits]); + let process = Process::new(&oci_process, false); + + assert!(process.set_rlimits().is_ok()); + assert_eq!(getrlimit(Resource::RLIMIT_CORE).unwrap().0, 10); + assert_eq!(getrlimit(Resource::RLIMIT_CORE).unwrap().1, 20); + } + + #[test] + fn test_set_io_priority() { + let mut oci_process = init_oci_process(); + let io_pri = IoPriority { + class: IoPriClass::IoprioClassBe, + priority: 7, + }; + oci_process.ioPriority = Some(io_pri.clone()); + let process = Process::new(&oci_process, false); + + assert!(process.set_io_priority().is_ok()); + + let io_prio_who_process: libc::c_int = 1; + let io_prio_who_pid = 0; + let ioprio = unsafe { + libc::syscall(libc::SYS_ioprio_get, io_prio_who_process, io_prio_who_pid) + }; + assert_eq!(ioprio, (2 as i64) << 13 | io_pri.priority); + } + + #[test] + fn test_set_scheduler() { + let mut oci_process = init_oci_process(); + let scheduler = Scheduler { + policy: SchedPolicy::SchedOther, + nice: None, + priority: None, + flags: None, + runtime: None, + deadline: None, + period: None, + }; + oci_process.scheduler = Some(scheduler); + let process = Process::new(&oci_process, false); + + assert!(process.set_scheduler().is_ok()); + } + + #[test] + fn test_set_no_new_privileges() { + let mut oci_process = init_oci_process(); + oci_process.noNewPrivileges = Some(true); + let process = Process::new(&oci_process, false); + + assert!(process.set_no_new_privileges().is_ok()); + } + + #[test] + #[ignore = "capset may not be permitted"] + fn test_drop_capabilities() { + let mut oci_process = init_oci_process(); + let caps = Capbilities { + effective: Some(vec![ + String::from("CAP_DAC_OVERRIDE"), + String::from("CAP_DAC_READ_SEARCH"), + String::from("CAP_SETFCAP"), + ]), + bounding: Some(vec![ + String::from("CAP_DAC_OVERRIDE"), + String::from("CAP_DAC_READ_SEARCH"), + ]), + inheritable: Some(vec![String::from("CAP_DAC_READ_SEARCH")]), + permitted: Some(vec![ + String::from("CAP_DAC_OVERRIDE"), + String::from("CAP_DAC_READ_SEARCH"), + String::from("CAP_SETFCAP"), + ]), + ambient: Some(vec![String::from("CAP_DAC_READ_SEARCH")]), + }; + oci_process.capabilities = Some(caps); + let process = Process::new(&oci_process, false); + + assert!(process.drop_capabilities().is_ok()); + let mut caps = caps::read(None, CapSet::Bounding).unwrap(); + assert_eq!(caps.len(), 2); + assert!(caps.get(&Capability::CAP_DAC_OVERRIDE).is_some()); + assert!(caps.get(&Capability::CAP_DAC_READ_SEARCH).is_some()); + caps = caps::read(None, CapSet::Effective).unwrap(); + assert_eq!(caps.len(), 3); + assert!(caps.get(&Capability::CAP_DAC_OVERRIDE).is_some()); + assert!(caps.get(&Capability::CAP_DAC_READ_SEARCH).is_some()); + assert!(caps.get(&Capability::CAP_SETFCAP).is_some()); + caps = caps::read(None, CapSet::Inheritable).unwrap(); + assert_eq!(caps.len(), 1); + assert!(caps.get(&Capability::CAP_DAC_READ_SEARCH).is_some()); + caps = caps::read(None, CapSet::Permitted).unwrap(); + assert_eq!(caps.len(), 3); + assert!(caps.get(&Capability::CAP_DAC_OVERRIDE).is_some()); + assert!(caps.get(&Capability::CAP_DAC_READ_SEARCH).is_some()); + assert!(caps.get(&Capability::CAP_SETFCAP).is_some()); + caps = caps::read(None, CapSet::Ambient).unwrap(); + assert_eq!(caps.len(), 1); + assert!(caps.get(&Capability::CAP_DAC_READ_SEARCH).is_some()); + } + + #[test] + fn test_reset_capabilities() { + let oci_process = init_oci_process(); + let process = Process::new(&oci_process, false); + + assert!(process.reset_capabilities().is_ok()); + let permit_caps = caps::read(None, CapSet::Permitted).unwrap(); + let eff_caps = caps::read(None, CapSet::Effective).unwrap(); + assert_eq!(permit_caps, eff_caps); + } + + #[test] + fn test_clean_envs() { + let oci_process = init_oci_process(); + let process = Process::new(&oci_process, false); + process.clean_envs(); + assert_eq!(env::vars().count(), 0); + } + } +} diff --git a/ozonec/src/linux/rootfs.rs b/ozonec/src/linux/rootfs.rs new file mode 100644 index 0000000000000000000000000000000000000000..48e433664788ae72dd73115a13bdebca5396bee0 --- /dev/null +++ b/ozonec/src/linux/rootfs.rs @@ -0,0 +1,505 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + fs::remove_file, + os::unix::fs::symlink, + path::{Path, PathBuf}, +}; + +use anyhow::{bail, Context, Result}; +use nix::{ + fcntl::{open, OFlag}, + mount::{umount2, MntFlags, MsFlags}, + sys::stat::{umask, Mode}, + unistd::{chroot, close, fchdir, pivot_root}, + NixPath, +}; +use procfs::process::Process; + +use super::{device::Device, mount::Mount}; +use crate::utils::OzonecErr; +use oci_spec::{ + linux::Device as OciDevice, + runtime::{Mount as OciMount, RuntimeConfig}, +}; + +pub struct Rootfs { + pub path: PathBuf, + propagation_flags: MsFlags, + mounts: Vec, + // Should we mknod the device or bind one. + mknod_device: bool, + devices: Vec, +} + +impl Rootfs { + pub fn new( + path: PathBuf, + propagation: Option, + mounts: Vec, + mknod_device: bool, + devices: Vec, + ) -> Result { + if !path.exists() { + bail!("Rootfs directory not exist"); + } + + let propagation_flags = Self::get_mount_flags(propagation)?; + Ok(Self { + path, + propagation_flags, + mounts, + mknod_device, + devices, + }) + } + + fn get_mount_flags(propagation: Option) -> Result { + let flags = match propagation.as_deref() { + Some("shared") => MsFlags::MS_SHARED, + Some("private") => MsFlags::MS_PRIVATE, + Some("slave") => MsFlags::MS_SLAVE, + Some("unbindable") => MsFlags::MS_UNBINDABLE, + Some(_) => bail!("Invalid rootfsPropagation"), + None => MsFlags::MS_REC | MsFlags::MS_SLAVE, + }; + Ok(flags) + } + + fn set_propagation(&self) -> Result<()> { + nix::mount::mount( + None::<&str>, + Path::new("/"), + None::<&str>, + self.propagation_flags, + None::<&str>, + ) + .with_context(|| "Failed to set rootfs mount propagation")?; + Ok(()) + } + + fn mount(&self) -> Result<()> { + nix::mount::mount( + Some(&self.path), + &self.path, + None::<&str>, + MsFlags::MS_BIND | MsFlags::MS_REC, + None::<&str>, + )?; + Ok(()) + } + + fn make_parent_mount_private(&self) -> Result<()> { + let process = Process::myself().with_context(|| OzonecErr::AccessProcSelf)?; + let mount_info = process.mountinfo().with_context(|| OzonecErr::GetMntInfo)?; + + if let Some(m) = mount_info + .into_iter() + .filter(|m| self.path.starts_with(&m.mount_point) && m.mount_point != self.path) + .map(|m| m.mount_point) + .max_by_key(|m| m.len()) + .as_ref() + { + nix::mount::mount(Some(m), m, None::<&str>, MsFlags::MS_PRIVATE, None::<&str>)?; + } + Ok(()) + } + + // OCI spec requires runtime MUST create the following symlinks if the source file exists after + // processing mounts: + // dev/fd -> /proc/self/fd + // dev/stdin -> /proc/self/fd/0 + // dev/stdout -> /proc/self/fd/1 + // dev/stderr -> /proc/self/fd/2 + fn set_default_symlinks(&self) -> Result<()> { + let link_pairs = vec![ + ((self.path).join("dev/fd"), "/proc/self/fd"), + ((self.path).join("dev/stdin"), "/proc/self/fd/0"), + ((self.path).join("dev/stdout"), "/proc/self/fd/1"), + ((self.path).join("dev/stderr"), "/proc/self/fd/2"), + ]; + + for pair in link_pairs { + let cloned_pair = pair.clone(); + symlink(pair.1, pair.0).with_context(|| { + format!( + "Failed to create symlink {} -> {}", + cloned_pair.0.display(), + cloned_pair.1 + ) + })?; + } + Ok(()) + } + + fn do_mounts(&self, config: &RuntimeConfig) -> Result<()> { + let mount = Mount::new(&self.path); + mount + .do_mounts(&self.mounts, &config.linux.as_ref().unwrap().mountLabel) + .with_context(|| "Failed to do mounts")?; + Ok(()) + } + + fn link_ptmx(&self) -> Result<()> { + let ptmx = self.path.clone().join("dev/ptmx"); + if ptmx.exists() { + remove_file(&ptmx).with_context(|| "Failed to delete ptmx")?; + } + symlink("pts/ptmx", &ptmx) + .with_context(|| format!("Failed to create symlink {} -> pts/ptmx", ptmx.display()))?; + Ok(()) + } + + fn create_default_devices(&self, mknod: bool) -> Result<()> { + let dev = Device::new(self.path.clone()); + dev.create_default_devices(mknod)?; + Ok(()) + } + + fn create_devices(&self, devices: &Vec, mknod: bool) -> Result<()> { + let dev = Device::new(self.path.clone()); + for d in devices { + if dev.is_default_device(d) { + dev.delete_device(d)?; + } + dev.create_device(d, mknod) + .with_context(|| format!("Failed to create device {}", d.path))?; + } + Ok(()) + } + + pub fn prepare_rootfs(&self, config: &RuntimeConfig) -> Result<()> { + self.set_propagation()?; + self.mount().with_context(|| "Failed to mount rootfs")?; + self.make_parent_mount_private() + .with_context(|| "Failed to make parent mount private")?; + self.do_mounts(config)?; + self.set_default_symlinks()?; + + let old_mode = umask(Mode::from_bits_truncate(0o000)); + self.create_default_devices(self.mknod_device)?; + self.create_devices(&self.devices, self.mknod_device)?; + umask(old_mode); + + self.link_ptmx()?; + Ok(()) + } + + pub fn chroot(path: &Path) -> Result<()> { + let new_root = open(path, OFlag::O_DIRECTORY | OFlag::O_RDONLY, Mode::empty()) + .with_context(|| OzonecErr::OpenFile(path.to_string_lossy().to_string()))?; + chroot(path)?; + fchdir(new_root).with_context(|| "Failed to chdir to new root directory")?; + Ok(()) + } + + pub fn pivot_root(path: &Path) -> Result<()> { + let new_root = open(path, OFlag::O_DIRECTORY | OFlag::O_RDONLY, Mode::empty()) + .with_context(|| OzonecErr::OpenFile(path.to_string_lossy().to_string()))?; + let old_root = open("/", OFlag::O_DIRECTORY | OFlag::O_RDONLY, Mode::empty()) + .with_context(|| OzonecErr::OpenFile("/".to_string()))?; + + pivot_root(path, path)?; + nix::mount::mount( + None::<&str>, + "/", + None::<&str>, + MsFlags::MS_SLAVE | MsFlags::MS_REC, + None::<&str>, + ) + .with_context(|| OzonecErr::Mount("/".to_string()))?; + + fchdir(old_root).with_context(|| "Failed to chdir to old root directory")?; + umount2(".", MntFlags::MNT_DETACH) + .with_context(|| "Failed to umount old root directory")?; + fchdir(new_root).with_context(|| "Failed to chdir to new root directory")?; + + close(old_root).with_context(|| "Failed to close old_root")?; + close(new_root).with_context(|| "Failed to close new_root")?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{ + fs::{self, create_dir_all, read_link, remove_dir_all}, + os::unix::fs::FileTypeExt, + }; + + use nix::unistd::chdir; + use rusty_fork::rusty_fork_test; + + use crate::linux::{container::tests::init_config, namespace::tests::set_namespace}; + use oci_spec::linux::NamespaceType; + + use super::*; + + fn init_rootfs(path: &str, propagation: Option, mounts: Vec) -> Rootfs { + let path = PathBuf::from(path); + create_dir_all(&path).unwrap(); + Rootfs::new(path, propagation, mounts, true, Vec::new()).unwrap() + } + + #[test] + fn test_rootfs_new() { + let path = PathBuf::from("/test_rootfs_new"); + assert!(Rootfs::new(path, None, Vec::new(), true, Vec::new()).is_err()); + } + + #[test] + fn test_get_mount_flags() { + assert_eq!( + Rootfs::get_mount_flags(Some(String::from("shared"))).unwrap(), + MsFlags::MS_SHARED + ); + assert_eq!( + Rootfs::get_mount_flags(Some(String::from("private"))).unwrap(), + MsFlags::MS_PRIVATE + ); + assert_eq!( + Rootfs::get_mount_flags(Some(String::from("slave"))).unwrap(), + MsFlags::MS_SLAVE + ); + assert_eq!( + Rootfs::get_mount_flags(Some(String::from("unbindable"))).unwrap(), + MsFlags::MS_UNBINDABLE + ); + assert_eq!( + Rootfs::get_mount_flags(None).unwrap(), + MsFlags::MS_REC | MsFlags::MS_SLAVE + ); + assert!(Rootfs::get_mount_flags(Some(String::from("unbind"))).is_err()); + } + + rusty_fork_test! { + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_propagation() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + let rootfs = init_rootfs( + "/tmp/ozonec/test_set_propagation", + Some(String::from("shared")), + Vec::new(), + ); + + assert!(rootfs.set_propagation().is_ok()); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_make_parent_mount_private() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + + let parent = PathBuf::from("/tmp/ozonec/test_make_parent_mount_private"); + create_dir_all(&parent).unwrap(); + nix::mount::mount( + Some(&parent), + &parent, + None::<&str>, + MsFlags::MS_BIND, + None::<&str>, + ) + .unwrap(); + let rootfs = init_rootfs( + "/tmp/ozonec/test_make_parent_mount_private/rootfs", + Some(String::from("shared")), + Vec::new(), + ); + + assert!(rootfs.make_parent_mount_private().is_ok()); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_set_default_symlinks() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + let mounts = vec![ + OciMount { + destination: String::from("/proc"), + source: Some(String::from("/proc")), + options: Some(Vec::new()), + fs_type: Some(String::from("proc")), + uidMappings: None, + gidMappings: None, + }, + OciMount { + destination: String::from("/dev"), + source: Some(String::from("tmpfs")), + options: Some(vec![ + String::from("nosuid"), + String::from("strictatime"), + String::from("mode=755"), + String::from("size=65536k"), + ]), + fs_type: Some(String::from("tmpfs")), + uidMappings: None, + gidMappings: None, + }, + ]; + let rootfs = init_rootfs( + "/tmp/ozonec/test_set_default_symlinks", + Some(String::from("shared")), + mounts, + ); + rootfs.mount().unwrap(); + + let mut config = init_config(); + config.root.path = rootfs.path.to_string_lossy().to_string(); + rootfs.do_mounts(&config).unwrap(); + + assert!(rootfs.set_default_symlinks().is_ok()); + chdir(&rootfs.path).unwrap(); + let mut path = PathBuf::from("dev/fd"); + let mut metadata = fs::symlink_metadata(&path).unwrap(); + assert!(metadata.is_symlink()); + assert_eq!(read_link(&path).unwrap(), PathBuf::from("/proc/self/fd")); + path = PathBuf::from("dev/stdin"); + metadata = fs::symlink_metadata(&path).unwrap(); + assert!(metadata.is_symlink()); + assert_eq!(read_link(&path).unwrap(), PathBuf::from("/proc/self/fd/0")); + path = PathBuf::from("dev/stdout"); + metadata = fs::symlink_metadata(&path).unwrap(); + assert!(metadata.is_symlink()); + assert_eq!(read_link(&path).unwrap(), PathBuf::from("/proc/self/fd/1")); + path = PathBuf::from("dev/stderr"); + metadata = fs::symlink_metadata(&path).unwrap(); + assert!(metadata.is_symlink()); + assert_eq!(read_link(&path).unwrap(), PathBuf::from("/proc/self/fd/2")); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_link_ptmx() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + let mounts = vec![OciMount { + destination: String::from("/dev"), + source: Some(String::from("tmpfs")), + options: Some(vec![ + String::from("nosuid"), + String::from("strictatime"), + String::from("mode=755"), + String::from("size=65536k"), + ]), + fs_type: Some(String::from("tmpfs")), + uidMappings: None, + gidMappings: None, + }]; + let rootfs = init_rootfs( + "/tmp/ozonec/test_link_ptmx", + Some(String::from("shared")), + mounts, + ); + let mut config = init_config(); + config.root.path = rootfs.path.to_string_lossy().to_string(); + rootfs.do_mounts(&config).unwrap(); + + assert!(rootfs.link_ptmx().is_ok()); + + chdir(&rootfs.path).unwrap(); + let path = PathBuf::from("dev/ptmx"); + let metadata = fs::symlink_metadata(&path).unwrap(); + assert!(metadata.is_symlink()); + assert_eq!(read_link(&path).unwrap(), PathBuf::from("pts/ptmx")); + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_create_default_devices() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + let mounts = vec![OciMount { + destination: String::from("/dev"), + source: Some(String::from("tmpfs")), + options: Some(vec![ + String::from("nosuid"), + String::from("strictatime"), + String::from("mode=755"), + String::from("size=65536k"), + ]), + fs_type: Some(String::from("tmpfs")), + uidMappings: None, + gidMappings: None, + }]; + let rootfs = init_rootfs( + "/tmp/ozonec/test_create_default_devices", + Some(String::from("shared")), + mounts, + ); + let mut config = init_config(); + config.root.path = rootfs.path.to_string_lossy().to_string(); + rootfs.do_mounts(&config).unwrap(); + + assert!(rootfs.create_default_devices(false).is_ok()); + for dev in Device::new(rootfs.path.clone()).default_devices() { + assert!(dev.path.exists()); + let metadata = fs::metadata(&dev.path).unwrap(); + assert!(metadata.file_type().is_char_device()); + } + } + + #[test] + #[ignore = "unshare may not be permitted"] + fn test_create_devices() { + remove_dir_all("/tmp/ozonec").unwrap_or_default(); + + set_namespace(NamespaceType::Mount); + + let mounts = vec![OciMount { + destination: String::from("/dev"), + source: Some(String::from("tmpfs")), + options: Some(vec![ + String::from("nosuid"), + String::from("strictatime"), + String::from("mode=755"), + String::from("size=65536k"), + ]), + fs_type: Some(String::from("tmpfs")), + uidMappings: None, + gidMappings: None, + }]; + let rootfs = init_rootfs( + "/tmp/ozonec/test_create_devices", + Some(String::from("shared")), + mounts, + ); + let mut config = init_config(); + config.root.path = rootfs.path.to_string_lossy().to_string(); + rootfs.do_mounts(&config).unwrap(); + + let devices = vec![OciDevice { + dev_type: String::from("c"), + path: String::from("/dev/test"), + major: Some(1), + minor: Some(3), + fileMode: Some(0o666u32), + uid: None, + gid: None, + }]; + assert!(rootfs.create_devices(&devices, true).is_ok()); + let path = rootfs.path.join("dev/test"); + assert!(path.exists()); + let metadata = fs::metadata(&path).unwrap(); + assert!(metadata.file_type().is_char_device()); + } + } +} diff --git a/ozonec/src/linux/seccomp.rs b/ozonec/src/linux/seccomp.rs new file mode 100644 index 0000000000000000000000000000000000000000..14f4ea4ca6a34d014939c5a3461a844b8f14015d --- /dev/null +++ b/ozonec/src/linux/seccomp.rs @@ -0,0 +1,249 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::vec; + +use anyhow::{bail, Context, Result}; + +use libseccomp::{ + ScmpAction, ScmpArch, ScmpArgCompare, ScmpCompareOp, ScmpFilterContext, ScmpSyscall, +}; +use oci_spec::linux::{Seccomp, SeccompAction as OciSeccompAction, SeccompOp}; + +use crate::utils::OzonecErr; + +fn parse_action(action: OciSeccompAction, errno: Option) -> ScmpAction { + let errno = errno.unwrap_or(libc::EPERM as u32); + match action { + OciSeccompAction::ScmpActKill => ScmpAction::KillThread, + OciSeccompAction::ScmpActKillProcess => ScmpAction::KillProcess, + OciSeccompAction::ScmpActTrap => ScmpAction::Trap, + OciSeccompAction::ScmpActErrno => ScmpAction::Errno(errno as i32), + OciSeccompAction::ScmpActTrace => ScmpAction::Trace(errno as u16), + OciSeccompAction::ScmpActLog => ScmpAction::Log, + OciSeccompAction::ScmpActAllow => ScmpAction::Allow, + OciSeccompAction::ScmpActNotify => ScmpAction::Notify, + } +} + +fn parse_cmp(op: SeccompOp, mask: u64) -> ScmpCompareOp { + match op { + SeccompOp::ScmpCmpNe => ScmpCompareOp::NotEqual, + SeccompOp::ScmpCmpLt => ScmpCompareOp::Less, + SeccompOp::ScmpCmpLe => ScmpCompareOp::LessOrEqual, + SeccompOp::ScmpCmpEq => ScmpCompareOp::Equal, + SeccompOp::ScmpCmpGe => ScmpCompareOp::GreaterEqual, + SeccompOp::ScmpCmpGt => ScmpCompareOp::Greater, + SeccompOp::ScmpCmpMaskedEq => ScmpCompareOp::MaskedEqual(mask), + } +} + +fn check_seccomp(seccomp: &Seccomp) -> Result<()> { + // We don't support NOTIFY as the default action. When the seccomp filter + // is created with NOTIFY, the container process will have to communicate + // the returned fd to another process. Therefore, ozonec needs to call + // the WRITE syscall. And then READ and CLOSE syscalls are also needed to + // be enabled to use. + if seccomp.defaultAction == OciSeccompAction::ScmpActNotify { + bail!("SCMP_ACT_NOTIFY is not supported as the default action"); + } + if let Some(syscalls) = &seccomp.syscalls { + for syscall in syscalls { + if syscall.action == OciSeccompAction::ScmpActNotify { + for name in &syscall.names { + if name == "write" { + bail!("SCMP_ACT_NOTIFY is not supported to be used for write syscall"); + } + } + } + } + } + + Ok(()) +} + +pub fn set_seccomp(seccomp: &Seccomp) -> Result<()> { + check_seccomp(seccomp)?; + + let default_action = parse_action(seccomp.defaultAction, seccomp.defaultErrnoRet); + if let Some(syscalls) = &seccomp.syscalls { + let mut filter = ScmpFilterContext::new_filter(default_action)?; + #[cfg(target_arch = "x86_64")] + filter + .add_arch(ScmpArch::X8664) + .with_context(|| OzonecErr::AddScmpArch)?; + #[cfg(target_arch = "aarch64")] + filter + .add_arch(ScmpArch::Aarch64) + .with_context(|| OzonecErr::AddScmpArch)?; + + for syscall in syscalls { + let action = parse_action(syscall.action, syscall.errnoRet); + if action == default_action { + continue; + } + + for name in &syscall.names { + let sc = ScmpSyscall::from_name(name)?; + let mut comparators: Vec = vec![]; + if let Some(args) = &syscall.args { + for arg in args { + let op = parse_cmp(arg.op, arg.value); + let cmp = match arg.op { + SeccompOp::ScmpCmpMaskedEq => { + ScmpArgCompare::new(arg.index as u32, op, arg.valueTwo.unwrap_or(0)) + } + _ => ScmpArgCompare::new(arg.index as u32, op, arg.value), + }; + comparators.push(cmp); + } + } + filter + .add_rule_conditional(action, sc, &comparators) + .with_context(|| "Failed to add conditional rule")?; + } + } + filter + .load() + .with_context(|| "Failed to load filter into the kernel")?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use rusty_fork::rusty_fork_test; + + use oci_spec::linux::{SeccompSyscall, SeccompSyscallArg}; + + use super::*; + + #[test] + fn test_parse_action() { + assert_eq!( + parse_action(OciSeccompAction::ScmpActKill, None), + ScmpAction::KillThread + ); + assert_eq!( + parse_action(OciSeccompAction::ScmpActKillProcess, None), + ScmpAction::KillProcess + ); + assert_eq!( + parse_action(OciSeccompAction::ScmpActTrap, None), + ScmpAction::Trap + ); + assert_eq!( + parse_action(OciSeccompAction::ScmpActErrno, Some(1)), + ScmpAction::Errno(1) + ); + assert_eq!( + parse_action(OciSeccompAction::ScmpActTrace, Some(1)), + ScmpAction::Trace(1) + ); + assert_eq!( + parse_action(OciSeccompAction::ScmpActLog, None), + ScmpAction::Log + ); + assert_eq!( + parse_action(OciSeccompAction::ScmpActAllow, None), + ScmpAction::Allow + ); + assert_eq!( + parse_action(OciSeccompAction::ScmpActNotify, None), + ScmpAction::Notify + ); + } + + #[test] + fn test_parse_cmp() { + assert_eq!(parse_cmp(SeccompOp::ScmpCmpNe, 0), ScmpCompareOp::NotEqual); + assert_eq!(parse_cmp(SeccompOp::ScmpCmpLt, 0), ScmpCompareOp::Less); + assert_eq!( + parse_cmp(SeccompOp::ScmpCmpLe, 0), + ScmpCompareOp::LessOrEqual + ); + assert_eq!(parse_cmp(SeccompOp::ScmpCmpEq, 0), ScmpCompareOp::Equal); + assert_eq!( + parse_cmp(SeccompOp::ScmpCmpGe, 0), + ScmpCompareOp::GreaterEqual + ); + assert_eq!(parse_cmp(SeccompOp::ScmpCmpGt, 0), ScmpCompareOp::Greater); + assert_eq!( + parse_cmp(SeccompOp::ScmpCmpMaskedEq, 1), + ScmpCompareOp::MaskedEqual(1) + ); + } + + #[test] + fn test_check_seccomp() { + let mut seccomp = Seccomp { + defaultAction: OciSeccompAction::ScmpActNotify, + defaultErrnoRet: None, + architectures: None, + flags: None, + listennerPath: None, + seccompFd: None, + listenerMetadata: None, + syscalls: None, + }; + assert!(check_seccomp(&seccomp).is_err()); + + seccomp.defaultAction = OciSeccompAction::ScmpActAllow; + let syscall = SeccompSyscall { + names: vec![String::from("write")], + action: OciSeccompAction::ScmpActNotify, + errnoRet: None, + args: None, + }; + seccomp.syscalls = Some(vec![syscall]); + assert!(check_seccomp(&seccomp).is_err()); + } + + rusty_fork_test! { + #[test] + fn test_set_seccomp() { + let mut seccomp = Seccomp { + defaultAction: OciSeccompAction::ScmpActAllow, + defaultErrnoRet: None, + architectures: None, + flags: None, + listennerPath: None, + seccompFd: None, + listenerMetadata: None, + syscalls: None, + }; + let syscall = SeccompSyscall { + names: vec![String::from("write")], + action: OciSeccompAction::ScmpActKill, + errnoRet: None, + args: Some(vec![ + SeccompSyscallArg { + index: 0, + value: 0, + valueTwo: Some(0), + op: SeccompOp::ScmpCmpEq, + }, + SeccompSyscallArg { + index: 2, + value: 0, + valueTwo: Some(0), + op: SeccompOp::ScmpCmpMaskedEq, + }, + ]), + }; + seccomp.syscalls = Some(vec![syscall]); + + assert!(set_seccomp(&seccomp).is_ok()); + } + } +} diff --git a/ozonec/src/linux/terminal.rs b/ozonec/src/linux/terminal.rs new file mode 100644 index 0000000000000000000000000000000000000000..26da7376e5ad2cd5486ad22d78d0e390eec95e8b --- /dev/null +++ b/ozonec/src/linux/terminal.rs @@ -0,0 +1,111 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + fs::File, + io::IoSlice, + mem::ManuallyDrop, + os::unix::io::{AsRawFd, RawFd}, + path::PathBuf, +}; + +use anyhow::{bail, Context, Result}; +use nix::{ + errno::errno, + fcntl::{open, OFlag}, + mount::MsFlags, + pty::{posix_openpt, ptsname, unlockpt}, + sys::{ + socket::{sendmsg, ControlMessage, MsgFlags, UnixAddr}, + stat::{fchmod, Mode}, + }, + unistd::{close, dup2}, +}; + +use crate::utils::OzonecErr; + +pub enum Stdio { + Stdin = 0, + Stdout = 1, + Stderr = 2, +} + +pub fn setup_console(console_fd: &RawFd, mount: bool) -> Result<()> { + let master_fd = posix_openpt(OFlag::O_RDWR).with_context(|| "openpt error")?; + let pty_name: &[u8] = b"/dev/ptmx"; + let iov = [IoSlice::new(pty_name)]; + // Use ManuallyDrop to keep fds open. + let master = ManuallyDrop::new(master_fd.as_raw_fd()); + let fds = [master.as_raw_fd()]; + let cmsg = ControlMessage::ScmRights(&fds); + sendmsg::( + console_fd.as_raw_fd(), + &iov, + &[cmsg], + MsgFlags::empty(), + None, + ) + .with_context(|| "sendmsg error")?; + + // SAFETY: FFI call with valid arguments. + let slave_name = unsafe { ptsname(&master_fd).with_context(|| "ptsname error")? }; + unlockpt(&master_fd).with_context(|| "unlockpt error")?; + let slave_path = PathBuf::from(&slave_name); + if mount { + let file = File::create("/dev/console").with_context(|| "Failed to create /dev/console")?; + fchmod(file.as_raw_fd(), Mode::from_bits_truncate(0o666u32)) + .with_context(|| "chmod error")?; + nix::mount::mount( + Some(&slave_path), + "/dev/console", + Some("bind"), + MsFlags::MS_BIND, + None::<&str>, + ) + .with_context(|| OzonecErr::Mount(slave_name.clone()))?; + } + + let slave_fd = open(&slave_path, OFlag::O_RDWR, Mode::empty()) + .with_context(|| OzonecErr::OpenFile(slave_name.clone()))?; + let slave = ManuallyDrop::new(slave_fd); + // SAFETY: FFI call with valid arguments. + if unsafe { libc::ioctl(slave.as_raw_fd(), libc::TIOCSCTTY) } != 0 { + bail!("TIOCSCTTY error: {}", errno()); + } + connect_stdio(&slave_fd, &slave_fd, &slave_fd)?; + close(console_fd.as_raw_fd()).with_context(|| "Failed to close console socket")?; + Ok(()) +} + +pub fn connect_stdio(stdin: &RawFd, stdout: &RawFd, stderr: &RawFd) -> Result<()> { + dup2(*stdin, (Stdio::Stdin as i32).as_raw_fd()) + .with_context(|| OzonecErr::Dup2("stdin".to_string()))?; + dup2(*stdout, (Stdio::Stdout as i32).as_raw_fd()) + .with_context(|| OzonecErr::Dup2("stdout".to_string()))?; + dup2(*stderr, (Stdio::Stderr as i32).as_raw_fd()) + .with_context(|| OzonecErr::Dup2("stderr".to_string()))?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_connect_stdio() { + let stdin: RawFd = 0; + let stdout: RawFd = 0; + let stderr: RawFd = 0; + + assert!(connect_stdio(&stdin, &stdout, &stderr).is_ok()); + } +} diff --git a/ozonec/src/main.rs b/ozonec/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..29c529bb9fbdcd75ef215d689b48c99ee5e9f6aa --- /dev/null +++ b/ozonec/src/main.rs @@ -0,0 +1,144 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +mod commands; +mod container; +mod linux; +mod utils; + +use std::{ + fs::remove_dir_all, + path::{Path, PathBuf}, + process::exit, +}; + +use anyhow::{anyhow, Context, Result}; +use clap::{crate_description, Args, Parser, Subcommand}; +use commands::{Delete, Exec, Kill, Start, State}; +use log::info; +use nix::unistd::geteuid; + +use crate::{commands::Create, utils::logger}; + +// Global options which are not binded to any specific command. +#[derive(Args, Debug)] +struct GlobalOpts { + /// Root directory to store container state. + #[arg(short, long)] + root: Option, + /// Path of log file. + #[arg(short, long)] + log: Option, + /// Enable debug log level. + #[arg(short, long)] + debug: bool, +} + +// Standard commands supported by [OCI runtime-spec] +// (https://github.com/opencontainers/runtime-spec/blob/master/runtime.md) +// and [OCI Command Line Interface] +// (https://github.com/opencontainers/runtime-tools/blob/master/docs/command-line-interface.md). +#[derive(Subcommand, Debug)] +enum StandardCmd { + Create(Create), + Start(Start), + State(State), + Kill(Kill), + Delete(Delete), +} + +// Extended commands not documented in [OCI Command Line Interface]. +#[derive(Subcommand, Debug)] +enum ExtendCmd { + Exec(Exec), +} + +#[derive(Subcommand, Debug)] +enum Command { + #[command(flatten)] + Standard(StandardCmd), + #[command(flatten)] + Extend(ExtendCmd), +} + +#[derive(Parser, Debug)] +#[command(version, author, about = crate_description!())] +#[command(propagate_version = true)] +struct Cli { + #[command(flatten)] + global: GlobalOpts, + #[command(subcommand)] + cmd: Command, +} + +fn cmd_run(command: Command, root: &Path) -> Result<()> { + match command { + Command::Standard(cmd) => match cmd { + StandardCmd::Create(create) => { + info!("Run command: {:?}", create); + + let mut root_exist = false; + create.run(root, &mut root_exist).map_err(|e| { + if !root_exist { + let _ = remove_dir_all(root); + } + anyhow!(e) + })? + } + StandardCmd::Start(start) => { + info!("Run command: {:?}", start); + start.run(root)? + } + StandardCmd::Kill(kill) => { + info!("Run command: {:?}", kill); + kill.run(root)? + } + StandardCmd::Delete(delete) => { + info!("Run command: {:?}", delete); + delete.run(root)? + } + StandardCmd::State(state) => { + info!("Run command: {:?}", state); + state.run(root)? + } + }, + Command::Extend(cmd) => match cmd { + ExtendCmd::Exec(exec) => { + info!("Run command: {:?}", exec); + exec.run(root)? + } + }, + } + Ok(()) +} + +fn real_main() -> Result<()> { + let cli = Cli::parse(); + + logger::init(&cli.global.log, cli.global.debug).with_context(|| "Failed to init logger")?; + + let root_path = if let Some(root) = cli.global.root { + root + } else { + let euid = geteuid(); + PathBuf::from(format!("/var/run/user/{}/ozonec", euid)) + }; + cmd_run(cli.cmd, &root_path) +} + +fn main() { + if let Err(e) = real_main() { + eprintln!("ERROR: {:?}", e); + exit(1); + } + exit(0); +} diff --git a/ozonec/src/utils/channel.rs b/ozonec/src/utils/channel.rs new file mode 100644 index 0000000000000000000000000000000000000000..b5d850e3e3d3971d38da8732ef93cbb3a603b58f --- /dev/null +++ b/ozonec/src/utils/channel.rs @@ -0,0 +1,263 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + fmt::Debug, + io::{IoSlice, IoSliceMut}, + marker::PhantomData, + mem, + os::unix::io::RawFd, + slice, +}; + +use anyhow::{bail, Context, Result}; +use nix::{ + sys::{ + socket::{ + recvmsg, sendmsg, setsockopt, socketpair, sockopt, AddressFamily, MsgFlags, SockFlag, + SockType, UnixAddr, + }, + time::TimeVal, + }, + unistd::{self, Pid}, +}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +// Wrapper for messages to be sent between parent and child processes. +#[derive(Debug, Serialize, Deserialize)] +pub enum Message { + IdMappingStart, + IdMappingDone, + InitReady(i32), + ContainerCreated, + ExecFailed(String), +} + +pub struct Sender { + fd: RawFd, + phantom: PhantomData, +} + +impl Sender +where + T: Serialize, +{ + pub fn close(&self) -> Result<()> { + Ok(unistd::close(self.fd)?) + } + + pub fn send(&self, msg: T) -> Result<()> { + let msg_vec = serde_json::to_vec(&msg).with_context(|| "Failed to load message")?; + let msg_len = msg_vec.len() as u64; + let iov = [ + // SAFETY: FFI call with valid arguments. + IoSlice::new(unsafe { + slice::from_raw_parts((&msg_len as *const u64) as *const u8, mem::size_of::()) + }), + IoSlice::new(&msg_vec), + ]; + + sendmsg::(self.fd, &iov, &[], MsgFlags::empty(), None)?; + Ok(()) + } +} + +pub struct Receiver { + fd: RawFd, + phantom: PhantomData, +} + +impl Receiver +where + T: DeserializeOwned, +{ + pub fn close(&self) -> Result<()> { + Ok(unistd::close(self.fd)?) + } + + pub fn set_timeout(&self, timeout: i64) -> Result<()> { + let timeval = TimeVal::new(0, timeout); + setsockopt(self.fd, sockopt::ReceiveTimeout, &timeval) + .with_context(|| "Failed to set receiver end timeout")?; + Ok(()) + } + + fn max_len_iovec(&self) -> Result { + let mut len: u64 = 0; + // SAFETY: len and type "u64" are both valid. + let mut iov = [IoSliceMut::new(unsafe { + slice::from_raw_parts_mut((&mut len as *mut u64) as *mut u8, mem::size_of::()) + })]; + + recvmsg::(self.fd, &mut iov, None, MsgFlags::MSG_PEEK)?; + match len { + 0 => bail!("Failed to get maximum length"), + _ => Ok(len), + } + } + + pub fn recv(&self) -> Result { + let msg_len = self.max_len_iovec()?; + let mut received_len: u64 = 0; + let mut buf = vec![0u8; msg_len as usize]; + let bytes = { + let mut iov = [ + // SAFETY: FFI call with valid arguments. + IoSliceMut::new(unsafe { + slice::from_raw_parts_mut( + (&mut received_len as *mut u64) as *mut u8, + mem::size_of::(), + ) + }), + IoSliceMut::new(&mut buf), + ]; + let mut cmsg = nix::cmsg_space!(T); + let msg = recvmsg::( + self.fd, + &mut iov, + Some(&mut cmsg), + MsgFlags::MSG_CMSG_CLOEXEC, + )?; + msg.bytes + }; + + match bytes { + 0 => bail!("Received zero length message"), + _ => Ok(serde_json::from_slice(&buf[..]) + .with_context(|| "Failed to read received message")?), + } + } +} + +pub struct Channel { + pub sender: Sender, + pub receiver: Receiver, +} + +impl Channel { + pub fn new() -> Result> { + let (sender_fd, receiver_fd) = socketpair( + AddressFamily::Unix, + SockType::SeqPacket, + None, + SockFlag::SOCK_CLOEXEC, + )?; + let sender = Sender { + fd: sender_fd, + phantom: PhantomData, + }; + let receiver = Receiver { + fd: receiver_fd, + phantom: PhantomData, + }; + + Ok(Channel { sender, receiver }) + } + + pub fn recv_container_created(&self) -> Result<()> { + let msg = self.receiver.recv()?; + match msg { + Message::ContainerCreated => Ok(()), + _ => bail!("Expect receiving ContainerCreated, but got {:?}", msg), + } + } + + pub fn send_container_created(&self) -> Result<()> { + self.sender + .send(Message::ContainerCreated) + .with_context(|| "Failed to send created message to parent process") + } + + pub fn recv_id_mappings(&self) -> Result<()> { + let msg = self.receiver.recv()?; + match msg { + Message::IdMappingStart => Ok(()), + _ => bail!("Expect receiving IdMappingStart, but got {:?}", msg), + } + } + + pub fn send_id_mappings(&self) -> Result<()> { + self.sender.send(Message::IdMappingStart) + } + + pub fn recv_init_pid(&self) -> Result { + let msg = self.receiver.recv()?; + match msg { + Message::InitReady(pid) => Ok(Pid::from_raw(pid)), + _ => bail!("Expect receiving InitReady, but got {:?}", msg), + } + } + + pub fn recv_id_mappings_done(&self) -> Result<()> { + let msg = self.receiver.recv()?; + match msg { + Message::IdMappingDone => Ok(()), + _ => bail!("Expect receiving IdMappingDone, but got {:?}", msg), + } + } + + pub fn send_id_mappings_done(&self) -> Result<()> { + self.sender.send(Message::IdMappingDone) + } + + pub fn send_init_pid(&self, pid: Pid) -> Result<()> { + self.sender + .send(Message::InitReady(pid.as_raw())) + .with_context(|| "Failed to send container process pid") + } +} + +#[cfg(test)] +mod tests { + use nix::sys::wait::{waitpid, WaitStatus}; + use unistd::getpid; + + use crate::linux::clone_process; + + use super::*; + + #[test] + fn test_channel() { + let channel = Channel::::new().unwrap(); + let child = clone_process("test_channel", || { + channel.receiver.close().unwrap(); + + channel.send_container_created().unwrap(); + channel.send_init_pid(getpid()).unwrap(); + channel.send_id_mappings().unwrap(); + channel.send_id_mappings_done().unwrap(); + + channel.sender.close().unwrap(); + Ok(0) + }) + .unwrap(); + + channel.sender.close().unwrap(); + + channel.recv_container_created().unwrap(); + channel.recv_init_pid().unwrap(); + channel.recv_id_mappings().unwrap(); + channel.recv_id_mappings_done().unwrap(); + + channel.receiver.close().unwrap(); + + match waitpid(child, None) { + Ok(WaitStatus::Exited(_, s)) => { + assert_eq!(s, 0); + } + Ok(_) => (), + Err(e) => { + panic!("Failed to waitpid for child process: {e}"); + } + } + } +} diff --git a/ozonec/src/utils/clone.rs b/ozonec/src/utils/clone.rs new file mode 100644 index 0000000000000000000000000000000000000000..2dd99e6629b2f0dfd63b1e122d915f18fb5327de --- /dev/null +++ b/ozonec/src/utils/clone.rs @@ -0,0 +1,125 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::os::unix::io::{AsRawFd, RawFd}; + +use anyhow::{bail, Context, Result}; +use libc::pid_t; +use nix::{errno::errno, unistd::Pid}; + +bitflags::bitflags! { + #[derive(Default)] + pub struct Flags: u64 { + const CHILD_CLEARTID = 0x00200000; + const CHILD_SETTID = 0x01000000; + const FILES = 0x00000400; + const FS = 0x00000200; + const INTO_CGROUP = 0x200000000; + const IO = 0x80000000; + const NEWCGROUP = 0x02000000; + const NEWIPC = 0x08000000; + const NEWNET = 0x40000000; + const NEWNS = 0x00020000; + const NEWPID = 0x20000000; + const NEWTIME = 0x00000080; + const NEWUSER = 0x10000000; + const NEWUTS = 0x04000000; + const PARENT = 0x00008000; + const PARENT_SETTID = 0x00100000; + const PIDFD = 0x00001000; + const PTRACE = 0x00002000; + const SETTLS = 0x00080000; + const SIGHAND = 0x00000800; + const SYSVSEM = 0x00040000; + const THREAD = 0x00010000; + const UNTRACED = 0x00800000; + const VFORK = 0x00004000; + const VM = 0x00000100; + } +} + +#[repr(C, align(8))] +#[derive(Debug, Default)] +pub struct CloneArgs { + pub flags: u64, + pub pid_fd: u64, + pub child_tid: u64, + pub parent_tid: u64, + pub exit_signal: u64, + pub stack: u64, + pub stack_size: u64, + pub tls: u64, + pub cgroup: u64, +} + +#[derive(Default)] +pub struct Clone3<'a> { + flags: Flags, + pidfd: Option<&'a mut RawFd>, + child_tid: Option<&'a mut libc::pid_t>, + parent_tid: Option<&'a mut libc::pid_t>, + exit_signal: u64, + stack: Option<&'a mut [u8]>, + tls: Option, + cgroup: Option<&'a dyn AsRawFd>, +} + +fn option_as_mut_ptr(o: &mut Option<&mut T>) -> *mut T { + match o { + Some(inner) => *inner as *mut T, + None => std::ptr::null_mut(), + } +} + +fn option_slice_as_mut_ptr(o: &mut Option<&mut [T]>) -> *mut T { + match o { + Some(inner) => inner.as_mut_ptr(), + None => std::ptr::null_mut(), + } +} + +impl<'a> Clone3<'a> { + pub fn exit_signal(&mut self, exit_signal: u64) -> &mut Self { + self.exit_signal = exit_signal; + self + } + + pub fn call(&mut self) -> Result { + let clone_args = CloneArgs { + flags: self.flags.bits(), + pid_fd: option_as_mut_ptr(&mut self.pidfd) as u64, + child_tid: option_as_mut_ptr(&mut self.child_tid) as u64, + parent_tid: option_as_mut_ptr(&mut self.parent_tid) as u64, + exit_signal: self.exit_signal, + stack: option_slice_as_mut_ptr(&mut self.stack) as u64, + stack_size: self.stack.as_ref().map(|stack| stack.len()).unwrap_or(0) as u64, + tls: self.tls.unwrap_or(0), + cgroup: self.cgroup.map(AsRawFd::as_raw_fd).unwrap_or(0) as u64, + }; + + // SAFETY: FFI call with valid arguments. + let ret = unsafe { + libc::syscall( + libc::SYS_clone3, + &clone_args as *const CloneArgs, + core::mem::size_of::(), + ) + }; + if ret == -1 { + bail!("clone3 error: errno {}", errno()); + } + + Ok(Pid::from_raw( + pid_t::try_from(ret).with_context(|| "Invalid pid")?, + )) + } +} diff --git a/ozonec/src/utils/error.rs b/ozonec/src/utils/error.rs new file mode 100644 index 0000000000000000000000000000000000000000..b3f93728d3a05efd55be82af0c62f8284dcd31ec --- /dev/null +++ b/ozonec/src/utils/error.rs @@ -0,0 +1,49 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum OzonecErr { + #[error("Failed to access /proc/{0}")] + ReadProcPid(i32), + #[error("Failed to access /proc/{0}/status")] + ReadProcStat(i32), + #[error("Failed to open {0}")] + OpenFile(String), + #[error("Failed to create directory {0}")] + CreateDir(String), + #[error("Failed to mount {0}")] + Mount(String), + #[error("Failed to access /proc/self")] + AccessProcSelf, + #[error("Failed to get mountinfo")] + GetMntInfo, + #[error("Dup2 {0} error")] + Dup2(String), + #[error("Failed to get all capabilities of {0} set")] + GetAllCaps(String), + #[error("Failed to set the capability set {0}")] + SetCaps(String), + #[error("Failed to add architecture to seccomp filter")] + AddScmpArch, + #[error("Failed to get current directory")] + GetCurDir, + #[error("Failed to load container state")] + LoadConState, + #[error("Failed to get oci state")] + GetOciState, + #[error("Failed to bind device: {0}")] + BindDev(String), + #[error("Close fd error")] + CloseFd, +} diff --git a/ozonec/src/utils/logger.rs b/ozonec/src/utils/logger.rs new file mode 100644 index 0000000000000000000000000000000000000000..33ecd86bc5ad6c60356b07a8c7951b63922729f2 --- /dev/null +++ b/ozonec/src/utils/logger.rs @@ -0,0 +1,292 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + fs::{remove_file, rename, File, OpenOptions}, + io::{stderr, Write}, + num::Wrapping, + os::unix::fs::OpenOptionsExt, + path::{Path, PathBuf}, + sync::Mutex, + time::UNIX_EPOCH, +}; + +use anyhow::{Context, Result}; +use log::{set_boxed_logger, set_max_level, Level, LevelFilter, Log, Metadata, Record}; +use nix::unistd::{getpid, gettid}; + +use super::OzonecErr; + +// Maximum size of log file is 100MB. +const LOG_ROTATE_SIZE_MAX: usize = 100 * 1024 * 1024; +// Logs are retained for seven days at most. +const LOG_ROTATE_CNT_MAX: u8 = 7; + +struct LogRotate { + handler: Box, + path: String, + size: Wrapping, + created_day: i32, +} + +impl LogRotate { + fn rotate(&mut self, inc_size: usize) -> Result<()> { + if self.path.is_empty() { + return Ok(()); + } + + self.size += Wrapping(inc_size); + let seconds = wall_time().0; + let today = formatted_time(seconds)[2]; + if self.size < Wrapping(LOG_ROTATE_SIZE_MAX) && self.created_day == today { + return Ok(()); + } + + // Delete oldest log file. + let mut rotate_cnt = LOG_ROTATE_CNT_MAX - 1; + let olddest = format!("{}{}", self.path, rotate_cnt); + if Path::new(&olddest).exists() { + remove_file(&olddest).with_context(|| "Failed to delete olddest log")?; + } + + // Rename remaining logs. + let mut new_log = olddest; + while rotate_cnt != 0 { + let mut old_log = self.path.clone(); + + rotate_cnt -= 1; + if rotate_cnt != 0 { + old_log += &rotate_cnt.to_string(); + } + + if Path::new(&old_log).exists() { + rename(&old_log, &new_log) + .with_context(|| format!("Failed to rename {} to {}", old_log, new_log))?; + } + new_log = old_log; + } + + self.handler = Box::new( + open_log_file(&PathBuf::from(self.path.clone())) + .with_context(|| format!("Failed to convert {}", self.path))?, + ); + self.size = Wrapping(0); + self.created_day = today; + Ok(()) + } +} + +fn open_log_file(path: &PathBuf) -> Result { + OpenOptions::new() + .read(false) + .write(true) + .append(true) + .create(true) + .mode(0o640) + .open(path) + .with_context(|| OzonecErr::OpenFile(path.to_string_lossy().to_string())) +} + +fn formatted_time(seconds: i64) -> [i32; 6] { + // SAFETY: an all-zero value is valid for libc::tm. + let mut ti: libc::tm = unsafe { std::mem::zeroed() }; + // SAFETY: seconds and ti are both local variables and valid. + unsafe { + libc::localtime_r(&seconds, &mut ti); + } + [ + ti.tm_year + 1900, + ti.tm_mon + 1, + ti.tm_mday, + ti.tm_hour, + ti.tm_min, + ti.tm_sec, + ] +} + +fn wall_time() -> (i64, i64) { + let mut ts = libc::timespec { + tv_sec: 0, + tv_nsec: 0, + }; + // SAFETY: ts is a local variable and valid. + unsafe { + libc::clock_gettime(libc::CLOCK_REALTIME, &mut ts); + } + (ts.tv_sec, ts.tv_nsec) +} + +fn formatted_now() -> String { + let (sec, nsec) = wall_time(); + let formatted_time = formatted_time(sec); + + format!( + "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}:{:09}", + formatted_time[0], + formatted_time[1], + formatted_time[2], + formatted_time[3], + formatted_time[4], + formatted_time[5], + nsec + ) +} + +struct Logger { + rotate: Mutex, + level: Level, +} + +impl Logger { + fn new(path: &Option, level: Level) -> Result { + let (log_file, log_size, created_day) = match path { + Some(p) => { + let file = Box::new(open_log_file(p)?); + let metadata = file.metadata().with_context(|| "Failed to get metadata")?; + let mod_time = metadata + .modified() + .with_context(|| "Failed to get modify time")?; + let seconds = mod_time + .duration_since(UNIX_EPOCH) + .with_context(|| "Failed to get duration time")? + .as_secs(); + let log_size = Wrapping(metadata.len() as usize); + let created_day = formatted_time(seconds as i64)[2]; + (file as Box, log_size, created_day) + } + None => (Box::new(stderr()) as Box, Wrapping(0), 0), + }; + + let rotate = Mutex::new(LogRotate { + handler: log_file, + path: path + .as_ref() + .unwrap_or(&PathBuf::new()) + .to_string_lossy() + .to_string(), + size: log_size, + created_day, + }); + Ok(Self { rotate, level }) + } +} + +impl Log for Logger { + fn enabled(&self, metadata: &Metadata) -> bool { + metadata.level() <= self.level + } + + fn log(&self, record: &Record) { + if !self.enabled(record.metadata()) { + return; + } + + let fmt_msg = format_args!( + "{:<5}: [{}][{}][{}: {}]:{}: {}\n", + formatted_now(), + getpid(), + gettid(), + record.file().unwrap_or(""), + record.line().unwrap_or(0), + record.level(), + record.args() + ) + .to_string(); + + let mut log_rotate = self.rotate.lock().unwrap(); + if let Err(e) = log_rotate.handler.write_all(fmt_msg.as_bytes()) { + eprintln!("Failed to log message: {:?}", e); + return; + } + if let Err(e) = log_rotate.rotate(fmt_msg.as_bytes().len()) { + eprintln!("Failed to rotate log files: {:?}", e); + } + } + + fn flush(&self) {} +} + +pub fn init(path: &Option, debug: bool) -> Result<()> { + let log_level = if debug { + Level::Debug + } else { + match std::env::var("OZONEC_LOG_LEVEL") { + Ok(level) => match level.to_lowercase().as_str() { + "error" => Level::Error, + "warn" => Level::Warn, + "info" => Level::Info, + "debug" => Level::Debug, + "trace" => Level::Trace, + _ => Level::Info, + }, + _ => Level::Info, + } + }; + + let logger = Box::new(Logger::new(path, log_level)?); + set_boxed_logger(logger) + .map(|_| set_max_level(LevelFilter::Trace)) + .with_context(|| "Logger has been already set")?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::{fs, os::unix::fs::MetadataExt}; + + use super::*; + + #[test] + fn test_logger_init() { + assert!(init(&Some(PathBuf::from("/tmp/ozonec.log")), false).is_ok()); + remove_file(Path::new("/tmp/ozonec.log")).unwrap(); + } + + #[test] + fn test_logger_rotate() { + let log_file = PathBuf::from("/tmp/ozonec.log"); + let logger = Logger::new(&Some(log_file.clone()), Level::Debug).unwrap(); + let mut locked_rotate = logger.rotate.lock().unwrap(); + // Time in metadata are not changed as the file descriptor is still opened. + let inode = fs::metadata(&log_file).unwrap().ino(); + for i in 1..LOG_ROTATE_CNT_MAX { + let file = format!("{}{}", locked_rotate.path, i); + let path = Path::new(&file); + File::create(path).unwrap(); + } + + locked_rotate.size = Wrapping(0); + assert!(locked_rotate.rotate(1024).is_ok()); + let mut new_inode = fs::metadata(&log_file).unwrap().ino(); + assert_eq!(inode, new_inode); + + locked_rotate.size = Wrapping(LOG_ROTATE_SIZE_MAX); + assert!(locked_rotate.rotate(1024).is_ok()); + new_inode = fs::metadata(&log_file).unwrap().ino(); + assert_ne!(inode, new_inode); + assert_eq!(locked_rotate.size, Wrapping(0)); + + locked_rotate.size = Wrapping(0); + locked_rotate.created_day = formatted_time(wall_time().0)[2] - 1; + assert!(locked_rotate.rotate(1024).is_ok()); + new_inode = fs::metadata(&log_file).unwrap().ino(); + assert_ne!(inode, new_inode); + assert_eq!(locked_rotate.size, Wrapping(0)); + + for i in 1..LOG_ROTATE_CNT_MAX { + let file = format!("{}{}", locked_rotate.path, i); + let path = Path::new(&file); + remove_file(path).unwrap(); + } + remove_file(Path::new("/tmp/ozonec.log")).unwrap(); + } +} diff --git a/ozonec/src/utils/mod.rs b/ozonec/src/utils/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..4e86fc35f5f65197816a8966230e0a38656bec98 --- /dev/null +++ b/ozonec/src/utils/mod.rs @@ -0,0 +1,123 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +pub mod logger; +pub mod prctl; + +mod channel; +mod clone; +mod error; + +pub use channel::{Channel, Message}; +pub use clone::Clone3; +pub use error::OzonecErr; + +use std::{ + fs::create_dir_all, + mem, + os::unix::io::{AsRawFd, RawFd}, + path::{Path, PathBuf}, +}; + +use anyhow::{bail, Context, Result}; +use nix::{ + errno::errno, + fcntl::{open, OFlag}, + sys::stat::Mode, + NixPath, +}; + +struct OpenHow(libc::open_how); + +bitflags::bitflags! { + struct ResolveFlag: libc::c_ulonglong { + const RESOLVE_BENEATH = libc::RESOLVE_BENEATH; + const RESOLVE_IN_ROOT = libc::RESOLVE_IN_ROOT; + const RESOLVE_NO_MAGICLINKS = libc::RESOLVE_NO_MAGICLINKS; + const RESOLVE_NO_SYMLINKS = libc::RESOLVE_NO_SYMLINKS; + const RESOLVE_NO_XDEV = libc::RESOLVE_NO_XDEV; + } +} + +impl OpenHow { + fn new() -> Self { + // SAFETY: FFI call with valid arguments. + unsafe { mem::zeroed() } + } + + fn flags(mut self, flags: OFlag) -> Self { + let flags = flags.bits() as libc::c_ulonglong; + self.0.flags = flags; + self + } + + fn mode(mut self, mode: Mode) -> Self { + let mode = mode.bits() as libc::c_ulonglong; + self.0.mode = mode; + self + } + + fn resolve(mut self, resolve: ResolveFlag) -> Self { + let resolve = resolve.bits() as libc::c_ulonglong; + self.0.resolve = resolve; + self + } +} + +/// Get a file descriptor by openat2 with `root` path, relative `target` path in `root` +/// and whether is director or not. If the target directory or file doesn't exist, create +/// automatically. +pub fn openat2_in_root(root: &Path, target: &Path, is_dir: bool) -> Result { + let mut flags = OFlag::O_CLOEXEC; + let mode; + if is_dir { + flags |= OFlag::O_DIRECTORY | OFlag::O_PATH; + mode = Mode::empty(); + create_dir_all(root.join(target)) + .with_context(|| OzonecErr::CreateDir(target.to_string_lossy().to_string()))?; + } else { + flags |= OFlag::O_CREAT; + mode = Mode::S_IRWXU; + }; + + let mut open_how = OpenHow::new() + .flags(flags) + .mode(mode) + .resolve(ResolveFlag::RESOLVE_IN_ROOT); + let dirfd = open(root, flags & !OFlag::O_CREAT, Mode::empty()) + .with_context(|| OzonecErr::OpenFile(root.to_string_lossy().to_string()))?; + let fd = target + // SAFETY: FFI call with valid arguments. + .with_nix_path(|p| unsafe { + libc::syscall( + libc::SYS_openat2, + dirfd.as_raw_fd(), + p.as_ptr(), + &mut open_how as *mut OpenHow, + mem::size_of::(), + ) + }) + .with_context(|| "with_nix_path error")?; + if fd < 0 { + bail!( + "openat2 {} error with RESOLVE_IN_ROOT: {}", + target.display(), + errno() + ); + } + Ok(RawFd::try_from(fd)?) +} + +/// Build path "/proc/self/fd/{}" with an opened file descriptor. +pub fn proc_fd_path(dirfd: RawFd) -> PathBuf { + PathBuf::from(format!("/proc/self/fd/{}", dirfd)) +} diff --git a/ozonec/src/utils/prctl.rs b/ozonec/src/utils/prctl.rs new file mode 100644 index 0000000000000000000000000000000000000000..5bc05441f7a484fa40589d3ce303d71c526fe484 --- /dev/null +++ b/ozonec/src/utils/prctl.rs @@ -0,0 +1,94 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::ffi::CString; + +use anyhow::{bail, Result}; +use libc::{c_int, c_ulong, prctl}; +use nix::errno::errno; + +#[allow(non_camel_case_types)] +enum PrctlOption { + PR_SET_DUMPABLE = 4, + PR_SET_KEEPCAPS = 8, + PR_SET_NAME = 15, + PR_SET_NO_NEW_PRIVS = 38, +} + +pub fn set_dumpable(dumpable: bool) -> Result<()> { + // SAFETY: FFI call with valid arguments. + let ret = unsafe { + prctl( + PrctlOption::PR_SET_DUMPABLE as c_int, + dumpable as c_ulong, + 0, + 0, + 0, + ) + }; + if ret != 0 { + bail!("errno {}", errno()); + } + Ok(()) +} + +pub fn set_keep_capabilities(keep_capabilities: bool) -> Result<()> { + // SAFETY: FFI call with valid arguments. + let ret = unsafe { + prctl( + PrctlOption::PR_SET_KEEPCAPS as c_int, + keep_capabilities as c_ulong, + 0, + 0, + 0, + ) + }; + if ret != 0 { + bail!("errno {}", errno()); + } + Ok(()) +} + +pub fn set_no_new_privileges(new_privileges: bool) -> Result<()> { + // SAFETY: FFI call with valid arguments. + let ret = unsafe { + prctl( + PrctlOption::PR_SET_NO_NEW_PRIVS as c_int, + new_privileges as c_ulong, + 0, + 0, + 0, + ) + }; + if ret != 0 { + bail!("errno {}", errno()); + } + Ok(()) +} + +pub fn set_name(name: &str) -> Result<()> { + let binding = CString::new(name).unwrap(); + // SAFETY: FFI call with valid arguments. + let ret = unsafe { + prctl( + PrctlOption::PR_SET_NAME as c_int, + binding.as_ptr() as c_ulong, + 0, + 0, + 0, + ) + }; + if ret != 0 { + bail!("errno {}", errno()); + } + Ok(()) +} diff --git a/ozonec/tests/README.md b/ozonec/tests/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c9d98cfee1d203149a6a9ca32aad1d26181bfb6c --- /dev/null +++ b/ozonec/tests/README.md @@ -0,0 +1,30 @@ +# Integration Tests + +ozonec uses [bats (Bash Automated Testing System)](https://github.com/bats-core/bats-core) framework to run +integration tests written in *bash*. + +## Before running tests + +Install [bats (Bash Automated Testing System)](https://github.com/bats-core/bats-core#installing-bats-from-source) from source: +``` +$ git clone https://github.com/bats-core/bats-core.git +$ cd bats-core +$ ./install.sh /usr/local +``` + +*bundle* directory which includes *config.json* and *rootfs* directory may be required to archived to bundle.tar.gz under the directory the test script belongs to. And *jq* may also be needed to modify json file in tests. + +## Running tests + +You can run tests using bats directly. For example: +``` +bats ./ +``` +Or you can just run a single test file. For example: +``` +bats create.bats +``` + +## Writing tests + +Please refer to [bats (Writing tests)](https://bats-core.readthedocs.io/en/stable/writing-tests.html). \ No newline at end of file diff --git a/ozonec/tests/config.json b/ozonec/tests/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4698e5e5c6465d4ab995439f21f969a40b4640d0 --- /dev/null +++ b/ozonec/tests/config.json @@ -0,0 +1,170 @@ +{ + "ociVersion": "1.0.2-dev", + "process": { + "user": { + "uid": 0, + "gid": 0 + }, + "args": [ + "sleep", + "3600" + ], + "env": [ + "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", + "TERM=xterm" + ], + "cwd": "/", + "capabilities": { + "bounding": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "effective": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ], + "permitted": [ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE" + ] + }, + "rlimits": [ + { + "type": "RLIMIT_NOFILE", + "hard": 1024, + "soft": 1024 + } + ], + "noNewPrivileges": true + }, + "root": { + "path": "rootfs", + "readonly": true + }, + "hostname": "runc", + "mounts": [ + { + "destination": "/proc", + "type": "proc", + "source": "proc" + }, + { + "destination": "/dev", + "type": "tmpfs", + "source": "tmpfs", + "options": [ + "nosuid", + "strictatime", + "mode=755", + "size=65536k" + ] + }, + { + "destination": "/dev/pts", + "type": "devpts", + "source": "devpts", + "options": [ + "nosuid", + "noexec", + "newinstance", + "ptmxmode=0666", + "mode=0620", + "gid=5" + ] + }, + { + "destination": "/dev/shm", + "type": "tmpfs", + "source": "shm", + "options": [ + "nosuid", + "noexec", + "nodev", + "mode=1777", + "size=65536k" + ] + }, + { + "destination": "/dev/mqueue", + "type": "mqueue", + "source": "mqueue", + "options": [ + "nosuid", + "noexec", + "nodev" + ] + }, + { + "destination": "/sys", + "type": "sysfs", + "source": "sysfs", + "options": [ + "nosuid", + "noexec", + "nodev", + "ro" + ] + }, + { + "destination": "/sys/fs/cgroup", + "type": "cgroup", + "source": "cgroup", + "options": [ + "nosuid", + "noexec", + "nodev", + "relatime", + "ro" + ] + } + ], + "linux": { + "resources": { + "devices": [ + { + "allow": false, + "access": "rwm" + } + ] + }, + "namespaces": [ + { + "type": "pid" + }, + { + "type": "network" + }, + { + "type": "ipc" + }, + { + "type": "uts" + }, + { + "type": "mount" + } + ], + "maskedPaths": [ + "/proc/acpi", + "/proc/asound", + "/proc/kcore", + "/proc/keys", + "/proc/latency_stats", + "/proc/timer_list", + "/proc/timer_stats", + "/proc/sched_debug", + "/sys/firmware", + "/proc/scsi" + ], + "readonlyPaths": [ + "/proc/bus", + "/proc/fs", + "/proc/irq", + "/proc/sys", + "/proc/sysrq-trigger" + ] + } +} diff --git a/ozonec/tests/console.bats b/ozonec/tests/console.bats new file mode 100644 index 0000000000000000000000000000000000000000..480f310a34dc7845ea7baa02e553bb6f25eea3ec --- /dev/null +++ b/ozonec/tests/console.bats @@ -0,0 +1,44 @@ +#! /usr/bin/env bats + +load helpers + +setup_file() +{ + setup_bundle +} + +setup() +{ + CONTAINER_ID=$(uuidgen) + setup_pty_server +} + +teardown() +{ + ozonec kill "$CONTAINER_ID" + ozonec delete "$CONTAINER_ID" + killall -9 pty-server + rm -f $TEST_DIR/console* +} + +@test "ozonec create with console socket" { + update_config '.process.terminal = true' + update_config '.process.args = ["ls", "-alh"]' + ozonec create --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" 3>&- + + run ozonec start "$CONTAINER_ID" + [[ $status -eq 0 ]] + run sed -n '1p' $TEST_DIR/console.log + [[ ${lines[0]} == *"total "* ]] +} + +@test "ozonec exec with console socket" { + update_config '.process.terminal = false' + update_config '.process.args = ["sleep", "3600"]' + ozonec create "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + + run ozonec exec -t --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" -- ls -alh + run sed -n '1p' $TEST_DIR/console.log + [[ ${lines[0]} == *"total "* ]] +} \ No newline at end of file diff --git a/ozonec/tests/create.bats b/ozonec/tests/create.bats new file mode 100644 index 0000000000000000000000000000000000000000..15b06b7a657dcea368b9c4dd683258c7350563ae --- /dev/null +++ b/ozonec/tests/create.bats @@ -0,0 +1,82 @@ +#! /usr/bin/env bats + +load helpers + +setup_file() +{ + setup_bundle +} + +setup() +{ + CONTAINER_ID=$(uuidgen) + ROOT_DIR="$DEFAULT_ROOT_DIR" +} + +teardown() +{ + if [ "$ROOT_DIR" == "$DEFAULT_ROOT_DIR" ]; then + ozonec kill "$CONTAINER_ID" 9 + ozonec delete "$CONTAINER_ID" + else + ozonec --root "$ROOT_DIR" kill "$CONTAINER_ID" 9 + ozonec --root "$ROOT_DIR" delete "$CONTAINER_ID" + fi +} + +@test "ozonec create" { + ozonec create "$CONTAINER_ID" 3>&- + check_container_status "$CONTAINER_ID" created "" + [ -d "$ROOT_DIR/$CONTAINER_ID" ] + [ -S "$ROOT_DIR/$CONTAINER_ID/notify.sock" ] + [ -f "$ROOT_DIR/$CONTAINER_ID/state.json" ] +} + +@test "ozonec create with absolute path of rootfs" { + local rootfs_dir="$(pwd)/rootfs" + update_config '.root.path = "'$rootfs_dir'"' + ozonec create "$CONTAINER_ID" 3>&- + check_container_status "$CONTAINER_ID" created "" +} + +@test "ozonec create with pidfile" { + ozonec create --pid-file ./pidfile "$CONTAINER_ID" 3>&- + local pid=$(cat ./pidfile) + check_container_status "$CONTAINER_ID" created "" "$pid" +} + +@test "ozonec create with duplicate id" { + ozonec create "$CONTAINER_ID" 3>&- + check_container_status "$CONTAINER_ID" created "" + ! ozonec create "$CONTAINER_ID" 3>&- +} + +@test "ozonec create with absolute bundle path" { + local bundle_dir="$(dirname `pwd`)/bundle" + ozonec create --bundle "$bundle_dir" "$CONTAINER_ID" 3>&- + check_container_status "$CONTAINER_ID" created "" +} + +@test "ozonec create with relative bundle path" { + local bundle_dir="../bundle" + ozonec create --bundle "$bundle_dir" "$CONTAINER_ID" 3>&- + check_container_status "$CONTAINER_ID" created "" +} + +@test "ozonec create with absolute root path" { + ROOT_DIR="$(dirname `pwd`)/root" + ozonec --root "$ROOT_DIR" create "$CONTAINER_ID" 3>&- + check_container_status "$CONTAINER_ID" created "$ROOT_DIR" + [ -d "$ROOT_DIR/$CONTAINER_ID" ] + [ -S "$ROOT_DIR/$CONTAINER_ID/notify.sock" ] + [ -f "$ROOT_DIR/$CONTAINER_ID/state.json" ] +} + +@test "ozonec create with relative root path" { + ROOT_DIR="../root" + ozonec --root "$ROOT_DIR" create "$CONTAINER_ID" 3>&- + check_container_status "$CONTAINER_ID" created "$ROOT_DIR" + [ -d "$ROOT_DIR/$CONTAINER_ID" ] + [ -S "$ROOT_DIR/$CONTAINER_ID/notify.sock" ] + [ -f "$ROOT_DIR/$CONTAINER_ID/state.json" ] +} \ No newline at end of file diff --git a/ozonec/tests/exec.bats b/ozonec/tests/exec.bats new file mode 100644 index 0000000000000000000000000000000000000000..418c07a167479cb097953151dc3eed5f621c0616 --- /dev/null +++ b/ozonec/tests/exec.bats @@ -0,0 +1,32 @@ +#! /usr/bin/env bats + +load helpers + +setup_file() +{ + setup_bundle + + export ROOT_DIR="$TEST_DIR/root" + export CONTAINER_ID=$(uuidgen) + + ozonec --root "$ROOT_DIR" create "$CONTAINER_ID" 3>&- + check_container_status "$CONTAINER_ID" created "$ROOT_DIR" + ozonec --root "$ROOT_DIR" start "$CONTAINER_ID" + check_container_status "$CONTAINER_ID" running "$ROOT_DIR" +} + +teardown_file() +{ + ozonec --root "$ROOT_DIR" kill "$CONTAINER_ID" 9 + ozonec --root "$ROOT_DIR" delete "$CONTAINER_ID" +} + +@test "ozonec exec" { + ozonec --root "$ROOT_DIR" exec "$CONTAINER_ID" -- ls -alh +} + +@test "ozonec exec with pidfile" { + ozonec --root "$ROOT_DIR" exec --pid-file pidfile "$CONTAINER_ID" -- ls -alh + local pid=$(cat pidfile) + [[ "$pid" -gt 0 ]] +} \ No newline at end of file diff --git a/ozonec/tests/helpers.bash b/ozonec/tests/helpers.bash new file mode 100644 index 0000000000000000000000000000000000000000..b37f098d8643afcc09631cdff0261d175d62d4bd --- /dev/null +++ b/ozonec/tests/helpers.bash @@ -0,0 +1,50 @@ +#! /bin/bash + +bats_require_minimum_version 1.5.0 + +DEFAULT_ROOT_DIR="/var/run/user/$(echo $UID)/ozonec" + +# Reformat config.json file with jq command. +function update_config() +{ + jq "$@" config.json | awk 'BEGIN{RS="";getline<"-";print>ARGV[1]}' config.json +} + +function setup_bundle() +{ + # Directory for each container. + export TEST_DIR=$(mktemp -d "$BATS_RUN_TMPDIR/ozonec.XXXXXX") + chmod a+x "$TEST_DIR" "$BATS_RUN_TMPDIR" + + local bundle="$BATS_TEST_DIRNAME/bundle.tar.gz" + tar --exclude 'rootfs/dev/*' -C "$TEST_DIR" -xf "$bundle" + cd "$TEST_DIR/bundle" +} + +function setup_pty_server() +{ + export CONSOLE_PATH="$TEST_DIR/console.sock" + rm -f $CONSOLE_PATH + # Fork twice to avoid sending SIGTTIN/SIGTTOU to pty-server. + # --no-stdin option is set to avoid SIGHUP sent to ozonec. + (pty-server --no-stdin $CONSOLE_PATH > $TEST_DIR/console.log &) & +} + +function check_container_status() { + local container_id="$1" + local state="$2" + local root="$3" + + if [ "$root" == "" ]; then + run ozonec state "$container_id" + else + run ozonec --root "$root" state "$container_id" + fi + [[ $status -eq 0 ]] + [[ "$output" == *"\"status\": \"$state\""* ]] + + if [ $# -gt 3 ]; then + local pid="$4" + [[ "$(expr match "$output" '.*"pid": \([0-9]*\).*')" == "$pid" ]] + fi +} \ No newline at end of file diff --git a/ozonec/tests/mount.bats b/ozonec/tests/mount.bats new file mode 100644 index 0000000000000000000000000000000000000000..ab563541d31bfe6d71500938229b69584d0e1e90 --- /dev/null +++ b/ozonec/tests/mount.bats @@ -0,0 +1,107 @@ +#! /usr/bin/env bats + +load helpers + +setup_file() +{ + setup_bundle +} + +setup() +{ + CONTAINER_ID=$(uuidgen) + setup_pty_server +} + +teardown() +{ + ozonec kill "$CONTAINER_ID" 9 + ozonec delete "$CONTAINER_ID" + killall -9 pty-server + rm -f $TEST_DIR/console* +} + +@test "ozonec with bind mount" { + update_config '.mounts += [{ + source: ".", + destination: "/tmp/rbind", + options: ["rbind"] + }]' + + ozonec create "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + ozonec exec -t --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" -- ls /tmp/rbind/config.json + run sed -n '1p' $TEST_DIR/console.log + [[ ${lines[0]} == *"/tmp/rbind/config.json"* ]] +} + +@test "ozonec mount /proc" { + ozonec create "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + ozonec exec -t --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" -- grep "^proc /proc proc " /proc/mounts + run sed -n '1p' $TEST_DIR/console.log + [[ ${lines[0]} == *"proc /proc proc "* ]] +} + +@test "ozonec mount /sys" { + ozonec create "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + ozonec exec -t --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" -- grep "^sysfs /sys sysfs " /proc/mounts + run sed -n '1p' $TEST_DIR/console.log + [[ ${lines[0]} == *"sysfs /sys "*"ro"*"nosuid"*"nodev"*"noexec"* ]] +} + +@test "ozonec mount /dev/pts" { + ozonec create "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + ozonec exec -t --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" -- grep "^devpts /dev/pts devpts " /proc/mounts + run sed -n '1p' $TEST_DIR/console.log + [[ ${lines[0]} == *"devpts /dev/pts devpts "*"nosuid"*"noexec"*"gid=5"*"mode=620"*"ptmxmode=666"* ]] +} + +@test "ozonec mount /dev/shm" { + ozonec create "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + ozonec exec -t --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" -- grep "^shm /dev/shm tmpfs" /proc/mounts + run sed -n '1p' $TEST_DIR/console.log + [[ ${lines[0]} == *"/dev/shm tmpfs "*"nosuid"*"nodev"*"noexec"*"size=65536k"* ]] +} + +@test "ozonec default devices" { + ozonec create "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + + ozonec exec -t --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" -- ls -al /dev/ + run grep -w "null" $TEST_DIR/console.log + [[ $status -eq 0 ]] + run grep -w "zero" $TEST_DIR/console.log + [[ $status -eq 0 ]] + run grep -w "full" $TEST_DIR/console.log + [[ $status -eq 0 ]] + run grep -w "random" $TEST_DIR/console.log + [[ $status -eq 0 ]] + run grep -w "urandom" $TEST_DIR/console.log + [[ $status -eq 0 ]] + run grep -w "tty" $TEST_DIR/console.log + [[ $status -eq 0 ]] + run grep -w "console" $TEST_DIR/console.log + [[ $status -ne 0 ]] + run grep -w "console" $TEST_DIR/console.log + [[ $status -ne 0 ]] + + rm -f $TEST_DIR/console.log + ozonec exec -t --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" -- ls -al /dev/pts/ptmx + run grep -w "/dev/pts/ptmx" $TEST_DIR/console.log + [[ $status -ne 0 ]] +} + +@test "ozonec create /dev/console" { + update_config '.process.terminal = true' + update_config '.process.args = ["ls", "/dev/console"]' + ozonec create --console-socket "$CONSOLE_PATH" "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + + cat $TEST_DIR/console.log + run sed -n '1p' $TEST_DIR/console.log + [[ ${lines[0]} == *"/dev/console"* ]] +} \ No newline at end of file diff --git a/ozonec/tests/namespace.bats b/ozonec/tests/namespace.bats new file mode 100644 index 0000000000000000000000000000000000000000..e4e22b4db1825075fdc802d6e51967f14768587e --- /dev/null +++ b/ozonec/tests/namespace.bats @@ -0,0 +1,51 @@ +#! /usr/bin/env bats + +load helpers + +setup_file() +{ + setup_bundle +} + +setup() +{ + CONTAINER_ID=$(uuidgen) +} + +teardown() +{ + ozonec kill "$CONTAINER_ID" 9 + ozonec delete "$CONTAINER_ID" +} + +@test "ozonec create new namespace" { + ozonec create -p $TEST_DIR/ozonec.pid "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + + local self_mnt_ns=$(readlink /proc/self/ns/mnt) + local container_pid=$(cat $TEST_DIR/ozonec.pid) + local container_mnt_ns=$(readlink /proc/$container_pid/ns/mnt) + [[ "$self_mnt_ns" != "$container_mnt_ns" ]] +} + +@test "ozonec join existed namespace" { + ozonec create -p $TEST_DIR/fst.pid "$CONTAINER_ID" 3>&- + ozonec start "$CONTAINER_ID" + local fst_container_pid=$(cat $TEST_DIR/fst.pid) + local fst_pid_ns=$(readlink /proc/$fst_container_pid/ns/pid) + + update_config '.linux.namespaces |= [{ + type: "pid", + path: "'/proc/$fst_container_pid/ns/pid'" + }, { + type: "mount" + }]' + local sec_container_id=$(uuidgen) + ozonec create -p $TEST_DIR/second.pid "$sec_container_id" 3>&- + ozonec start "$sec_container_id" + local sec_container_pid=$(cat $TEST_DIR/second.pid) + local sec_pid_ns=$(readlink /proc/$sec_container_pid/ns/pid) + ozonec kill "$sec_container_id" 9 + ozonec delete "$sec_container_id" + [[ "$fst_pid_ns" == "$sec_pid_ns" ]] +} \ No newline at end of file diff --git a/ozonec/tests/tools/pty-server/Cargo.lock b/ozonec/tests/tools/pty-server/Cargo.lock new file mode 100644 index 0000000000000000000000000000000000000000..68cf35c39cffa3711b977cc58ba0280557a89b00 --- /dev/null +++ b/ozonec/tests/tools/pty-server/Cargo.lock @@ -0,0 +1,194 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "anyhow" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "clap" +version = "4.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f13b9c79b5d1dd500d20ef541215a6423c75829ef43117e1b4d17fd8af0b5d76" +dependencies = [ + "bitflags", + "clap_derive", + "clap_lex", + "once_cell", +] + +[[package]] +name = "clap_derive" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "684a277d672e91966334af371f1a7b5833f9aa00b07c84e92fbce95e00208ce8" +dependencies = [ + "heck", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "033f6b7a4acb1f358c742aaca805c939ee73b4c6209ae4318ec7aca81c42e646" +dependencies = [ + "os_str_bytes", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "libc" +version = "0.2.159" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" + +[[package]] +name = "memoffset" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +dependencies = [ + "autocfg", +] + +[[package]] +name = "nix" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "memoffset", + "pin-utils", + "static_assertions", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "os_str_bytes" +version = "6.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro2" +version = "1.0.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pty-server" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "nix", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" diff --git a/ozonec/tests/tools/pty-server/Cargo.toml b/ozonec/tests/tools/pty-server/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..93889af0cf1c07f64935aa0049b2cf016f31edf8 --- /dev/null +++ b/ozonec/tests/tools/pty-server/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "pty-server" +version = "0.1.0" +authors = ["Huawei StratoVirt Team"] +edition = "2021" +license = "Mulan PSL v2" +description = "A reference implementation of a consumer of ozonec's --console-socket API." + +[dependencies] +anyhow = "= 1.0.71" +clap = { version = "= 4.1.4", default-features = false, features = ["derive", "cargo", "std", "help", "usage"] } +nix = "= 0.26.2" + +[workspace] + +[profile.dev] +panic = "unwind" + +[profile.release] +lto = true +strip = true +opt-level = 'z' +codegen-units = 1 +panic = "abort" diff --git a/ozonec/tests/tools/pty-server/src/main.rs b/ozonec/tests/tools/pty-server/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..533dfb46296024321ccb98f67f62e45399bef825 --- /dev/null +++ b/ozonec/tests/tools/pty-server/src/main.rs @@ -0,0 +1,123 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::{ + fs::File, + io::{self, stdin, stdout, IoSliceMut}, + os::{ + fd::{AsRawFd, FromRawFd, RawFd}, + unix::net::{UnixListener, UnixStream}, + }, + process::exit, + thread, +}; + +use anyhow::{anyhow, bail, Context, Result}; +use clap::{builder::NonEmptyStringValueParser, crate_description, Parser}; +use nix::{ + cmsg_space, + errno::errno, + sys::{ + socket::{recvmsg, ControlMessageOwned, MsgFlags, UnixAddr}, + termios::{tcgetattr, tcsetattr, OutputFlags, SetArg}, + }, +}; + +#[derive(Parser, Debug)] +#[command(version, author, about = crate_description!())] +struct Cli { + #[arg(short, long)] + pub no_stdin: bool, + // Specify path of console socket to connect. + #[arg(value_parser = NonEmptyStringValueParser::new(), required = true)] + pub console_socket: String, +} + +fn clear_onlcr(fd: RawFd) -> Result<()> { + let mut termios = + tcgetattr(fd).with_context(|| anyhow!("tcgetattr error: errno {}, fd: {}", errno(), fd))?; + termios.output_flags &= !OutputFlags::ONLCR; + tcsetattr(fd, SetArg::TCSANOW, &termios) + .with_context(|| anyhow!("tcsetattr error: errno {}", errno()))?; + Ok(()) +} + +fn handle_connection(stream: &UnixStream, no_stdin: bool) -> Result<()> { + let mut msg_iov = Vec::with_capacity(10); + let mut iov = [IoSliceMut::new(msg_iov.as_mut_slice())]; + let mut cmsg_buffer = cmsg_space!([RawFd; 1]); + let mut master: RawFd = -1; + + let ret = recvmsg::( + stream.as_raw_fd(), + &mut iov, + Some(&mut cmsg_buffer), + MsgFlags::empty(), + ) + .with_context(|| "recvmsg error")?; + for ctl_msg in ret.cmsgs() { + match ctl_msg { + ControlMessageOwned::ScmRights(fds) => master = fds[0], + _ => (), + } + } + + clear_onlcr(master)?; + let output = thread::spawn(move || { + let mut us = unsafe { File::from_raw_fd(master) }; + io::copy(&mut us, &mut stdout()) + }); + if !no_stdin { + let input = thread::spawn(move || { + let mut us = unsafe { File::from_raw_fd(master) }; + io::copy(&mut stdin(), &mut us) + }); + if let Err(e) = input.join().expect("Input thread has exited.") { + eprintln!("Input thread error: {}", e); + } + } + if let Err(e) = output.join().expect("Output thread has exited.") { + eprintln!("Output thread error: {}", e); + } + + Ok(()) +} + +fn listen_on_socket(listener: &UnixListener, no_stdin: bool) -> Result<()> { + for stream in listener.incoming() { + match stream { + Ok(s) => handle_connection(&s, no_stdin)?, + Err(e) => bail!("Failed to accept incoming connection: {}", e), + } + } + Ok(()) +} + +fn real_main() -> Result<()> { + let cli = Cli::parse(); + + let listener = + UnixListener::bind(&cli.console_socket).with_context(|| "Failed to bind to the socket")?; + listen_on_socket(&listener, cli.no_stdin)?; + + Ok(()) +} + +fn main() { + match real_main() { + Ok(_) => exit(0), + Err(e) => { + eprintln!("{}", e); + exit(1) + } + } +} diff --git a/src/main.rs b/src/main.rs index 89c4e23e46671413686f7735a51e25a47acc3ac4..de3fd8e69ca6d8d72fcbd4a2598871eba2f220ca 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,7 +18,7 @@ use anyhow::{bail, Context, Result}; use log::{error, info}; use thiserror::Error; -use machine::{LightMachine, MachineOps, StdMachine}; +use machine::{type_init, LightMachine, MachineOps, StdMachine}; use machine_manager::{ cmdline::{check_api_channel, create_args_parser, create_vmconfig}, config::MachineType, @@ -71,6 +71,8 @@ fn main() -> ExitCode { } fn run() -> Result<()> { + type_init()?; + let cmd_args = create_args_parser().get_matches()?; if cmd_args.is_present("mod-test") { @@ -101,7 +103,13 @@ fn run() -> Result<()> { exit_with_code(VM_EXIT_GENE_ERR); })); - let mut vm_config: VmConfig = create_vmconfig(&cmd_args)?; + let mut vm_config: VmConfig = match create_vmconfig(&cmd_args) { + Ok(vm_cfg) => vm_cfg, + Err(e) => { + error!("Failed to create vmconfig {:?}", e); + return Err(e); + } + }; info!("VmConfig is {:?}", vm_config); match real_main(&cmd_args, &mut vm_config) { @@ -109,18 +117,15 @@ fn run() -> Result<()> { info!("MainLoop over, Vm exit"); // clean temporary file TempCleaner::clean(); + EventLoop::loop_clean(); handle_signal(); } Err(ref e) => { set_termi_canon_mode().expect("Failed to set terminal to canonical mode."); - if cmd_args.is_present("display log") { - error!("{}", format!("{:?}\r\n", e)); - } else { - write!(&mut std::io::stderr(), "{}", format_args!("{:?}\r\n", e)) - .expect("Failed to write to stderr"); - } + error!("{}", format!("{:?}\r\n", e)); // clean temporary file TempCleaner::clean(); + EventLoop::loop_clean(); exit_with_code(VM_EXIT_GENE_ERR); } } @@ -160,7 +165,7 @@ fn real_main(cmd_args: &arg_parser::ArgMatches, vm_config: &mut VmConfig) -> Res LightMachine::new(vm_config).with_context(|| "Failed to init MicroVM")?, )); MachineOps::realize(&vm, vm_config).with_context(|| "Failed to realize micro VM.")?; - EventLoop::set_manager(vm.clone(), None); + EventLoop::set_manager(vm.clone()); for listener in listeners { sockets.push(Socket::from_listener(listener, Some(vm.clone()))); @@ -173,15 +178,13 @@ fn real_main(cmd_args: &arg_parser::ArgMatches, vm_config: &mut VmConfig) -> Res )); MachineOps::realize(&vm, vm_config) .with_context(|| "Failed to realize standard VM.")?; - EventLoop::set_manager(vm.clone(), None); + EventLoop::set_manager(vm.clone()); if is_test_enabled() { let sock_path = cmd_args.value_of("mod-test"); - let test_sock = Some(TestSock::new(sock_path.unwrap().as_str(), vm.clone())); + let test_sock = TestSock::new(sock_path.unwrap().as_str(), vm.clone()); EventLoop::update_event( - EventNotifierHelper::internal_notifiers(Arc::new(Mutex::new( - test_sock.unwrap(), - ))), + EventNotifierHelper::internal_notifiers(Arc::new(Mutex::new(test_sock))), None, ) .with_context(|| "Failed to add test socket to MainLoop")?; @@ -199,7 +202,7 @@ fn real_main(cmd_args: &arg_parser::ArgMatches, vm_config: &mut VmConfig) -> Res let vm = Arc::new(Mutex::new( StdMachine::new(vm_config).with_context(|| "Failed to init NoneVM")?, )); - EventLoop::set_manager(vm.clone(), None); + EventLoop::set_manager(vm.clone()); for listener in listeners { sockets.push(Socket::from_listener(listener, Some(vm.clone()))); @@ -208,7 +211,7 @@ fn real_main(cmd_args: &arg_parser::ArgMatches, vm_config: &mut VmConfig) -> Res } }; - let balloon_switch_on = vm_config.dev_name.get("balloon").is_some(); + let balloon_switch_on = vm_config.dev_name.contains_key("balloon"); if !cmd_args.is_present("disable-seccomp") { vm.lock() .unwrap() @@ -227,6 +230,5 @@ fn real_main(cmd_args: &arg_parser::ArgMatches, vm_config: &mut VmConfig) -> Res machine::vm_run(&vm, cmd_args).with_context(|| "Failed to start VM.")?; EventLoop::loop_run().with_context(|| "MainLoop exits unexpectedly: error occurs")?; - EventLoop::loop_clean(); Ok(()) } diff --git a/tests/mod_test/Cargo.toml b/tests/mod_test/Cargo.toml index 9ccebc0581651c59dc770d62a91f5996efe10fdd..9b144af1d1723862e1465ed46f7e0b67c7526537 100644 --- a/tests/mod_test/Cargo.toml +++ b/tests/mod_test/Cargo.toml @@ -8,7 +8,7 @@ license = "Mulan PSL v2" [dependencies] rand = "0.8.5" hex = "0.4.3" -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" anyhow = "1.0" serde_json = "1.0" libc = "0.2" diff --git a/tests/mod_test/src/libdriver/ivshmem.rs b/tests/mod_test/src/libdriver/ivshmem.rs index 03d76629598f6b7c79fc489e7ac622b121abafb0..edb5ef6206fd45c12a91747be854d191a35e9f69 100644 --- a/tests/mod_test/src/libdriver/ivshmem.rs +++ b/tests/mod_test/src/libdriver/ivshmem.rs @@ -19,16 +19,18 @@ use super::{ pub struct TestIvshmemDev { pub pci_dev: TestPciDev, - pub bar_addr: PCIBarAddr, - bar_idx: u8, + bar0_addr: PCIBarAddr, + bar1_addr: PCIBarAddr, + pub bar2_addr: PCIBarAddr, } impl TestIvshmemDev { pub fn new(pci_bus: Rc>) -> Self { Self { pci_dev: TestPciDev::new(pci_bus), - bar_addr: 0, - bar_idx: 2, + bar0_addr: 0, + bar1_addr: 0, + bar2_addr: 0, } } @@ -37,30 +39,40 @@ impl TestIvshmemDev { assert!(self.pci_dev.find_pci_device(devfn)); self.pci_dev.enable(); - self.bar_addr = self.pci_dev.io_map(self.bar_idx); + self.bar0_addr = self.pci_dev.io_map(0); + self.bar1_addr = self.pci_dev.io_map(1); + self.bar2_addr = self.pci_dev.io_map(2); } pub fn writeb(&mut self, offset: u64, value: u8) { - self.pci_dev.io_writeb(self.bar_addr, offset, value); + self.pci_dev.io_writeb(self.bar2_addr, offset, value); } pub fn writew(&mut self, offset: u64, value: u16) { - self.pci_dev.io_writew(self.bar_addr, offset, value); + self.pci_dev.io_writew(self.bar2_addr, offset, value); } pub fn writel(&mut self, offset: u64, value: u32) { - self.pci_dev.io_writel(self.bar_addr, offset, value); + self.pci_dev.io_writel(self.bar2_addr, offset, value); } pub fn writeq(&mut self, offset: u64, value: u64) { - self.pci_dev.io_writeq(self.bar_addr, offset, value); + self.pci_dev.io_writeq(self.bar2_addr, offset, value); } pub fn readw(&self, offset: u64) -> u16 { - self.pci_dev.io_readw(self.bar_addr, offset) + self.pci_dev.io_readw(self.bar2_addr, offset) } pub fn readl(&self, offset: u64) -> u32 { - self.pci_dev.io_readl(self.bar_addr, offset) + self.pci_dev.io_readl(self.bar2_addr, offset) + } + + pub fn writel_reg(&self, offset: u64, value: u32) { + self.pci_dev.io_writel(self.bar0_addr, offset, value); + } + + pub fn readl_reg(&self, offset: u64) -> u32 { + self.pci_dev.io_readl(self.bar0_addr, offset) } } diff --git a/tests/mod_test/src/libdriver/pci.rs b/tests/mod_test/src/libdriver/pci.rs index e78f7d67ab66f0bea0c9d3cfd4480f46515a1b87..1fc0ad10cb4b2f25aff47fe4df7ad5118353b0b8 100644 --- a/tests/mod_test/src/libdriver/pci.rs +++ b/tests/mod_test/src/libdriver/pci.rs @@ -133,13 +133,13 @@ impl TestPciDev { pub fn enable(&self) { let mut cmd = self.config_readw(PCI_COMMAND); - cmd |= (PCI_COMMAND_IO | PCI_COMMAND_MEMORY | PCI_COMMAND_MASTER) as u16; + cmd |= u16::from(PCI_COMMAND_IO | PCI_COMMAND_MEMORY | PCI_COMMAND_MASTER); self.config_writew(PCI_COMMAND, cmd); cmd = self.config_readw(PCI_COMMAND); - assert!(cmd & PCI_COMMAND_IO as u16 == PCI_COMMAND_IO as u16); - assert!(cmd & PCI_COMMAND_MEMORY as u16 == PCI_COMMAND_MEMORY as u16); - assert!(cmd & PCI_COMMAND_MASTER as u16 == PCI_COMMAND_MASTER as u16); + assert!(cmd & u16::from(PCI_COMMAND_IO) == u16::from(PCI_COMMAND_IO)); + assert!(cmd & u16::from(PCI_COMMAND_MEMORY) == u16::from(PCI_COMMAND_MEMORY)); + assert!(cmd & u16::from(PCI_COMMAND_MASTER) == u16::from(PCI_COMMAND_MASTER)); } pub fn find_capability(&self, id: u8, start_addr: u8) -> u8 { @@ -200,7 +200,7 @@ impl TestPciDev { } else { self.io_map(bar_table as u8) }; - self.msix_table_off = (table & !PCI_MSIX_TABLE_BIR) as u64; + self.msix_table_off = u64::from(table & !PCI_MSIX_TABLE_BIR); let table = self.config_readl(addr + PCI_MSIX_PBA); let bar_pba = table & PCI_MSIX_TABLE_BIR; @@ -209,7 +209,7 @@ impl TestPciDev { } else { self.msix_pba_bar = self.msix_table_bar; } - self.msix_pba_off = (table & !PCI_MSIX_TABLE_BIR) as u64; + self.msix_pba_off = u64::from(table & !PCI_MSIX_TABLE_BIR); self.msix_enabled = true; } @@ -334,21 +334,16 @@ impl TestPciDev { } pub fn io_map(&self, barnum: u8) -> u64 { - let addr: u32; - let size: u64; - let location: u64; - let bar_addr: PCIBarAddr; - assert!(barnum <= 5); let bar_offset: u8 = BAR_MAP[barnum as usize]; self.config_writel(bar_offset, 0xFFFFFFFF); - addr = self.config_readl(bar_offset) & !(0x0F_u32); + let addr: u32 = self.config_readl(bar_offset) & !(0x0F_u32); assert!(addr != 0); let mut pci_bus = self.pci_bus.borrow_mut(); - size = 1 << addr.trailing_zeros(); - location = (pci_bus.mmio_alloc_ptr + size - 1) / size * size; + let size: u64 = 1 << addr.trailing_zeros(); + let location: u64 = (pci_bus.mmio_alloc_ptr + size - 1) / size * size; if location < pci_bus.mmio_alloc_ptr || location + size > pci_bus.mmio_limit { return INVALID_BAR_ADDR; } @@ -356,7 +351,7 @@ impl TestPciDev { pci_bus.mmio_alloc_ptr = location + size; drop(pci_bus); self.config_writel(bar_offset, location as u32); - bar_addr = location; + let bar_addr: PCIBarAddr = location; bar_addr } @@ -413,7 +408,7 @@ impl TestPciDev { impl PciMsixOps for TestPciDev { fn set_msix_vector(&self, msix_entry: u16, msix_addr: u64, msix_data: u32) { assert!(self.msix_enabled); - let offset = self.msix_table_off + (msix_entry * 16) as u64; + let offset = self.msix_table_off + u64::from(msix_entry * 16); let msix_table_bar = self.msix_table_bar; self.io_writel( diff --git a/tests/mod_test/src/libdriver/pci_bus.rs b/tests/mod_test/src/libdriver/pci_bus.rs index 1a146ebf29717959d8f265ac7ca75d8299196f21..db0889db6f304af5eb26b8fdf5c3c347c4dbb991 100644 --- a/tests/mod_test/src/libdriver/pci_bus.rs +++ b/tests/mod_test/src/libdriver/pci_bus.rs @@ -64,7 +64,8 @@ impl TestPciBus { } fn get_addr(&self, bus_num: u8, devfn: u8, offset: u8) -> u64 { - self.ecam_alloc_ptr + ((bus_num as u32) << 20 | (devfn as u32) << 12 | offset as u32) as u64 + self.ecam_alloc_ptr + + u64::from(u32::from(bus_num) << 20 | u32::from(devfn) << 12 | u32::from(offset)) } pub fn pci_auto_bus_scan(&self, root_port_num: u8) { @@ -106,11 +107,13 @@ impl TestPciBus { impl PciBusOps for TestPciBus { fn memread(&self, addr: u32, len: usize) -> Vec { - self.test_state.borrow().memread(addr as u64, len as u64) + self.test_state + .borrow() + .memread(u64::from(addr), len as u64) } fn memwrite(&self, addr: u32, buf: &[u8]) { - self.test_state.borrow().memwrite(addr as u64, buf); + self.test_state.borrow().memwrite(u64::from(addr), buf); } fn config_readb(&self, bus_num: u8, devfn: u8, offset: u8) -> u8 { diff --git a/tests/mod_test/src/libdriver/qcow2.rs b/tests/mod_test/src/libdriver/qcow2.rs index ca07a6093f4cf0cf3fcbc77bc5d9fe5269d7c5a8..fb6ba7685358ef4f5e3c4d0259332bb1c41eead6 100644 --- a/tests/mod_test/src/libdriver/qcow2.rs +++ b/tests/mod_test/src/libdriver/qcow2.rs @@ -66,21 +66,21 @@ impl Qcow2Driver { fn raw_read(&self, offset: u64, buf: &mut [u8]) -> i64 { let ptr = buf.as_mut_ptr() as u64; let cnt = buf.len() as u64; - let iovec = vec![Iovec::new(ptr, cnt)]; - let ret = unsafe { + let iovec = [Iovec::new(ptr, cnt)]; + + unsafe { preadv( self.file.as_raw_fd() as c_int, iovec.as_ptr() as *const iovec, iovec.len() as c_int, offset as off_t, ) as i64 - }; - ret + } } fn raw_write(&mut self, offset: u64, buf: &mut [u8]) { self.file.seek(SeekFrom::Start(offset)).unwrap(); - self.file.write_all(&buf).unwrap(); + self.file.write_all(buf).unwrap(); } } @@ -192,12 +192,7 @@ impl QcowHeader { // From size to bits. fn size_to_bits(size: u64) -> Option { - for i in 0..63 { - if size >> i == 1 { - return Some(i); - } - } - return None; + (0..63).find(|&i| size >> i == 1) } /// Create a qcow2 format image for test. @@ -236,15 +231,15 @@ pub fn create_qcow2_img(image_path: String, image_size: u64) { .custom_flags(libc::O_CREAT | libc::O_TRUNC) .open(image_path.clone()) .unwrap(); - file.set_len(cluster_sz * 3 + header.l1_size as u64 * ENTRY_SIZE) + file.set_len(cluster_sz * 3 + u64::from(header.l1_size) * ENTRY_SIZE) .unwrap(); file.write_all(&header.to_vec()).unwrap(); // Cluster 1 is the refcount table. - assert_eq!(header.refcount_table_offset, cluster_sz * 1); + assert_eq!(header.refcount_table_offset, cluster_sz); let mut refcount_table = [0_u8; ENTRY_SIZE as usize]; BigEndian::write_u64(&mut refcount_table, cluster_sz * 2); - file.seek(SeekFrom::Start(cluster_sz * 1)).unwrap(); + file.seek(SeekFrom::Start(cluster_sz)).unwrap(); file.write_all(&refcount_table).unwrap(); // Clusters which has been allocated. @@ -281,7 +276,7 @@ fn write_full_disk(image_path: String) { // Write l2 table. let mut refcount_block: Vec = Vec::new(); let mut l1_table = [0_u8; ENTRY_SIZE as usize]; - BigEndian::write_u64(&mut l1_table, cluster_size * 4 | QCOW2_OFFSET_COPIED); + BigEndian::write_u64(&mut l1_table, (cluster_size * 4) | QCOW2_OFFSET_COPIED); let mut l2_table: Vec = Vec::new(); for _ in 0..5 { refcount_block.push(0x00); @@ -318,7 +313,7 @@ pub fn delete_snapshot(state: Rc>, device: &str, snap: &str) pub fn query_snapshot(state: Rc>) -> Value { let qmp_str = - format!("{{\"execute\":\"human-monitor-command\",\"arguments\":{{\"command-line\":\"info snapshots\"}}}}"); + "{\"execute\":\"human-monitor-command\",\"arguments\":{\"command-line\":\"info snapshots\"}}".to_string(); let value = state.borrow_mut().qmp(&qmp_str); value @@ -326,7 +321,7 @@ pub fn query_snapshot(state: Rc>) -> Value { // Check if there exists snapshot with the specified name. pub fn check_snapshot(state: Rc>, snap: &str) -> bool { - let value = query_snapshot(state.clone()); + let value = query_snapshot(state); let str = (*value.get("return").unwrap()).as_str().unwrap(); let lines: Vec<&str> = str.split("\r\n").collect(); for line in lines { diff --git a/tests/mod_test/src/libdriver/usb.rs b/tests/mod_test/src/libdriver/usb.rs index 0a03dd213c273a75c8f06dc292b655504acf4695..c3633b48be864f6818f6051a6c8c1fbfdf118355 100644 --- a/tests/mod_test/src/libdriver/usb.rs +++ b/tests/mod_test/src/libdriver/usb.rs @@ -174,11 +174,11 @@ impl TestNormalTRB { pub fn generate_setup_td(device_req: &UsbDeviceRequest) -> TestNormalTRB { let mut setup_trb = TestNormalTRB::default(); - setup_trb.parameter = (device_req.length as u64) << 48 - | (device_req.index as u64) << 32 - | (device_req.value as u64) << 16 - | (device_req.request as u64) << 8 - | device_req.request_type as u64; + setup_trb.parameter = u64::from(device_req.length) << 48 + | u64::from(device_req.index) << 32 + | u64::from(device_req.value) << 16 + | u64::from(device_req.request) << 8 + | u64::from(device_req.request_type); setup_trb.set_idt_flag(true); setup_trb.set_ch_flag(true); setup_trb.set_trb_type(TRBType::TrSetup as u32); @@ -193,7 +193,7 @@ impl TestNormalTRB { data_trb.set_ch_flag(true); data_trb.set_dir_flag(in_dir); data_trb.set_trb_type(TRBType::TrData as u32); - data_trb.set_trb_transfer_length(len as u32); + data_trb.set_trb_transfer_length(u32::from(len)); data_trb } @@ -395,6 +395,12 @@ pub struct TestEventRingSegment { pub reserved: u32, } +impl Default for TestEventRingSegment { + fn default() -> Self { + Self::new() + } +} + impl TestEventRingSegment { pub fn new() -> Self { Self { @@ -503,21 +509,21 @@ impl TestXhciPciDevice { pub fn run(&mut self) { let status = self.pci_dev.io_readl( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_USBSTS as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_USBSTS, ); assert!(status & USB_STS_HCH == USB_STS_HCH); let cmd = self.pci_dev.io_readl( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_USBCMD as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_USBCMD, ); self.pci_dev.io_writel( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_USBCMD as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_USBCMD, cmd | USB_CMD_RUN, ); let status = self.pci_dev.io_readl( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_USBSTS as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_USBSTS, ); assert!(status & USB_STS_HCH != USB_STS_HCH); } @@ -589,7 +595,7 @@ impl TestXhciPciDevice { self.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = self.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::ShortPacket as u32); - let buf = self.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, 1); + let buf = self.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), 1); assert_eq!(buf[0], 0); // configure endpoint self.configure_endpoint(slot_id, false); @@ -627,7 +633,7 @@ impl TestXhciPciDevice { pub fn reset_controller(&mut self, auto_run: bool) { // reset xhci self.oper_regs_write(0, USB_CMD_HCRST); - let status = self.oper_regs_read(XHCI_OPER_REG_USBSTS as u64); + let status = self.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_HCE != USB_STS_HCE); if auto_run { self.init_host_controller(XHCI_PCI_SLOT_NUM, XHCI_PCI_FUN_NUM); @@ -643,18 +649,21 @@ impl TestXhciPciDevice { pub fn oper_regs_read(&self, offset: u64) -> u32 { self.pci_dev - .io_readl(self.bar_addr, XHCI_PCI_OPER_OFFSET as u64 + offset) + .io_readl(self.bar_addr, u64::from(XHCI_PCI_OPER_OFFSET) + offset) } pub fn oper_regs_write(&mut self, offset: u64, value: u32) { - self.pci_dev - .io_writel(self.bar_addr, XHCI_PCI_OPER_OFFSET as u64 + offset, value); + self.pci_dev.io_writel( + self.bar_addr, + u64::from(XHCI_PCI_OPER_OFFSET) + offset, + value, + ); } pub fn interrupter_regs_read(&self, intr_idx: u64, offset: u64) -> u32 { self.pci_dev.io_readl( self.bar_addr, - XHCI_PCI_RUNTIME_OFFSET as u64 + u64::from(XHCI_PCI_RUNTIME_OFFSET) + XHCI_INTR_REG_SIZE + intr_idx * XHCI_INTR_REG_SIZE + offset, @@ -664,7 +673,7 @@ impl TestXhciPciDevice { pub fn interrupter_regs_write(&mut self, intr_idx: u64, offset: u64, value: u32) { self.pci_dev.io_writel( self.bar_addr, - XHCI_PCI_RUNTIME_OFFSET as u64 + u64::from(XHCI_PCI_RUNTIME_OFFSET) + RUNTIME_REGS_INTERRUPT_OFFSET + intr_idx * XHCI_INTR_REG_SIZE + offset, @@ -675,7 +684,7 @@ impl TestXhciPciDevice { pub fn interrupter_regs_readq(&self, intr_idx: u64, offset: u64) -> u64 { self.pci_dev.io_readq( self.bar_addr, - XHCI_PCI_RUNTIME_OFFSET as u64 + u64::from(XHCI_PCI_RUNTIME_OFFSET) + XHCI_INTR_REG_SIZE + intr_idx * XHCI_INTR_REG_SIZE + offset, @@ -685,7 +694,7 @@ impl TestXhciPciDevice { pub fn interrupter_regs_writeq(&mut self, intr_idx: u64, offset: u64, value: u64) { self.pci_dev.io_writeq( self.bar_addr, - XHCI_PCI_RUNTIME_OFFSET as u64 + u64::from(XHCI_PCI_RUNTIME_OFFSET) + RUNTIME_REGS_INTERRUPT_OFFSET + intr_idx * XHCI_INTR_REG_SIZE + offset, @@ -696,14 +705,14 @@ impl TestXhciPciDevice { pub fn port_regs_read(&self, port_id: u32, offset: u64) -> u32 { self.pci_dev.io_readl( self.bar_addr, - (XHCI_PCI_PORT_OFFSET + XHCI_PCI_PORT_LENGTH * (port_id - 1) as u32) as u64 + offset, + u64::from(XHCI_PCI_PORT_OFFSET + XHCI_PCI_PORT_LENGTH * (port_id - 1)) + offset, ) } pub fn port_regs_write(&mut self, port_id: u32, offset: u64, value: u32) { self.pci_dev.io_writel( self.bar_addr, - (XHCI_PCI_PORT_OFFSET + XHCI_PCI_PORT_LENGTH * (port_id - 1) as u32) as u64 + offset, + u64::from(XHCI_PCI_PORT_OFFSET + XHCI_PCI_PORT_LENGTH * (port_id - 1)) + offset, value, ); } @@ -711,7 +720,7 @@ impl TestXhciPciDevice { pub fn doorbell_write(&mut self, slot_id: u32, target: u32) { self.pci_dev.io_writel( self.bar_addr, - XHCI_PCI_DOORBELL_OFFSET as u64 + (slot_id << 2) as u64, + u64::from(XHCI_PCI_DOORBELL_OFFSET) + u64::from(slot_id << 2), target, ); } @@ -741,84 +750,84 @@ impl TestXhciPciDevice { // Interface Version Number let cap = self .pci_dev - .io_readl(self.bar_addr, XHCI_PCI_CAP_OFFSET as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET)); assert!(cap & 0x01000000 == 0x01000000); // HCSPARAMS1 let hcsparams1 = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x4) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x4)); assert_eq!(hcsparams1 & 0xffffff, 0x000140); // HCSPARAMS2 let hcsparams2 = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x8) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x8)); assert_eq!(hcsparams2, 0xf); // HCSPARAMS3 let hcsparams3 = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0xc) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0xc)); assert_eq!(hcsparams3, 0); // HCCPARAMS1 let hccparams1 = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x10) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x10)); // AC64 = 1 assert_eq!(hccparams1 & 1, 1); // doorbell offset let db_offset = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x14) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x14)); assert_eq!(db_offset, 0x2000); // runtime offset let runtime_offset = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x18) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x18)); assert_eq!(runtime_offset, 0x1000); // HCCPARAMS2 let hccparams2 = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x1c) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x1c)); assert_eq!(hccparams2, 0); // USB 2.0 let usb2_version = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x20) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x20)); assert!(usb2_version & 0x02000000 == 0x02000000); let usb2_name = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x24) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x24)); assert_eq!(usb2_name, 0x20425355); let usb2_port = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x28) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x28)); let usb2_port_num = (usb2_port >> 8) & 0xff; // extend capability end let end = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x2c) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x2c)); assert_eq!(end, 0); // USB 3.0 let usb3_version = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x30) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x30)); assert!(usb3_version & 0x03000000 == 0x03000000); let usb3_name = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x34) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x34)); assert_eq!(usb3_name, 0x20425355); let usb3_port = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x38) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x38)); let usb3_port_num = (usb3_port >> 8) & 0xff; // extend capability end let end = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x3c) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x3c)); assert_eq!(end, 0); // Max ports let hcsparams1 = self .pci_dev - .io_readl(self.bar_addr, (XHCI_PCI_CAP_OFFSET + 0x4) as u64); + .io_readl(self.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET + 0x4)); assert_eq!(hcsparams1 >> 24, usb2_port_num + usb3_port_num); } @@ -827,36 +836,36 @@ impl TestXhciPciDevice { let enabled_slot = USB_CONFIG_MAX_SLOTS_ENABLED & USB_CONFIG_MAX_SLOTS_EN_MASK; self.pci_dev.io_writel( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_CONFIG as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_CONFIG, enabled_slot, ); let config = self.pci_dev.io_readl( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_CONFIG as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_CONFIG, ); assert_eq!(config, enabled_slot); } pub fn init_device_context_base_address_array_pointer(&mut self) { let dcba = DEVICE_CONTEXT_ENTRY_SIZE * (USB_CONFIG_MAX_SLOTS_ENABLED + 1); - let dcbaap = self.allocator.borrow_mut().alloc(dcba as u64); + let dcbaap = self.allocator.borrow_mut().alloc(u64::from(dcba)); self.pci_dev.io_writeq( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_DCBAAP as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_DCBAAP, dcbaap, ); let value = self.pci_dev.io_readq( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_DCBAAP as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_DCBAAP, ); assert_eq!(value, dcbaap); self.xhci.dcbaap = value; } pub fn init_command_ring_dequeue_pointer(&mut self) { - let cmd_ring_sz = TRB_SIZE as u64 * COMMAND_RING_LEN; + let cmd_ring_sz = u64::from(TRB_SIZE) * COMMAND_RING_LEN; let cmd_ring = self.allocator.borrow_mut().alloc(cmd_ring_sz); self.pci_dev .pci_bus @@ -867,13 +876,13 @@ impl TestXhciPciDevice { self.xhci.cmd_ring.init(cmd_ring, cmd_ring_sz); self.pci_dev.io_writeq( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_CMD_RING_CTRL as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_CMD_RING_CTRL, cmd_ring, ); // Read dequeue pointer return 0. let cmd_ring = self.pci_dev.io_readq( self.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_CMD_RING_CTRL as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_CMD_RING_CTRL, ); assert_eq!(cmd_ring, 0); } @@ -903,13 +912,9 @@ impl TestXhciPciDevice { pub fn reset_port(&mut self, port_id: u32) { assert!(port_id > 0); - let port_offset = - (XHCI_PCI_PORT_OFFSET + XHCI_PCI_PORT_LENGTH * (port_id - 1) as u32) as u64; - self.pci_dev.io_writel( - self.bar_addr, - port_offset + XHCI_PORTSC_OFFSET, - PORTSC_PR as u32, - ); + let port_offset = u64::from(XHCI_PCI_PORT_OFFSET + XHCI_PCI_PORT_LENGTH * (port_id - 1)); + self.pci_dev + .io_writel(self.bar_addr, port_offset + XHCI_PORTSC_OFFSET, PORTSC_PR); self.oper_regs_write(XHCI_OPER_REG_USBSTS, USB_STS_PCD); let status = self.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_PCD != USB_STS_PCD); @@ -954,14 +959,14 @@ impl TestXhciPciDevice { let ep0_tr_ring = self .allocator .borrow_mut() - .alloc(TRB_SIZE as u64 * TRANSFER_RING_LEN); + .alloc(u64::from(TRB_SIZE) * TRANSFER_RING_LEN); ep0_ctx.set_tr_dequeue_pointer(ep0_tr_ring | 1); ep0_ctx.set_ep_state(0); ep0_ctx.set_ep_type(4); self.mem_write_u32(input_ctx_addr + 0x40, ep0_ctx.as_dwords()); self.xhci.device_slot[slot_id as usize].endpoints[(CONTROL_ENDPOINT_ID - 1) as usize] .transfer_ring - .init(ep0_tr_ring, TRB_SIZE as u64 * TRANSFER_RING_LEN); + .init(ep0_tr_ring, u64::from(TRB_SIZE) * TRANSFER_RING_LEN); let mut trb = TestNormalTRB::default(); trb.parameter = input_ctx_addr; @@ -1030,7 +1035,7 @@ impl TestXhciPciDevice { { TD_TRB_LIMIT } else { - TRB_SIZE as u64 * TRANSFER_RING_LEN + u64::from(TRB_SIZE) * TRANSFER_RING_LEN }; for i in 0..endpoint_id.len() { @@ -1137,7 +1142,7 @@ impl TestXhciPciDevice { self.interrupter_regs_writeq( intr_idx as u64, XHCI_INTR_REG_ERDP_LO, - self.xhci.interrupter[intr_idx].er_pointer | ERDP_EHB as u64, + self.xhci.interrupter[intr_idx].er_pointer | u64::from(ERDP_EHB), ); self.event_list.push_back(event); } else { @@ -1153,10 +1158,13 @@ impl TestXhciPciDevice { pub fn queue_device_request(&mut self, slot_id: u32, device_req: &UsbDeviceRequest) -> u64 { // Setup Stage. - let mut setup_trb = TestNormalTRB::generate_setup_td(&device_req); + let mut setup_trb = TestNormalTRB::generate_setup_td(device_req); self.queue_trb(slot_id, CONTROL_ENDPOINT_ID, &mut setup_trb); // Data Stage. - let ptr = self.allocator.borrow_mut().alloc(device_req.length as u64); + let ptr = self + .allocator + .borrow_mut() + .alloc(u64::from(device_req.length)); let in_dir = device_req.request_type & USB_DIRECTION_DEVICE_TO_HOST == USB_DIRECTION_DEVICE_TO_HOST; let mut data_trb = TestNormalTRB::generate_data_td(ptr, device_req.length, in_dir); @@ -1252,16 +1260,15 @@ impl TestXhciPciDevice { // Read data from parameter directly. pub fn get_transfer_data_direct(&self, addr: u64, len: u64) -> Vec { - let buf = self.mem_read(addr, len as usize); - buf + self.mem_read(addr, len as usize) } // Read data from parameter as address. pub fn get_transfer_data_indirect(&self, addr: u64, len: u64) -> Vec { let buf = self.mem_read(addr, 8); let mem = LittleEndian::read_u64(&buf); - let buf = self.mem_read(mem, len as usize); - buf + + self.mem_read(mem, len as usize) } pub fn get_transfer_data_indirect_with_offset( @@ -1272,8 +1279,8 @@ impl TestXhciPciDevice { ) -> Vec { let buf = self.mem_read(addr, 8); let mem = LittleEndian::read_u64(&buf); - let buf = self.mem_read(mem + offset, len); - buf + + self.mem_read(mem + offset, len) } pub fn get_command_pointer(&self) -> u64 { @@ -1311,7 +1318,7 @@ impl TestXhciPciDevice { let output_ctx_addr = self.get_device_context_address(slot_id); let mut ep_ctx = XhciEpCtx::default(); self.mem_read_u32( - output_ctx_addr + 0x20 * ep_id as u64, + output_ctx_addr + 0x20 * u64::from(ep_id), ep_ctx.as_mut_dwords(), ); ep_ctx @@ -1339,7 +1346,7 @@ impl TestXhciPciDevice { trb.set_cycle_bit(self.get_cycle_bit(slot_id, ep_id)); } let en_ptr = self.get_transfer_pointer(slot_id, ep_id); - self.write_trb(en_ptr, &trb); + self.write_trb(en_ptr, trb); self.increase_transfer_ring(slot_id, ep_id, 1); } @@ -1383,11 +1390,11 @@ impl TestXhciPciDevice { assert_eq!(data, erstsz); // ERSTBA let table_size = EVENT_RING_SEGMENT_TABLE_ENTRY_SIZE * erstsz; - let evt_ring_seg_table = self.allocator.borrow_mut().alloc(table_size as u64); + let evt_ring_seg_table = self.allocator.borrow_mut().alloc(u64::from(table_size)); self.xhci.interrupter[intr_idx].erstba = evt_ring_seg_table; // NOTE: Only support one Segment now. let mut seg = TestEventRingSegment::new(); - let evt_ring_sz = (TRB_SIZE * ersz) as u64; + let evt_ring_sz = u64::from(TRB_SIZE * ersz); let evt_ring = self.allocator.borrow_mut().alloc(evt_ring_sz); seg.init(evt_ring, ersz); self.pci_dev @@ -1420,7 +1427,7 @@ impl TestXhciPciDevice { assert_eq!(data, self.get_event_pointer(intr_idx)); // enable USB_CMD_INTE - let value = self.oper_regs_read(XHCI_OPER_REG_USBCMD as u64); + let value = self.oper_regs_read(XHCI_OPER_REG_USBCMD); self.oper_regs_write(XHCI_OPER_REG_USBCMD, value | USB_CMD_INTE); // enable INTE let value = self.interrupter_regs_read(intr_idx as u64, XHCI_INTR_REG_IMAN); @@ -1480,7 +1487,7 @@ impl TestXhciPciDevice { fn increase_event_ring(&mut self, intr_idx: usize) { self.xhci.interrupter[intr_idx].trb_count -= 1; - self.xhci.interrupter[intr_idx].er_pointer += TRB_SIZE as u64; + self.xhci.interrupter[intr_idx].er_pointer += u64::from(TRB_SIZE); if self.xhci.interrupter[intr_idx].trb_count == 0 { self.xhci.interrupter[intr_idx].segment_index += 1; if self.xhci.interrupter[intr_idx].segment_index @@ -1503,7 +1510,7 @@ impl TestXhciPciDevice { fn read_segment_entry(&self, intr_idx: usize, index: u32) -> TestEventRingSegment { assert!(index <= self.xhci.interrupter[intr_idx].erstsz); - let addr = self.xhci.interrupter[intr_idx].erstba + (TRB_SIZE * index) as u64; + let addr = self.xhci.interrupter[intr_idx].erstba + u64::from(TRB_SIZE * index); let evt_seg_buf = self.mem_read(addr, TRB_SIZE as usize); let mut evt_seg = TestEventRingSegment::new(); evt_seg.addr = LittleEndian::read_u64(&evt_seg_buf); @@ -1513,17 +1520,17 @@ impl TestXhciPciDevice { } fn set_device_context_address(&mut self, slot_id: u32, addr: u64) { - let device_ctx_addr = self.xhci.dcbaap + (slot_id * DEVICE_CONTEXT_ENTRY_SIZE) as u64; + let device_ctx_addr = self.xhci.dcbaap + u64::from(slot_id * DEVICE_CONTEXT_ENTRY_SIZE); let mut buf = [0_u8; 8]; LittleEndian::write_u64(&mut buf, addr); self.mem_write(device_ctx_addr, &buf); } fn get_device_context_address(&self, slot_id: u32) -> u64 { - let device_ctx_addr = self.xhci.dcbaap + (slot_id * DEVICE_CONTEXT_ENTRY_SIZE) as u64; + let device_ctx_addr = self.xhci.dcbaap + u64::from(slot_id * DEVICE_CONTEXT_ENTRY_SIZE); let mut buf = self.mem_read(device_ctx_addr, 8); - let addr = LittleEndian::read_u64(&mut buf); - addr + + LittleEndian::read_u64(&mut buf) } fn has_msix(&mut self, msix_addr: u64, msix_data: u32) -> bool { @@ -1537,21 +1544,25 @@ impl TestXhciPciDevice { fn increase_command_ring(&mut self) { let cmd_ring = self.xhci.cmd_ring; - if cmd_ring.pointer + TRB_SIZE as u64 >= cmd_ring.start + cmd_ring.size * TRB_SIZE as u64 { + if cmd_ring.pointer + u64::from(TRB_SIZE) + >= cmd_ring.start + cmd_ring.size * u64::from(TRB_SIZE) + { self.queue_link_trb(0, 0, cmd_ring.start, true); } - self.xhci.cmd_ring.pointer += TRB_SIZE as u64; + self.xhci.cmd_ring.pointer += u64::from(TRB_SIZE); } fn increase_transfer_ring(&mut self, slot_id: u32, ep_id: u32, len: u64) { let tr_ring = self.xhci.device_slot[slot_id as usize].endpoints[(ep_id - 1) as usize].transfer_ring; - if tr_ring.pointer + TRB_SIZE as u64 >= tr_ring.start + tr_ring.size * TRB_SIZE as u64 { + if tr_ring.pointer + u64::from(TRB_SIZE) + >= tr_ring.start + tr_ring.size * u64::from(TRB_SIZE) + { self.queue_link_trb(slot_id, ep_id, tr_ring.start, true); } self.xhci.device_slot[slot_id as usize].endpoints[(ep_id - 1) as usize] .transfer_ring - .increase_pointer(TRB_SIZE as u64 * len); + .increase_pointer(u64::from(TRB_SIZE) * len); } fn write_trb(&mut self, addr: u64, trb: &TestNormalTRB) { @@ -1585,7 +1596,7 @@ impl TestXhciPciDevice { // Descriptor impl TestXhciPciDevice { fn get_usb_device_type(&mut self) -> UsbDeviceType { - let usb_device_type = if *self.device_config.get("tablet").unwrap_or(&false) { + if *self.device_config.get("tablet").unwrap_or(&false) { UsbDeviceType::Tablet } else if *self.device_config.get("keyboard").unwrap_or(&false) { UsbDeviceType::Keyboard @@ -1595,9 +1606,7 @@ impl TestXhciPciDevice { UsbDeviceType::Camera } else { UsbDeviceType::Other - }; - - usb_device_type + } } fn get_iad_desc(&mut self, offset: &mut u64, addr: u64) { @@ -1607,8 +1616,8 @@ impl TestXhciPciDevice { } // 1. IAD header descriptor - *offset += USB_DT_CONFIG_SIZE as u64; - let buf = self.get_transfer_data_indirect_with_offset(addr, 8 as usize, *offset); + *offset += u64::from(USB_DT_CONFIG_SIZE); + let buf = self.get_transfer_data_indirect_with_offset(addr, 8_usize, *offset); // descriptor type assert_eq!(buf[1], USB_DT_INTERFACE_ASSOCIATION); @@ -1631,8 +1640,8 @@ impl TestXhciPciDevice { assert_eq!(buf[6], SC_VIDEOCONTROL); // get total vc length from its header descriptor - *offset += USB_DT_INTERFACE_SIZE as u64; - let buf = self.get_transfer_data_indirect_with_offset(addr, 0xd as usize, *offset); + *offset += u64::from(USB_DT_INTERFACE_SIZE); + let buf = self.get_transfer_data_indirect_with_offset(addr, 0xd_usize, *offset); let total = u16::from_le_bytes(buf[5..7].try_into().unwrap()); let remained = total - 0xd; @@ -1641,7 +1650,7 @@ impl TestXhciPciDevice { let _buf = self.get_transfer_data_indirect_with_offset(addr, remained as usize, *offset); // 3. VS interface - *offset += remained as u64; + *offset += u64::from(remained); let buf = self.get_transfer_data_indirect_with_offset( addr, USB_DT_INTERFACE_SIZE as usize, @@ -1653,8 +1662,8 @@ impl TestXhciPciDevice { assert_eq!(buf[6], SC_VIDEOSTREAMING); // get total vs length from its header descriptor - *offset += USB_DT_INTERFACE_SIZE as u64; - let buf = self.get_transfer_data_indirect_with_offset(addr, 0xf as usize, *offset); + *offset += u64::from(USB_DT_INTERFACE_SIZE); + let buf = self.get_transfer_data_indirect_with_offset(addr, 0xf_usize, *offset); let total = u16::from_le_bytes(buf[4..6].try_into().unwrap()); let remained = total - 0xf; @@ -1668,7 +1677,7 @@ impl TestXhciPciDevice { return; } - *offset += USB_DT_CONFIG_SIZE as u64; + *offset += u64::from(USB_DT_CONFIG_SIZE); let buf = self.get_transfer_data_indirect_with_offset( addr, USB_DT_INTERFACE_SIZE as usize, @@ -1692,7 +1701,7 @@ impl TestXhciPciDevice { match usb_device_type { UsbDeviceType::Tablet => { // hid descriptor - *offset += USB_DT_INTERFACE_SIZE as u64; + *offset += u64::from(USB_DT_INTERFACE_SIZE); let buf = self.get_transfer_data_indirect_with_offset(addr, 9, *offset); assert_eq!( buf, @@ -1711,14 +1720,14 @@ impl TestXhciPciDevice { } UsbDeviceType::Keyboard => { // hid descriptor - *offset += USB_DT_INTERFACE_SIZE as u64; + *offset += u64::from(USB_DT_INTERFACE_SIZE); let buf = self.get_transfer_data_indirect_with_offset(addr, 9, *offset); assert_eq!(buf, [0x09, 0x21, 0x11, 0x01, 0x00, 0x01, 0x22, 0x3f, 0]); } _ => {} } - *offset += USB_DT_INTERFACE_SIZE as u64; + *offset += u64::from(USB_DT_INTERFACE_SIZE); // endpoint descriptor let buf = self.get_transfer_data_indirect_with_offset( addr, @@ -1738,7 +1747,7 @@ impl TestXhciPciDevice { assert_eq!(buf[1], USB_DESCRIPTOR_TYPE_ENDPOINT); // endpoint address assert_eq!(buf[2], USB_DIRECTION_DEVICE_TO_HOST | 0x01); - *offset += USB_DT_ENDPOINT_SIZE as u64; + *offset += u64::from(USB_DT_ENDPOINT_SIZE); // endpoint descriptor let buf = self.get_transfer_data_indirect_with_offset( addr, @@ -1761,7 +1770,7 @@ impl TestXhciPciDevice { assert_eq!(evt.ccode, TRBCCode::ShortPacket as u32); let len = name.len() * 2 + 2; - let buf = self.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, len as u64); + let buf = self.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), len as u64); for i in 0..name.len() { assert_eq!(buf[2 * i + 2], name.as_bytes()[i]); } @@ -1774,8 +1783,10 @@ impl TestXhciPciDevice { self.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = self.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::ShortPacket as u32); - let buf = - self.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, USB_DT_DEVICE_SIZE as u64); + let buf = self.get_transfer_data_indirect( + evt.ptr - u64::from(TRB_SIZE), + u64::from(USB_DT_DEVICE_SIZE), + ); // descriptor type assert_eq!(buf[1], USB_DESCRIPTOR_TYPE_DEVICE); // bcdUSB @@ -1796,7 +1807,7 @@ impl TestXhciPciDevice { self.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = self.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::ShortPacket as u32); - let addr = evt.ptr - TRB_SIZE as u64; + let addr = evt.ptr - u64::from(TRB_SIZE); let mut offset = 0; let buf = self.get_transfer_data_indirect_with_offset(addr, USB_DT_CONFIG_SIZE as usize, offset); @@ -1842,7 +1853,7 @@ impl TestXhciPciDevice { self.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = self.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::Success as u32); - let buf = self.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, 63); + let buf = self.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), 63); assert_eq!( buf, [ @@ -1856,13 +1867,13 @@ impl TestXhciPciDevice { ); } UsbDeviceType::Tablet => { - self.get_hid_report_descriptor(slot_id, HID_POINTER_REPORT_LEN as u16); + self.get_hid_report_descriptor(slot_id, u16::from(HID_POINTER_REPORT_LEN)); self.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = self.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::Success as u32); let buf = self.get_transfer_data_indirect( - evt.ptr - TRB_SIZE as u64, - HID_POINTER_REPORT_LEN as u64, + evt.ptr - u64::from(TRB_SIZE), + u64::from(HID_POINTER_REPORT_LEN), ); assert_eq!( buf, @@ -1887,7 +1898,7 @@ impl TestXhciPciDevice { let device_req = UsbDeviceRequest { request_type: USB_DEVICE_IN_REQUEST, request: USB_REQUEST_GET_DESCRIPTOR, - value: (USB_DT_DEVICE as u16) << 8, + value: u16::from(USB_DT_DEVICE) << 8, index: 0, length: buf_len, }; @@ -1899,7 +1910,7 @@ impl TestXhciPciDevice { let device_req = UsbDeviceRequest { request_type: USB_DEVICE_IN_REQUEST, request: USB_REQUEST_GET_DESCRIPTOR, - value: (USB_DT_CONFIGURATION as u16) << 8, + value: u16::from(USB_DT_CONFIGURATION) << 8, index: 0, length: buf_len, }; @@ -1911,7 +1922,7 @@ impl TestXhciPciDevice { let device_req = UsbDeviceRequest { request_type: USB_DEVICE_IN_REQUEST, request: USB_REQUEST_GET_DESCRIPTOR, - value: (USB_DT_STRING as u16) << 8 | index, + value: u16::from(USB_DT_STRING) << 8 | index, index: 0, length: buf_len, }; @@ -2119,7 +2130,7 @@ impl TestXhciPciDevice { // Memory operation impl TestXhciPciDevice { pub fn mem_read_u32(&self, addr: u64, buf: &mut [u32]) { - let vec_len = size_of::() * buf.len(); + let vec_len = std::mem::size_of_val(buf); let tmp = self.mem_read(addr, vec_len); for i in 0..buf.len() { buf[i] = LittleEndian::read_u32(&tmp[(size_of::() * i)..]); @@ -2127,7 +2138,7 @@ impl TestXhciPciDevice { } pub fn mem_write_u32(&self, addr: u64, buf: &[u32]) { - let vec_len = size_of::() * buf.len(); + let vec_len = std::mem::size_of_val(buf); let mut vec = vec![0_u8; vec_len]; let tmp = vec.as_mut_slice(); for i in 0..buf.len() { @@ -2189,7 +2200,7 @@ impl TestXhciPciDevice { let device_req = UsbDeviceRequest { request_type: USB_INTERFACE_CLASS_IN_REQUEST, request: GET_INFO, - value: (VS_PROBE_CONTROL as u16) << 8, + value: u16::from(VS_PROBE_CONTROL) << 8, index: VS_INTERFACE_NUM, length: 1, }; @@ -2197,7 +2208,7 @@ impl TestXhciPciDevice { self.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = self.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::Success as u32); - let buf = self.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, 1); + let buf = self.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), 1); buf[0] } @@ -2206,7 +2217,7 @@ impl TestXhciPciDevice { let device_req = UsbDeviceRequest { request_type: USB_INTERFACE_CLASS_IN_REQUEST, request: GET_CUR, - value: (VS_PROBE_CONTROL as u16) << 8, + value: u16::from(VS_PROBE_CONTROL) << 8, index: VS_INTERFACE_NUM, length: len, }; @@ -2214,7 +2225,7 @@ impl TestXhciPciDevice { self.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = self.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::Success as u32); - let buf = self.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, len as u64); + let buf = self.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), u64::from(len)); let mut vs_control = VideoStreamingControl::default(); vs_control.as_mut_bytes().copy_from_slice(&buf); vs_control @@ -2228,7 +2239,7 @@ impl TestXhciPciDevice { let device_req = UsbDeviceRequest { request_type: USB_INTERFACE_CLASS_OUT_REQUEST, request: SET_CUR, - value: (VS_PROBE_CONTROL as u16) << 8, + value: u16::from(VS_PROBE_CONTROL) << 8, index: VS_INTERFACE_NUM, length: len, }; @@ -2243,7 +2254,7 @@ impl TestXhciPciDevice { let device_req = UsbDeviceRequest { request_type: USB_INTERFACE_CLASS_OUT_REQUEST, request: SET_CUR, - value: (VS_COMMIT_CONTROL as u16) << 8, + value: u16::from(VS_COMMIT_CONTROL) << 8, index: VS_INTERFACE_NUM, length: 0, }; @@ -2287,16 +2298,16 @@ impl TestXhciPciDevice { let cnt = (total + TRB_MAX_LEN - 1) / TRB_MAX_LEN; let mut data = Vec::new(); for _ in 0..cnt { - self.queue_indirect_td(slot_id, ep_id, TRB_MAX_LEN as u64); + self.queue_indirect_td(slot_id, ep_id, u64::from(TRB_MAX_LEN)); self.doorbell_write(slot_id, ep_id); // wait for frame done. std::thread::sleep(std::time::Duration::from_millis(FRAME_WAIT_MS)); let evt = self.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); if evt.ccode == TRBCCode::Success as u32 { - let mut buf = self.get_transfer_data_indirect(evt.ptr, TRB_MAX_LEN as u64); + let mut buf = self.get_transfer_data_indirect(evt.ptr, u64::from(TRB_MAX_LEN)); data.append(&mut buf); } else if evt.ccode == TRBCCode::ShortPacket as u32 { - let copied = (TRB_MAX_LEN - evt.length) as u64; + let copied = u64::from(TRB_MAX_LEN - evt.length); let mut buf = self.get_transfer_data_indirect(evt.ptr, copied); data.append(&mut buf); if total == data.len() as u32 { @@ -2318,6 +2329,12 @@ pub struct TestUsbBuilder { config: HashMap, } +impl Default for TestUsbBuilder { + fn default() -> Self { + Self::new() + } +} + impl TestUsbBuilder { pub fn new() -> Self { let mut args = Vec::new(); @@ -2374,7 +2391,7 @@ impl TestUsbBuilder { } pub fn with_usb_storage(mut self, image_path: &str, media: &str) -> Self { - let args = format!("-device usb-storage,drive=drive0,id=storage0"); + let args = "-device usb-storage,drive=drive0,id=storage0".to_string(); let args: Vec<&str> = args[..].split(' ').collect(); let mut args = args.into_iter().map(|s| s.to_string()).collect(); self.args.append(&mut args); @@ -2480,8 +2497,7 @@ pub fn qmp_plug_keyboard_event(test_state: RefMut, num: u32) -> Value str += &num_str; str += "\",\"bus\":\"usb.0\",\"port\":\"1\"}}"; - let value = test_state.qmp(&str); - value + test_state.qmp(&str) } pub fn qmp_plug_tablet_event(test_state: RefMut, num: u32) -> Value { @@ -2492,8 +2508,7 @@ pub fn qmp_plug_tablet_event(test_state: RefMut, num: u32) -> Value { str += &num_str; str += "\",\"bus\":\"usb.0\",\"port\":\"2\"}}"; - let value = test_state.qmp(&str); - value + test_state.qmp(&str) } pub fn qmp_unplug_usb_event(test_state: RefMut, num: u32) -> Value { @@ -2502,8 +2517,7 @@ pub fn qmp_unplug_usb_event(test_state: RefMut, num: u32) -> Value { str += &num_str; str += "\"}}"; - let value = test_state.qmp(&str); - value + test_state.qmp(&str) } pub fn qmp_event_read(test_state: RefMut) { @@ -2512,6 +2526,6 @@ pub fn qmp_event_read(test_state: RefMut) { pub fn clear_iovec(test_state: RefMut, iovecs: &Vec) { for iov in iovecs.iter() { - test_state.memwrite(iov.io_base, &vec![0; iov.io_len as usize]); + test_state.memwrite(iov.io_base, &vec![0; iov.io_len]); } } diff --git a/tests/mod_test/src/libdriver/virtio.rs b/tests/mod_test/src/libdriver/virtio.rs index 1c9297b7d03cd61011ab7ec0beededb4b1c328cb..c5d357e3b376fef1c5039816052f5b945626be34 100644 --- a/tests/mod_test/src/libdriver/virtio.rs +++ b/tests/mod_test/src/libdriver/virtio.rs @@ -258,7 +258,7 @@ impl TestVringIndirectDesc { let mut flags = test_state.borrow().readw( self.desc - + (size_of::() as u64 * self.index as u64) + + (size_of::() as u64 * u64::from(self.index)) + offset_of!(VringDesc, flags) as u64, ); @@ -359,7 +359,7 @@ impl TestVirtQueue { test_state.borrow().writew( self.used + offset_of!(VringUsed, ring) as u64 - + (size_of::() as u64 * self.size as u64), + + (size_of::() as u64 * u64::from(self.size)), 0, ); } @@ -376,7 +376,7 @@ impl TestVirtQueue { let features = virtio_dev.get_guest_features(); virtio_dev.queue_select(index); - let queue_size = virtio_dev.get_queue_size() as u32; + let queue_size = u32::from(virtio_dev.get_queue_size()); assert!(queue_size != 0); assert!(queue_size & (queue_size - 1) == 0); @@ -390,12 +390,12 @@ impl TestVirtQueue { let addr = alloc .borrow_mut() - .alloc(get_vring_size(self.size, self.align) as u64); + .alloc(u64::from(get_vring_size(self.size, self.align))); self.desc = addr; - self.avail = self.desc + (self.size * size_of::() as u32) as u64; + self.avail = self.desc + u64::from(self.size * size_of::() as u32); self.used = round_up( - self.avail + (size_of::() as u32 * (3 + self.size)) as u64, - self.align as u64, + self.avail + u64::from(size_of::() as u32 * (3 + self.size)), + u64::from(self.align), ) .unwrap(); } @@ -413,7 +413,7 @@ impl TestVirtQueue { let elem_addr = self.used + offset_of!(VringUsed, ring) as u64 - + (self.last_used_idx as u32 % self.size) as u64 + + u64::from(u32::from(self.last_used_idx) % self.size) * size_of::() as u64; let id_addr = elem_addr + offset_of!(VringUsedElem, id) as u64; @@ -434,7 +434,7 @@ impl TestVirtQueue { test_state.borrow().readw( self.used + offset_of!(VringUsed, ring) as u64 - + (size_of::() as u64 * self.size as u64), + + (size_of::() as u64 * u64::from(self.size)), ) } @@ -442,7 +442,7 @@ impl TestVirtQueue { test_state.borrow().writew( self.avail + offset_of!(VringAvail, ring) as u64 - + (size_of::() as u64 * self.size as u64), + + (size_of::() as u64 * u64::from(self.size)), index, ); } @@ -466,7 +466,7 @@ impl TestVirtQueue { test_state.borrow().writew( self.avail + offset_of!(VringAvail, ring) as u64 - + (size_of::() * (idx as u32 % self.size) as usize) as u64, + + (size_of::() * (u32::from(idx) % self.size) as usize) as u64, desc_idx, ); } @@ -505,7 +505,7 @@ impl TestVirtQueue { next: 0, }; self.add_elem_to_desc(test_state.clone(), desc_elem); - self.update_avail(test_state.clone(), free_head); + self.update_avail(test_state, free_head); free_head } @@ -536,7 +536,7 @@ impl TestVirtQueue { }; self.add_elem_to_desc(test_state.clone(), desc_elem); } - self.update_avail(test_state.clone(), free_head); + self.update_avail(test_state, free_head); free_head } @@ -551,13 +551,13 @@ impl TestVirtQueue { let free_head = self.free_head; let desc_elem = VringDesc { addr: indirect.desc, - len: size_of::() as u32 * indirect.elem as u32, + len: size_of::() as u32 * u32::from(indirect.elem), flags: VRING_DESC_F_INDIRECT, next: 0, }; self.add_elem_to_desc(test_state.clone(), desc_elem); if !mixed { - self.update_avail(test_state.clone(), free_head); + self.update_avail(test_state, free_head); } free_head } @@ -565,7 +565,7 @@ impl TestVirtQueue { // Add a vring desc elem to desc table. fn add_elem_to_desc(&mut self, test_state: Rc>, elem: VringDesc) { self.num_free -= 1; - let desc_elem_addr = self.desc + VRING_DESC_SIZE * self.free_head as u64; + let desc_elem_addr = self.desc + VRING_DESC_SIZE * u64::from(self.free_head); test_state .borrow() .memwrite(desc_elem_addr, elem.as_bytes()); @@ -591,7 +591,7 @@ impl TestVirtioDev { #[inline] pub fn get_vring_size(num: u32, align: u32) -> u32 { let desc_avail = - (size_of::() as u32 * num + size_of::() as u32 * (3 + num)) as u64; - let desc_avail_align = round_up(desc_avail, align as u64).unwrap() as u32; + u64::from(size_of::() as u32 * num + size_of::() as u32 * (3 + num)); + let desc_avail_align = round_up(desc_avail, u64::from(align)).unwrap() as u32; desc_avail_align + size_of::() as u32 * 3 + size_of::() as u32 * num } diff --git a/tests/mod_test/src/libdriver/virtio_block.rs b/tests/mod_test/src/libdriver/virtio_block.rs index 67340432fec1b93e7a04bd24de07c4b225369ce9..5cc5121c974bd0922f3f9bfafbe2b8c7c5d00595 100644 --- a/tests/mod_test/src/libdriver/virtio_block.rs +++ b/tests/mod_test/src/libdriver/virtio_block.rs @@ -154,7 +154,7 @@ pub fn create_blk( let machine = TestStdMachine::new(test_state.clone()); let allocator = machine.allocator.clone(); - let virtio_blk = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let virtio_blk = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); virtio_blk.borrow_mut().init(pci_slot, pci_fn); @@ -188,7 +188,7 @@ pub fn virtio_blk_request( .alloc((size_of::() + data_size + 512) as u64); let data_addr = if align { - round_up(addr + REQ_ADDR_LEN as u64, 512).unwrap() + round_up(addr + u64::from(REQ_ADDR_LEN), 512).unwrap() } else { addr + REQ_DATA_OFFSET }; @@ -233,11 +233,11 @@ pub fn add_blk_request( read = false; } // Get addr and write to Stratovirt. - let req_addr = virtio_blk_request(test_state.clone(), alloc.clone(), blk_req, align); + let req_addr = virtio_blk_request(test_state.clone(), alloc, blk_req, align); // Desc elem: [addr, len, flags, next]. let data_addr = if align { - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() } else { req_addr + REQ_DATA_OFFSET }; @@ -254,14 +254,12 @@ pub fn add_blk_request( write: read, }); data_entries.push(TestVringDescEntry { - data: data_addr + REQ_DATA_LEN as u64, + data: data_addr + u64::from(REQ_DATA_LEN), len: REQ_STATUS_LEN, write: true, }); - let free_head = vq - .borrow_mut() - .add_chained(test_state.clone(), data_entries); + let free_head = vq.borrow_mut().add_chained(test_state, data_entries); (free_head, req_addr) } @@ -292,7 +290,7 @@ pub fn virtio_blk_write( .kick_virtqueue(test_state.clone(), virtqueue.clone()); blk.borrow().poll_used_elem( test_state.clone(), - virtqueue.clone(), + virtqueue, free_head, TIMEOUT_US, &mut None, @@ -300,7 +298,7 @@ pub fn virtio_blk_write( ); let status_addr = if align { - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64 + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN) } else { req_addr + REQ_STATUS_OFFSET }; @@ -320,7 +318,7 @@ pub fn virtio_blk_read( ) { let (free_head, req_addr) = add_blk_request( test_state.clone(), - alloc.clone(), + alloc, virtqueue.clone(), VIRTIO_BLK_T_IN, sector, @@ -331,7 +329,7 @@ pub fn virtio_blk_read( .kick_virtqueue(test_state.clone(), virtqueue.clone()); blk.borrow().poll_used_elem( test_state.clone(), - virtqueue.clone(), + virtqueue, free_head, TIMEOUT_US, &mut None, @@ -339,13 +337,13 @@ pub fn virtio_blk_read( ); let data_addr = if align { - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() } else { - req_addr + REQ_ADDR_LEN as u64 + req_addr + u64::from(REQ_ADDR_LEN) }; let status_addr = if align { - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64 + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN) } else { req_addr + REQ_STATUS_OFFSET }; @@ -376,8 +374,8 @@ pub fn virtio_blk_read_write_zeroes( } read = false; } - let req_addr = virtio_blk_request(test_state.clone(), alloc.clone(), blk_req, false); - let data_addr = req_addr + REQ_ADDR_LEN as u64; + let req_addr = virtio_blk_request(test_state.clone(), alloc, blk_req, false); + let data_addr = req_addr + u64::from(REQ_ADDR_LEN); let data_entries: Vec = vec![ TestVringDescEntry { data: req_addr, @@ -401,13 +399,13 @@ pub fn virtio_blk_read_write_zeroes( blk.borrow().kick_virtqueue(test_state.clone(), vq.clone()); blk.borrow().poll_used_elem( test_state.clone(), - vq.clone(), + vq, free_head, TIMEOUT_US, &mut None, true, ); - let status_addr = req_addr + REQ_ADDR_LEN as u64 + data_len as u64; + let status_addr = req_addr + u64::from(REQ_ADDR_LEN) + data_len as u64; let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_OK); @@ -459,7 +457,7 @@ pub fn tear_down( vqs: Vec>>, image_path: Rc, ) { - blk.borrow_mut().destroy_device(alloc.clone(), vqs); + blk.borrow_mut().destroy_device(alloc, vqs); test_state.borrow_mut().stop(); if !image_path.is_empty() { cleanup_img(image_path.to_string()); diff --git a/tests/mod_test/src/libdriver/virtio_gpu.rs b/tests/mod_test/src/libdriver/virtio_gpu.rs index b8bd09d7b4818f4ba069412dc95dda43cabe1cc0..ac5236d8991f44500b51037115e3bad248b8462c 100644 --- a/tests/mod_test/src/libdriver/virtio_gpu.rs +++ b/tests/mod_test/src/libdriver/virtio_gpu.rs @@ -430,8 +430,8 @@ impl TestVirtioGpu { .borrow_mut() .setup_virtqueue_intr(2, self.allocator.clone(), cursor_q.clone()); - self.ctrl_q = ctrl_q.clone(); - self.cursor_q = cursor_q.clone(); + self.ctrl_q = ctrl_q; + self.cursor_q = cursor_q; self.device.borrow_mut().set_driver_ok(); } @@ -648,9 +648,9 @@ pub fn set_up( demo_dpy.borrow_mut().init(dpy_pci_slot); let virtgpu = Rc::new(RefCell::new(TestVirtioGpu::new( - machine.pci_bus.clone(), - allocator.clone(), - test_state.clone(), + machine.pci_bus, + allocator, + test_state, ))); virtgpu.borrow_mut().init(gpu_pci_slot, gpu_pci_fn); @@ -671,7 +671,7 @@ pub fn get_display_info(gpu: &Rc>) -> VirtioGpuDisplayInf gpu.borrow_mut() .request_complete(true, hdr.as_bytes(), None, None, Some(&mut resp)); - return resp; + resp } // VIRTIO_GPU_CMD_GET_EDID @@ -689,7 +689,7 @@ pub fn get_edid(gpu: &Rc>, hdr_ctx: VirtioGpuGetEdid) -> None, Some(&mut resp), ); - return resp; + resp } pub fn current_curosr_check(dpy: &Rc>, local: &Vec) -> bool { @@ -697,7 +697,7 @@ pub fn current_curosr_check(dpy: &Rc>, local: &Vec>, local: &Vec>) -> VirtioGpuCtrlHdr { @@ -924,5 +924,5 @@ pub fn invalid_cmd_test(gpu: &Rc>) -> VirtioGpuCtrlHdr { gpu.borrow_mut() .request_complete(true, hdr.as_bytes(), None, None, Some(&mut resp)); - return resp; + resp } diff --git a/tests/mod_test/src/libdriver/virtio_pci_modern.rs b/tests/mod_test/src/libdriver/virtio_pci_modern.rs index 2054bcd1e36673855e2fcf014a1ab2f615f5a0dd..dfc0f2cee8fd6ca34354da37bcb4dbae8af4b562 100644 --- a/tests/mod_test/src/libdriver/virtio_pci_modern.rs +++ b/tests/mod_test/src/libdriver/virtio_pci_modern.rs @@ -263,7 +263,7 @@ impl TestVirtioPciDev { } fn has_msix(&self, msix_addr: u64, msix_data: u32) -> bool { - return self.pci_dev.has_msix(msix_addr, msix_data); + self.pci_dev.has_msix(msix_addr, msix_data) } pub fn setup_virtqueue_intr( @@ -288,50 +288,50 @@ impl TestVirtioPciDev { impl VirtioDeviceOps for TestVirtioPciDev { fn config_readb(&self, addr: u64) -> u8 { self.pci_dev - .io_readb(self.bar, self.device_base as u64 + addr) + .io_readb(self.bar, u64::from(self.device_base) + addr) } fn config_readw(&self, addr: u64) -> u16 { self.pci_dev - .io_readw(self.bar, self.device_base as u64 + addr) + .io_readw(self.bar, u64::from(self.device_base) + addr) } fn config_readl(&self, addr: u64) -> u32 { self.pci_dev - .io_readl(self.bar, self.device_base as u64 + addr) + .io_readl(self.bar, u64::from(self.device_base) + addr) } fn config_readq(&self, addr: u64) -> u64 { self.pci_dev - .io_readq(self.bar, self.device_base as u64 + addr) + .io_readq(self.bar, u64::from(self.device_base) + addr) } #[allow(unused)] fn config_writeb(&self, addr: u64, value: u8) { self.pci_dev - .io_writeb(self.bar, self.device_base as u64 + addr, value) + .io_writeb(self.bar, u64::from(self.device_base) + addr, value) } #[allow(unused)] fn config_writew(&self, addr: u64, value: u16) { self.pci_dev - .io_writew(self.bar, self.device_base as u64 + addr, value) + .io_writew(self.bar, u64::from(self.device_base) + addr, value) } #[allow(unused)] fn config_writel(&self, addr: u64, value: u32) { self.pci_dev - .io_writel(self.bar, self.device_base as u64 + addr, value) + .io_writel(self.bar, u64::from(self.device_base) + addr, value) } #[allow(unused)] fn config_writeq(&self, addr: u64, value: u64) { self.pci_dev - .io_writeq(self.bar, self.device_base as u64 + addr, value) + .io_writeq(self.bar, u64::from(self.device_base) + addr, value) } fn isr_readb(&self) -> u8 { - self.pci_dev.io_readb(self.bar, self.isr_base as u64) + self.pci_dev.io_readb(self.bar, u64::from(self.isr_base)) } fn enable_interrupt(&mut self) { @@ -345,46 +345,50 @@ impl VirtioDeviceOps for TestVirtioPciDev { fn get_device_features(&self) -> u64 { self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, device_feature_select) as u64, + u64::from(self.common_base) + + offset_of!(VirtioPciCommonCfg, device_feature_select) as u64, 0, ); - let lo: u64 = self.pci_dev.io_readl( + let lo: u64 = u64::from(self.pci_dev.io_readl( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, device_feature) as u64, - ) as u64; + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, device_feature) as u64, + )); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, device_feature_select) as u64, + u64::from(self.common_base) + + offset_of!(VirtioPciCommonCfg, device_feature_select) as u64, 1, ); - let hi: u64 = self.pci_dev.io_readl( + let hi: u64 = u64::from(self.pci_dev.io_readl( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, device_feature) as u64, - ) as u64; + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, device_feature) as u64, + )); (hi << 32) | lo } fn set_guest_features(&self, features: u64) { self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, guest_feature_select) as u64, + u64::from(self.common_base) + + offset_of!(VirtioPciCommonCfg, guest_feature_select) as u64, 0, ); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, guest_feature) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, guest_feature) as u64, features as u32, ); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, guest_feature_select) as u64, + u64::from(self.common_base) + + offset_of!(VirtioPciCommonCfg, guest_feature_select) as u64, 1, ); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, guest_feature) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, guest_feature) as u64, (features >> 32) as u32, ); } @@ -392,36 +396,38 @@ impl VirtioDeviceOps for TestVirtioPciDev { fn get_guest_features(&self) -> u64 { self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, guest_feature_select) as u64, + u64::from(self.common_base) + + offset_of!(VirtioPciCommonCfg, guest_feature_select) as u64, 0, ); - let lo: u64 = self.pci_dev.io_readl( + let lo: u64 = u64::from(self.pci_dev.io_readl( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, guest_feature) as u64, - ) as u64; + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, guest_feature) as u64, + )); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, guest_feature_select) as u64, + u64::from(self.common_base) + + offset_of!(VirtioPciCommonCfg, guest_feature_select) as u64, 1, ); - let hi: u64 = self.pci_dev.io_readl( + let hi: u64 = u64::from(self.pci_dev.io_readl( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, guest_feature) as u64, - ) as u64; + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, guest_feature) as u64, + )); (hi << 32) | lo } fn get_status(&self) -> u8 { self.pci_dev.io_readb( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, device_status) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, device_status) as u64, ) } fn set_status(&self, status: u8) { self.pci_dev.io_writeb( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, device_status) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, device_status) as u64, status, ) } @@ -429,21 +435,21 @@ impl VirtioDeviceOps for TestVirtioPciDev { fn get_generation(&self) -> u8 { self.pci_dev.io_readb( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, config_generation) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, config_generation) as u64, ) } fn get_queue_nums(&self) -> u16 { self.pci_dev.io_readw( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, num_queues) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, num_queues) as u64, ) } fn queue_select(&self, index: u16) { self.pci_dev.io_writew( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_select) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_select) as u64, index, ); } @@ -451,14 +457,14 @@ impl VirtioDeviceOps for TestVirtioPciDev { fn get_queue_select(&self) -> u16 { self.pci_dev.io_readw( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_select) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_select) as u64, ) } fn set_queue_size(&self, size: u16) { self.pci_dev.io_writew( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_size) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_size) as u64, size, ) } @@ -466,39 +472,39 @@ impl VirtioDeviceOps for TestVirtioPciDev { fn get_queue_size(&self) -> u16 { self.pci_dev.io_readw( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_size) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_size) as u64, ) } fn activate_queue(&self, desc: u64, avail: u64, used: u64) { self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_desc_lo) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_desc_lo) as u64, desc as u32, ); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_desc_hi) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_desc_hi) as u64, (desc >> 32) as u32, ); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_avail_lo) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_avail_lo) as u64, avail as u32, ); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_avail_hi) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_avail_hi) as u64, (avail >> 32) as u32, ); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_used_lo) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_used_lo) as u64, used as u32, ); self.pci_dev.io_writel( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_used_hi) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_used_hi) as u64, (used >> 32) as u32, ); } @@ -555,15 +561,15 @@ impl VirtioDeviceOps for TestVirtioPciDev { let notify_off = self.pci_dev.io_readw( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_notify_off) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_notify_off) as u64, ); - virtqueue.borrow_mut().queue_notify_off = - self.notify_base as u64 + notify_off as u64 * self.notify_off_multiplier as u64; + virtqueue.borrow_mut().queue_notify_off = u64::from(self.notify_base) + + u64::from(notify_off) * u64::from(self.notify_off_multiplier); self.pci_dev.io_writew( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_enable) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_enable) as u64, 1, ); @@ -591,7 +597,7 @@ impl VirtioDeviceOps for TestVirtioPciDev { let vq = virtqueue.borrow(); let idx: u16 = test_state.borrow().readw(vq.avail + 2); - if (!vq.event) || (idx >= vq.get_avail_event(test_state.clone()) + 1) { + if (!vq.event) || (idx > vq.get_avail_event(test_state)) { self.virtqueue_notify(virtqueue.clone()); } } @@ -681,12 +687,12 @@ impl VirtioPCIMSIXOps for TestVirtioPciDev { fn set_config_vector(&self, vector: u16) { self.pci_dev.io_writew( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, msix_config) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, msix_config) as u64, vector, ); let vector_get: u16 = self.pci_dev.io_readw( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, msix_config) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, msix_config) as u64, ); assert_eq!( vector, vector_get, @@ -699,12 +705,12 @@ impl VirtioPCIMSIXOps for TestVirtioPciDev { self.queue_select(vq_idx); self.pci_dev.io_writew( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_msix_vector) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_msix_vector) as u64, vector, ); let vector_get: u16 = self.pci_dev.io_readw( self.bar, - self.common_base as u64 + offset_of!(VirtioPciCommonCfg, queue_msix_vector) as u64, + u64::from(self.common_base) + offset_of!(VirtioPciCommonCfg, queue_msix_vector) as u64, ); if vector_get != vector { println!("WARN: set vector {}, get vector {}", vector, vector_get); diff --git a/tests/mod_test/src/libdriver/virtio_rng.rs b/tests/mod_test/src/libdriver/virtio_rng.rs index 9b7520cda3a5ba3449e4cbdaeba60e1a418404a4..082d714a096633b6408add03e7876aa35431dd9c 100644 --- a/tests/mod_test/src/libdriver/virtio_rng.rs +++ b/tests/mod_test/src/libdriver/virtio_rng.rs @@ -49,7 +49,7 @@ pub fn create_rng( let machine = TestStdMachine::new(test_state.clone()); let allocator = machine.allocator.clone(); - let rng = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let rng = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); rng.borrow_mut().init(pci_slot, pci_fn); diff --git a/tests/mod_test/src/libdriver/vnc.rs b/tests/mod_test/src/libdriver/vnc.rs index 3158c5ee730b1a3db4f96cf07f465793a2dee11c..bf7034702bedc06b0804dd133c3b573e8e65eeb1 100644 --- a/tests/mod_test/src/libdriver/vnc.rs +++ b/tests/mod_test/src/libdriver/vnc.rs @@ -57,19 +57,14 @@ pub const PIXMAN_YUY2: u32 = 4; pub const REFRESH_TIME_INTERVAL: u64 = 3000 * 1000 * 1000; /// Input event. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Default)] pub enum InputEvent { KbdEvent = 0, MouseEvent = 1, + #[default] InvalidEvent = 255, } -impl Default for InputEvent { - fn default() -> Self { - InputEvent::InvalidEvent - } -} - impl From for InputEvent { fn from(v: u8) -> Self { match v { @@ -91,21 +86,16 @@ pub struct InputMessage { } /// GPU device Event. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Default)] pub enum GpuEvent { ReplaceSurface = 0, ReplaceCursor = 1, GraphicUpdateArea = 2, GraphicUpdateDirty = 3, + #[default] Deactive = 4, } -impl Default for GpuEvent { - fn default() -> Self { - GpuEvent::Deactive - } -} - #[derive(Debug, Clone, Copy, Default)] pub struct TestGpuCmd { pub event_type: GpuEvent, @@ -118,7 +108,7 @@ pub struct TestGpuCmd { // Encodings Type #[repr(u32)] -#[derive(Clone, Copy, PartialEq, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub enum EncodingType { EncodingRaw = 0x00000000, EncodingCopyrect = 0x00000001, @@ -190,7 +180,7 @@ impl From for EncodingType { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] pub enum RfbServerMsg { FramebufferUpdate = 0, SetColourMapEntries = 1, @@ -691,7 +681,7 @@ impl VncClient { pub fn epoll_ctl(&mut self, event: EpollEvent) -> io::Result<()> { self.epoll - .ctl(ControlOperation::Add, self.stream.as_raw_fd() as i32, event) + .ctl(ControlOperation::Add, self.stream.as_raw_fd(), event) } /// Wait for events on the epoll. @@ -701,13 +691,8 @@ impl VncClient { /// 2. Return if event happen or time out. pub fn epoll_wait(&mut self, event_set: EventSet) -> io::Result { let event = EpollEvent::new(event_set, self.stream.as_raw_fd() as u64); - if let Err(e) = self.epoll.ctl( - ControlOperation::Modify, - self.stream.as_raw_fd() as i32, - event, - ) { - return Err(e); - } + self.epoll + .ctl(ControlOperation::Modify, self.stream.as_raw_fd(), event)?; self.epoll .wait(EPOLL_DEFAULT_TIMEOUT, &mut self.ready_events[..]) } @@ -716,11 +701,8 @@ impl VncClient { pub fn stream_read_to_end(&mut self) -> Result<()> { let mut buf: Vec = Vec::new(); let event = EpollEvent::new(EventSet::IN, self.stream.as_raw_fd() as u64); - self.epoll.ctl( - ControlOperation::Modify, - self.stream.as_raw_fd() as i32, - event, - )?; + self.epoll + .ctl(ControlOperation::Modify, self.stream.as_raw_fd(), event)?; match self .epoll @@ -787,7 +769,7 @@ impl VncClient { if "RFB 003.008\n".as_bytes().to_vec() != buf[..12].to_vec() { bail!("Unsupported RFB version"); } - self.write_msg(&"RFB 003.008\n".as_bytes().to_vec())?; + self.write_msg("RFB 003.008\n".as_bytes())?; buf.drain(..12); // Step 2: Auth num is 1. @@ -800,25 +782,22 @@ impl VncClient { bail!("Unsupported security type!"); } buf.drain(..auth_num as usize); - self.write_msg(&(sec_type as u8).to_be_bytes().to_vec())?; - - match sec_type { - TestAuthType::VncAuthNone => { - // Step 3. Handle_auth: Authstate::No, Server accept auth and client send share - // mode. - self.read_msg(&mut buf, 4)?; - if buf[..4].to_vec() != [0_u8; 4].to_vec() { - bail!("Reject by vnc server"); - } - self.write_msg(&0_u8.to_be_bytes().to_vec())?; - buf.drain(..4); - - // Step 4. display mode information init: width + height + pixelformat + app_name. - self.read_msg(&mut buf, 24)?; - self.display_mod.from_bytes(&mut buf); - self.display_mod.check(); + self.write_msg((sec_type as u8).to_be_bytes().as_ref())?; + + if let TestAuthType::VncAuthNone = sec_type { + // Step 3. Handle_auth: Authstate::No, Server accept auth and client send share + // mode. + self.read_msg(&mut buf, 4)?; + if buf[..4].to_vec() != [0_u8; 4].to_vec() { + bail!("Reject by vnc server"); } - _ => {} + self.write_msg(0_u8.to_be_bytes().as_ref())?; + buf.drain(..4); + + // Step 4. display mode information init: width + height + pixelformat + app_name. + self.read_msg(&mut buf, 24)?; + self.display_mod.from_bytes(&mut buf); + self.display_mod.check(); } self.stream_read_to_end()?; println!("Connection established!"); @@ -856,10 +835,7 @@ impl VncClient { let mut test_event = TestSetupEncoding::new(); if let Some(encoding) = enc { test_event.encs.push(encoding); - test_event.num_encodings = match enc_num { - Some(num) => num, - None => 1_u16, - }; + test_event.num_encodings = enc_num.unwrap_or(1_u16); } else { for encoding in EncodingType::ENCODINGTYPE { test_event.encs.push(encoding); @@ -1020,7 +996,7 @@ impl VncClient { let message_len: usize = frame_buff.w as usize * frame_buff.h as usize * (pf.bit_per_pixel as usize / 8); println!("Total bytes of image data: {:?}", message_len); - self.read_msg(buf, message_len as usize)?; + self.read_msg(buf, message_len)?; buf.drain(..message_len); Ok(()) } @@ -1211,7 +1187,7 @@ impl TestDemoGpuDevice { test_state.borrow_mut().writel(addr + 13, cmd.h); test_state.borrow_mut().writel(addr + 17, cmd.data_len); // Write to specific address. - self.pci_dev.io_writeq(self.bar_addr, 0 as u64, addr); + self.pci_dev.io_writeq(self.bar_addr, 0_u64, addr); test_state.borrow().clock_step_ns(REFRESH_TIME_INTERVAL); println!("cmd : {:?}", cmd); } @@ -1376,7 +1352,7 @@ pub fn set_up( } let input = Rc::new(RefCell::new(TestDemoInputDevice::new( - machine.pci_bus.clone(), + machine.pci_bus, allocator, ))); input.borrow_mut().init(input_conf.pci_slot); diff --git a/tests/mod_test/src/libtest.rs b/tests/mod_test/src/libtest.rs index 6de33d6645e5ef3c651335892469dc415bc4d0f9..c98c32244b51528f4661583fa3cf56f2a66e1f78 100644 --- a/tests/mod_test/src/libtest.rs +++ b/tests/mod_test/src/libtest.rs @@ -54,6 +54,12 @@ impl StreamHandler { .unwrap(); } + fn clear_stream(&self) { + let mut stream = self.stream.try_clone().unwrap(); + stream.set_nonblocking(true).unwrap(); + let _ = stream.read(&mut [0_u8; 1024]); + } + fn read_line(&self, timeout: Duration) -> String { let start = Instant::now(); let mut resp = self.read_buffer.borrow_mut(); @@ -132,11 +138,12 @@ impl TestState { let resp: Value = serde_json::from_slice(self.qmp_sock.read_line(timeout).as_bytes()).unwrap(); assert!(resp.get("event").is_some()); - return resp; + resp } pub fn qmp(&self, cmd: &str) -> Value { let timeout = Duration::from_secs(10); + self.qmp_sock.clear_stream(); self.qmp_sock.write_line(cmd); serde_json::from_slice(self.qmp_sock.read_line(timeout).as_bytes()).unwrap() } @@ -198,7 +205,7 @@ impl TestState { pub fn readq(&self, addr: u64) -> u64 { let cmd = format!("readq 0x{:x}", addr); - self.send_read_cmd(&cmd) as u64 + self.send_read_cmd(&cmd) } pub fn memread(&self, addr: u64, size: u64) -> Vec { @@ -355,7 +362,12 @@ pub fn test_init(extra_arg: Vec<&str>) -> TestState { let listener = init_socket(&test_socket); - let child = Command::new(binary_path) + let mut cmd = Command::new(binary_path); + + #[cfg(target_env = "ohos")] + cmd.args(["-disable-seccomp"]); + + let child = cmd .args(["-accel", "test"]) .args(["-qmp", &format!("unix:{},server,nowait", qmp_socket)]) .args(["-mod-test", &test_socket]) diff --git a/tests/mod_test/src/utils.rs b/tests/mod_test/src/utils.rs index b84205d4bcfb2190877971ff5d3038c68565e48a..06a2118a2fa30bbd8ba6d6ac0bbf33d8f0aef6da 100644 --- a/tests/mod_test/src/utils.rs +++ b/tests/mod_test/src/utils.rs @@ -56,18 +56,18 @@ pub fn read_le_u64(input: &mut &[u8]) -> u64 { } pub fn swap_u16(value: u16) -> u16 { - return value << 8 | value >> 8; + value << 8 | value >> 8 } pub fn swap_u32(value: u32) -> u32 { - let lower_u16 = swap_u16(value as u16) as u32; - let higher_u16 = swap_u16((value >> 16) as u16) as u32; + let lower_u16 = u32::from(swap_u16(value as u16)); + let higher_u16 = u32::from(swap_u16((value >> 16) as u16)); lower_u16 << 16 | higher_u16 } pub fn swap_u64(value: u64) -> u64 { - let lower_u32 = swap_u32(value as u32) as u64; - let higher_u32 = swap_u32((value >> 32) as u32) as u64; + let lower_u32 = u64::from(swap_u32(value as u32)); + let higher_u32 = u64::from(swap_u32((value >> 32) as u32)); lower_u32 << 32 | higher_u32 } @@ -127,3 +127,27 @@ pub fn cleanup_img(image_path: String) { fs::remove_file(img_path).expect("lack permissions to remove the file"); } + +pub fn support_numa() -> bool { + let numa_nodes_path = "/sys/devices/system/node/"; + + if Path::new(numa_nodes_path).exists() { + match fs::read_dir(numa_nodes_path) { + Ok(entries) => { + let mut has_nodes = false; + for entry in entries { + if let Ok(entry) = entry { + if entry.file_name().to_str().unwrap_or("").starts_with("node") { + has_nodes = true; + break; + } + } + } + has_nodes + } + Err(_) => false, + } + } else { + false + } +} diff --git a/tests/mod_test/tests/aarch64/acpi_test.rs b/tests/mod_test/tests/aarch64/acpi_test.rs index d28c8221ddf355607c625458c934790f5db39a9f..563d5ebc9b6af7201ffa3fd70b62a27cf467fad4 100644 --- a/tests/mod_test/tests/aarch64/acpi_test.rs +++ b/tests/mod_test/tests/aarch64/acpi_test.rs @@ -103,8 +103,8 @@ fn check_fadt(data: &[u8]) -> (u32, u64) { assert_eq!(String::from_utf8_lossy(&data[..4]), "FACP"); assert_eq!(LittleEndian::read_u32(&data[4..]), FADT_TABLE_DATA_LENGTH); // Check length - // Enable HW_REDUCED_ACPI and LOW_POWER_S0_IDLE_CAPABLE bit - assert_eq!(LittleEndian::read_i32(&data[112..]), 0x30_0500); + // Enable HW_REDUCED_ACPI bit + assert_eq!(LittleEndian::read_i32(&data[112..]), 0x10_0500); assert_eq!(LittleEndian::read_u16(&data[129..]), 0x3); // ARM Boot Architecture Flags assert_eq!(LittleEndian::read_i32(&data[131..]), 3); // FADT minor revision @@ -140,12 +140,12 @@ fn check_madt(data: &[u8], cpu: u8) { offset += mem::size_of::(); for i in 0..cpu { assert_eq!(data[offset + 1], 80); // The length of this structure - assert_eq!(LittleEndian::read_u32(&data[(offset + 4)..]), i as u32); // CPU interface number - assert_eq!(LittleEndian::read_u32(&data[(offset + 8)..]), i as u32); // ACPI processor UID + assert_eq!(LittleEndian::read_u32(&data[(offset + 4)..]), u32::from(i)); // CPU interface number + assert_eq!(LittleEndian::read_u32(&data[(offset + 8)..]), u32::from(i)); // ACPI processor UID assert_eq!(LittleEndian::read_u32(&data[(offset + 12)..]), 5); // Flags assert_eq!(LittleEndian::read_u32(&data[(offset + 20)..]), 23); // Performance monitoring interrupts assert_eq!(LittleEndian::read_u64(&data[(offset + 56)..]), 25); // Virtual GIC maintenance interrupt - assert_eq!(LittleEndian::read_u64(&data[(offset + 68)..]), i as u64); // MPIDR + assert_eq!(LittleEndian::read_u64(&data[(offset + 68)..]), u64::from(i)); // MPIDR offset += mem::size_of::(); } @@ -462,7 +462,7 @@ fn check_madt_of_two_gicr( let len = LittleEndian::read_u32(&read_data[(madt_addr + offset + 12)..]); assert_eq!( MEM_LAYOUT[LayoutEntryType::HighGicRedist as usize].1, - len as u64 + u64::from(len) ); } diff --git a/tests/mod_test/tests/balloon_test.rs b/tests/mod_test/tests/balloon_test.rs index dbcd67fde8dcd7de8a18bd2964fd83dbba02801b..8fbc400285b29b0934f8960354fdc88089c6970a 100644 --- a/tests/mod_test/tests/balloon_test.rs +++ b/tests/mod_test/tests/balloon_test.rs @@ -17,6 +17,7 @@ use std::process::Command; use std::rc::Rc; use std::{thread, time}; +use mod_test::utils::support_numa; use serde_json::json; use mod_test::libdriver::machine::TestStdMachine; @@ -38,7 +39,7 @@ const ADDRESS_BASE: u64 = 0x4000_0000; fn read_lines(filename: String) -> io::Lines> { let file = File::open(filename).unwrap(); - return io::BufReader::new(file).lines(); + io::BufReader::new(file).lines() } fn get_hugesize() -> u64 { @@ -47,12 +48,12 @@ fn get_hugesize() -> u64 { for line in lines { if let Ok(info) = line { if info.starts_with("HugePages_Free:") { - let free: Vec<&str> = info.split(":").collect(); + let free: Vec<&str> = info.split(':').collect(); free_page = free[1].trim().parse::().unwrap(); } if info.starts_with("Hugepagesize:") { - let huges: Vec<&str> = info.split(":").collect(); - let sizes: Vec<&str> = huges[1].trim().split(" ").collect(); + let huges: Vec<&str> = info.split(':').collect(); + let sizes: Vec<&str> = huges[1].trim().split(' ').collect(); let size = sizes[0].trim().parse::().unwrap(); return free_page * size; } @@ -90,7 +91,7 @@ impl VirtioBalloonTest { let mut extra_args: Vec<&str> = Vec::new(); let mut fpr_switch = String::from("false"); let mut auto_switch = String::from("false"); - let mem_path = format!("-mem-path /tmp/stratovirt/hugepages"); + let mem_path = "-mem-path /tmp/stratovirt/hugepages".to_string(); let mut machine_args = MACHINE_TYPE_ARG.to_string(); if shared { @@ -125,12 +126,11 @@ impl VirtioBalloonTest { let machine = TestStdMachine::new_bymem(test_state.clone(), memsize * MBSIZE, page_size); let allocator = machine.allocator.clone(); - let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); dev.borrow_mut().init(pci_slot, 0); let features = dev.borrow_mut().get_device_features(); - let inf_queue; - let def_queue; + let mut fpr_queue = None; let mut auto_queue = None; let mut que_num = 2_usize; @@ -144,8 +144,8 @@ impl VirtioBalloonTest { let ques = dev.borrow_mut() .init_device(test_state.clone(), allocator.clone(), features, que_num); - inf_queue = ques[0].clone(); - def_queue = ques[1].clone(); + let inf_queue = ques[0].clone(); + let def_queue = ques[1].clone(); if cfg.fpr { fpr_queue = Some(ques[idx].clone()); idx += 1; @@ -208,18 +208,16 @@ impl VirtioBalloonTest { let machine = TestStdMachine::new_bymem(test_state.clone(), 2 * MBSIZE, 4096); let allocator = machine.allocator.clone(); - let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); dev.borrow_mut().init(4, 0); let features = dev.borrow_mut().get_device_features(); - let inf_queue; - let def_queue; let ques = dev .borrow_mut() .init_device(test_state.clone(), allocator.clone(), features, 2); - inf_queue = ques[0].clone(); - def_queue = ques[1].clone(); + let inf_queue = ques[0].clone(); + let def_queue = ques[1].clone(); VirtioBalloonTest { device: dev, @@ -326,18 +324,18 @@ fn balloon_fun(shared: bool, huge: bool) { let free_page = balloon .allocator .borrow_mut() - .alloc(page_num as u64 * PAGE_SIZE_UNIT); + .alloc(u64::from(page_num) * PAGE_SIZE_UNIT); let pfn = (free_page >> 12) as u32; let pfn_addr = balloon.allocator.borrow_mut().alloc(PAGE_SIZE_UNIT); while idx < page_num { balloon .state .borrow_mut() - .writel(pfn_addr + 4 * idx as u64, pfn + idx); + .writel(pfn_addr + 4 * u64::from(idx), pfn + idx); balloon .state .borrow_mut() - .writeb(free_page + PAGE_SIZE_UNIT * idx as u64, 1); + .writeb(free_page + PAGE_SIZE_UNIT * u64::from(idx), 1); idx += 1; } @@ -347,7 +345,7 @@ fn balloon_fun(shared: bool, huge: bool) { while loop_num < page_num { let entry = TestVringDescEntry { - data: pfn_addr + (loop_num as u64 * 4), + data: pfn_addr + (u64::from(loop_num) * 4), len: 4, write: false, }; @@ -376,7 +374,7 @@ fn balloon_fun(shared: bool, huge: bool) { while loop_num < page_num { let entry = TestVringDescEntry { - data: pfn_addr + (loop_num as u64 * 4), + data: pfn_addr + (u64::from(loop_num) * 4), len: 4, write: false, }; @@ -569,12 +567,12 @@ fn balloon_feature_001() { let machine = TestStdMachine::new_bymem(test_state.clone(), 128 * MBSIZE, PAGE_SIZE_UNIT); let allocator = machine.allocator.clone(); - let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); dev.borrow_mut().init(pci_slot, pci_fn); dev.borrow_mut().pci_dev.enable_msix(None); dev.borrow_mut() - .setup_msix_configuration_vector(allocator.clone(), 0); + .setup_msix_configuration_vector(allocator, 0); let features = dev.borrow_mut().get_device_features(); @@ -616,12 +614,12 @@ fn balloon_feature_002() { let machine = TestStdMachine::new_bymem(test_state.clone(), 128 * MBSIZE, PAGE_SIZE_UNIT); let allocator = machine.allocator.clone(); - let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); dev.borrow_mut().init(pci_slot, pci_fn); dev.borrow_mut().pci_dev.enable_msix(None); dev.borrow_mut() - .setup_msix_configuration_vector(allocator.clone(), 0); + .setup_msix_configuration_vector(allocator, 0); let features = dev.borrow_mut().get_device_features(); @@ -654,18 +652,18 @@ fn balloon_fpr_fun(shared: bool) { let free_page = balloon .allocator .borrow_mut() - .alloc(page_num as u64 * PAGE_SIZE_UNIT); + .alloc(u64::from(page_num) * PAGE_SIZE_UNIT); let pfn = (free_page >> 12) as u32; let pfn_addr = balloon.allocator.borrow_mut().alloc(PAGE_SIZE_UNIT); while idx < page_num { balloon .state .borrow_mut() - .writel(pfn_addr + 4 * idx as u64, pfn + idx); + .writel(pfn_addr + 4 * u64::from(idx), pfn + idx); balloon .state .borrow_mut() - .writeb(free_page + PAGE_SIZE_UNIT * idx as u64, 1); + .writeb(free_page + PAGE_SIZE_UNIT * u64::from(idx), 1); idx += 1; } // balloon Illegal addresses @@ -693,7 +691,7 @@ fn balloon_fpr_fun(shared: bool) { while loop_num < page_num { let entry = TestVringDescEntry { - data: pfn_addr + (loop_num as u64 * 4), + data: pfn_addr + (u64::from(loop_num) * 4), len: 4, write: true, }; @@ -707,7 +705,7 @@ fn balloon_fpr_fun(shared: bool) { .kick_virtqueue(balloon.state.clone(), fpr.clone()); balloon.device.borrow_mut().poll_used_elem( balloon.state.clone(), - fpr.clone(), + fpr, free_head, TIMEOUT_US, &mut None, @@ -778,7 +776,7 @@ fn query() { assert_eq!( *ret.get("return").unwrap(), - json!({"actual": 2147483648 as u64}) + json!({"actual": 2147483648_u64}) ); balloon.state.borrow_mut().stop(); @@ -834,10 +832,7 @@ fn balloon_config_001() { let ten_millis = time::Duration::from_millis(10); thread::sleep(ten_millis); let ret = balloon.state.borrow_mut().qmp_read(); - assert_eq!( - *ret.get("data").unwrap(), - json!({"actual": 536870912 as u64}) - ); + assert_eq!(*ret.get("data").unwrap(), json!({"actual": 536870912_u64})); balloon .state @@ -905,10 +900,7 @@ fn balloon_config_002() { let ten_millis = time::Duration::from_millis(10); thread::sleep(ten_millis); let ret = balloon.state.borrow_mut().qmp_read(); - assert_eq!( - *ret.get("data").unwrap(), - json!({"actual": 536870912 as u64}) - ); + assert_eq!(*ret.get("data").unwrap(), json!({"actual": 536870912_u64})); balloon .state @@ -939,7 +931,7 @@ fn balloon_deactive_001() { let balloon = VirtioBalloonTest::new(1024, PAGE_SIZE_UNIT, false, false, cfg); let bar = balloon.device.borrow().bar; - let common_base = balloon.device.borrow().common_base as u64; + let common_base = u64::from(balloon.device.borrow().common_base); balloon.device.borrow().pci_dev.io_writel( bar, @@ -956,7 +948,7 @@ fn balloon_deactive_001() { .qmp("{\"execute\": \"query-balloon\"}"); assert_eq!( *ret.get("return").unwrap(), - json!({"actual": 1073741824 as u64}) + json!({"actual": 1073741824_u64}) ); balloon.state.borrow_mut().stop(); } @@ -1008,7 +1000,7 @@ fn auto_balloon_test_001() { balloon .state .borrow_mut() - .memwrite(msg_addr, &stat.as_bytes()); + .memwrite(msg_addr, stat.as_bytes()); let auto_queue = balloon.auto_queue.unwrap(); @@ -1024,7 +1016,7 @@ fn auto_balloon_test_001() { .kick_virtqueue(balloon.state.clone(), auto_queue.clone()); balloon.device.borrow_mut().poll_used_elem( balloon.state.clone(), - auto_queue.clone(), + auto_queue, free_head, TIMEOUT_US, &mut None, @@ -1054,6 +1046,10 @@ fn auto_balloon_test_001() { /// Expect: /// 1/2.Success fn balloon_numa1() { + if !support_numa() { + return; + } + let page_num = 255_u32; let mut idx = 0_u32; let balloon = VirtioBalloonTest::numa_node_new(); @@ -1065,11 +1061,11 @@ fn balloon_numa1() { balloon .state .borrow_mut() - .writel(pfn_addr + 4 * idx as u64, pfn + idx); + .writel(pfn_addr + 4 * u64::from(idx), pfn + idx); balloon .state .borrow_mut() - .writeb(free_page + PAGE_SIZE_UNIT * idx as u64, 1); + .writeb(free_page + PAGE_SIZE_UNIT * u64::from(idx), 1); idx += 1; } @@ -1079,7 +1075,7 @@ fn balloon_numa1() { while loop_num < page_num { let entry = TestVringDescEntry { - data: pfn_addr + (loop_num as u64 * 4), + data: pfn_addr + (u64::from(loop_num) * 4), len: 4, write: false, }; @@ -1108,7 +1104,7 @@ fn balloon_numa1() { while loop_num < page_num { let entry = TestVringDescEntry { - data: pfn_addr + (loop_num as u64 * 4), + data: pfn_addr + (u64::from(loop_num) * 4), len: 4, write: false, }; diff --git a/tests/mod_test/tests/block_test.rs b/tests/mod_test/tests/block_test.rs index e7a633aa4e0dbaaf09918fdd989c070ac24f8220..1f0a513c6735250ae7af920669c9a8c5a0dc9571 100644 --- a/tests/mod_test/tests/block_test.rs +++ b/tests/mod_test/tests/block_test.rs @@ -63,7 +63,7 @@ fn virtio_blk_discard_and_write_zeroes( TestVirtBlkReq::new(VIRTIO_BLK_T_WRITE_ZEROES, 1, 0, req_len) }; blk_req.data = unsafe { String::from_utf8_unchecked(req_data.to_vec()) }; - let req_addr = virtio_blk_request(test_state.clone(), alloc.clone(), blk_req, false); + let req_addr = virtio_blk_request(test_state.clone(), alloc, blk_req, false); let mut data_entries: Vec = Vec::with_capacity(3); data_entries.push(TestVringDescEntry { @@ -89,7 +89,7 @@ fn virtio_blk_discard_and_write_zeroes( if need_poll_elem { blk.borrow().poll_used_elem( test_state.clone(), - virtqueue.clone(), + virtqueue, free_head, TIMEOUT_US, &mut None, @@ -153,7 +153,7 @@ fn virtio_blk_get_id( ) { let (free_head, req_addr) = add_blk_request( test_state.clone(), - alloc.clone(), + alloc, virtqueue.clone(), VIRTIO_BLK_T_GET_ID, 0, @@ -162,18 +162,19 @@ fn virtio_blk_get_id( blk.borrow().virtqueue_notify(virtqueue.clone()); blk.borrow().poll_used_elem( test_state.clone(), - virtqueue.clone(), + virtqueue, free_head, TIMEOUT_US, &mut None, true, ); - let status_addr = round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + let status_addr = + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_OK); - let data_addr = round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap(); + let data_addr = round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap(); assert_eq!( String::from_utf8( test_state @@ -194,7 +195,7 @@ fn virtio_blk_flush( ) { let (free_head, req_addr) = add_blk_request( test_state.clone(), - alloc.clone(), + alloc, virtqueue.clone(), VIRTIO_BLK_T_FLUSH, sector, @@ -203,14 +204,15 @@ fn virtio_blk_flush( blk.borrow().virtqueue_notify(virtqueue.clone()); blk.borrow().poll_used_elem( test_state.clone(), - virtqueue.clone(), + virtqueue, free_head, TIMEOUT_US, &mut None, true, ); - let status_addr = round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + let status_addr = + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_OK); } @@ -224,7 +226,7 @@ fn virtio_blk_illegal_req( ) { let (free_head, req_addr) = add_blk_request( test_state.clone(), - alloc.clone(), + alloc, virtqueue.clone(), req_type, 0, @@ -233,14 +235,15 @@ fn virtio_blk_illegal_req( blk.borrow().virtqueue_notify(virtqueue.clone()); blk.borrow().poll_used_elem( test_state.clone(), - virtqueue.clone(), + virtqueue, free_head, TIMEOUT_US, &mut None, true, ); - let status_addr = round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + let status_addr = + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_UNSUPP); } @@ -265,7 +268,7 @@ fn blk_basic() { .init_device(test_state.clone(), alloc.clone(), features, 1); let capacity = blk.borrow().config_readq(0); - assert_eq!(capacity, TEST_IMAGE_SIZE / REQ_DATA_LEN as u64); + assert_eq!(capacity, TEST_IMAGE_SIZE / u64::from(REQ_DATA_LEN)); virtio_blk_write( blk.clone(), @@ -436,7 +439,7 @@ fn blk_feature_ro() { .init_device(test_state.clone(), alloc.clone(), features, 1); let capacity = blk.borrow().config_readq(0); - assert_eq!(capacity, TEST_IMAGE_SIZE / REQ_DATA_LEN as u64); + assert_eq!(capacity, TEST_IMAGE_SIZE / u64::from(REQ_DATA_LEN)); virtio_blk_write( blk.clone(), @@ -473,7 +476,7 @@ fn blk_feature_ro() { .init_device(test_state.clone(), alloc.clone(), features, 1); let capacity = blk.borrow().config_readq(0); - assert_eq!(capacity, TEST_IMAGE_SIZE / REQ_DATA_LEN as u64); + assert_eq!(capacity, TEST_IMAGE_SIZE / u64::from(REQ_DATA_LEN)); virtio_blk_read( blk.clone(), @@ -503,7 +506,7 @@ fn blk_feature_ro() { ); let status_addr = - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_IOERR); @@ -623,7 +626,7 @@ fn blk_feature_mq() { true, )); - let data_addr = round_up(req_addr[i] + REQ_ADDR_LEN as u64, 512).unwrap(); + let data_addr = round_up(req_addr[i] + u64::from(REQ_ADDR_LEN), 512).unwrap(); let mut data_entries: Vec = Vec::with_capacity(3); data_entries.push(TestVringDescEntry { @@ -637,7 +640,7 @@ fn blk_feature_mq() { write: false, }); data_entries.push(TestVringDescEntry { - data: data_addr + REQ_DATA_LEN as u64, + data: data_addr + u64::from(REQ_DATA_LEN), len: REQ_STATUS_LEN, write: true, }); @@ -666,8 +669,8 @@ fn blk_feature_mq() { } for i in 0..num_queues { - let status_addr = - round_up(req_addr[i] + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + let status_addr = round_up(req_addr[i] + u64::from(REQ_ADDR_LEN), 512).unwrap() + + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_OK); } @@ -708,6 +711,11 @@ fn blk_all_features() { let device_args = Rc::new(String::from( ",multifunction=on,serial=111111,num-queues=4,bootindex=1,iothread=iothread1", )); + #[cfg(target_env = "ohos")] + let drive_args = Rc::new(String::from( + ",direct=false,readonly=off,throttling.iops-total=1024", + )); + #[cfg(not(target_env = "ohos"))] let drive_args = if aio_probe(AioEngine::IoUring).is_ok() { Rc::new(String::from( ",direct=on,aio=io_uring,readonly=off,throttling.iops-total=1024", @@ -793,13 +801,13 @@ fn blk_small_file_511b() { .init_device(test_state.clone(), alloc.clone(), features, 1); let capacity = blk.borrow().config_readq(0); - assert_eq!(capacity, size / REQ_DATA_LEN as u64); + assert_eq!(capacity, size / u64::from(REQ_DATA_LEN)); let mut blk_req = TestVirtBlkReq::new(VIRTIO_BLK_T_OUT, 1, 0, REQ_DATA_LEN as usize); blk_req.data.push_str("TEST"); let req_addr = virtio_blk_request(test_state.clone(), alloc.clone(), blk_req, true); - let data_addr = round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap(); + let data_addr = round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap(); let mut data_entries: Vec = Vec::with_capacity(3); data_entries.push(TestVringDescEntry { @@ -813,7 +821,7 @@ fn blk_small_file_511b() { write: false, }); data_entries.push(TestVringDescEntry { - data: data_addr + REQ_DATA_LEN as u64, + data: data_addr + u64::from(REQ_DATA_LEN), len: REQ_STATUS_LEN, write: true, }); @@ -833,7 +841,7 @@ fn blk_small_file_511b() { ); let status_addr = - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_IOERR); @@ -858,7 +866,7 @@ fn blk_small_file_511b() { ); let status_addr = - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_IOERR); @@ -996,7 +1004,7 @@ fn blk_iops() { } let status_addr = - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_ne!(status, VIRTIO_BLK_S_OK); @@ -1006,16 +1014,15 @@ fn blk_iops() { if blk.borrow().queue_was_notified(virtqueues[0].clone()) && virtqueues[0].borrow_mut().get_buf(test_state.clone()) + && virtqueues[0].borrow().desc_len.contains_key(&free_head) { - if virtqueues[0].borrow().desc_len.contains_key(&free_head) { - break; - } + break; } assert!(Instant::now() <= time_out); } let status_addr = - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_OK); @@ -1038,21 +1045,23 @@ fn blk_iops() { /// 1/2/3: success. #[test] fn blk_with_different_aio() { - const BLOCK_DRIVER_CFG: [(ImageType, &str, AioEngine); 6] = [ + let block_driver_cfg: Vec<(ImageType, &str, AioEngine)> = vec![ (ImageType::Raw, "off", AioEngine::Off), (ImageType::Qcow2, "off", AioEngine::Off), (ImageType::Raw, "off", AioEngine::Threads), (ImageType::Qcow2, "off", AioEngine::Threads), + #[cfg(not(target_env = "ohos"))] (ImageType::Raw, "on", AioEngine::Native), + #[cfg(not(target_env = "ohos"))] (ImageType::Raw, "on", AioEngine::IoUring), ]; - for (image_type, direct, aio_engine) in BLOCK_DRIVER_CFG { + for (image_type, direct, aio_engine) in block_driver_cfg { println!("Image type: {:?}", image_type); let image_path = Rc::new(create_img(TEST_IMAGE_SIZE_1M, 1, &image_type)); let device_args = Rc::new(String::from("")); let drive_args = if aio_probe(aio_engine).is_ok() { - Rc::new(format!(",direct={},aio={}", direct, aio_engine.to_string())) + Rc::new(format!(",direct={},aio={}", direct, aio_engine)) } else { continue; }; @@ -1106,6 +1115,7 @@ fn blk_with_different_aio() { /// 3. Destroy device. /// Expect: /// 1/2/3: success. +#[cfg(not(target_env = "ohos"))] #[test] fn blk_aio_io_uring() { for image_type in ImageType::IMAGE_TYPE { @@ -1218,7 +1228,7 @@ fn blk_rw_config() { .init_device(test_state.clone(), alloc.clone(), features, 1); let capacity = blk.borrow().config_readq(0); - assert_eq!(capacity, TEST_IMAGE_SIZE / REQ_DATA_LEN as u64); + assert_eq!(capacity, TEST_IMAGE_SIZE / u64::from(REQ_DATA_LEN)); blk.borrow().config_writeq(0, 1024); let capacity = blk.borrow().config_readq(0); @@ -1589,8 +1599,8 @@ fn blk_parallel_req() { ); for i in 0..4 { - let status_addr = - round_up(req_addr_vec[i] + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + let status_addr = round_up(req_addr_vec[i] + u64::from(REQ_ADDR_LEN), 512).unwrap() + + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_OK); } @@ -1625,7 +1635,7 @@ fn blk_exceed_capacity() { .init_device(test_state.clone(), alloc.clone(), features, 1); let capacity = blk.borrow().config_readq(0); - assert_eq!(capacity, TEST_IMAGE_SIZE / REQ_DATA_LEN as u64); + assert_eq!(capacity, TEST_IMAGE_SIZE / u64::from(REQ_DATA_LEN)); let (free_head, req_addr) = add_blk_request( test_state.clone(), @@ -1647,7 +1657,7 @@ fn blk_exceed_capacity() { ); let status_addr = - round_up(req_addr + REQ_ADDR_LEN as u64, 512).unwrap() + REQ_DATA_LEN as u64; + round_up(req_addr + u64::from(REQ_ADDR_LEN), 512).unwrap() + u64::from(REQ_DATA_LEN); let status = test_state.borrow().readb(status_addr); assert_eq!(status, VIRTIO_BLK_S_IOERR); @@ -1759,16 +1769,16 @@ fn blk_feature_discard() { ); if image_type == ImageType::Raw && status == VIRTIO_BLK_S_OK { let image_size = get_disk_size(image_path.clone()); - assert_eq!(image_size, full_disk_size - num_sectors as u64 / 2); + assert_eq!(image_size, full_disk_size - u64::from(num_sectors) / 2); } else if image_type == ImageType::Qcow2 && status == VIRTIO_BLK_S_OK - && (num_sectors as u64 * 512 & CLUSTER_SIZE - 1) == 0 + && ((u64::from(num_sectors) * 512) & (CLUSTER_SIZE - 1)) == 0 { // If the disk format is equal to Qcow2. // the length of the num sectors needs to be aligned with the cluster size, // otherwise the calculated file size is not accurate. let image_size = get_disk_size(image_path.clone()); - let delete_num = (num_sectors as u64 * 512) >> 10; + let delete_num = (u64::from(num_sectors) * 512) >> 10; assert_eq!(image_size, full_disk_size - delete_num); } @@ -1895,7 +1905,7 @@ fn blk_feature_write_zeroes() { test_state.clone(), alloc.clone(), virtqueues[0].clone(), - &req_data.as_bytes().to_vec(), + req_data.as_bytes(), status, true, false, @@ -1929,17 +1939,17 @@ fn blk_feature_write_zeroes() { && (write_zeroes == "unmap" && discard == "unmap" && flags == 1 || len != wz_len) { let image_size = get_disk_size(image_path.clone()); - assert_eq!(image_size, full_disk_size - num_sectors as u64 / 2); + assert_eq!(image_size, full_disk_size - u64::from(num_sectors) / 2); } else if image_type == ImageType::Qcow2 && status == VIRTIO_BLK_S_OK && (write_zeroes == "unmap" && discard == "unmap" && flags == 1 || len != wz_len) - && (num_sectors as u64 * 512 & CLUSTER_SIZE - 1) == 0 + && ((u64::from(num_sectors) * 512) & (CLUSTER_SIZE - 1)) == 0 { // If the disk format is equal to Qcow2. // the length of the num sectors needs to be aligned with the cluster size, // otherwise the calculated file size is not accurate. let image_size = get_disk_size(image_path.clone()); - let delete_num = (num_sectors as u64 * 512) >> 10; + let delete_num = (u64::from(num_sectors) * 512) >> 10; assert_eq!(image_size, full_disk_size - delete_num); } @@ -1972,7 +1982,7 @@ fn blk_snapshot_basic() { .init_device(test_state.clone(), alloc.clone(), features, 1); create_snapshot(test_state.clone(), "drive0", "snap0"); - assert_eq!(check_snapshot(test_state.clone(), "snap0"), true); + assert!(check_snapshot(test_state.clone(), "snap0")); virtio_blk_write( blk.clone(), @@ -1992,7 +2002,7 @@ fn blk_snapshot_basic() { ); delete_snapshot(test_state.clone(), "drive0", "snap0"); - assert_eq!(check_snapshot(test_state.clone(), "snap0"), false); + assert!(!check_snapshot(test_state.clone(), "snap0")); virtio_blk_write( blk.clone(), @@ -2011,13 +2021,7 @@ fn blk_snapshot_basic() { true, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - virtqueues, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, virtqueues, image_path); } /// Block device whose backend file has snapshot sends I/O request. @@ -2041,14 +2045,8 @@ fn blk_snapshot_basic2() { .borrow_mut() .init_device(test_state.clone(), alloc.clone(), features, 1); create_snapshot(test_state.clone(), "drive0", "snap0"); - assert_eq!(check_snapshot(test_state.clone(), "snap0"), true); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - virtqueues, - Rc::new("".to_string()), - ); + assert!(check_snapshot(test_state.clone(), "snap0")); + tear_down(blk, test_state, alloc, virtqueues, Rc::new("".to_string())); let device_args = Rc::new(String::from("")); let drive_args = Rc::new(String::from(",direct=false")); @@ -2084,10 +2082,10 @@ fn blk_snapshot_basic2() { ); create_snapshot(test_state.clone(), "drive0", "snap1"); - assert_eq!(check_snapshot(test_state.clone(), "snap1"), true); + assert!(check_snapshot(test_state.clone(), "snap1")); delete_snapshot(test_state.clone(), "drive0", "snap0"); - assert_eq!(check_snapshot(test_state.clone(), "snap0"), false); + assert!(!check_snapshot(test_state.clone(), "snap0")); virtio_blk_write( blk.clone(), @@ -2106,11 +2104,5 @@ fn blk_snapshot_basic2() { true, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - virtqueues, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, virtqueues, image_path); } diff --git a/tests/mod_test/tests/fwcfg_test.rs b/tests/mod_test/tests/fwcfg_test.rs index f82083256d070e0d066495be7a372658352258e6..c51535a3cff57a55db3adda2571f77ed5b3f98f6 100644 --- a/tests/mod_test/tests/fwcfg_test.rs +++ b/tests/mod_test/tests/fwcfg_test.rs @@ -34,7 +34,7 @@ fn test_signature() { let mut test_state = test_init(args); let mut read_data: Vec = Vec::with_capacity(4); - let target_data: [u8; 4] = ['Q' as u8, 'E' as u8, 'M' as u8, 'U' as u8]; + let target_data: [u8; 4] = [b'Q', b'E', b'M', b'U']; // Select Signature entry and read it. test_state.fw_cfg_read_bytes(FwCfgEntryType::Signature as u16, &mut read_data, 4); @@ -163,7 +163,7 @@ fn test_filedir_by_dma() { bios_args(&mut args); let test_state = Rc::new(RefCell::new(test_init(args))); let machine = TestStdMachine::new(test_state.clone()); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; let file_name = "etc/boot-fail-wait"; let mut read_data: Vec = Vec::with_capacity(mem::size_of::()); @@ -207,7 +207,7 @@ fn test_boot_index() { let test_state = Rc::new(RefCell::new(test_init(args))); let machine = TestStdMachine::new(test_state.clone()); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; let file_name = "bootorder"; let mut read_data: Vec = Vec::with_capacity(dev_path.len()); @@ -240,7 +240,7 @@ fn test_smbios_type0() { let test_state = Rc::new(RefCell::new(test_init(args))); let machine = TestStdMachine::new(test_state.clone()); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; let anchor_file = "etc/smbios/smbios-anchor"; let tables_file = "etc/smbios/smbios-tables"; @@ -251,12 +251,12 @@ fn test_smbios_type0() { &mut allocator.borrow_mut(), anchor_file, &mut read_data, - 24 as u32, + 24_u32, ); - assert_eq!(anchor_size, 24 as u32); + assert_eq!(anchor_size, 24_u32); assert_eq!(String::from_utf8_lossy(&read_data[..5]), "_SM3_"); - assert_eq!(read_data[6], 24 as u8); + assert_eq!(read_data[6], 24_u8); let talble_len = LittleEndian::read_u32(&read_data[12..]); assert_eq!(talble_len, 372); @@ -304,7 +304,7 @@ fn test_smbios_type1() { let test_state = Rc::new(RefCell::new(test_init(args))); let machine = TestStdMachine::new(test_state.clone()); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; let anchor_file = "etc/smbios/smbios-anchor"; let tables_file = "etc/smbios/smbios-tables"; @@ -315,12 +315,12 @@ fn test_smbios_type1() { &mut allocator.borrow_mut(), anchor_file, &mut read_data, - 24 as u32, + 24_u32, ); - assert_eq!(anchor_size, 24 as u32); + assert_eq!(anchor_size, 24_u32); assert_eq!(String::from_utf8_lossy(&read_data[..5]), "_SM3_"); - assert_eq!(read_data[6], 24 as u8); + assert_eq!(read_data[6], 24_u8); let talble_len = LittleEndian::read_u32(&read_data[12..]); assert_eq!(talble_len, 414); @@ -342,7 +342,7 @@ fn test_smbios_type1() { "version0" ); assert_eq!(read_table_date[48], 1); - assert_eq!(read_table_date[49], 27 as u8); + assert_eq!(read_table_date[49], 27_u8); let handle1 = LittleEndian::read_u16(&read_table_date[50..]); assert_eq!(handle1, 0x100); @@ -409,7 +409,7 @@ fn test_smbios_type2() { let test_state = Rc::new(RefCell::new(test_init(args))); let machine = TestStdMachine::new(test_state.clone()); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; let anchor_file = "etc/smbios/smbios-anchor"; let tables_file = "etc/smbios/smbios-tables"; @@ -420,12 +420,12 @@ fn test_smbios_type2() { &mut allocator.borrow_mut(), anchor_file, &mut read_data, - 24 as u32, + 24_u32, ); - assert_eq!(anchor_size, 24 as u32); + assert_eq!(anchor_size, 24_u32); assert_eq!(String::from_utf8_lossy(&read_data[..5]), "_SM3_"); - assert_eq!(read_data[6], 24 as u8); + assert_eq!(read_data[6], 24_u8); let talble_len = LittleEndian::read_u32(&read_data[12..]); let mut read_table_date: Vec = Vec::with_capacity(talble_len as usize); @@ -484,7 +484,7 @@ fn test_smbios_type3() { let test_state = Rc::new(RefCell::new(test_init(args))); let machine = TestStdMachine::new(test_state.clone()); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; let anchor_file = "etc/smbios/smbios-anchor"; let tables_file = "etc/smbios/smbios-tables"; @@ -495,12 +495,12 @@ fn test_smbios_type3() { &mut allocator.borrow_mut(), anchor_file, &mut read_data, - 24 as u32, + 24_u32, ); - assert_eq!(anchor_size, 24 as u32); + assert_eq!(anchor_size, 24_u32); assert_eq!(String::from_utf8_lossy(&read_data[..5]), "_SM3_"); - assert_eq!(read_data[6], 24 as u8); + assert_eq!(read_data[6], 24_u8); let talble_len = LittleEndian::read_u32(&read_data[12..]); let mut read_table_date: Vec = Vec::with_capacity(talble_len as usize); @@ -547,7 +547,7 @@ fn test_smbios_type4() { let mut args: Vec<&str> = Vec::new(); bios_args(&mut args); - let cpu_args = format!("-smp 8,maxcpus=8,sockets=2,cores=2,threads=2"); + let cpu_args = "-smp 8,maxcpus=8,sockets=2,cores=2,threads=2".to_string(); let mut extra_args = cpu_args.split(' ').collect(); args.append(&mut extra_args); @@ -559,7 +559,7 @@ fn test_smbios_type4() { let test_state = Rc::new(RefCell::new(test_init(args))); let machine = TestStdMachine::new(test_state.clone()); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; let anchor_file = "etc/smbios/smbios-anchor"; let tables_file = "etc/smbios/smbios-tables"; @@ -570,12 +570,12 @@ fn test_smbios_type4() { &mut allocator.borrow_mut(), anchor_file, &mut read_data, - 24 as u32, + 24_u32, ); - assert_eq!(anchor_size, 24 as u32); + assert_eq!(anchor_size, 24_u32); assert_eq!(String::from_utf8_lossy(&read_data[..5]), "_SM3_"); - assert_eq!(read_data[6], 24 as u8); + assert_eq!(read_data[6], 24_u8); let talble_len = LittleEndian::read_u32(&read_data[12..]); let mut read_table_date: Vec = Vec::with_capacity(talble_len as usize); @@ -631,7 +631,7 @@ fn test_smbios_type17() { let mut args: Vec<&str> = Vec::new(); bios_args(&mut args); - let cpu_args = format!("-smp 8,maxcpus=8,sockets=2,cores=2,threads=2"); + let cpu_args = "-smp 8,maxcpus=8,sockets=2,cores=2,threads=2".to_string(); let mut extra_args = cpu_args.split(' ').collect(); args.append(&mut extra_args); @@ -644,7 +644,7 @@ fn test_smbios_type17() { let test_state = Rc::new(RefCell::new(test_init(args))); let machine = TestStdMachine::new(test_state.clone()); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; let anchor_file = "etc/smbios/smbios-anchor"; let tables_file = "etc/smbios/smbios-tables"; @@ -655,12 +655,12 @@ fn test_smbios_type17() { &mut allocator.borrow_mut(), anchor_file, &mut read_data, - 24 as u32, + 24_u32, ); - assert_eq!(anchor_size, 24 as u32); + assert_eq!(anchor_size, 24_u32); assert_eq!(String::from_utf8_lossy(&read_data[..5]), "_SM3_"); - assert_eq!(read_data[6], 24 as u8); + assert_eq!(read_data[6], 24_u8); let talble_len = LittleEndian::read_u32(&read_data[12..]); assert_eq!(talble_len, 467); diff --git a/tests/mod_test/tests/memory_test.rs b/tests/mod_test/tests/memory_test.rs index 155ce51a9a2475280d08ce3cd0901ed45430e694..9c553949f0e390d66100bae4cc744fd65c4574a3 100644 --- a/tests/mod_test/tests/memory_test.rs +++ b/tests/mod_test/tests/memory_test.rs @@ -16,6 +16,7 @@ use std::process::Command; use std::rc::Rc; use std::string::String; +use mod_test::utils::support_numa; use serde_json::{json, Value::String as JsonString}; use mod_test::{ @@ -76,7 +77,7 @@ impl MemoryTest { let test_state = Rc::new(RefCell::new(test_init(extra_args))); let machine = TestStdMachine::new_bymem(test_state.clone(), memsize * 1024 * 1024, page_size); - let allocator = machine.allocator.clone(); + let allocator = machine.allocator; MemoryTest { state: test_state, @@ -97,7 +98,7 @@ fn ram_read_write(memory_test: &MemoryTest) { .state .borrow_mut() .memread(addr, str.len() as u64); - assert_eq!(str, String::from_utf8(ret.clone()).unwrap()); + assert_eq!(str, String::from_utf8(ret).unwrap()); memory_test.state.borrow_mut().stop(); } @@ -269,7 +270,7 @@ fn rom_device_region_readwrite() { // Add a dummy rom device by qmp. The function of the device is to multiply the written value by // 2 through the write interface and save it, and read the saved value through the read // interface. - let file = File::create(&ROM_DEV_PATH).unwrap(); + let file = File::create(ROM_DEV_PATH).unwrap(); file.set_len(PAGE_SIZE).unwrap(); let qmp_str = format!( "{{ \"execute\": \"update_region\", @@ -308,7 +309,7 @@ fn rom_device_region_readwrite() { // device. The device can set the write mode to writable according to the device status during // the write operation, or directly return an error indicating that the write is not allowed. // The read operation is the same as that of IO region. - let file = File::create(&ROM_DEV_PATH).unwrap(); + let file = File::create(ROM_DEV_PATH).unwrap(); file.set_len(PAGE_SIZE).unwrap(); let qmp_str = format!( "{{ \"execute\": \"update_region\", @@ -351,7 +352,7 @@ fn ram_device_region_readwrite() { let memory_test = MemoryTest::new(MEM_SIZE, PAGE_SIZE, false, false, None, None); let addr = 0x1_0000_0000; // 4GB - let file = File::create(&RAM_DEV_PATH).unwrap(); + let file = File::create(RAM_DEV_PATH).unwrap(); file.set_len(PAGE_SIZE).unwrap(); let qmp_str = format!( "{{ \"execute\": \"update_region\", @@ -474,6 +475,7 @@ fn prealloc_ram_read_write() { /// 4. Destroy device. /// Expect: /// 1/2/3/4: success. +#[cfg(not(target_env = "ohos"))] #[test] fn hugepage_ram_read_write() { // crate hugetlbfs directory @@ -609,6 +611,10 @@ fn ram_readwrite_exception() { /// 1/2/3/4: success. #[test] fn ram_readwrite_numa() { + if !support_numa() { + return; + } + let mut args: Vec<&str> = Vec::new(); let mut extra_args: Vec<&str> = MACHINE_TYPE_ARG.split(' ').collect(); args.append(&mut extra_args); @@ -649,7 +655,7 @@ fn ram_readwrite_numa() { let ret = test_state .borrow_mut() .memread(start_base, str.len() as u64); - assert_eq!(str, String::from_utf8(ret.clone()).unwrap()); + assert_eq!(str, String::from_utf8(ret).unwrap()); test_state.borrow_mut().stop(); } @@ -665,6 +671,10 @@ fn ram_readwrite_numa() { /// 1/2/3/4: success. #[test] fn ram_readwrite_numa1() { + if !support_numa() { + return; + } + let mut args: Vec<&str> = Vec::new(); let mut extra_args: Vec<&str> = MACHINE_TYPE_ARG.split(' ').collect(); args.append(&mut extra_args); @@ -706,10 +716,10 @@ fn ram_readwrite_numa1() { let ret = test_state .borrow_mut() .memread(start_base, str.len() as u64); - assert_eq!(str, String::from_utf8(ret.clone()).unwrap()); + assert_eq!(str, String::from_utf8(ret).unwrap()); test_state.borrow_mut().qmp("{\"execute\": \"query-mem\"}"); - let file = File::create(&RAM_DEV_PATH).unwrap(); + let file = File::create(RAM_DEV_PATH).unwrap(); file.set_len(PAGE_SIZE).unwrap(); let qmp_str = format!( "{{ \"execute\": \"update_region\", diff --git a/tests/mod_test/tests/net_test.rs b/tests/mod_test/tests/net_test.rs index 3ec79a1eb219df844005c3bdf807d908f79d63c4..34ffe1d36a17282ffee1aae3934b2a35a77d2a33 100644 --- a/tests/mod_test/tests/net_test.rs +++ b/tests/mod_test/tests/net_test.rs @@ -332,7 +332,7 @@ impl ByteCode for VirtioNetHdr {} fn execute_cmd(cmd: String, check: bool) { let args = cmd.split(' ').collect::>(); - if args.len() <= 0 { + if args.is_empty() { return; } @@ -343,7 +343,7 @@ fn execute_cmd(cmd: String, check: bool) { let output = cmd_exe .output() - .expect(format!("Failed to execute {}", cmd).as_str()); + .unwrap_or_else(|_| panic!("Failed to execute {}", cmd)); println!("{:?}", args); if check { assert!(output.status.success()); @@ -361,25 +361,21 @@ fn execute_cmd_checked(cmd: String) { fn create_tap(id: u8, mq: bool) { let br_name = "mst_net_qbr".to_string() + &id.to_string(); let tap_name = "mst_net_qtap".to_string() + &id.to_string(); - execute_cmd_checked("ip link add name ".to_string() + &br_name + &" type bridge".to_string()); + execute_cmd_checked("ip link add name ".to_string() + &br_name + " type bridge"); if mq { - execute_cmd_checked( - "ip tuntap add ".to_string() + &tap_name + &" mode tap multi_queue".to_string(), - ); + execute_cmd_checked("ip tuntap add ".to_string() + &tap_name + " mode tap multi_queue"); } else { - execute_cmd_checked("ip tuntap add ".to_string() + &tap_name + &" mode tap".to_string()); + execute_cmd_checked("ip tuntap add ".to_string() + &tap_name + " mode tap"); } - execute_cmd_checked( - "ip link set ".to_string() + &tap_name + &" master ".to_string() + &br_name, - ); - execute_cmd_checked("ip link set ".to_string() + &br_name + &" up".to_string()); - execute_cmd_checked("ip link set ".to_string() + &tap_name + &" up".to_string()); + execute_cmd_checked("ip link set ".to_string() + &tap_name + " master " + &br_name); + execute_cmd_checked("ip link set ".to_string() + &br_name + " up"); + execute_cmd_checked("ip link set ".to_string() + &tap_name + " up"); execute_cmd_checked( "ip address add ".to_string() + &id.to_string() - + &".1.1.".to_string() + + ".1.1." + &id.to_string() - + &"/24 dev ".to_string() + + "/24 dev " + &br_name, ); } @@ -387,16 +383,14 @@ fn create_tap(id: u8, mq: bool) { fn clear_tap(id: u8, mq: bool) { let br_name = "mst_net_qbr".to_string() + &id.to_string(); let tap_name = "mst_net_qtap".to_string() + &id.to_string(); - execute_cmd_unchecked("ip link set ".to_string() + &tap_name + &" down".to_string()); - execute_cmd_unchecked("ip link set ".to_string() + &br_name + &" down".to_string()); + execute_cmd_unchecked("ip link set ".to_string() + &tap_name + " down"); + execute_cmd_unchecked("ip link set ".to_string() + &br_name + " down"); if mq { - execute_cmd_unchecked( - "ip tuntap del ".to_string() + &tap_name + &" mode tap multi_queue".to_string(), - ); + execute_cmd_unchecked("ip tuntap del ".to_string() + &tap_name + " mode tap multi_queue"); } else { - execute_cmd_unchecked("ip tuntap del ".to_string() + &tap_name + &" mode tap".to_string()); + execute_cmd_unchecked("ip tuntap del ".to_string() + &tap_name + " mode tap"); } - execute_cmd_unchecked("ip link delete ".to_string() + &br_name + &" type bridge".to_string()); + execute_cmd_unchecked("ip link delete ".to_string() + &br_name + " type bridge"); } #[allow(unused)] @@ -451,7 +445,7 @@ pub fn create_net( let test_state = Rc::new(RefCell::new(test_init(extra_args))); let machine = TestStdMachine::new(test_state.clone()); let allocator = machine.allocator.clone(); - let virtio_net = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let virtio_net = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); virtio_net.borrow_mut().init(pci_slot, pci_fn); (virtio_net, test_state, allocator) @@ -496,7 +490,7 @@ fn tear_down( id: u8, mq: bool, ) { - net.borrow_mut().destroy_device(alloc.clone(), vqs); + net.borrow_mut().destroy_device(alloc, vqs); test_state.borrow_mut().stop(); clear_tap(id, mq); } @@ -513,7 +507,7 @@ fn fill_rx_vq( vq.borrow_mut() .add(test_state.clone(), addr, MAX_PACKET_LEN as u32, true); } - vq.borrow().set_used_event(test_state.clone(), 0); + vq.borrow().set_used_event(test_state, 0); } fn init_net_device( @@ -552,8 +546,8 @@ fn poll_used_ring( let mut idx = test_state .borrow() .readw(vq.borrow().used + offset_of!(VringUsed, idx) as u64); - while start < idx as u64 { - for i in start..idx as u64 { + while start < u64::from(idx) { + for i in start..u64::from(idx) { let len = test_state.borrow().readw( vq.borrow().used + offset_of!(VringUsed, ring) as u64 @@ -567,8 +561,8 @@ fn poll_used_ring( let addr = test_state .borrow() - .readq(vq.borrow().desc + id as u64 * VRING_DESC_SIZE); - let packets = test_state.borrow().memread(addr, len as u64); + .readq(vq.borrow().desc + u64::from(id) * VRING_DESC_SIZE); + let packets = test_state.borrow().memread(addr, u64::from(len)); let src_mac_pos = VIRTIO_NET_HDR_SIZE + ETHERNET_HDR_SIZE + ARP_HDR_SIZE; let dst_mac_pos = src_mac_pos + 10; if arp_request[src_mac_pos..src_mac_pos + MAC_ADDR_LEN] @@ -582,7 +576,7 @@ fn poll_used_ring( } } } - start = idx as u64; + start = u64::from(idx); vq.borrow().set_used_event(test_state.clone(), start as u16); idx = test_state .borrow() @@ -695,14 +689,8 @@ fn send_request( .borrow_mut() .add(test_state.clone(), addr, request.len() as u32, false); net.borrow().virtqueue_notify(vq.clone()); - net.borrow().poll_used_elem( - test_state.clone(), - vq, - free_head, - TIMEOUT_US, - &mut None, - true, - ); + net.borrow() + .poll_used_elem(test_state, vq, free_head, TIMEOUT_US, &mut None, true); } fn send_arp_request( @@ -716,17 +704,11 @@ fn send_arp_request( send_request( net.clone(), test_state.clone(), - alloc.clone(), + alloc, vqs[1].clone(), - &arp_request, - ); - check_arp_mac( - net.clone(), - test_state.clone(), - vqs[0].clone(), - &arp_request, - need_reply, + arp_request, ); + check_arp_mac(net, test_state, vqs[0].clone(), arp_request, need_reply); } fn check_device_status(net: Rc>, status: u8) { @@ -750,7 +732,7 @@ fn check_device_status(net: Rc>, status: u8) { /// 1/2/3: success. #[test] fn virtio_net_rx_tx_test() { - let id = 1 * TEST_MAC_ADDR_NUMS; + let id = TEST_MAC_ADDR_NUMS; let (net, test_state, alloc) = set_up(id, false, 0, false); // Three virtqueues: tx/rx/ctrl. @@ -768,18 +750,11 @@ fn virtio_net_rx_tx_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &arp_request.as_bytes(), + arp_request.as_bytes(), true, ); - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - false, - ); + tear_down(net, test_state, alloc, vqs, id, false); } /// Send and receive packet test with iothread. @@ -809,18 +784,11 @@ fn virtio_net_rx_tx_test_iothread() { test_state.clone(), alloc.clone(), vqs.clone(), - &arp_request.as_bytes(), + arp_request.as_bytes(), true, ); - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - false, - ); + tear_down(net, test_state, alloc, vqs, id, false); } /// Test the control mq command. @@ -877,7 +845,7 @@ fn virtio_net_ctrl_mq_test() { class: VIRTIO_NET_CTRL_MQ, cmd, }; - test_state.borrow().memwrite(addr, &ctrl_hdr.as_bytes()); + test_state.borrow().memwrite(addr, ctrl_hdr.as_bytes()); test_state .borrow() .writew(addr + size_of::() as u64, vq_pairs); @@ -925,14 +893,7 @@ fn virtio_net_ctrl_mq_test() { assert_eq!(ack, status); } - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - true, - ); + tear_down(net, test_state, alloc, vqs, id, true); } /// Write or Read mac address from device config. @@ -958,7 +919,7 @@ fn net_config_mac_rw( /// Virtio net configure is not allowed to change except mac. fn write_net_config_check(net: Rc>, offset: u64, value: u64, size: u8) { - let origin_value = net.borrow().config_readw(offset) as u64; + let origin_value = u64::from(net.borrow().config_readw(offset)); assert_ne!(origin_value, value); match size { 1 => net.borrow().config_writeb(offset, value as u8), @@ -966,7 +927,7 @@ fn write_net_config_check(net: Rc>, offset: u64, value 4 => net.borrow().config_writel(offset, value as u32), _ => (), }; - let value = net.borrow().config_readw(offset) as u64; + let value = u64::from(net.borrow().config_readw(offset)); assert_eq!(origin_value, value); } @@ -1015,38 +976,38 @@ fn virtio_net_write_and_check_config() { write_net_config_check( net.clone(), offset_of!(VirtioNetConfig, status) as u64, - u16::MAX as u64, + u64::from(u16::MAX), 2, ); write_net_config_check( net.clone(), offset_of!(VirtioNetConfig, max_virtqueue_pairs) as u64, - u16::MAX as u64, + u64::from(u16::MAX), 2, ); write_net_config_check( net.clone(), offset_of!(VirtioNetConfig, mtu) as u64, - u16::MAX as u64, + u64::from(u16::MAX), 2, ); write_net_config_check( net.clone(), offset_of!(VirtioNetConfig, speed) as u64, - u32::MAX as u64, + u64::from(u32::MAX), 4, ); write_net_config_check( net.clone(), offset_of!(VirtioNetConfig, duplex) as u64, - u8::MAX as u64, + u64::from(u8::MAX), 1, ); write_net_config_check( net.clone(), size_of:: as u64 + 1, - u8::MAX as u64, + u64::from(u8::MAX), 1, ); @@ -1072,7 +1033,7 @@ fn send_ctrl_vq_request( ) { let ctrl_vq = &vqs[2]; let addr = alloc.borrow_mut().alloc(ctrl_data.len() as u64); - test_state.borrow().memwrite(addr, &ctrl_data); + test_state.borrow().memwrite(addr, ctrl_data); let data_entries: Vec = vec![ TestVringDescEntry { data: addr, @@ -1154,14 +1115,7 @@ fn ctrl_vq_set_mac_table( ctrl_data.len() ); - send_ctrl_vq_request( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs.clone(), - &ctrl_data, - ack, - ); + send_ctrl_vq_request(net, test_state, alloc, vqs, &ctrl_data, ack); } fn ctrl_vq_set_mac_address( @@ -1184,14 +1138,14 @@ fn ctrl_vq_set_mac_address( }; send_ctrl_vq_request( net.clone(), - test_state.clone(), - alloc.clone(), - vqs.clone(), - &ctrl_mac_addr.as_bytes(), + test_state, + alloc, + vqs, + ctrl_mac_addr.as_bytes(), VIRTIO_NET_OK, ); // Check mac address result. - let config_mac = net_config_mac_rw(net.clone(), None); + let config_mac = net_config_mac_rw(net, None); assert_eq!(config_mac, ARP_SOURCE_MAC); } @@ -1231,7 +1185,7 @@ fn virtio_net_ctrl_vlan_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &ctrl_rx_info.as_bytes(), + ctrl_rx_info.as_bytes(), VIRTIO_NET_OK, ); @@ -1250,7 +1204,7 @@ fn virtio_net_ctrl_vlan_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &ctrl_vlan_info.as_bytes(), + ctrl_vlan_info.as_bytes(), ack, ); } @@ -1262,7 +1216,7 @@ fn virtio_net_ctrl_vlan_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &ctrl_vlan_info.as_bytes(), + ctrl_vlan_info.as_bytes(), ack, ); } @@ -1273,7 +1227,7 @@ fn virtio_net_ctrl_vlan_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &ctrl_vlan_info.as_bytes(), + ctrl_vlan_info.as_bytes(), VIRTIO_NET_ERR, ); // Test invalid cmd. @@ -1283,7 +1237,7 @@ fn virtio_net_ctrl_vlan_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &ctrl_vlan_info.as_bytes(), + ctrl_vlan_info.as_bytes(), VIRTIO_NET_ERR, ); // Test invalid vid length. @@ -1303,7 +1257,7 @@ fn virtio_net_ctrl_vlan_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &get_arp_request(id).as_bytes(), + get_arp_request(id).as_bytes(), true, ); send_arp_request( @@ -1311,18 +1265,11 @@ fn virtio_net_ctrl_vlan_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &get_arp_request_vlan(id).as_bytes(), + get_arp_request_vlan(id).as_bytes(), false, ); - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - false, - ); + tear_down(net, test_state, alloc, vqs, id, false); } /// Test the control mac command. @@ -1406,7 +1353,7 @@ fn virtio_net_ctrl_mac_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &ctrl_rx_info.as_bytes(), + ctrl_rx_info.as_bytes(), VIRTIO_NET_OK, ); @@ -1470,7 +1417,7 @@ fn virtio_net_ctrl_mac_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &arp_request.as_bytes(), + arp_request.as_bytes(), true, ); @@ -1556,7 +1503,7 @@ fn virtio_net_ctrl_rx_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &ctrl_rx_info.as_bytes(), + ctrl_rx_info.as_bytes(), VIRTIO_NET_OK, ); let mut ctrl_rx_info = CtrlRxInfo::new(VIRTIO_NET_CTRL_RX, 0, 0); @@ -1571,7 +1518,7 @@ fn virtio_net_ctrl_rx_test() { alloc.clone(), vqs.clone(), 0, - value as u32, + u32::from(value), VIRTIO_NET_OK, ); arp_request.arp_packet.src_mac[0] += 1; @@ -1602,7 +1549,7 @@ fn virtio_net_ctrl_rx_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &ctrl_rx_info.as_bytes(), + ctrl_rx_info.as_bytes(), ack, ); } @@ -1617,7 +1564,7 @@ fn virtio_net_ctrl_rx_test() { test_state.clone(), alloc.clone(), vqs.clone(), - &arp_request.as_bytes(), + arp_request.as_bytes(), need_reply, ); @@ -1663,7 +1610,7 @@ fn virtio_net_ctrl_abnormal_test() { for i in 0..test_num { let ctrl_vq = &vqs[2]; let addr = alloc.borrow_mut().alloc(ctrl_data.len() as u64); - test_state.borrow().memwrite(addr, &ctrl_data); + test_state.borrow().memwrite(addr, ctrl_data); // ctrl_rx_info.switch: u8 let mut data_len = 1; @@ -1698,14 +1645,7 @@ fn virtio_net_ctrl_abnormal_test() { check_device_status(net.clone(), VIRTIO_CONFIG_S_NEEDS_RESET); } - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - false, - ); + tear_down(net, test_state, alloc, vqs, id, false); } /// Test the abnormal rx/tx request. @@ -1750,7 +1690,7 @@ fn virtio_net_abnormal_rx_tx_test() { assert_eq!(size, QUEUE_SIZE_NET); for _ in 0..size { let addr = alloc.borrow_mut().alloc(length); - test_state.borrow().memwrite(addr, &request.as_bytes()); + test_state.borrow().memwrite(addr, request.as_bytes()); vqs[1] .borrow_mut() .add(test_state.clone(), addr, length as u32, false); @@ -1775,14 +1715,7 @@ fn virtio_net_abnormal_rx_tx_test() { assert!(time::Instant::now() - start_time < timeout_us); } - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - false, - ); + tear_down(net, test_state, alloc, vqs, id, false); } /// Test the abnormal rx/tx request 2. @@ -1830,7 +1763,7 @@ fn virtio_net_abnormal_rx_tx_test_2() { let request = get_arp_request(id); let length = request.as_bytes().len() as u64; let addr = alloc.borrow_mut().alloc(length); - test_state.borrow().memwrite(addr, &request.as_bytes()); + test_state.borrow().memwrite(addr, request.as_bytes()); vqs[1] .borrow_mut() .add(test_state.clone(), addr, length as u32, false); @@ -1897,18 +1830,11 @@ fn virtio_net_set_abnormal_feature() { test_state.clone(), alloc.clone(), vqs.clone(), - &arp_request.as_bytes(), + arp_request.as_bytes(), true, ); - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - false, - ); + tear_down(net, test_state, alloc, vqs, id, false); } /// Send abnormal packet. @@ -1943,7 +1869,7 @@ fn virtio_net_send_abnormal_packet() { test_state.clone(), alloc.clone(), vqs.clone(), - &arp_request.as_bytes(), + arp_request.as_bytes(), false, ); @@ -1961,7 +1887,7 @@ fn virtio_net_send_abnormal_packet() { test_state.clone(), alloc.clone(), vqs[1].clone(), - &data_bytes, + data_bytes, ); } @@ -1984,14 +1910,7 @@ fn virtio_net_send_abnormal_packet() { .qmp("{\"execute\": \"qmp_capabilities\"}"); assert_eq!(*ret.get("return").unwrap(), json!({})); - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - false, - ); + tear_down(net, test_state, alloc, vqs, id, false); } /// Send and receive packet test with mq. @@ -2022,18 +1941,11 @@ fn virtio_net_rx_tx_mq_test() { test_state.clone(), alloc.clone(), vqs[i as usize * 2 + 1].clone(), - &get_arp_request(id + i as u8 * TEST_MAC_ADDR_NUMS).as_bytes(), + get_arp_request(id + i as u8 * TEST_MAC_ADDR_NUMS).as_bytes(), ); } - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - true, - ); + tear_down(net, test_state, alloc, vqs, id, true); } /// Test the abnormal rx/tx request 3. @@ -2071,24 +1983,27 @@ fn virtio_net_abnormal_rx_tx_test_3() { let notify_off = net.borrow().pci_dev.io_readw( net.borrow().bar, - net.borrow().common_base as u64 + u64::from(net.borrow().common_base) + offset_of!(VirtioPciCommonCfg, queue_notify_off) as u64, ); - vq.borrow_mut().queue_notify_off = net.borrow().notify_base as u64 - + notify_off as u64 * net.borrow().notify_off_multiplier as u64; + vq.borrow_mut().queue_notify_off = u64::from(net.borrow().notify_base) + + u64::from(notify_off) * u64::from(net.borrow().notify_off_multiplier); net.borrow() - .setup_virtqueue_intr((i + 1) as u16, alloc.clone(), vq.clone()); + .setup_virtqueue_intr(i + 1, alloc.clone(), vq.clone()); vqs.push(vq); } fill_rx_vq(test_state.clone(), alloc.clone(), vqs[0].clone()); - net.borrow().set_driver_ok(); + + // Set driver ok without check. + let status = net.borrow().get_status() | VIRTIO_CONFIG_S_DRIVER_OK; + net.borrow().set_status(status); let request = get_arp_request(id); let length = request.as_bytes().len() as u64; let addr = alloc.borrow_mut().alloc(length); - test_state.borrow().memwrite(addr, &request.as_bytes()); + test_state.borrow().memwrite(addr, request.as_bytes()); vqs[1] .borrow_mut() .add(test_state.clone(), addr, length as u32, false); @@ -2099,12 +2014,5 @@ fn virtio_net_abnormal_rx_tx_test_3() { .readw(vqs[1].borrow().used + offset_of!(VringUsed, idx) as u64); assert_eq!(used_idx, 0); - tear_down( - net.clone(), - test_state.clone(), - alloc.clone(), - vqs, - id, - false, - ); + tear_down(net, test_state, alloc, vqs, id, false); } diff --git a/tests/mod_test/tests/pci_test.rs b/tests/mod_test/tests/pci_test.rs index e23d67b123e3c4291272a5a2c8f1419db6f1ff1f..dcad2ea4aeedfcaac3c081be34433007a6a7ddf8 100644 --- a/tests/mod_test/tests/pci_test.rs +++ b/tests/mod_test/tests/pci_test.rs @@ -74,7 +74,7 @@ fn init_demo_dev(cfg: DemoDev, dev_num: u8) -> (Rc>, Rc = "-D /tmp/oscar.log".split(' ').collect(); demo_dev_args.append(&mut args); - let demo_str = fmt_demo_deves(cfg.clone(), dev_num); + let demo_str = fmt_demo_deves(cfg, dev_num); args = demo_str[..].split(' ').collect(); demo_dev_args.append(&mut args); @@ -82,7 +82,7 @@ fn init_demo_dev(cfg: DemoDev, dev_num: u8) -> (Rc>, Rc Self { - let mut root_port = TestPciDev::new(machine.clone().borrow().pci_bus.clone()); + let mut root_port = TestPciDev::new(machine.borrow().pci_bus.clone()); root_port.set_bus_num(bus_num); root_port.devfn = devfn; assert_eq!(root_port.config_readw(PCI_SUB_CLASS_DEVICE), 0x0604); root_port.enable(); root_port.enable_msix(None); - let root_port_msix = MsixVector::new(0, alloc.clone()); + let root_port_msix = MsixVector::new(0, alloc); root_port.set_msix_vector( root_port_msix.msix_entry, root_port_msix.msix_addr, @@ -169,7 +169,7 @@ fn build_root_port_args(root_port_nums: u8) -> Vec { if multifunc { addr = bus / 8 + 1; func += 1; - func = func % 8; + func %= 8; } else { addr += 1; func = 0; @@ -250,8 +250,8 @@ fn build_hotplug_blk_cmd( let add_blk_command = format!( "{{\"execute\": \"blockdev-add\", \ \"arguments\": {{\"node-name\": \"drive-{}\", \"file\": {{\"driver\": \ - \"file\", \"filename\": \"{}\", \"aio\": \"native\"}}, \ - \"cache\": {{\"direct\": true}}, \"read-only\": false}}}}", + \"file\", \"filename\": \"{}\", \"aio\": \"off\"}}, \ + \"cache\": {{\"direct\": false}}, \"read-only\": false}}}}", hotplug_blk_id, hotplug_image_path ); @@ -287,7 +287,7 @@ fn build_all_device_args( ) -> Vec { let mut device_args: Vec = Vec::new(); let mut root_port_args = build_root_port_args(root_port_nums); - if root_port_args.len() != 0 { + if !root_port_args.is_empty() { device_args.append(&mut root_port_args); } @@ -319,7 +319,7 @@ fn create_blk( pci_fn: u8, ) -> Rc> { let virtio_blk = Rc::new(RefCell::new(TestVirtioPciDev::new( - machine.clone().borrow().pci_bus.clone(), + machine.borrow().pci_bus.clone(), ))); virtio_blk.borrow_mut().pci_dev.set_bus_num(bus_num); virtio_blk.borrow_mut().init(pci_slot, pci_fn); @@ -366,7 +366,7 @@ fn create_machine( .borrow() .pci_bus .borrow() - .pci_auto_bus_scan(root_port_nums as u8); + .pci_auto_bus_scan(root_port_nums); let allocator = machine.borrow().allocator.clone(); (test_state, machine, allocator) @@ -405,15 +405,14 @@ fn tear_down( blk.clone().unwrap().borrow_mut().pci_dev.disable_msix(); } if vqs.is_some() { - blk.clone() - .unwrap() + blk.unwrap() .borrow_mut() - .destroy_device(alloc.clone(), vqs.unwrap()); + .destroy_device(alloc, vqs.unwrap()); } test_state.borrow_mut().stop(); if let Some(img_paths) = image_paths { - img_paths.iter().enumerate().for_each(|(_i, image_path)| { + img_paths.iter().for_each(|image_path| { cleanup_img(image_path.to_string()); }) } @@ -564,12 +563,7 @@ fn validate_blk_io_success( .borrow_mut() .init_device(test_state.clone(), alloc.clone(), features, 1); - validate_std_blk_io( - blk.clone(), - test_state.clone(), - virtqueues.clone(), - alloc.clone(), - ); + validate_std_blk_io(blk.clone(), test_state, virtqueues.clone(), alloc.clone()); blk.borrow_mut().pci_dev.disable_msix(); blk.borrow() @@ -583,14 +577,14 @@ fn simple_blk_io_req( alloc: Rc>, ) -> u32 { let (free_head, _req_addr) = add_blk_request( - test_state.clone(), - alloc.clone(), + test_state, + alloc, virtqueue.clone(), VIRTIO_BLK_T_OUT, 0, false, ); - blk.borrow().virtqueue_notify(virtqueue.clone()); + blk.borrow().virtqueue_notify(virtqueue); free_head } @@ -629,14 +623,7 @@ fn validate_std_blk_io( false, ); - virtio_blk_read( - blk.clone(), - test_state.clone(), - alloc.clone(), - virtqueues[0].clone(), - 0, - false, - ); + virtio_blk_read(blk, test_state, alloc, virtqueues[0].clone(), 0, false); } fn wait_root_port_msix(root_port: Rc>) -> bool { @@ -690,8 +677,8 @@ fn lookup_all_cap_addr(cap_id: u8, pci_dev: TestPciDev) -> Vec { fn get_msix_flag(pci_dev: TestPciDev) -> u16 { let addr = pci_dev.find_capability(PCI_CAP_ID_MSIX, 0); assert_ne!(addr, 0); - let old_value = pci_dev.config_readw(addr + PCI_MSIX_MSG_CTL); - old_value + + pci_dev.config_readw(addr + PCI_MSIX_MSG_CTL) } fn set_msix_enable(pci_dev: TestPciDev) { @@ -728,31 +715,26 @@ fn unmask_msix_global(pci_dev: TestPciDev) { } fn mask_msix_vector(pci_dev: TestPciDev, vector: u16) { - let offset: u64 = pci_dev.msix_table_off + (vector * PCI_MSIX_ENTRY_SIZE) as u64; + let offset: u64 = pci_dev.msix_table_off + u64::from(vector * PCI_MSIX_ENTRY_SIZE); - let vector_mask = pci_dev.io_readl( - pci_dev.msix_table_bar, - offset + PCI_MSIX_ENTRY_VECTOR_CTRL as u64, - ); + let vector_mask = pci_dev.io_readl(pci_dev.msix_table_bar, offset + PCI_MSIX_ENTRY_VECTOR_CTRL); pci_dev.io_writel( pci_dev.msix_table_bar, - offset + PCI_MSIX_ENTRY_VECTOR_CTRL as u64, + offset + PCI_MSIX_ENTRY_VECTOR_CTRL, vector_mask | PCI_MSIX_ENTRY_CTRL_MASKBIT, ); } fn unmask_msix_vector(pci_dev: TestPciDev, vector: u16) { - let offset: u64 = pci_dev.msix_table_off + (vector * PCI_MSIX_ENTRY_SIZE) as u64; + let offset: u64 = pci_dev.msix_table_off + u64::from(vector * PCI_MSIX_ENTRY_SIZE); - let vector_control = pci_dev.io_readl( - pci_dev.msix_table_bar, - offset + PCI_MSIX_ENTRY_VECTOR_CTRL as u64, - ); + let vector_control = + pci_dev.io_readl(pci_dev.msix_table_bar, offset + PCI_MSIX_ENTRY_VECTOR_CTRL); pci_dev.io_writel( pci_dev.msix_table_bar, - offset + PCI_MSIX_ENTRY_VECTOR_CTRL as u64, + offset + PCI_MSIX_ENTRY_VECTOR_CTRL, vector_control & !PCI_MSIX_ENTRY_CTRL_MASKBIT, ); } @@ -771,7 +753,7 @@ fn hotplug_blk( // Hotplug a block device whose bdf is 2:0:0. let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path.clone(), bus, slot, func); + build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path, bus, slot, func); let ret = test_state.borrow().qmp(&add_blk_command); assert_eq!(*ret.get("return").unwrap(), json!({})); @@ -795,7 +777,7 @@ fn hotplug_blk( validate_hotplug(root_port.clone()); handle_isr(root_port.clone()); - power_on_device(root_port.clone()); + power_on_device(root_port); } fn hotunplug_blk( @@ -828,7 +810,7 @@ fn hotunplug_blk( "Wait for interrupt of root port timeout" ); validate_cmd_complete(root_port.clone()); - handle_isr(root_port.clone()); + handle_isr(root_port); // Verify the vendor id for the virtio block device. validate_config_value_2byte( blk.borrow().pci_dev.pci_bus.clone(), @@ -904,13 +886,13 @@ fn test_pci_device_discovery_001() { let (test_state, machine, alloc, image_paths) = set_up(blk_nums, root_port_nums, true, false); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // Verify the vendor id for non-existent devices. validate_config_value_2byte( blk.borrow().pci_dev.pci_bus.clone(), 1, - 1 << 3 | 0, + 1 << 3, PCI_VENDOR_ID, 0xFFFF, 0xFFFF, @@ -952,7 +934,7 @@ fn test_pci_device_discovery_002() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); // Create a root port whose bdf is 0:2:0. @@ -960,7 +942,7 @@ fn test_pci_device_discovery_002() { machine.clone(), alloc.clone(), 0, - 2 << 3 | 0, + 2 << 3, ))); // Create a block device whose bdf is 1:0:0. @@ -977,12 +959,12 @@ fn test_pci_device_discovery_002() { ); // Hotplug a block device whose id is 0. - hotunplug_blk(test_state.clone(), blk.clone(), root_port_1.clone(), 0); + hotunplug_blk(test_state.clone(), blk, root_port_1, 0); // Hotplug a block device whose id is 1 and bdf is 2:0:0. hotplug_blk( test_state.clone(), - root_port_2.clone(), + root_port_2, &mut image_paths, 1, 2, @@ -991,7 +973,7 @@ fn test_pci_device_discovery_002() { ); // Create a block device whose bdf is 2:0:0. - let blk = create_blk(machine.clone(), 2, 0, 0); + let blk = create_blk(machine, 2, 0, 0); // Verify the vendor id for the virtio block device hotplugged. validate_config_value_2byte( blk.borrow().pci_dev.pci_bus.clone(), @@ -1015,10 +997,10 @@ fn test_pci_device_discovery_003() { // Create a root port whose bdf is 0:1:0. let root_port = Rc::new(RefCell::new(RootPort::new( - machine.clone(), + machine, alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); // Verify the vendor id for the virtio block device hotplugged. @@ -1037,7 +1019,7 @@ fn test_pci_device_discovery_003() { // Hotplug a block device whose bdf is 1:0:0. let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(blk_id, hotplug_image_path.clone(), 1, 0, 0); + build_hotplug_blk_cmd(blk_id, hotplug_image_path, 1, 0, 0); let ret = test_state.borrow().qmp(&add_blk_command); assert_eq!(*ret.get("return").unwrap(), json!({})); let ret = test_state.borrow().qmp(&add_device_command); @@ -1069,12 +1051,12 @@ fn test_pci_device_discovery_004() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); let blk_id = 0; let hotplug_image_path = create_img(TEST_IMAGE_SIZE, 1, &ImageType::Raw); - image_paths.push(hotplug_image_path.clone()); + image_paths.push(hotplug_image_path); // Hotplug a block device whose id is 0 and bdf is 1:0:0. hotplug_blk( @@ -1088,10 +1070,10 @@ fn test_pci_device_discovery_004() { ); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // Hotunplug the virtio block device whose id is 0. - hotunplug_blk(test_state.clone(), blk.clone(), root_port.clone(), blk_id); + hotunplug_blk(test_state.clone(), blk, root_port, blk_id); tear_down(None, test_state, alloc, None, Some(image_paths)); } @@ -1104,7 +1086,7 @@ fn test_pci_type0_config() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // Verify that the vendor id of type0 device is read-only. validate_config_perm_2byte( @@ -1249,7 +1231,7 @@ fn test_pci_type1_config() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a root port whose bdf is 0:1:0. - let root_port = RootPort::new(machine.clone(), alloc.clone(), 0, 1 << 3 | 0); + let root_port = RootPort::new(machine, alloc.clone(), 0, 1 << 3); assert_eq!(root_port.rp_dev.config_readb(PCI_PRIMARY_BUS), 0); assert_ne!(root_port.rp_dev.config_readb(PCI_SECONDARY_BUS), 0); @@ -1265,16 +1247,16 @@ fn test_pci_type1_reset() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a root port whose bdf is 0:1:0. - let root_port = RootPort::new(machine.clone(), alloc.clone(), 0, 1 << 3 | 0); + let root_port = RootPort::new(machine, alloc.clone(), 0, 1 << 3); let command = root_port.rp_dev.config_readw(PCI_COMMAND); - let cmd_memory = command & PCI_COMMAND_MEMORY as u16; + let cmd_memory = command & u16::from(PCI_COMMAND_MEMORY); // Bitwise inversion of memory space enable. let write_cmd = if cmd_memory != 0 { - command & !PCI_COMMAND_MEMORY as u16 + command & u16::from(!PCI_COMMAND_MEMORY) } else { - command | PCI_COMMAND_MEMORY as u16 + command | u16::from(PCI_COMMAND_MEMORY) }; root_port.rp_dev.config_writew(PCI_COMMAND, write_cmd); let old_command = root_port.rp_dev.config_readw(PCI_COMMAND); @@ -1303,9 +1285,8 @@ fn test_out_boundary_config_access() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); let devfn = 1 << 3 | 1; - let addr = machine.borrow().pci_bus.borrow().ecam_alloc_ptr - + ((0 as u32) << 20 | (devfn as u32) << 12 | 0 as u32) as u64 - - 1; + let addr = + machine.borrow().pci_bus.borrow().ecam_alloc_ptr + u64::from((devfn as u32) << 12) - 1; let write_value = u16::max_value(); let buf = write_value.to_le_bytes(); @@ -1326,14 +1307,14 @@ fn test_out_size_config_access() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a root port whose bdf is 0:1:0. - let root_port = RootPort::new(machine.clone(), alloc.clone(), 0, 1 << 3 | 0); + let root_port = RootPort::new(machine, alloc.clone(), 0, 1 << 3); let vendor_device_id = root_port.rp_dev.config_readl(PCI_VENDOR_ID); let command_status = root_port.rp_dev.config_readl(PCI_COMMAND); let value = root_port.rp_dev.config_readq(0); assert_ne!( value, - (vendor_device_id as u64) << 32 | command_status as u64 + u64::from(vendor_device_id) << 32 | u64::from(command_status) ); tear_down(None, test_state, alloc, None, Some(image_paths)); @@ -1347,7 +1328,7 @@ fn test_out_boundary_msix_access() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a root port whose bdf is 0:1:0. - let root_port = RootPort::new(machine.clone(), alloc.clone(), 0, 1 << 3 | 0); + let root_port = RootPort::new(machine, alloc.clone(), 0, 1 << 3); // Out-of-bounds access to the msix table. let write_value = u32::max_value(); @@ -1377,7 +1358,7 @@ fn test_repeat_io_map_bar() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let vqs = blk.borrow_mut().init_device( test_state.clone(), @@ -1412,7 +1393,7 @@ fn test_pci_type0_msix_config() { let root_port_nums = 0; let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, false, false); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 0, 1, 0); + let blk = create_blk(machine, 0, 1, 0); // Verify that there is only one msix capability addr of the type0 pci device. let blk_cap_msix_addrs = lookup_all_cap_addr(PCI_CAP_ID_MSIX, blk.borrow().pci_dev.clone()); @@ -1478,7 +1459,7 @@ fn test_pci_msix_global_ctl() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let vqs = blk.borrow_mut().init_device( test_state.clone(), alloc.clone(), @@ -1552,7 +1533,7 @@ fn test_pci_msix_local_ctl() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let vqs = blk.borrow_mut().init_device( test_state.clone(), alloc.clone(), @@ -1592,7 +1573,7 @@ fn test_alloc_abnormal_vector() { let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // 1. Init device. blk.borrow_mut().reset(); @@ -1608,7 +1589,7 @@ fn test_alloc_abnormal_vector() { let virtqueue = blk .borrow() - .setup_virtqueue(test_state.clone(), alloc.clone(), 0 as u16); + .setup_virtqueue(test_state.clone(), alloc.clone(), 0_u16); blk.borrow() .setup_virtqueue_intr((queue_num + 2) as u16, alloc.clone(), virtqueue.clone()); blk.borrow().set_driver_ok(); @@ -1633,7 +1614,7 @@ fn test_intx_basic() { let root_port_nums = 1; let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // 1. Init device. blk.borrow_mut().reset(); @@ -1643,11 +1624,11 @@ fn test_intx_basic() { blk.borrow_mut().set_features_ok(); set_msix_disable(blk.borrow().pci_dev.clone()); - blk.borrow_mut().pci_dev.set_intx_irq_num(1 as u8); + blk.borrow_mut().pci_dev.set_intx_irq_num(1_u8); let virtqueue = blk .borrow() - .setup_virtqueue(test_state.clone(), alloc.clone(), 0 as u16); + .setup_virtqueue(test_state.clone(), alloc.clone(), 0_u16); blk.borrow().set_driver_ok(); let free_head = simple_blk_io_req( @@ -1692,7 +1673,7 @@ fn test_intx_disable() { let root_port_nums = 1; let (test_state, machine, alloc, image_paths) = set_up(root_port_nums, blk_nums, true, false); - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // 1. Init device. blk.borrow_mut().reset(); @@ -1702,11 +1683,11 @@ fn test_intx_disable() { blk.borrow_mut().set_features_ok(); set_msix_disable(blk.borrow().pci_dev.clone()); - blk.borrow_mut().pci_dev.set_intx_irq_num(1 as u8); + blk.borrow_mut().pci_dev.set_intx_irq_num(1_u8); let virtqueue = blk .borrow() - .setup_virtqueue(test_state.clone(), alloc.clone(), 0 as u16); + .setup_virtqueue(test_state.clone(), alloc.clone(), 0_u16); blk.borrow().set_driver_ok(); // Disable INTx. @@ -1784,22 +1765,14 @@ fn test_pci_hotplug_001() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); // Hotplug a block device whose id is 1 and bdf is 1:0:0. - hotplug_blk( - test_state.clone(), - root_port.clone(), - &mut image_paths, - 0, - 1, - 0, - 0, - ); + hotplug_blk(test_state.clone(), root_port, &mut image_paths, 0, 1, 0, 0); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let vqs = blk.borrow_mut().init_device( test_state.clone(), alloc.clone(), @@ -1825,7 +1798,7 @@ fn test_pci_hotplug_002() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); // Create a root port whose bdf is 0:2:0. @@ -1833,13 +1806,13 @@ fn test_pci_hotplug_002() { machine.clone(), alloc.clone(), 0, - 2 << 3 | 0, + 2 << 3, ))); // Hotplug a block device whose id is 1 and bdf is 1:0:0. hotplug_blk( test_state.clone(), - root_port_1.clone(), + root_port_1, &mut image_paths, 1, 1, @@ -1851,17 +1824,17 @@ fn test_pci_hotplug_002() { // Hotplug a block device whose id is 2 and bdf is 2:0:0. hotplug_blk( test_state.clone(), - root_port_2.clone(), + root_port_2, &mut image_paths, 2, 2, 0, 0, ); - let blk_2 = create_blk(machine.clone(), 2, 0, 0); + let blk_2 = create_blk(machine, 2, 0, 0); - validate_blk_io_success(blk_1.clone(), test_state.clone(), alloc.clone()); - validate_blk_io_success(blk_2.clone(), test_state.clone(), alloc.clone()); + validate_blk_io_success(blk_1, test_state.clone(), alloc.clone()); + validate_blk_io_success(blk_2, test_state.clone(), alloc.clone()); tear_down(None, test_state, alloc, None, Some(image_paths)); } @@ -1879,7 +1852,7 @@ fn test_pci_hotplug_003() { // Hotplug a block device whose id is 0, bdf is 1:1:0. let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(0, hotplug_image_path.clone(), 1, 1, 0); + build_hotplug_blk_cmd(0, hotplug_image_path, 1, 1, 0); let ret = test_state.borrow().qmp(&add_blk_command); assert_eq!(*ret.get("return").unwrap(), json!({})); // Verify that hotpluging the device in non-zero slot will fail. @@ -1902,7 +1875,7 @@ fn test_pci_hotplug_004() { let hotplug_blk_id = 1; let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path.clone(), 0, 1, 0); + build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path, 0, 1, 0); let ret = test_state.borrow().qmp(&add_blk_command); assert_eq!(*ret.get("return").unwrap(), json!({})); let ret = test_state.borrow().qmp(&add_device_command); @@ -1920,7 +1893,7 @@ fn test_pci_hotplug_005() { set_up(root_port_nums, blk_nums, true, false); let hotplug_image_path = create_img(TEST_IMAGE_SIZE, 1, &ImageType::Raw); - image_paths.push(hotplug_image_path.clone()); + image_paths.push(hotplug_image_path); let hotplug_blk_id = 0; let (add_blk_command, add_device_command) = @@ -1971,11 +1944,11 @@ fn test_pci_hotplug_007() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); set_msix_disable(root_port.borrow().rp_dev.clone()); - root_port.borrow_mut().rp_dev.set_intx_irq_num(1 as u8); + root_port.borrow_mut().rp_dev.set_intx_irq_num(1_u8); // Hotplug a block device whose id is 1 and bdf is 1:0:0. let bus = 1; @@ -1987,7 +1960,7 @@ fn test_pci_hotplug_007() { // Hotplug a block device whose bdf is 1:0:0. let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path.clone(), bus, slot, 0); + build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path, bus, slot, 0); let ret = test_state.borrow().qmp(&add_blk_command); assert_eq!(*ret.get("return").unwrap(), json!({})); @@ -2011,10 +1984,10 @@ fn test_pci_hotplug_007() { validate_hotplug(root_port.clone()); handle_isr(root_port.clone()); - power_on_device(root_port.clone()); + power_on_device(root_port); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let vqs = blk.borrow_mut().init_device( test_state.clone(), alloc.clone(), @@ -2039,14 +2012,14 @@ fn test_pci_hotunplug_001() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // Hotunplug the block device whose bdf is 1:0:0. - hotunplug_blk(test_state.clone(), blk.clone(), root_port.clone(), 0); + hotunplug_blk(test_state.clone(), blk, root_port, 0); tear_down(None, test_state, alloc, None, Some(image_paths)); } @@ -2080,11 +2053,11 @@ fn test_pci_hotunplug_003() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let unplug_blk_id = 0; // Hotunplug the block device attaching the root port. @@ -2135,7 +2108,7 @@ fn test_pci_hotunplug_003() { assert!(!(*ret.get("error").unwrap()).is_null()); // The block device will be unplugged when indicator of power and slot is power off. - power_off_device(root_port.clone()); + power_off_device(root_port); test_state.borrow().wait_qmp_event(); // Verify the vendor id for the virtio block device. @@ -2163,7 +2136,7 @@ fn test_pci_hotunplug_004() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); // Create root port whose bdf is 0:2:0. @@ -2171,14 +2144,14 @@ fn test_pci_hotunplug_004() { machine.clone(), alloc.clone(), 0, - 2 << 3 | 0, + 2 << 3, ))); // Create a block device whose bdf is 1:0:0. let blk_1 = create_blk(machine.clone(), 1, 0, 0); // Create a block device whose bdf is 2:0:0. - let blk_2 = create_blk(machine.clone(), 2, 0, 0); + let blk_2 = create_blk(machine, 2, 0, 0); let unplug_blk_id = 0; let (delete_device_command, delete_blk_command_1) = build_hotunplug_blk_cmd(unplug_blk_id); @@ -2200,10 +2173,10 @@ fn test_pci_hotunplug_004() { "Wait for interrupt of root port timeout" ); - power_off_device(root_port_1.clone()); + power_off_device(root_port_1); test_state.borrow().wait_qmp_event(); - power_off_device(root_port_2.clone()); + power_off_device(root_port_2); test_state.borrow().wait_qmp_event(); // The block device will be unplugged when indicator of power and slot is power off. @@ -2248,13 +2221,13 @@ fn test_pci_hotunplug_005() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // Hotplug the block device whose id is 0 and bdf is 1:0:0. - hotunplug_blk(test_state.clone(), blk.clone(), root_port.clone(), 0); + hotunplug_blk(test_state.clone(), blk, root_port, 0); let (delete_device_command, _delete_blk_command) = build_hotunplug_blk_cmd(0); let ret = test_state.borrow().qmp(&delete_device_command); @@ -2291,11 +2264,11 @@ fn test_pci_hotunplug_007() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let unplug_blk_id = 0; // Hotunplug the block device attaching the root port. @@ -2309,7 +2282,7 @@ fn test_pci_hotunplug_007() { // The block device will be unplugged when indicator of power and slot is power off. power_off_device(root_port.clone()); // Trigger a 2nd write to PIC/PCC, which will be ignored by the device, and causes no harm. - power_off_device(root_port.clone()); + power_off_device(root_port); test_state.borrow().wait_qmp_event(); @@ -2338,14 +2311,14 @@ fn test_pci_hotunplug_008() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); set_msix_disable(root_port.borrow().rp_dev.clone()); - root_port.borrow_mut().rp_dev.set_intx_irq_num(1 as u8); + root_port.borrow_mut().rp_dev.set_intx_irq_num(1_u8); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); // Hotunplug the block device whose bdf is 1:0:0. let hotunplug_blk_id = 0; @@ -2372,7 +2345,7 @@ fn test_pci_hotunplug_008() { "Wait for interrupt of root port timeout" ); validate_cmd_complete(root_port.clone()); - handle_isr(root_port.clone()); + handle_isr(root_port); // Verify the vendor id for the virtio block device. validate_config_value_2byte( blk.borrow().pci_dev.pci_bus.clone(), @@ -2399,7 +2372,7 @@ fn test_pci_hotplug_combine_001() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); let hotplug_blk_id = 0; @@ -2408,7 +2381,7 @@ fn test_pci_hotplug_combine_001() { // Hotplug a block device whose bdf is 1:0:0. let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path.clone(), 1, 0, 0); + build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path, 1, 0, 0); let ret = test_state.borrow().qmp(&add_blk_command); assert_eq!(*ret.get("return").unwrap(), json!({})); let ret = test_state.borrow().qmp(&add_device_command); @@ -2431,7 +2404,7 @@ fn test_pci_hotplug_combine_001() { 1, ); // Verify that the function of the block device is normal. - validate_std_blk_io(blk.clone(), test_state.clone(), vqs.clone(), alloc.clone()); + validate_std_blk_io(blk.clone(), test_state.clone(), vqs, alloc.clone()); let (delete_device_command, delete_blk_command) = build_hotunplug_blk_cmd(hotplug_blk_id); let ret = test_state.borrow().qmp(&delete_device_command); @@ -2464,7 +2437,7 @@ fn test_pci_hotplug_combine_001() { // Hotplug a block device whose bdf is 1:0:0. let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path.clone(), 1, 0, 0); + build_hotplug_blk_cmd(hotplug_blk_id, hotplug_image_path, 1, 0, 0); let ret = test_state.borrow().qmp(&add_blk_command); assert_eq!(*ret.get("return").unwrap(), json!({})); let ret = test_state.borrow().qmp(&add_device_command); @@ -2488,7 +2461,7 @@ fn test_pci_hotplug_combine_001() { 0xFFFF, ); - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let vqs = blk.borrow_mut().init_device( test_state.clone(), alloc.clone(), @@ -2496,7 +2469,7 @@ fn test_pci_hotplug_combine_001() { 1, ); // Verify that the function of the block device is normal. - validate_std_blk_io(blk.clone(), test_state.clone(), vqs.clone(), alloc.clone()); + validate_std_blk_io(blk.clone(), test_state.clone(), vqs, alloc.clone()); let (delete_device_command, delete_blk_command) = build_hotunplug_blk_cmd(hotplug_blk_id); let ret = test_state.borrow().qmp(&delete_device_command); @@ -2506,7 +2479,7 @@ fn test_pci_hotplug_combine_001() { ); handle_isr(root_port.clone()); - power_off_device(root_port.clone()); + power_off_device(root_port); assert_eq!(*ret.get("return").unwrap(), json!({})); test_state.borrow().wait_qmp_event(); @@ -2539,7 +2512,7 @@ fn test_pci_hotplug_combine_002() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); let hotplug_blk_id = 0; @@ -2557,7 +2530,7 @@ fn test_pci_hotplug_combine_002() { power_indicator_off(root_port.clone()); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); validate_blk_io_success(blk.clone(), test_state.clone(), alloc.clone()); @@ -2597,7 +2570,7 @@ fn test_pci_hotplug_combine_002() { ); handle_isr(root_port.clone()); - power_off_device(root_port.clone()); + power_off_device(root_port); test_state.borrow().wait_qmp_event(); let ret = test_state.borrow().qmp(&delete_blk_command); @@ -2629,7 +2602,7 @@ fn test_pci_hotplug_combine_003() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); let hotunplug_blk_id = 0; @@ -2649,25 +2622,25 @@ fn test_pci_hotplug_combine_003() { // Hotplug a block device whose bdf is 1:0:0. let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(hotunplug_blk_id, hotplug_image_path.clone(), 1, 0, 0); + build_hotplug_blk_cmd(hotunplug_blk_id, hotplug_image_path, 1, 0, 0); let ret = test_state.borrow().qmp(&add_blk_command); assert_eq!(*ret.get("return").unwrap(), json!({})); let ret = test_state.borrow().qmp(&add_device_command); assert!(!(*ret.get("error").unwrap()).is_null()); - power_off_device(root_port.clone()); + power_off_device(root_port); test_state.borrow().wait_qmp_event(); let hotplug_image_path = create_img(TEST_IMAGE_SIZE, 1, &ImageType::Raw); image_paths.push(hotplug_image_path.clone()); // Hotplug a block device whose bdf is 1:0:0. let (add_blk_command, add_device_command) = - build_hotplug_blk_cmd(hotunplug_blk_id, hotplug_image_path.clone(), 1, 0, 0); + build_hotplug_blk_cmd(hotunplug_blk_id, hotplug_image_path, 1, 0, 0); let ret = test_state.borrow().qmp(&add_blk_command); assert!(!(*ret.get("error").unwrap()).is_null()); let ret = test_state.borrow().qmp(&add_device_command); assert_eq!(*ret.get("return").unwrap(), json!({})); - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let vqs = blk.borrow_mut().init_device( test_state.clone(), alloc.clone(), @@ -2695,7 +2668,7 @@ fn test_pci_root_port_exp_cap() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); let cap_exp_addr = root_port.borrow().rp_dev.find_capability(PCI_CAP_ID_EXP, 0); @@ -2779,7 +2752,7 @@ fn test_pci_root_port_exp_cap() { 0, ); // Create a block device whose bdf is 1:0:0. - let blk = create_blk(machine.clone(), 1, 0, 0); + let blk = create_blk(machine, 1, 0, 0); let nlw_mask = PCI_EXP_LNKSTA_NLW; let negotiated_link_width = (root_port.borrow().rp_dev.pci_bus.borrow().config_readw( @@ -2849,12 +2822,7 @@ fn test_pci_root_port_exp_cap() { ); // Hotplug the block device whose id is 0 and bdf is 1:0:0. - hotunplug_blk( - test_state.clone(), - blk.clone(), - root_port.clone(), - hotplug_blk_id, - ); + hotunplug_blk(test_state.clone(), blk, root_port.clone(), hotplug_blk_id); let dllla_mask = PCI_EXP_LNKSTA_DLLLA; validate_config_value_2byte( @@ -2953,7 +2921,7 @@ fn test_pci_combine_001() { // set memory enabled = 0 let mut val = dev_locked.config_readw(PCI_COMMAND); - val &= !(PCI_COMMAND_MEMORY as u16); + val &= !u16::from(PCI_COMMAND_MEMORY); dev_locked.config_writew(PCI_COMMAND, val); // mmio r/w stops working. @@ -2962,7 +2930,7 @@ fn test_pci_combine_001() { assert_ne!(out, 10); // set memory enabled = 1 - val |= PCI_COMMAND_MEMORY as u16; + val |= u16::from(PCI_COMMAND_MEMORY); dev_locked.config_writew(PCI_COMMAND, val); // mmio r/w gets back to work. @@ -2987,7 +2955,7 @@ fn test_pci_combine_002() { machine.clone(), alloc.clone(), 0, - 1 << 3 | 0, + 1 << 3, ))); let blk = Rc::new(RefCell::new(TestVirtioPciDev::new( machine.borrow().pci_bus.clone(), @@ -3007,7 +2975,7 @@ fn test_pci_combine_002() { wait_root_port_msix(root_port.clone()), "Wait for interrupt of root port timeout" ); - power_off_device(root_port.clone()); + power_off_device(root_port); // r/w mmio during hotunplug test_state.borrow().writeb(bar_addr, 5); diff --git a/tests/mod_test/tests/pvpanic_test.rs b/tests/mod_test/tests/pvpanic_test.rs index 044515966d1e5ed8af641eb86ddf4af36c96da68..fd084959909e76ebdb1cc0b5e134b33128f88c2f 100644 --- a/tests/mod_test/tests/pvpanic_test.rs +++ b/tests/mod_test/tests/pvpanic_test.rs @@ -15,11 +15,11 @@ use std::fs; use std::path::Path; use std::rc::Rc; +use devices::misc::pvpanic::{PVPANIC_CRASHLOADED, PVPANIC_PANICKED}; use devices::pci::config::{ PCI_CLASS_SYSTEM_OTHER, PCI_DEVICE_ID_REDHAT_PVPANIC, PCI_SUBDEVICE_ID_QEMU, PCI_VENDOR_ID_REDHAT, PCI_VENDOR_ID_REDHAT_QUMRANET, }; -use machine_manager::config::{PVPANIC_CRASHLOADED, PVPANIC_PANICKED}; use mod_test::{ libdriver::{machine::TestStdMachine, pci::*}, libtest::{test_init, TestState, MACHINE_TYPE_ARG}, @@ -59,14 +59,14 @@ impl PvPanicDevCfg { test_machine_args.append(&mut args); } - let pvpanic_str = fmt_pvpanic_deves(self.clone()); + let pvpanic_str = fmt_pvpanic_deves(*self); args = pvpanic_str[..].split(' ').collect(); test_machine_args.append(&mut args); let test_state = Rc::new(RefCell::new(test_init(test_machine_args))); let machine = Rc::new(RefCell::new(TestStdMachine::new(test_state.clone()))); - let mut pvpanic_pci_dev = TestPciDev::new(machine.clone().borrow().pci_bus.clone()); + let mut pvpanic_pci_dev = TestPciDev::new(machine.borrow().pci_bus.clone()); let devfn = self.addr << 3; pvpanic_pci_dev.devfn = devfn; diff --git a/tests/mod_test/tests/rng_test.rs b/tests/mod_test/tests/rng_test.rs index c7a982212be57705eb7327cf3edbe53d89350b32..9296ebb084b4578a32b1558d4cd974b2adf5be38 100644 --- a/tests/mod_test/tests/rng_test.rs +++ b/tests/mod_test/tests/rng_test.rs @@ -82,7 +82,7 @@ fn virtio_rng_read_batch( .kick_virtqueue(test_state.clone(), virtqueue.clone()); rng.borrow().poll_used_elem( test_state.clone(), - virtqueue.clone(), + virtqueue, free_head, TIMEOUT_US, &mut len, @@ -90,7 +90,7 @@ fn virtio_rng_read_batch( ); assert!(len.unwrap() >= 1); - assert!(len.unwrap() as u64 <= bytes); + assert!(u64::from(len.unwrap()) <= bytes); test_state.borrow().memread(req_addr, RNG_DATA_BYTES) } @@ -123,7 +123,7 @@ fn virtio_rng_read_chained( .kick_virtqueue(test_state.clone(), virtqueue.clone()); rng.borrow().poll_used_elem( test_state.clone(), - virtqueue.clone(), + virtqueue, free_head, TIMEOUT_US, &mut len, @@ -131,7 +131,7 @@ fn virtio_rng_read_chained( ); assert!(len.unwrap() >= 1); - assert!(len.unwrap() as u64 <= bytes * DEFAULT_RNG_REQS); + assert!(u64::from(len.unwrap()) <= bytes * DEFAULT_RNG_REQS); test_state.borrow().memread(req_addr, RNG_DATA_BYTES) } @@ -142,7 +142,7 @@ fn tear_down( alloc: Rc>, vqs: Vec>>, ) { - rng.borrow_mut().destroy_device(alloc.clone(), vqs); + rng.borrow_mut().destroy_device(alloc, vqs); test_state.borrow_mut().stop(); } @@ -186,7 +186,7 @@ fn rng_read() { ); assert!(random_num_check(data)); - tear_down(rng.clone(), test_state.clone(), alloc.clone(), virtqueues); + tear_down(rng, test_state, alloc, virtqueues); } /// Rng device batch read random numbers function test. @@ -229,7 +229,7 @@ fn rng_read_batch() { ); assert!(random_num_check(data)); - tear_down(rng.clone(), test_state.clone(), alloc.clone(), virtqueues); + tear_down(rng, test_state, alloc, virtqueues); } /// Rng device rate limit random numbers reading test. @@ -300,7 +300,7 @@ fn rng_limited_rate() { RNG_DATA_BYTES ))); - tear_down(rng.clone(), test_state.clone(), alloc.clone(), virtqueues); + tear_down(rng, test_state, alloc, virtqueues); } /// Rng device read a large number of random numbers test. @@ -344,7 +344,7 @@ fn rng_read_with_max() { ); assert!(random_num_check(data)); - tear_down(rng.clone(), test_state.clone(), alloc.clone(), virtqueues); + tear_down(rng, test_state, alloc, virtqueues); } /// Rng device read/write config space. @@ -376,5 +376,5 @@ fn rng_rw_config() { let config = rng.borrow().config_readq(0); assert_ne!(config, 0xff); - tear_down(rng.clone(), test_state.clone(), alloc.clone(), virtqueues); + tear_down(rng, test_state, alloc, virtqueues); } diff --git a/tests/mod_test/tests/scream_test.rs b/tests/mod_test/tests/scream_test.rs index 6a7d76c0139bf4228c94e8f718d9d84767a8073b..d7669fe09b012f8c511b84badb101b45e893760c 100644 --- a/tests/mod_test/tests/scream_test.rs +++ b/tests/mod_test/tests/scream_test.rs @@ -22,7 +22,10 @@ use std::{ use core::time; -use devices::misc::scream::{ShmemHeader, ShmemStreamFmt, ShmemStreamHeader, SCREAM_MAGIC}; +use devices::misc::scream::{ + audio_demo::INITIAL_VOLUME_VAL, ShmemHeader, ShmemStreamFmt, ShmemStreamHeader, + IVSHMEM_BAR0_STATUS, IVSHMEM_BAR0_VOLUME, SCREAM_MAGIC, STATUS_PLAY_BIT, STATUS_START_BIT, +}; use mod_test::{ libdriver::{ivshmem::TestIvshmemDev, machine::TestStdMachine}, libtest::{test_init, TestState, MACHINE_TYPE_ARG}, @@ -235,13 +238,16 @@ fn scream_playback_basic_test() { thread::sleep(time::Duration::from_millis(1000)); play_header_init(&mut ivshmem.borrow_mut()); + ivshmem + .borrow_mut() + .writel_reg(IVSHMEM_BAR0_STATUS, STATUS_PLAY_BIT | STATUS_START_BIT); thread::sleep(time::Duration::from_millis(POLL_DELAY_MS)); // write one audio chunk for i in 0..AUDIO_CHUNK_SIZE { ivshmem.borrow_mut().writeb( - PLAY_DADA_OFFSET + (AUDIO_CHUNK_SIZE + i) as u64, + PLAY_DADA_OFFSET + u64::from(AUDIO_CHUNK_SIZE + i), AUDIO_DEFAULT_DATA[i as usize], ); } @@ -257,9 +263,10 @@ fn scream_playback_basic_test() { // When four consecutive frames of data are written, only the last two frames of data can be // read. for i in 0..AUDIO_CHUNK_SIZE { - ivshmem - .borrow_mut() - .writeb(PLAY_DADA_OFFSET + i as u64, AUDIO_DEFAULT_DATA[i as usize]); + ivshmem.borrow_mut().writeb( + PLAY_DADA_OFFSET + u64::from(i), + AUDIO_DEFAULT_DATA[i as usize], + ); } // update play header chunk_idx @@ -319,6 +326,38 @@ fn scream_playback_basic_test() { scream_tmp_clear(playback_path, record_path); } +/// scream device volume synchronization. +/// TestStep: +/// 1. Init scream device. +/// 2. Check volume's initial value. +/// 3. Set volume and read back to check. +/// 4. Stop VM. +/// Expect: +/// 1/2/3/4: success. +#[test] +fn scream_volume_sync_test() { + let pci_slot = 0x1; + let (playback_path, record_path) = get_audio_file_name(); + audio_data_init(playback_path.clone(), record_path.clone()); + let (ivshmem, test_state) = set_up( + IVSHMEM_DEFAULT_SIZE, + pci_slot, + playback_path.clone(), + record_path.clone(), + ); + ivshmem.borrow_mut().init(pci_slot); + + let init_val = ivshmem.borrow_mut().readl_reg(IVSHMEM_BAR0_VOLUME); + assert_eq!(init_val, INITIAL_VOLUME_VAL); + + ivshmem.borrow_mut().writel_reg(IVSHMEM_BAR0_VOLUME, 0xff); + let second_val = ivshmem.borrow_mut().readl_reg(IVSHMEM_BAR0_VOLUME); + assert_eq!(second_val, 0xff); + + test_state.borrow_mut().stop(); + scream_tmp_clear(playback_path, record_path); +} + /// scream device record audio. /// TestStep: /// 1. Init scream device and start recording. @@ -341,6 +380,9 @@ fn scream_record_basic_test() { ivshmem.borrow_mut().init(pci_slot); record_header_init(&mut ivshmem.borrow_mut()); + ivshmem + .borrow_mut() + .writel_reg(IVSHMEM_BAR0_STATUS, STATUS_START_BIT); let mut cnt = 0; let mut chunk_idx = 0; @@ -374,7 +416,10 @@ fn scream_record_basic_test() { let offset = RECORD_BASE + offset_of!(ShmemStreamHeader, chunk_idx) as u64; chunk_idx = ivshmem.borrow_mut().readw(offset); - assert_eq!(chunk_idx as u32, AUDIO_CHUNK_CNT % (AUDIO_CHUNK_CNT - 1)); + assert_eq!( + u32::from(chunk_idx), + AUDIO_CHUNK_CNT % (AUDIO_CHUNK_CNT - 1) + ); let audio_data = ivshmem.borrow_mut().readl(RECORD_DATA_OFFSET); let mut check_data = 0; @@ -498,7 +543,7 @@ fn scream_exception_002() { // write one audio chunk for i in 0..AUDIO_CHUNK_SIZE { ivshmem.borrow_mut().writeb( - PLAY_DADA_OFFSET + (AUDIO_CHUNK_SIZE + i) as u64, + PLAY_DADA_OFFSET + u64::from(AUDIO_CHUNK_SIZE + i), AUDIO_DEFAULT_DATA[i as usize], ); } diff --git a/tests/mod_test/tests/scsi_test.rs b/tests/mod_test/tests/scsi_test.rs index aecb4b808040cc45a20385184895c185481a234c..acbdf17a284e34e8640a8a886cc9a7b408ecbbef 100644 --- a/tests/mod_test/tests/scsi_test.rs +++ b/tests/mod_test/tests/scsi_test.rs @@ -27,6 +27,7 @@ use mod_test::libdriver::virtio::{ use mod_test::libdriver::virtio_pci_modern::TestVirtioPciDev; use mod_test::libtest::{test_init, TestState, MACHINE_TYPE_ARG}; use mod_test::utils::{cleanup_img, create_img, ImageType, TEST_IMAGE_SIZE}; +#[cfg(not(target_env = "ohos"))] use util::aio::{aio_probe, AioEngine}; use util::byte_code::ByteCode; use util::offset_of; @@ -133,15 +134,11 @@ impl VirtioScsiTest { use_iothread: iothread, }; - let readonly = if scsi_type == ScsiDeviceType::ScsiHd { - false - } else { - true - }; + let readonly = scsi_type != ScsiDeviceType::ScsiHd; let scsi_devices: Vec = vec![ScsiDeviceConfig { cntlr_id: 0, device_type: scsi_type, - image_path: image_path.clone(), + image_path, target, lun, read_only: readonly, @@ -201,7 +198,7 @@ impl VirtioScsiTest { if let Some(data) = data_out { let out_len = data.len() as u32; let out_bytes = data.as_bytes().to_vec(); - let out_addr = self.alloc.borrow_mut().alloc(out_len as u64); + let out_addr = self.alloc.borrow_mut().alloc(u64::from(out_len)); self.state.borrow().memwrite(out_addr, out_bytes.as_slice()); data_entries.push(TestVringDescEntry { data: out_addr, @@ -215,7 +212,7 @@ impl VirtioScsiTest { let resp_addr = self .alloc .borrow_mut() - .alloc(cmdresp_len + data_in_len as u64); + .alloc(cmdresp_len + u64::from(data_in_len)); let resp_bytes = resp.as_bytes(); self.state.borrow().memwrite(resp_addr, resp_bytes); @@ -257,13 +254,13 @@ impl VirtioScsiTest { size_of::(), ) }; - *resp = slice[0].clone(); + *resp = slice[0]; if data_in_len > 0 { data_in.append( self.state .borrow() - .memread(resp_addr + cmdresp_len, data_in_len as u64) + .memread(resp_addr + cmdresp_len, u64::from(data_in_len)) .as_mut(), ); } @@ -439,8 +436,8 @@ impl TestVirtioScsiCmdReq { let mut target_lun = [0_u8; 8]; target_lun[0] = 1; target_lun[1] = target; - target_lun[2] = (lun >> 8) as u8 & 0xff; - target_lun[3] = lun as u8 & 0xff; + target_lun[2] = (lun >> 8) as u8; + target_lun[3] = lun as u8; req.lun = target_lun; req.cdb = cdb; @@ -505,6 +502,7 @@ impl std::fmt::Display for ScsiDeviceType { } } +#[allow(dead_code)] #[derive(Clone, Debug, Copy)] enum TestAioType { AioOff = 0, @@ -623,7 +621,7 @@ fn scsi_test_init( let machine = TestStdMachine::new(test_state.clone()); let allocator = machine.allocator.clone(); - let virtio_scsi = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let virtio_scsi = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); virtio_scsi.borrow_mut().init(pci_slot, pci_fn); (virtio_scsi, test_state, allocator) @@ -665,7 +663,7 @@ fn scsi_hd_basic_test() { for i in 0..32 { // Test 1 Result: Only response 0 for target == 31. Otherwise response // VIRTIO_SCSI_S_BAD_TARGET. - let expect_result = if i == target as u16 { + let expect_result = if i == u16::from(target) { VIRTIO_SCSI_S_OK } else { VIRTIO_SCSI_S_BAD_TARGET @@ -675,7 +673,7 @@ fn scsi_hd_basic_test() { target: i as u8, lun: 0, data_out: None, - data_in_length: INQUIRY_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_DATA_LEN), expect_response: expect_result, expect_status: GOOD, expect_result_data: None, @@ -705,7 +703,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: REPORT_LUNS_DATA_LEN as u32, + data_in_length: u32::from(REPORT_LUNS_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -740,7 +738,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: READ_CAPACITY_10_DATA_LEN as u32, + data_in_length: u32::from(READ_CAPACITY_10_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: None, @@ -753,8 +751,12 @@ fn scsi_hd_basic_test() { // Bytes[4-7]: Logical Block Length In Bytes. // Total size = (last logical block address + 1) * block length. assert_eq!( - (u32::from_be_bytes(data_in.as_ref().unwrap()[0..4].try_into().unwrap()) as u64 + 1) - * (u32::from_be_bytes(data_in.as_ref().unwrap()[4..8].try_into().unwrap()) as u64), + (u64::from(u32::from_be_bytes( + data_in.as_ref().unwrap()[0..4].try_into().unwrap() + )) + 1) + * u64::from(u32::from_be_bytes( + data_in.as_ref().unwrap()[4..8].try_into().unwrap() + )), TEST_IMAGE_SIZE ); @@ -793,7 +795,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: MODE_SENSE_PAGE_CACHE_LEN_DATA_LEN as u32, + data_in_length: u32::from(MODE_SENSE_PAGE_CACHE_LEN_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -838,7 +840,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: MODE_SENSE_PAGE_ALL_DATA_LEN as u32, + data_in_length: u32::from(MODE_SENSE_PAGE_ALL_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -861,7 +863,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: INQUIRY_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: None, @@ -885,7 +887,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: INQUIRY_SUPPORTED_VPD_PAGES_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_SUPPORTED_VPD_PAGES_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -906,7 +908,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: INQUIRY_UNIT_SERIAL_NUMBER_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_UNIT_SERIAL_NUMBER_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: None, @@ -931,7 +933,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: INQUIRY_DEVICE_IDENTIFICATION_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_DEVICE_IDENTIFICATION_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: None, @@ -953,7 +955,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: INQUIRY_BLOCK_LIMITS_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_BLOCK_LIMITS_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: None, @@ -981,7 +983,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: INQUIRY_BLOCK_DEVICE_CHARACTERISTICS_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_BLOCK_DEVICE_CHARACTERISTICS_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1010,7 +1012,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: INQUIRY_LOGICAL_BLOCK_PROVISIONING_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_LOGICAL_BLOCK_PROVISIONING_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1031,7 +1033,7 @@ fn scsi_hd_basic_test() { target, lun, data_out: None, - data_in_length: INQUIRY_REFERRALS_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_REFERRALS_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: CHECK_CONDITION, expect_result_data: None, @@ -1110,7 +1112,7 @@ fn scsi_cd_basic_test() { target, lun, data_out: None, - data_in_length: MODE_SENSE_LEN_DATA_LEN as u32, + data_in_length: u32::from(MODE_SENSE_LEN_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1129,7 +1131,7 @@ fn scsi_cd_basic_test() { target, lun, data_out: None, - data_in_length: TEST_SCSI_SENSE_LEN as u32, + data_in_length: TEST_SCSI_SENSE_LEN, expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: None, @@ -1172,7 +1174,7 @@ fn scsi_cd_basic_test() { target, lun, data_out: None, - data_in_length: READ_TOC_DATA_LEN as u32, + data_in_length: u32::from(READ_TOC_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1205,7 +1207,7 @@ fn scsi_cd_basic_test() { target, lun, data_out: None, - data_in_length: READ_TOC_MSF_DATA_LEN as u32, + data_in_length: u32::from(READ_TOC_MSF_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1229,7 +1231,7 @@ fn scsi_cd_basic_test() { target, lun, data_out: None, - data_in_length: READ_TOC_FORMAT_DATA_LEN as u32, + data_in_length: u32::from(READ_TOC_FORMAT_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1270,7 +1272,7 @@ fn scsi_cd_basic_test() { target, lun, data_out: None, - data_in_length: READ_DISC_INFORMATION_DATA_LEN as u32, + data_in_length: u32::from(READ_DISC_INFORMATION_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1339,7 +1341,7 @@ fn scsi_cd_basic_test() { target, lun, data_out: None, - data_in_length: GET_CONFIGURATION_DATA_LEN as u32, + data_in_length: u32::from(GET_CONFIGURATION_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1375,7 +1377,7 @@ fn scsi_cd_basic_test() { target, lun, data_out: None, - data_in_length: GET_EVENT_STATUS_NOTIFICATION_DATA_LEN as u32, + data_in_length: u32::from(GET_EVENT_STATUS_NOTIFICATION_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1423,7 +1425,7 @@ fn scsi_target_cdb_test() { target, lun: req_lun, data_out: None, - data_in_length: REPORT_LUNS_DATA_LEN as u32, + data_in_length: u32::from(REPORT_LUNS_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1447,7 +1449,7 @@ fn scsi_target_cdb_test() { target, lun: req_lun, data_out: None, - data_in_length: INQUIRY_TARGET_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_TARGET_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1480,7 +1482,7 @@ fn scsi_target_cdb_test() { target, lun: 0, data_out: None, - data_in_length: INQUIRY_TARGET_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_TARGET_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: Some(expect_result_vec), @@ -1503,7 +1505,7 @@ fn scsi_target_cdb_test() { target, lun: req_lun, data_out: None, - data_in_length: INQUIRY_TARGET_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_TARGET_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: GOOD, expect_result_data: None, @@ -1524,7 +1526,7 @@ fn scsi_target_cdb_test() { target, lun: 0, data_out: None, - data_in_length: INQUIRY_TARGET_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_TARGET_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: CHECK_CONDITION, expect_result_data: None, @@ -1545,7 +1547,7 @@ fn scsi_target_cdb_test() { target, lun: req_lun, data_out: None, - data_in_length: INQUIRY_TARGET_DATA_LEN as u32, + data_in_length: u32::from(INQUIRY_TARGET_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: CHECK_CONDITION, expect_result_data: None, @@ -1619,7 +1621,7 @@ fn scsi_target_cdb_test() { target, lun: req_lun, data_out: None, - data_in_length: READ_CAPACITY_10_DATA_LEN as u32, + data_in_length: u32::from(READ_CAPACITY_10_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: CHECK_CONDITION, expect_result_data: None, @@ -1865,13 +1867,14 @@ fn aio_model_test() { let mut lun = 0x2; let mut device_vec: Vec = Vec::new(); + #[cfg(not(target_env = "ohos"))] if aio_probe(AioEngine::IoUring).is_ok() { // Scsi Disk 1. AIO io_uring. Direct false. let image_path = Rc::new(create_img(TEST_IMAGE_SIZE, 0, &ImageType::Raw)); device_vec.push(ScsiDeviceConfig { cntlr_id: 0, device_type: ScsiDeviceType::ScsiHd, - image_path: image_path.clone(), + image_path, target, lun, read_only: false, @@ -1886,7 +1889,7 @@ fn aio_model_test() { device_vec.push(ScsiDeviceConfig { cntlr_id: 0, device_type: ScsiDeviceType::ScsiHd, - image_path: image_path.clone(), + image_path, target, lun, read_only: false, @@ -1905,7 +1908,7 @@ fn aio_model_test() { device_vec.push(ScsiDeviceConfig { cntlr_id: 0, device_type: ScsiDeviceType::ScsiHd, - image_path: image_path.clone(), + image_path, target, lun, read_only: false, @@ -1916,6 +1919,7 @@ fn aio_model_test() { // Scsi Disk 5. AIO native. Direct false. This is not allowed. // Stratovirt will report "native aio type should be used with direct on" + #[cfg(not(target_env = "ohos"))] if aio_probe(AioEngine::Native).is_ok() { // Scsi Disk 6. AIO native. Direct true. lun += 1; @@ -1923,7 +1927,7 @@ fn aio_model_test() { device_vec.push(ScsiDeviceConfig { cntlr_id: 0, device_type: ScsiDeviceType::ScsiHd, - image_path: image_path.clone(), + image_path, target, lun, read_only: false, @@ -2246,7 +2250,7 @@ fn send_cd_command_to_hd_test() { target, lun, data_out: None, - data_in_length: MODE_SENSE_LEN_DATA_LEN as u32, + data_in_length: u32::from(MODE_SENSE_LEN_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: CHECK_CONDITION, expect_result_data: None, @@ -2265,7 +2269,7 @@ fn send_cd_command_to_hd_test() { target, lun, data_out: None, - data_in_length: READ_DISC_INFORMATION_DATA_LEN as u32, + data_in_length: u32::from(READ_DISC_INFORMATION_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: CHECK_CONDITION, expect_result_data: None, @@ -2283,7 +2287,7 @@ fn send_cd_command_to_hd_test() { target, lun, data_out: None, - data_in_length: GET_CONFIGURATION_DATA_LEN as u32, + data_in_length: u32::from(GET_CONFIGURATION_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: CHECK_CONDITION, expect_result_data: None, @@ -2305,7 +2309,7 @@ fn send_cd_command_to_hd_test() { target, lun, data_out: None, - data_in_length: GET_EVENT_STATUS_NOTIFICATION_DATA_LEN as u32, + data_in_length: u32::from(GET_EVENT_STATUS_NOTIFICATION_DATA_LEN), expect_response: VIRTIO_SCSI_S_OK, expect_status: CHECK_CONDITION, expect_result_data: None, @@ -2332,7 +2336,7 @@ fn send_cd_command_to_hd_test() { fn wrong_io_test() { let target = 0xff; let lun = 0xff; - let size = 1 * 1024; // Disk size: 1K. + let size = 1024; // Disk size: 1K. let mut vst = VirtioScsiTest::testcase_start_with_config(ScsiDeviceType::ScsiHd, target, lun, size, true); diff --git a/tests/mod_test/tests/serial_test.rs b/tests/mod_test/tests/serial_test.rs index ac46ce5164c9c5271b1fa8e7d3667f7c32b87307..8009f8f2043a1cab8be2a4b5264b5ca4d6d89c33 100644 --- a/tests/mod_test/tests/serial_test.rs +++ b/tests/mod_test/tests/serial_test.rs @@ -111,7 +111,7 @@ impl SerialTest { fn get_pty_path(&mut self) -> String { let ret = self.state.borrow().qmp("{\"execute\": \"query-chardev\"}"); - if (*ret.get("return").unwrap()).as_array().unwrap().len() != 0 + if !(*ret.get("return").unwrap()).as_array().unwrap().is_empty() && (*ret.get("return").unwrap())[0].get("filename").is_some() { let filename = (*ret.get("return").unwrap())[0] @@ -120,9 +120,9 @@ impl SerialTest { .to_string() .replace('"', ""); let mut file_path: Vec<&str> = filename.split("pty:").collect(); - return file_path.pop().unwrap().to_string(); + file_path.pop().unwrap().to_string() } else { - return String::from(""); + String::from("") } } @@ -155,7 +155,7 @@ impl SerialTest { self.serial .borrow() - .kick_virtqueue(self.state.clone(), queue.clone()); + .kick_virtqueue(self.state.clone(), queue); (addr, free_head) } @@ -214,7 +214,8 @@ impl SerialTest { // Port Ready. for port in self.ports.clone().iter() { - let ready_msg = VirtioConsoleControl::new(*port.0 as u32, VIRTIO_CONSOLE_PORT_READY, 1); + let ready_msg = + VirtioConsoleControl::new(u32::from(*port.0), VIRTIO_CONSOLE_PORT_READY, 1); self.out_control_event(ready_msg); // If it's a console port. @@ -258,7 +259,7 @@ impl SerialTest { // driver -> device: port open. let open_msg: VirtioConsoleControl = - VirtioConsoleControl::new(*port.0 as u32, VIRTIO_CONSOLE_PORT_OPEN, 1); + VirtioConsoleControl::new(u32::from(*port.0), VIRTIO_CONSOLE_PORT_OPEN, 1); self.out_control_event(open_msg); } } @@ -314,14 +315,14 @@ impl SerialTest { server: _, nowait: _, } => { - stream = self.connect_socket_host(&path); + stream = self.connect_socket_host(path); } } // Connect Guest. // driver -> device: port open. let open_msg: VirtioConsoleControl = - VirtioConsoleControl::new(port.nr as u32, VIRTIO_CONSOLE_PORT_OPEN, 1); + VirtioConsoleControl::new(u32::from(port.nr), VIRTIO_CONSOLE_PORT_OPEN, 1); self.out_control_event(open_msg); // IO: Guest -> Host. @@ -354,9 +355,9 @@ impl SerialTest { let result = match port.chardev_type { ChardevType::Pty => { let output = self.connect_pty_host(true); - output.unwrap().write(&test_data.as_bytes()) + output.unwrap().write(test_data.as_bytes()) } - _ => stream.as_ref().unwrap().write(&test_data.as_bytes()), + _ => stream.as_ref().unwrap().write(test_data.as_bytes()), }; match result { Ok(_num) => { @@ -461,7 +462,7 @@ fn create_serial(ports_config: Vec, pci_slot: u8, pci_fn: u8) -> Ser fn verify_output_data(test_state: Rc>, addr: u64, len: u32, test_data: &String) { let mut data_buf: Vec = Vec::with_capacity(len as usize); - data_buf.append(test_state.borrow().memread(addr, len as u64).as_mut()); + data_buf.append(test_state.borrow().memread(addr, u64::from(len)).as_mut()); let data = String::from_utf8(data_buf).unwrap(); assert_eq!(data, *test_data); } @@ -598,7 +599,7 @@ fn virtserialport_socket_basic() { nowait: true, }; let port = PortConfig { - chardev_type: socket.clone(), + chardev_type: socket, nr: 1, is_console: false, }; @@ -656,18 +657,19 @@ fn virtconsole_pty_err_out_control_msg() { }; let pci_slot = 0x04; let pci_fn = 0x0; - let mut st = create_serial(vec![port.clone()], pci_slot, pci_fn); + let mut st = create_serial(vec![port], pci_slot, pci_fn); st.serial_init(); // Error out control msg which has invalid event. Just discard this invalid msg. Nothing // happened. - let invalid_event_msg = VirtioConsoleControl::new(nr as u32, VIRTIO_CONSOLE_PORT_NAME, 1); + let invalid_event_msg = VirtioConsoleControl::new(u32::from(nr), VIRTIO_CONSOLE_PORT_NAME, 1); st.out_control_event(invalid_event_msg); // Error out control msg which has non-existed port id. Just discard this invalid msg. Nothing // happened. - let invalid_event_msg = VirtioConsoleControl::new((nr + 5) as u32, VIRTIO_CONSOLE_PORT_OPEN, 1); + let invalid_event_msg = + VirtioConsoleControl::new(u32::from(nr + 5), VIRTIO_CONSOLE_PORT_OPEN, 1); st.out_control_event(invalid_event_msg); // Error out control msg which size is illegal. @@ -705,7 +707,7 @@ fn virtconsole_pty_invalid_in_control_buffer() { }; let pci_slot = 0x04; let pci_fn = 0x0; - let mut st = create_serial(vec![port.clone()], pci_slot, pci_fn); + let mut st = create_serial(vec![port], pci_slot, pci_fn); // Init virtqueues. st.virtqueue_setup(DEFAULT_SERIAL_VIRTQUEUES); @@ -774,7 +776,7 @@ fn virtserialport_socket_not_connect() { nowait: true, }; let port = PortConfig { - chardev_type: socket.clone(), + chardev_type: socket, nr, is_console: false, }; diff --git a/tests/mod_test/tests/usb_camera_test.rs b/tests/mod_test/tests/usb_camera_test.rs index b519e1ca53075b2c35fbd1e9015f13fc9c623af6..98ac301a0840b8699c368471168a7272fb3570f9 100644 --- a/tests/mod_test/tests/usb_camera_test.rs +++ b/tests/mod_test/tests/usb_camera_test.rs @@ -155,7 +155,7 @@ fn check_multi_frames( slot_id, VS_ENDPOINT_ID, frame_len, - UVC_HEADER_LEN as u32, + u32::from(UVC_HEADER_LEN), max_payload, ); for buf in &payload_list { @@ -234,7 +234,7 @@ fn qmp_plug_camera(test_state: &Rc>, id: &str, camdev: &str) let test_state = test_state.borrow_mut(); let cmd = r#"{"execute": "device_add", "arguments": {"id": "ID", "driver": "usb-camera", "cameradev": "CAMDEV"}}"#; let cmd = cmd.replace("ID", id); - let cmd = cmd.replace("CAMDEV", &camdev); + let cmd = cmd.replace("CAMDEV", camdev); test_state.qmp(&cmd) } @@ -314,7 +314,7 @@ fn test_xhci_camera_invalid_frame_len() { slot_id, VS_ENDPOINT_ID, len as u32, - UVC_HEADER_LEN as u32, + u32::from(UVC_HEADER_LEN), cur.dwMaxPayloadTransferSize, ); for item in payload_list { @@ -553,13 +553,19 @@ fn test_xhci_camera_hotplug_invalid() { .with_config("auto_run", true) .build(); + #[cfg(not(target_env = "ohos"))] qmp_cameradev_add(&test_state, "camdev0", "v4l2", "/tmp/not-existed"); + #[cfg(target_env = "ohos")] + qmp_cameradev_add(&test_state, "camdev0", "ohcamera", "InvalidNum"); // Invalid cameradev. let value = qmp_plug_camera(&test_state, "usbcam0", "camdev0"); let desc = value["error"]["desc"].as_str().unwrap().to_string(); + #[cfg(not(target_env = "ohos"))] assert_eq!(desc, "Failed to open v4l2 backend /tmp/not-existed."); + #[cfg(target_env = "ohos")] + assert_eq!(desc, "OH Camera: failed to init cameras"); // Invalid device id. - let value = qmp_unplug_camera(&test_state.clone(), "usbcam0"); + let value = qmp_unplug_camera(&test_state, "usbcam0"); let desc = value["error"]["desc"].as_str().unwrap().to_string(); assert_eq!(desc, "Failed to detach device: id usbcam0 not found"); // Invalid cameradev id. diff --git a/tests/mod_test/tests/usb_storage_test.rs b/tests/mod_test/tests/usb_storage_test.rs index 472d91ef27c90b77cd77186976086c8faea97d5e..38a45a7c12f9c1105a0564d3f0923983d84ab27c 100644 --- a/tests/mod_test/tests/usb_storage_test.rs +++ b/tests/mod_test/tests/usb_storage_test.rs @@ -83,7 +83,7 @@ fn cbw_phase( } let mut iovecs = Vec::new(); - let ptr = guest_allocator.alloc(CBW_SIZE as u64); + let ptr = guest_allocator.alloc(u64::from(CBW_SIZE)); xhci.mem_write(ptr, &cbw_buf); let iovec = TestIovec::new(ptr, len as usize, false); @@ -104,10 +104,10 @@ fn data_phase( ) { let mut iovecs = Vec::new(); let ptr = guest_allocator.alloc(buf.len() as u64); - let iovec = TestIovec::new(ptr, buf.len() as usize, false); + let iovec = TestIovec::new(ptr, buf.len(), false); if !to_host { - xhci.mem_write(ptr, &buf); + xhci.mem_write(ptr, buf); } iovecs.push(iovec); @@ -142,7 +142,7 @@ fn csw_phase( sig_check: bool, ) -> u64 { let mut iovecs = Vec::new(); - let ptr = guest_allocator.alloc(len as u64); + let ptr = guest_allocator.alloc(u64::from(len)); let iovec = TestIovec::new(ptr, len as usize, false); iovecs.push(iovec); @@ -335,7 +335,7 @@ fn usb_storage_functional_get_max_lun() { xhci.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = xhci.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::Success as u32); - let buf = xhci.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, 1); + let buf = xhci.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), 1); assert_eq!(buf, [0]); @@ -878,7 +878,7 @@ fn usb_storage_cbw_invalid_endpoint() { LittleEndian::write_u32(&mut cbw_buf[0..4], cbw.sig); let mut iovecs = Vec::new(); - let ptr = guest_allocator.borrow_mut().alloc(CBW_SIZE as u64); + let ptr = guest_allocator.borrow_mut().alloc(u64::from(CBW_SIZE)); xhci.mem_write(ptr, &cbw_buf); let iovec = TestIovec::new(ptr, CBW_SIZE as usize, false); @@ -927,7 +927,7 @@ fn usb_storage_csw_invalid_endpoint() { // Test 2: CSW phase. let mut iovecs = Vec::new(); - let ptr = guest_allocator.borrow_mut().alloc(CSW_SIZE as u64); + let ptr = guest_allocator.borrow_mut().alloc(u64::from(CSW_SIZE)); let iovec = TestIovec::new(ptr, CSW_SIZE as usize, false); iovecs.push(iovec); diff --git a/tests/mod_test/tests/usb_test.rs b/tests/mod_test/tests/usb_test.rs index 8e7070fabc890f8c40e2239b0b4bff44e428c050..8a5f81140c783c9908691389926d5785081bb7df 100644 --- a/tests/mod_test/tests/usb_test.rs +++ b/tests/mod_test/tests/usb_test.rs @@ -487,7 +487,7 @@ fn test_xhci_keyboard_over_ring_limit() { xhci.queue_link_trb( slot_id, HID_DEVICE_ENDPOINT_ID, - org_ptr + TRB_SIZE as u64 * 64, + org_ptr + u64::from(TRB_SIZE) * 64, false, ); } else if i == 1 { @@ -544,7 +544,7 @@ fn test_xhci_keyboard_reorder() { let buf = xhci.get_transfer_data_indirect(evt.ptr, HID_KEYBOARD_LEN); assert_eq!(buf, [0, 0, 30, 31, 32, 33, 0, 0]); // 1 2 3 4 Up - let key_list = vec![ + let key_list = [ KEYCODE_NUM1, KEYCODE_NUM1 + 1, KEYCODE_NUM1 + 2, @@ -736,7 +736,7 @@ fn test_xhci_keyboard_invalid_value() { xhci.queue_trb(slot_id, HID_DEVICE_ENDPOINT_ID, &mut trb); xhci.doorbell_write(slot_id, HID_DEVICE_ENDPOINT_ID); // NOTE: no HCE, only primary interrupter supported now. - let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS as u64); + let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_HCE != USB_STS_HCE); test_state.borrow_mut().stop(); @@ -845,7 +845,7 @@ fn test_xhci_keyboard_over_transfer_ring() { xhci.queue_link_trb(slot_id, HID_DEVICE_ENDPOINT_ID, ptr, false); xhci.doorbell_write(slot_id, HID_DEVICE_ENDPOINT_ID); // Host Controller Error - let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS as u64); + let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_HCE == USB_STS_HCE); xhci.reset_controller(true); @@ -856,7 +856,7 @@ fn test_xhci_keyboard_over_transfer_ring() { xhci.queue_td_by_iovec(slot_id, HID_DEVICE_ENDPOINT_ID, &mut iovecs, false); xhci.doorbell_write(slot_id, HID_DEVICE_ENDPOINT_ID); // Host Controller Error - let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS as u64); + let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_HCE == USB_STS_HCE); xhci.reset_controller(true); @@ -948,58 +948,59 @@ fn test_xhci_keyboard_controller_init_invalid_register() { xhci.read_capability(); let old_value = xhci .pci_dev - .io_readl(xhci.bar_addr, XHCI_PCI_CAP_OFFSET as u64 + 0x2c); + .io_readl(xhci.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET) + 0x2c); xhci.pci_dev - .io_writel(xhci.bar_addr, XHCI_PCI_CAP_OFFSET as u64 + 0x2c, 0xffff); + .io_writel(xhci.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET) + 0x2c, 0xffff); let value = xhci .pci_dev - .io_readl(xhci.bar_addr, XHCI_PCI_CAP_OFFSET as u64 + 0x2c); + .io_readl(xhci.bar_addr, u64::from(XHCI_PCI_CAP_OFFSET) + 0x2c); assert_eq!(value, old_value); // Case 3: write invalid slot. xhci.pci_dev.io_writel( xhci.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_CONFIG as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_CONFIG, 0xffff, ); let config = xhci.pci_dev.io_readl( xhci.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_CONFIG as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_CONFIG, ); assert_ne!(config, 0xffff); // Case 4: invalid oper xhci.pci_dev.io_writel( xhci.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_USBSTS as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_USBSTS, 0xffff, ); let status = xhci.pci_dev.io_readl( xhci.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_USBSTS as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_USBSTS, ); assert_ne!(status, 0xffff); // Device Notify Control xhci.pci_dev.io_writel( xhci.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_DNCTRL as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_DNCTRL, 0x12345, ); let ndctrl = xhci.pci_dev.io_readl( xhci.bar_addr, - XHCI_PCI_OPER_OFFSET as u64 + XHCI_OPER_REG_DNCTRL as u64, + u64::from(XHCI_PCI_OPER_OFFSET) + XHCI_OPER_REG_DNCTRL, ); assert_eq!(ndctrl, 0x12345 & XHCI_OPER_NE_MASK); // invalid port offset. let invalid_offset = 0x7; xhci.pci_dev.io_writel( xhci.bar_addr, - XHCI_PCI_PORT_OFFSET as u64 + invalid_offset, + u64::from(XHCI_PCI_PORT_OFFSET) + invalid_offset, 0xff, ); - let invalid_offset = xhci - .pci_dev - .io_readl(xhci.bar_addr, XHCI_PCI_PORT_OFFSET as u64 + invalid_offset); + let invalid_offset = xhci.pci_dev.io_readl( + xhci.bar_addr, + u64::from(XHCI_PCI_PORT_OFFSET) + invalid_offset, + ); assert_eq!(invalid_offset, 0); xhci.init_device_context_base_address_array_pointer(); @@ -1010,25 +1011,28 @@ fn test_xhci_keyboard_controller_init_invalid_register() { xhci.interrupter_regs_writeq(0, XHCI_INTR_REG_ERSTBA_LO, 0); // micro frame index. xhci.pci_dev - .io_writel(xhci.bar_addr, XHCI_PCI_RUNTIME_OFFSET as u64, 0xf); + .io_writel(xhci.bar_addr, u64::from(XHCI_PCI_RUNTIME_OFFSET), 0xf); let mf_index = xhci .pci_dev - .io_readl(xhci.bar_addr, XHCI_PCI_RUNTIME_OFFSET as u64); + .io_readl(xhci.bar_addr, u64::from(XHCI_PCI_RUNTIME_OFFSET)); assert!(mf_index <= 0x3fff); // invalid offset - xhci.pci_dev - .io_writel(xhci.bar_addr, XHCI_PCI_RUNTIME_OFFSET as u64 + 0x1008, 0xf); + xhci.pci_dev.io_writel( + xhci.bar_addr, + u64::from(XHCI_PCI_RUNTIME_OFFSET) + 0x1008, + 0xf, + ); let over_offset = xhci .pci_dev - .io_readl(xhci.bar_addr, XHCI_PCI_RUNTIME_OFFSET as u64 + 0x1008); + .io_readl(xhci.bar_addr, u64::from(XHCI_PCI_RUNTIME_OFFSET) + 0x1008); assert_eq!(over_offset, 0); // Case 6: invalid doorbell xhci.pci_dev - .io_writel(xhci.bar_addr, XHCI_PCI_DOORBELL_OFFSET as u64, 0xf); + .io_writel(xhci.bar_addr, u64::from(XHCI_PCI_DOORBELL_OFFSET), 0xf); let invalid_db = xhci .pci_dev - .io_readl(xhci.bar_addr, XHCI_PCI_DOORBELL_OFFSET as u64); + .io_readl(xhci.bar_addr, u64::from(XHCI_PCI_DOORBELL_OFFSET)); assert_eq!(invalid_db, 0); // Case 7: invalid size @@ -1080,7 +1084,7 @@ fn test_xhci_keyboard_controller_init_miss_step() { xhci.enable_slot(); assert!(xhci.fetch_event(PRIMARY_INTERRUPTER_ID).is_none()); // Host Controller Error - let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS as u64); + let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_HCE == USB_STS_HCE); xhci.reset_controller(false); @@ -1104,7 +1108,7 @@ fn test_xhci_keyboard_controller_init_miss_step() { xhci.address_device(slot_id, false, port_id); assert!(xhci.fetch_event(PRIMARY_INTERRUPTER_ID).is_none()); // Host Controller Error - let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS as u64); + let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_HCE == USB_STS_HCE); xhci.reset_controller(false); @@ -1326,7 +1330,7 @@ fn test_xhci_keyboard_over_command_ring() { xhci.queue_link_trb(0, 0, ptr, false); xhci.doorbell_write(0, 0); // Host Controller Error - let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS as u64); + let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_HCE == USB_STS_HCE); xhci.reset_controller(true); @@ -1468,7 +1472,7 @@ fn test_xhci_keyboard_device_init_invalid_request() { let device_req = UsbDeviceRequest { request_type: USB_DEVICE_IN_REQUEST, request: USB_REQUEST_GET_DESCRIPTOR, - value: (USB_DT_CONFIGURATION as u16) << 8 | 6, + value: u16::from(USB_DT_CONFIGURATION) << 8 | 6, index: 10, length: 10, }; @@ -1545,13 +1549,15 @@ fn test_xhci_keyboard_device_init_invalid_control() { let device_req = UsbDeviceRequest { request_type: USB_DEVICE_IN_REQUEST, request: USB_REQUEST_GET_DESCRIPTOR, - value: (USB_DT_CONFIGURATION as u16) << 8, + value: u16::from(USB_DT_CONFIGURATION) << 8, index: 0, length: 64, }; // Case 1: no SetUp Stage. // Data Stage. - let ptr = guest_allocator.borrow_mut().alloc(device_req.length as u64); + let ptr = guest_allocator + .borrow_mut() + .alloc(u64::from(device_req.length)); let in_dir = device_req.request_type & USB_DIRECTION_DEVICE_TO_HOST == USB_DIRECTION_DEVICE_TO_HOST; let mut data_trb = TestNormalTRB::generate_data_td(ptr, device_req.length, in_dir); @@ -1569,7 +1575,9 @@ fn test_xhci_keyboard_device_init_invalid_control() { let mut setup_trb = TestNormalTRB::generate_setup_td(&device_req); xhci.queue_trb(slot_id, CONTROL_ENDPOINT_ID, &mut setup_trb); // Data Stage. - let ptr = guest_allocator.borrow_mut().alloc(device_req.length as u64); + let ptr = guest_allocator + .borrow_mut() + .alloc(u64::from(device_req.length)); let in_dir = device_req.request_type & USB_DIRECTION_DEVICE_TO_HOST == USB_DIRECTION_DEVICE_TO_HOST; let mut data_trb = TestNormalTRB::generate_data_td(ptr, device_req.length, in_dir); @@ -1587,7 +1595,9 @@ fn test_xhci_keyboard_device_init_invalid_control() { setup_trb.set_idt_flag(false); xhci.queue_trb(slot_id, CONTROL_ENDPOINT_ID, &mut setup_trb); // Data Stage. - let ptr = guest_allocator.borrow_mut().alloc(device_req.length as u64); + let ptr = guest_allocator + .borrow_mut() + .alloc(u64::from(device_req.length)); let in_dir = device_req.request_type & USB_DIRECTION_DEVICE_TO_HOST == USB_DIRECTION_DEVICE_TO_HOST; let mut data_trb = TestNormalTRB::generate_data_td(ptr, device_req.length, in_dir); @@ -1605,7 +1615,9 @@ fn test_xhci_keyboard_device_init_invalid_control() { setup_trb.set_trb_transfer_length(11); xhci.queue_trb(slot_id, CONTROL_ENDPOINT_ID, &mut setup_trb); // Data Stage. - let ptr = guest_allocator.borrow_mut().alloc(device_req.length as u64); + let ptr = guest_allocator + .borrow_mut() + .alloc(u64::from(device_req.length)); let in_dir = device_req.request_type & USB_DIRECTION_DEVICE_TO_HOST == USB_DIRECTION_DEVICE_TO_HOST; let mut data_trb = TestNormalTRB::generate_data_td(ptr, device_req.length, in_dir); @@ -1622,7 +1634,9 @@ fn test_xhci_keyboard_device_init_invalid_control() { let mut setup_trb = TestNormalTRB::generate_setup_td(&device_req); xhci.queue_trb(slot_id, CONTROL_ENDPOINT_ID, &mut setup_trb); // Data Stage. - let ptr = guest_allocator.borrow_mut().alloc(device_req.length as u64); + let ptr = guest_allocator + .borrow_mut() + .alloc(u64::from(device_req.length)); let in_dir = device_req.request_type & USB_DIRECTION_DEVICE_TO_HOST == USB_DIRECTION_DEVICE_TO_HOST; let mut data_trb = TestNormalTRB::generate_data_td(ptr, device_req.length, in_dir); @@ -1752,7 +1766,7 @@ fn test_xhci_keyboard_device_init_reset_device() { let slot_id = evt.get_slot_id(); // Case 1: reset after enable slot. xhci.reset_device(slot_id); - let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS as u64); + let status = xhci.oper_regs_read(XHCI_OPER_REG_USBSTS); assert!(status & USB_STS_HCE == USB_STS_HCE); xhci.reset_controller(true); @@ -1871,7 +1885,7 @@ fn test_xhci_keyboard_device_init_device_request_repeat() { xhci.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = xhci.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::ShortPacket as u32); - let buf = xhci.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, 2); + let buf = xhci.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), 2); assert_eq!(buf, [0, 0]); // set configuration xhci.set_configuration(slot_id, 1); @@ -1883,7 +1897,7 @@ fn test_xhci_keyboard_device_init_device_request_repeat() { xhci.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = xhci.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::ShortPacket as u32); - let buf = xhci.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, 2); + let buf = xhci.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), 2); assert_eq!(buf[0], 1); // Set remote wakeup. xhci.set_feature(slot_id, USB_DEVICE_REMOTE_WAKEUP as u16); @@ -1895,7 +1909,7 @@ fn test_xhci_keyboard_device_init_device_request_repeat() { xhci.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = xhci.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::ShortPacket as u32); - let buf = xhci.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, 2); + let buf = xhci.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), 2); assert_eq!(buf, [2, 0]); // Clear remote wakeup. xhci.clear_feature(slot_id, USB_DEVICE_REMOTE_WAKEUP as u16); @@ -2026,9 +2040,9 @@ fn test_xhci_tablet_basic() { [ i as u8 % 3, (i * 10) as u8, - (i * 10 >> 8) as u8, + ((i * 10) >> 8) as u8, (i * 20) as u8, - (i * 20 >> 8) as u8, + ((i * 20) >> 8) as u8, 0, 0 ] @@ -2039,9 +2053,9 @@ fn test_xhci_tablet_basic() { [ 0, (i * 10) as u8, - (i * 10 >> 8) as u8, + ((i * 10) >> 8) as u8, (i * 20) as u8, - (i * 20 >> 8) as u8, + ((i * 20) >> 8) as u8, 0, 0 ] @@ -2162,7 +2176,7 @@ fn test_xhci_tablet_over_ring_limit() { xhci.queue_link_trb( slot_id, HID_DEVICE_ENDPOINT_ID, - org_ptr + TRB_SIZE as u64 * 64, + org_ptr + u64::from(TRB_SIZE) * 64, false, ); } else if i == 1 { @@ -2236,7 +2250,7 @@ fn test_xhci_tablet_device_init_control_command() { xhci.doorbell_write(slot_id, CONTROL_ENDPOINT_ID); let evt = xhci.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::ShortPacket as u32); - let buf = xhci.get_transfer_data_indirect(evt.ptr - TRB_SIZE as u64, HID_POINTER_LEN); + let buf = xhci.get_transfer_data_indirect(evt.ptr - u64::from(TRB_SIZE), HID_POINTER_LEN); assert_eq!(buf, [0, 0, 0, 0, 0, 0, 0]); xhci.test_pointer_event(slot_id, test_state.clone()); @@ -2419,11 +2433,11 @@ fn test_xhci_disable_interrupt() { // Case: disable USB_CMD_INTE qmp_send_pointer_event(test_state.borrow_mut(), 100, 200, 0, true); xhci.queue_direct_td(slot_id, HID_DEVICE_ENDPOINT_ID, HID_POINTER_LEN); - let value = xhci.oper_regs_read(XHCI_OPER_REG_USBCMD as u64); + let value = xhci.oper_regs_read(XHCI_OPER_REG_USBCMD); xhci.oper_regs_write(XHCI_OPER_REG_USBCMD, value & !USB_CMD_INTE); xhci.doorbell_write(slot_id, HID_DEVICE_ENDPOINT_ID); assert!(xhci.fetch_event(PRIMARY_INTERRUPTER_ID).is_none()); - let value = xhci.oper_regs_read(XHCI_OPER_REG_USBCMD as u64); + let value = xhci.oper_regs_read(XHCI_OPER_REG_USBCMD); xhci.oper_regs_write(XHCI_OPER_REG_USBCMD, value | USB_CMD_INTE); let evt = xhci.fetch_event(PRIMARY_INTERRUPTER_ID).unwrap(); assert_eq!(evt.ccode, TRBCCode::Success as u32); @@ -2433,8 +2447,7 @@ fn test_xhci_disable_interrupt() { // Case: disable IMAN_IE qmp_send_pointer_event(test_state.borrow_mut(), 100, 200, 0, true); xhci.queue_direct_td(slot_id, HID_DEVICE_ENDPOINT_ID, HID_POINTER_LEN); - let value = - xhci.interrupter_regs_read(PRIMARY_INTERRUPTER_ID as u64, XHCI_INTR_REG_IMAN as u64); + let value = xhci.interrupter_regs_read(PRIMARY_INTERRUPTER_ID as u64, XHCI_INTR_REG_IMAN); xhci.interrupter_regs_write( PRIMARY_INTERRUPTER_ID as u64, XHCI_INTR_REG_IMAN, diff --git a/tests/mod_test/tests/virtio_gpu_test.rs b/tests/mod_test/tests/virtio_gpu_test.rs index 26d9c1c5fffd0a4a995efaa402070bc55fd16b06..c9baef2382ecec91effdbb5b85d67506abd2ee1a 100644 --- a/tests/mod_test/tests/virtio_gpu_test.rs +++ b/tests/mod_test/tests/virtio_gpu_test.rs @@ -60,19 +60,19 @@ fn image_display_fun() { let (dpy, gpu) = set_up(&gpu_cfg); let image_addr = gpu.borrow_mut().allocator.borrow_mut().alloc(image_size); - let image_byte_0 = vec![0 as u8; 1]; - let image_byte_1 = vec![1 as u8; 1]; - let image_0 = vec![0 as u8; image_size as usize]; + let image_byte_0 = vec![0_u8; 1]; + let image_byte_1 = vec![1_u8; 1]; + let image_0 = vec![0_u8; image_size as usize]; // image with half data 1 - let mut image_half_1 = vec![0 as u8; image_size as usize]; + let mut image_half_1 = vec![0_u8; image_size as usize]; let mut i = 0; while i < image_size / 2 { image_half_1[i as usize] = 1; i += 1; } // image with quarter data1 - let mut image_quarter_1 = vec![0 as u8; image_size as usize]; + let mut image_quarter_1 = vec![0_u8; image_size as usize]; let mut i = 0; while i < image_size / 4 { image_quarter_1[i as usize] = 1; @@ -190,9 +190,9 @@ fn image_display_fun() { #[test] fn cursor_display_fun() { - let image_0: Vec = vec![0 as u8; D_CURSOR_IMG_SIZE as usize]; - let image_1: Vec = vec![1 as u8; D_CURSOR_IMG_SIZE as usize]; - let image_byte_1 = vec![1 as u8; 1]; + let image_0: Vec = vec![0_u8; D_CURSOR_IMG_SIZE as usize]; + let image_1: Vec = vec![1_u8; D_CURSOR_IMG_SIZE as usize]; + let image_byte_1 = vec![1_u8; 1]; let image_size = cal_image_hostmem(D_FMT, D_CURSOR_WIDTH, D_CURSOR_HEIGHT); let image_size = image_size.0.unwrap() as u64; @@ -676,7 +676,7 @@ fn cursor_update_dfx() { gpu.borrow_mut().allocator.borrow_mut().alloc(image_size); let image_empty: Vec = vec![]; - let image_0: Vec = vec![0 as u8; D_CURSOR_IMG_SIZE as usize]; + let image_0: Vec = vec![0_u8; D_CURSOR_IMG_SIZE as usize]; // invalid scanout id assert!(current_curosr_check(&dpy, &image_empty)); diff --git a/tests/mod_test/tests/virtio_test.rs b/tests/mod_test/tests/virtio_test.rs index 0e7cb23749b5a2f33852d3b98848bfae1dda41a1..c306f3592a2d417d0a8c702f356f09b19b973838 100644 --- a/tests/mod_test/tests/virtio_test.rs +++ b/tests/mod_test/tests/virtio_test.rs @@ -78,17 +78,12 @@ fn send_one_request( alloc: Rc>, vq: Rc>, ) { - let (free_head, req_addr) = add_request( - test_state.clone(), - alloc.clone(), - vq.clone(), - VIRTIO_BLK_T_OUT, - 0, - ); + let (free_head, req_addr) = + add_request(test_state.clone(), alloc, vq.clone(), VIRTIO_BLK_T_OUT, 0); blk.borrow().virtqueue_notify(vq.clone()); blk.borrow().poll_used_elem( test_state.clone(), - vq.clone(), + vq, free_head, TIMEOUT_US, &mut None, @@ -128,7 +123,6 @@ fn init_device_step( vqs = blk .borrow_mut() .init_virtqueue(test_state.clone(), alloc.clone(), 1); - () } 8 => { blk.borrow().set_driver_ok(); @@ -140,7 +134,7 @@ fn init_device_step( // Try to send write and read request to StratoVirt, ignore // the interrupt from device. - if vqs.len() > 0 { + if !vqs.is_empty() { let (_, _) = add_request( test_state.clone(), alloc.clone(), @@ -171,16 +165,14 @@ fn check_req_result( addr: u64, timeout_us: u64, ) { - let status = blk - .borrow() - .req_result(test_state.clone(), addr, timeout_us); + let status = blk.borrow().req_result(test_state, addr, timeout_us); assert!(!blk.borrow().queue_was_notified(vq)); assert_eq!(status, VIRTIO_BLK_S_OK); } fn check_queue(blk: Rc>, desc: u64, avail: u64, used: u64) { let bar = blk.borrow().bar; - let common_base = blk.borrow().common_base as u64; + let common_base = u64::from(blk.borrow().common_base); let reqs = [ (offset_of!(VirtioPciCommonCfg, queue_desc_lo), desc), (offset_of!(VirtioPciCommonCfg, queue_desc_hi), desc >> 32), @@ -193,7 +185,7 @@ fn check_queue(blk: Rc>, desc: u64, avail: u64, used: let addr = blk .borrow() .pci_dev - .io_readl(bar, common_base as u64 + offset as u64); + .io_readl(bar, common_base + offset as u64); assert_eq!(addr, value as u32); } } @@ -289,13 +281,7 @@ fn do_event_idx_with_flag(flag: u16) { DEFAULT_IO_REQS * 2 - 1, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Feature Test. @@ -333,13 +319,7 @@ fn virtio_feature_none() { check_stratovirt_status(test_state.clone()); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Feature Test. @@ -415,13 +395,7 @@ fn virtio_feature_vertion_1() { DEFAULT_IO_REQS, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Driver just enable VIRTIO_F_VERSION_1|VIRTIO_RING_F_INDIRECT_DESC feature, @@ -452,14 +426,13 @@ fn virtio_feature_indirect() { free_head = vqs[0] .borrow_mut() .add(test_state.clone(), req_addr, 8, false); - let offset = free_head as u64 * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; + let offset = u64::from(free_head) * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; test_state .borrow() - .writew(vqs[0].borrow().desc + offset as u64, VRING_DESC_F_NEXT); - test_state.borrow().writew( - vqs[0].borrow().desc + offset as u64 + 2, - free_head as u16 + 1, - ); + .writew(vqs[0].borrow().desc + offset, VRING_DESC_F_NEXT); + test_state + .borrow() + .writew(vqs[0].borrow().desc + offset + 2, free_head as u16 + 1); let mut indirect_req = TestVringIndirectDesc::new(); indirect_req.setup(alloc.clone(), test_state.clone(), 2); indirect_req.add_desc(test_state.clone(), req_addr + 8, 520, false); @@ -485,20 +458,19 @@ fn virtio_feature_indirect() { free_head = vqs[0] .borrow_mut() .add(test_state.clone(), req_addr, 8, false); - let offset = free_head as u64 * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; + let offset = u64::from(free_head) * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; test_state .borrow() - .writew(vqs[0].borrow().desc + offset as u64, VRING_DESC_F_NEXT); - test_state.borrow().writew( - vqs[0].borrow().desc + offset as u64 + 2, - free_head as u16 + 1, - ); + .writew(vqs[0].borrow().desc + offset, VRING_DESC_F_NEXT); + test_state + .borrow() + .writew(vqs[0].borrow().desc + offset + 2, free_head as u16 + 1); let mut indirect_req = TestVringIndirectDesc::new(); indirect_req.setup(alloc.clone(), test_state.clone(), 2); indirect_req.add_desc(test_state.clone(), req_addr + 8, 8, false); indirect_req.add_desc( test_state.clone(), - req_addr + REQ_ADDR_LEN as u64, + req_addr + u64::from(REQ_ADDR_LEN), 513, true, ); @@ -523,19 +495,13 @@ fn virtio_feature_indirect() { String::from_utf8( test_state .borrow() - .memread(req_addr + REQ_ADDR_LEN as u64, 4) + .memread(req_addr + u64::from(REQ_ADDR_LEN), 4) ) .unwrap(), "TEST" ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Driver just enable VIRTIO_F_VERSION_1|VIRTIO_RING_F_EVENT_IDX feature, @@ -595,20 +561,19 @@ fn virtio_feature_indirect_and_event_idx() { let free_head = vqs[0] .borrow_mut() .add(test_state.clone(), req_addr, REQ_ADDR_LEN, false); - let offset = free_head as u64 * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; + let offset = u64::from(free_head) * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; test_state .borrow() - .writew(vqs[0].borrow().desc + offset as u64, VRING_DESC_F_NEXT); - test_state.borrow().writew( - vqs[0].borrow().desc + offset as u64 + 2, - free_head as u16 + 1, - ); + .writew(vqs[0].borrow().desc + offset, VRING_DESC_F_NEXT); + test_state + .borrow() + .writew(vqs[0].borrow().desc + offset + 2, free_head as u16 + 1); // 2 desc elems in indirect desc table. let mut indirect_req = TestVringIndirectDesc::new(); indirect_req.setup(alloc.clone(), test_state.clone(), 2); indirect_req.add_desc( test_state.clone(), - req_addr + REQ_ADDR_LEN as u64, + req_addr + u64::from(REQ_ADDR_LEN), REQ_DATA_LEN, false, ); @@ -678,13 +643,7 @@ fn virtio_feature_indirect_and_event_idx() { DEFAULT_IO_REQS * 2 - 1, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Setting abnormal status in device initialization. @@ -763,13 +722,7 @@ fn virtio_init_device_abnormal_status() { check_stratovirt_status(test_state.clone()); // 4. Destroy device. - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Setting abnormal feature in device initialization. @@ -874,12 +827,12 @@ fn virtio_init_device_abnormal_features() { fn virtio_init_device_abnormal_vring_info() { // (err_type, value, ack, device_status) let reqs = [ - (0, u16::MAX as u64, 0, 0), + (0, u64::from(u16::MAX), 0, 0), (0, 2, 0, 0), (1, 0_u64, 0xff, 0), (1, 255, 0xff, 0), (1, 1 << 15, 0xff, 0), - (1, u16::MAX as u64, 0xff, 0), + (1, u64::from(u16::MAX), 0xff, 0), (2, 0, 0xff, 0), (3, 0, 0xff, 0), (4, 0, 0xff, 0), @@ -921,7 +874,7 @@ fn virtio_init_device_abnormal_vring_info() { blk.borrow().queue_select(value as u16); } - let queue_size = blk.borrow().get_queue_size() as u32; + let queue_size = u32::from(blk.borrow().get_queue_size()); // Set invalid queue size. if err_type == 1 { @@ -937,18 +890,19 @@ fn virtio_init_device_abnormal_vring_info() { vq.borrow_mut().indirect = (features & (1 << VIRTIO_RING_F_INDIRECT_DESC)) != 0; vq.borrow_mut().event = (features & (1 << VIRTIO_RING_F_EVENT_IDX)) != 0; - let addr = alloc - .borrow_mut() - .alloc(get_vring_size(queue_size, VIRTIO_PCI_VRING_ALIGN) as u64); + let addr = alloc.borrow_mut().alloc(u64::from(get_vring_size( + queue_size, + VIRTIO_PCI_VRING_ALIGN, + ))); vq.borrow_mut().desc = addr; - let avail = addr + (queue_size * size_of::() as u32) as u64 + 16; + let avail = addr + u64::from(queue_size * size_of::() as u32) + 16; vq.borrow_mut().avail = avail; let used = (avail - + (size_of::() as u32 * (3 + queue_size)) as u64 - + VIRTIO_PCI_VRING_ALIGN as u64 + + u64::from(size_of::() as u32 * (3 + queue_size)) + + u64::from(VIRTIO_PCI_VRING_ALIGN) - 1) - & !(VIRTIO_PCI_VRING_ALIGN as u64 - 1) + 16; + & (!(u64::from(VIRTIO_PCI_VRING_ALIGN) - 1) + 16); vq.borrow_mut().used = used + 16; match err_type { @@ -1015,32 +969,32 @@ fn virtio_init_device_abnormal_vring_info() { let notify_off = blk.borrow().pci_dev.io_readw( blk.borrow().bar, - blk.borrow().common_base as u64 + u64::from(blk.borrow().common_base) + offset_of!(VirtioPciCommonCfg, queue_notify_off) as u64, ); - vq.borrow_mut().queue_notify_off = blk.borrow().notify_base as u64 - + notify_off as u64 * blk.borrow().notify_off_multiplier as u64; + vq.borrow_mut().queue_notify_off = u64::from(blk.borrow().notify_base) + + u64::from(notify_off) * u64::from(blk.borrow().notify_off_multiplier); let offset = offset_of!(VirtioPciCommonCfg, queue_enable) as u64; // TEST enable vq with 0 if err_type == 9 { blk.borrow().pci_dev.io_writew( blk.borrow().bar, - blk.borrow().common_base as u64 + offset, + u64::from(blk.borrow().common_base) + offset, 0, ); } else { blk.borrow().pci_dev.io_writew( blk.borrow().bar, - blk.borrow().common_base as u64 + u64::from(blk.borrow().common_base) + offset_of!(VirtioPciCommonCfg, queue_enable) as u64, 1, ); if err_type == 10 { - let status = blk - .borrow() - .pci_dev - .io_readw(blk.borrow().bar, blk.borrow().common_base as u64 + offset); + let status = blk.borrow().pci_dev.io_readw( + blk.borrow().bar, + u64::from(blk.borrow().common_base) + offset, + ); assert_eq!(status, 1); } } @@ -1136,13 +1090,7 @@ fn virtio_init_device_out_of_order_1() { 0, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Init device out of order test 2. @@ -1195,13 +1143,7 @@ fn virtio_init_device_out_of_order_2() { 0, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Init device out of order test 3. @@ -1256,13 +1198,7 @@ fn virtio_init_device_out_of_order_3() { 0, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Repeat the initialization operation. @@ -1326,13 +1262,7 @@ fn virtio_init_device_repeat() { 0, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Setting abnormal desc addr in IO request. @@ -1481,11 +1411,11 @@ fn virtio_io_abnormal_desc_len() { if length == 16 { test_state.borrow().writel( indirect_desc + offset_of!(VringDesc, len) as u64, - u16::MAX as u32 * (VRING_DESC_SIZE as u32 + 1), + u32::from(u16::MAX) * (VRING_DESC_SIZE as u32 + 1), ); test_state.borrow().writel( indirect_desc + offset_of!(VringDesc, flags) as u64, - (VRING_DESC_F_INDIRECT | VRING_DESC_F_NEXT) as u32, + u32::from(VRING_DESC_F_INDIRECT | VRING_DESC_F_NEXT), ); } } @@ -1595,19 +1525,18 @@ fn virtio_io_abnormal_desc_flags_2() { let free_head = vqs[0] .borrow_mut() .add(test_state.clone(), req_addr, REQ_ADDR_LEN, false); - let offset = free_head as u64 * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; + let offset = u64::from(free_head) * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; test_state .borrow() - .writew(vqs[0].borrow().desc + offset as u64, VRING_DESC_F_NEXT); - test_state.borrow().writew( - vqs[0].borrow().desc + offset as u64 + 2, - free_head as u16 + 1, - ); + .writew(vqs[0].borrow().desc + offset, VRING_DESC_F_NEXT); + test_state + .borrow() + .writew(vqs[0].borrow().desc + offset + 2, free_head as u16 + 1); let mut indirect_req = TestVringIndirectDesc::new(); indirect_req.setup(alloc.clone(), test_state.clone(), 2); indirect_req.add_desc( test_state.clone(), - req_addr + REQ_ADDR_LEN as u64, + req_addr + u64::from(REQ_ADDR_LEN), REQ_DATA_LEN, false, ); @@ -1626,13 +1555,7 @@ fn virtio_io_abnormal_desc_flags_2() { assert!(blk.borrow().get_status() & VIRTIO_CONFIG_S_NEEDS_RESET > 0); check_stratovirt_status(test_state.clone()); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Setting abnormal desc flag in IO request, testcase 3. @@ -1670,14 +1593,13 @@ fn virtio_io_abnormal_desc_flags_3() { .borrow_mut() .add(test_state.clone(), req_addr, 8, false); - let offset = free_head as u64 * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; + let offset = u64::from(free_head) * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; test_state .borrow() - .writew(vqs[0].borrow().desc + offset as u64, VRING_DESC_F_NEXT); - test_state.borrow().writew( - vqs[0].borrow().desc + offset as u64 + 2, - free_head as u16 + 1, - ); + .writew(vqs[0].borrow().desc + offset, VRING_DESC_F_NEXT); + test_state + .borrow() + .writew(vqs[0].borrow().desc + offset + 2, free_head as u16 + 1); let mut indirect_req = TestVringIndirectDesc::new(); indirect_req.setup(alloc.clone(), test_state.clone(), 2); indirect_req.add_desc(test_state.clone(), req_addr + 8, 520, false); @@ -1687,7 +1609,7 @@ fn virtio_io_abnormal_desc_flags_3() { .add_indirect(test_state.clone(), indirect_req, true); // Add VRING_DESC_F_WRITE or VRING_DESC_F_NEXT to desc[0]->flags; - let addr = vqs[0].borrow().desc + 16_u64 * (free_head + 1) as u64 + 12; + let addr = vqs[0].borrow().desc + 16_u64 * u64::from(free_head + 1) + 12; let flags = test_state.borrow().readw(addr) | flag; test_state.borrow().writew(addr, flags); blk.borrow().virtqueue_notify(vqs[0].clone()); @@ -1834,13 +1756,7 @@ fn virtio_io_abnormal_desc_elem_place() { check_stratovirt_status(test_state.clone()); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Setting (queue_size + 1) indirect desc elems in IO request. @@ -1871,14 +1787,13 @@ fn virtio_io_abnormal_indirect_desc_elem_num() { let free_head = vqs[0] .borrow_mut() .add(test_state.clone(), req_addr, REQ_ADDR_LEN, false); - let offset = free_head as u64 * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; + let offset = u64::from(free_head) * VRING_DESC_SIZE + offset_of!(VringDesc, flags) as u64; test_state .borrow() - .writew(vqs[0].borrow().desc + offset as u64, VRING_DESC_F_NEXT); - test_state.borrow().writew( - vqs[0].borrow().desc + offset as u64 + 2, - free_head as u16 + 1, - ); + .writew(vqs[0].borrow().desc + offset, VRING_DESC_F_NEXT); + test_state + .borrow() + .writew(vqs[0].borrow().desc + offset + 2, free_head as u16 + 1); let mut indirect_req = TestVringIndirectDesc::new(); indirect_req.setup(alloc.clone(), test_state.clone(), queue_size as u16 + 1); for i in 0..queue_size { @@ -1910,13 +1825,7 @@ fn virtio_io_abnormal_indirect_desc_elem_num() { check_stratovirt_status(test_state.clone()); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Setting invalid flags to avail->flag in IO request. @@ -2186,13 +2095,7 @@ fn virtio_io_abnormal_used_idx() { true, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Virtio test step out of order, testcase 1. @@ -2256,13 +2159,7 @@ fn virtio_test_out_of_order_1() { check_stratovirt_status(test_state.clone()); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Virtio test step out of order, testcase 2. @@ -2284,13 +2181,7 @@ fn virtio_test_out_of_order_2() { 1, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); let (blk, test_state, alloc, image_path) = set_up(&ImageType::Raw); let vqs = blk.borrow_mut().init_device( @@ -2315,13 +2206,7 @@ fn virtio_test_out_of_order_2() { 0, ); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs, - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } /// Virtio test step repeat. @@ -2386,11 +2271,5 @@ fn virtio_test_repeat() { blk.borrow_mut().destroy_device(alloc.clone(), vqs.clone()); blk.borrow_mut().destroy_device(alloc.clone(), vqs.clone()); - tear_down( - blk.clone(), - test_state.clone(), - alloc.clone(), - vqs.clone(), - image_path.clone(), - ); + tear_down(blk, test_state, alloc, vqs, image_path); } diff --git a/tests/mod_test/tests/virtiofs_test.rs b/tests/mod_test/tests/virtiofs_test.rs index 544c8808310f21d6a0e32991885976c0cc583735..6598a9c7e5fe9b88ac980294326aa4d28c588cad 100644 --- a/tests/mod_test/tests/virtiofs_test.rs +++ b/tests/mod_test/tests/virtiofs_test.rs @@ -75,7 +75,7 @@ fn env_prepare(temp: bool) -> (String, String, String) { .unwrap(); Command::new("mknod") - .arg(virtiofs_test_character_device.clone()) + .arg(virtiofs_test_character_device) .arg("c") .arg("1") .arg("1") @@ -132,7 +132,7 @@ impl VirtioFsTest { let machine = TestStdMachine::new_bymem(test_state.clone(), memsize * 1024 * 1024, page_size); let allocator = machine.allocator.clone(); - let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus.clone()))); + let dev = Rc::new(RefCell::new(TestVirtioPciDev::new(machine.pci_bus))); dev.borrow_mut().init(pci_slot, pci_fn); let features = virtio_fs_default_feature(dev.clone()); let queues = @@ -156,7 +156,7 @@ impl VirtioFsTest { if let Some(member) = reqmember { let member_size = member.len() as u64; let member_addr = self.allocator.borrow_mut().alloc(member_size); - self.state.borrow().memwrite(member_addr, &member); + self.state.borrow().memwrite(member_addr, member); data_entries.push(TestVringDescEntry { data: member_addr, len: member_size as u32, @@ -288,7 +288,7 @@ impl VirtioFsTest { } fn testcase_end(&self, test_dir: String) { - self.testcase_check_and_end(None, test_dir.clone()); + self.testcase_check_and_end(None, test_dir); } fn testcase_check_and_end(&self, absolute_virtiofs_sock: Option, test_dir: String) { @@ -297,9 +297,9 @@ impl VirtioFsTest { .destroy_device(self.allocator.clone(), self.queues.clone()); if let Some(path) = absolute_virtiofs_sock { - let path_clone = path.clone(); + let path_clone = path; let sock_path = Path::new(&path_clone); - assert_eq!(sock_path.exists(), true); + assert!(sock_path.exists()); self.state.borrow_mut().stop(); } else { self.state.borrow_mut().stop(); @@ -339,10 +339,10 @@ fn fuse_init(fs: &VirtioFsTest) -> (FuseOutHeader, FuseInitOut) { let fuse_out_head = FuseOutHeader::default(); let fuse_init_out = FuseInitOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_init_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_init_out.as_bytes(), + fuse_in_head.as_bytes(), + fuse_init_in.as_bytes(), + fuse_out_head.as_bytes(), + fuse_init_out.as_bytes(), ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -356,15 +356,13 @@ fn fuse_destroy(fs: &VirtioFsTest) -> FuseOutHeader { let fuse_in_head = FuseInHeader::new(len as u32, FUSE_DESTROY, 0, 0, 0, 0, 0, 0); let fuse_out_head = FuseOutHeader::default(); let (_, _, outheaderaddr, _outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), + Some(fuse_in_head.as_bytes()), None, - Some(&fuse_out_head.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); - let out_header = read_obj::(fs.state.clone(), outheaderaddr.unwrap()); - - out_header + read_obj::(fs.state.clone(), outheaderaddr.unwrap()) } fn fuse_lookup(fs: &VirtioFsTest, name: String) -> u64 { @@ -375,10 +373,10 @@ fn fuse_lookup(fs: &VirtioFsTest, name: String) -> u64 { let fuse_out_head = FuseOutHeader::default(); let fuse_lookup_out = FuseEntryOut::default(); let (_outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_lookup_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_lookup_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_lookup_out.as_bytes(), ); let entry_out = read_obj::(fs.state.clone(), outbodyaddr); @@ -396,10 +394,10 @@ fn fuse_open(fs: &VirtioFsTest, nodeid: u64) -> u64 { let fuse_out_head = FuseOutHeader::default(); let fuse_open_out = FuseOpenOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_open_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_open_out.as_bytes(), + fuse_in_head.as_bytes(), + fuse_open_in.as_bytes(), + fuse_out_head.as_bytes(), + fuse_open_out.as_bytes(), ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -419,10 +417,10 @@ fn fuse_open_dir(fs: &VirtioFsTest, nodeid: u64) -> u64 { let fuse_out_head = FuseOutHeader::default(); let fuse_open_out = FuseOpenOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_open_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_open_out.as_bytes(), + fuse_in_head.as_bytes(), + fuse_open_in.as_bytes(), + fuse_out_head.as_bytes(), + fuse_open_out.as_bytes(), ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -452,10 +450,10 @@ fn fuse_lseek( let len = (size_of::() + trim_lseek_in_len) as u32; let fuse_in_head = FuseInHeader::new(len, FUSE_LSEEK, 0, nodeid, 0, 0, 0, 0); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_lseek_in.as_bytes()[0..lseek_in_len - trim], - &fuse_out_head.as_bytes(), - &fuse_lseek_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_lseek_out.as_bytes(), ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -475,10 +473,10 @@ fn fuse_getattr(fs: &VirtioFsTest, nodeid: u64, fh: u64) -> (FuseOutHeader, Fuse let fuse_out_head = FuseOutHeader::default(); let fuse_getattr_out = FuseAttrOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_getattr_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_getattr_out.as_bytes(), + fuse_in_head.as_bytes(), + fuse_getattr_in.as_bytes(), + fuse_out_head.as_bytes(), + fuse_getattr_out.as_bytes(), ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -568,10 +566,10 @@ fn mkdir_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_mkdir_out = FuseEntryOut::default(); let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_mkdir_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_mkdir_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_mkdir_out.as_bytes(), ); // Check. @@ -582,7 +580,7 @@ fn mkdir_test() { linkpath.push_str("/shared/dir"); let linkpath_clone = linkpath.clone(); let link_path = Path::new(&linkpath_clone); - assert_eq!(link_path.is_dir(), true); + assert!(link_path.is_dir()); // kill process and clean env. fs.testcase_end(virtiofs_test_dir); @@ -619,9 +617,9 @@ fn sync_fun() { }; let fuse_out_head = FuseOutHeader::default(); let (_, _, outheader, _outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), - Some(&fuse_fallocate_in.as_bytes()), - Some(&fuse_out_head.as_bytes()), + Some(fuse_in_head.as_bytes()), + Some(fuse_fallocate_in.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); @@ -654,9 +652,9 @@ fn syncdir_test() { }; let fuse_out_head = FuseOutHeader::default(); let (_, _, outheader, _outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), - Some(&fuse_fallocate_in.as_bytes()), - Some(&fuse_out_head.as_bytes()), + Some(fuse_in_head.as_bytes()), + Some(fuse_fallocate_in.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); @@ -684,9 +682,9 @@ fn invalid_fuse_test() { let fuse_out_head = FuseOutHeader::default(); let fake_fuse_out_body = [0]; let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fake_fuse_in_body, - &fuse_out_head.as_bytes(), + fuse_out_head.as_bytes(), &fake_fuse_out_body, ); @@ -805,9 +803,9 @@ fn ls_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_read_out = [0; DEFAULT_READ_SIZE]; let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_read_in.as_bytes(), - &fuse_out_head.as_bytes(), + fuse_in_head.as_bytes(), + fuse_read_in.as_bytes(), + fuse_out_head.as_bytes(), &fuse_read_out, ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -820,10 +818,10 @@ fn ls_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_forget_out = FuseForgetOut::default(); let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_read_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_forget_out.as_bytes(), + fuse_in_head.as_bytes(), + fuse_read_in.as_bytes(), + fuse_out_head.as_bytes(), + fuse_forget_out.as_bytes(), ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); assert_eq!(out_header.error, 0); @@ -840,9 +838,9 @@ fn ls_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_read_out = [0_u8; DEFAULT_READ_SIZE]; let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_read_in.as_bytes(), - &fuse_out_head.as_bytes(), + fuse_in_head.as_bytes(), + fuse_read_in.as_bytes(), + fuse_out_head.as_bytes(), &fuse_read_out, ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -858,9 +856,9 @@ fn ls_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_read_out = [0_u8; DEFAULT_READ_SIZE]; let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_read_in.as_bytes(), - &fuse_out_head.as_bytes(), + fuse_in_head.as_bytes(), + fuse_read_in.as_bytes(), + fuse_out_head.as_bytes(), &fuse_read_out, ); @@ -881,10 +879,10 @@ fn fuse_setattr( let fuse_out_head = FuseOutHeader::default(); let fuse_attr_out = FuseAttrOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_setattr_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_attr_out.as_bytes(), + fuse_in_head.as_bytes(), + fuse_setattr_in.as_bytes(), + fuse_out_head.as_bytes(), + fuse_attr_out.as_bytes(), ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -1012,10 +1010,10 @@ fn unlink_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_unlink_out = FuseEntryOut::default(); let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_unlink_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_unlink_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_unlink_out.as_bytes(), ); // Check. @@ -1026,7 +1024,7 @@ fn unlink_test() { linkpath.push_str("/shared/testfile"); let linkpath_clone = linkpath.clone(); let link_path = Path::new(&linkpath_clone); - assert_eq!(link_path.exists(), false); + assert!(!link_path.exists()); // kill process and clean env. fs.testcase_end(virtiofs_test_dir); @@ -1060,10 +1058,10 @@ fn rmdir_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_unlink_out = FuseEntryOut::default(); let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_unlink_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_unlink_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_unlink_out.as_bytes(), ); // Check. @@ -1074,7 +1072,7 @@ fn rmdir_test() { linkpath.push_str("/shared/dir"); let linkpath_clone = linkpath.clone(); let link_path = Path::new(&linkpath_clone); - assert_eq!(link_path.exists(), false); + assert!(!link_path.exists()); // kill process and clean env. fs.testcase_end(virtiofs_test_dir); @@ -1100,10 +1098,10 @@ fn symlink_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_init_out = FuseEntryOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_init_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_init_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_init_out.as_bytes(), ); // Check. @@ -1117,18 +1115,18 @@ fn symlink_test() { linkpath.push_str("/shared/link"); let linkpath_clone = linkpath.clone(); let link_path = Path::new(&linkpath_clone); - assert_eq!(link_path.is_symlink(), true); + assert!(link_path.is_symlink()); // Read link - let node_id = fuse_lookup(&fs, linkname.clone()); + let node_id = fuse_lookup(&fs, linkname); let len = size_of::() as u32; let fuse_in_head = FuseInHeader::new(len, FUSE_READLINK, 8, node_id, 0, 0, 0, 0); let fuse_out_head = FuseOutHeader::default(); let fuse_read_link_out = [0_u8; 1024]; let (_, _, outheader, outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), + Some(fuse_in_head.as_bytes()), None, - Some(&fuse_out_head.as_bytes()), + Some(fuse_out_head.as_bytes()), Some(&fuse_read_link_out), ); @@ -1168,9 +1166,9 @@ fn fallocate_test() { }; let fuse_out_head = FuseOutHeader::default(); let (_, _, outheader, _outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), - Some(&fuse_fallocate_in.as_bytes()), - Some(&fuse_out_head.as_bytes()), + Some(fuse_in_head.as_bytes()), + Some(fuse_fallocate_in.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); @@ -1214,10 +1212,10 @@ fn posix_file_lock_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_lk_out = FuseLkOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_lk_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_lk_out.as_bytes(), + fuse_in_head.as_bytes(), + fuse_lk_in.as_bytes(), + fuse_out_head.as_bytes(), + fuse_lk_out.as_bytes(), ); // Check file is unlock. @@ -1244,10 +1242,10 @@ fn posix_file_lock_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_lk_out = FuseLkOut::default(); let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_lk_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_lk_out.as_bytes(), + fuse_in_head.as_bytes(), + fuse_lk_in.as_bytes(), + fuse_out_head.as_bytes(), + fuse_lk_out.as_bytes(), ); // check. @@ -1281,10 +1279,10 @@ fn mknod_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_init_out = FuseEntryOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_mknod_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_init_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_init_out.as_bytes(), ); // Check. @@ -1316,9 +1314,9 @@ fn get_xattr(fs: &VirtioFsTest, name: String, nodeid: u64) -> (FuseOutHeader, St let fuse_out_head = FuseOutHeader::default(); let fuse_out = [0_u8; DEFAULT_XATTR_SIZE as usize]; let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_in.as_bytes(), - &fuse_out_head.as_bytes(), + fuse_out_head.as_bytes(), &fuse_out, ); @@ -1326,7 +1324,7 @@ fn get_xattr(fs: &VirtioFsTest, name: String, nodeid: u64) -> (FuseOutHeader, St let fuse_read_out = fs .state .borrow() - .memread(outbodyaddr, DEFAULT_XATTR_SIZE as u64); + .memread(outbodyaddr, u64::from(DEFAULT_XATTR_SIZE)); let attr = String::from_utf8(fuse_read_out).unwrap(); (out_header, attr) @@ -1343,9 +1341,9 @@ fn flush_file(fs: &VirtioFsTest, nodeid: u64, fh: u64) { }; let fuse_out_head = FuseOutHeader::default(); let (_, _, outheader, _) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), - Some(&fuse_in.as_bytes()), - Some(&fuse_out_head.as_bytes()), + Some(fuse_in_head.as_bytes()), + Some(fuse_in.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); @@ -1361,10 +1359,10 @@ fn write_file(fs: &VirtioFsTest, nodeid: u64, fh: u64, write_buf: String) { let fuse_out_head = FuseOutHeader::default(); let fuse_write_out = FuseWriteOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_write_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_write_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_write_out.as_bytes(), ); let out_header = read_obj::(fs.state.clone(), outheaderaddr); @@ -1384,9 +1382,9 @@ fn release_file(fs: &VirtioFsTest, nodeid: u64, fh: u64) { }; let fuse_out_head = FuseOutHeader::default(); let (_, _, outheader, _) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), - Some(&fuse_read_in.as_bytes()), - Some(&fuse_out_head.as_bytes()), + Some(fuse_in_head.as_bytes()), + Some(fuse_read_in.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); @@ -1407,10 +1405,10 @@ fn create_file(fs: &VirtioFsTest, name: String) -> (FuseOutHeader, FuseCreateOut let fuse_out_head = FuseOutHeader::default(); let fuse_out = FuseCreateOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_out.as_bytes(), ); // Check. @@ -1470,9 +1468,9 @@ fn read_file(fs: &VirtioFsTest, nodeid: u64, fh: u64) -> String { let fuse_out_head = FuseOutHeader::default(); let fuse_out = [0_u8; DEFAULT_READ_SIZE]; let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_in.as_bytes(), - &fuse_out_head.as_bytes(), + fuse_in_head.as_bytes(), + fuse_in.as_bytes(), + fuse_out_head.as_bytes(), &fuse_out, ); @@ -1480,9 +1478,8 @@ fn read_file(fs: &VirtioFsTest, nodeid: u64, fh: u64) -> String { let out_header = read_obj::(fs.state.clone(), outheaderaddr); assert_eq!(out_header.error, 0); let fuse_read_out = fs.state.borrow().memread(outbodyaddr, 5); - let str = String::from_utf8(fuse_read_out).unwrap(); - str + String::from_utf8(fuse_read_out).unwrap() } #[test] @@ -1495,7 +1492,7 @@ fn openfile_test() { // start vm. let fs = VirtioFsTest::new(TEST_MEM_SIZE, TEST_PAGE_SIZE, virtiofs_sock); fuse_init(&fs); - let nodeid = fuse_lookup(&fs, file.clone()); + let nodeid = fuse_lookup(&fs, file); // open/write/flush/close/open/read/close let fh = fuse_open(&fs, nodeid); @@ -1537,9 +1534,9 @@ fn rename_test() { let fuse_in_head = FuseInHeader::new(len, FUSE_RENAME, 0, PARENT_NODEID, 0, 0, 0, 0); let fuse_out_head = FuseOutHeader::default(); let (_, _, outheader, _outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), + Some(fuse_in_head.as_bytes()), Some(&fuse_rename_in.as_bytes()), - Some(&fuse_out_head.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); @@ -1575,10 +1572,10 @@ fn link_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_entry_out = FuseEntryOut::default(); let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_rename_in.as_bytes(), - &fuse_out_head.as_bytes(), - &fuse_entry_out.as_bytes(), + fuse_out_head.as_bytes(), + fuse_entry_out.as_bytes(), ); // Check. @@ -1607,10 +1604,10 @@ fn statfs_test() { let fuse_out_head = FuseOutHeader::default(); let fuse_statfs_out = FuseKstatfs::default(); let (_, _, outheader, _outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), + Some(fuse_in_head.as_bytes()), None, - Some(&fuse_out_head.as_bytes()), - Some(&fuse_statfs_out.as_bytes()), + Some(fuse_out_head.as_bytes()), + Some(fuse_statfs_out.as_bytes()), ); // Check. @@ -1640,9 +1637,9 @@ fn virtio_fs_fuse_ioctl_test() { let fuse_in_head = FuseInHeader::new(len, FUSE_IOCTL, 0, nodeid, 0, 0, 0, 0); let fuse_out_head = FuseOutHeader::default(); let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &[0], - &fuse_out_head.as_bytes(), + fuse_out_head.as_bytes(), &[0], ); @@ -1670,9 +1667,9 @@ fn virtio_fs_fuse_abnormal_test() { let fuse_out_head = FuseOutHeader::default(); let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &[0], - &fuse_out_head.as_bytes(), + fuse_out_head.as_bytes(), &[0], ); @@ -1716,15 +1713,13 @@ fn fuse_setxattr(fs: &VirtioFsTest, name: String, value: String, nodeid: u64) -> }; let fuse_out_head = FuseOutHeader::default(); let (_, _, outheader, _outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), + Some(fuse_in_head.as_bytes()), Some(&fuse_setxattr_in.as_bytes()), - Some(&fuse_out_head.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); - let out_header = read_obj::(fs.state.clone(), outheader.unwrap()); - - out_header + read_obj::(fs.state.clone(), outheader.unwrap()) } fn fuse_removexattr(fs: &VirtioFsTest, name: String, nodeid: u64) -> FuseOutHeader { @@ -1733,14 +1728,13 @@ fn fuse_removexattr(fs: &VirtioFsTest, name: String, nodeid: u64) -> FuseOutHead let fuse_removexattr_in = FuseRemoveXattrIn { name }; let fuse_out_head = FuseOutHeader::default(); let (_, _, outheader, _outbodyaddr) = fs.do_virtio_request( - Some(&fuse_in_head.as_bytes()), + Some(fuse_in_head.as_bytes()), Some(&fuse_removexattr_in.as_bytes()), - Some(&fuse_out_head.as_bytes()), + Some(fuse_out_head.as_bytes()), None, ); - let out_header = read_obj::(fs.state.clone(), outheader.unwrap()); - out_header + read_obj::(fs.state.clone(), outheader.unwrap()) } fn fuse_listxattr(fs: &VirtioFsTest, nodeid: u64) -> (FuseOutHeader, u64) { @@ -1755,9 +1749,9 @@ fn fuse_listxattr(fs: &VirtioFsTest, nodeid: u64) -> (FuseOutHeader, u64) { let fuse_out_head = FuseOutHeader::default(); let fuse_out = [0_u8; DEFAULT_XATTR_SIZE as usize]; let (outheaderaddr, outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &fuse_in.as_bytes(), - &fuse_out_head.as_bytes(), + fuse_out_head.as_bytes(), &fuse_out, ); @@ -1804,11 +1798,11 @@ fn regularfile_xattr_test() { let attr_list = fs .state .borrow() - .memread(outbodyaddr, DEFAULT_XATTR_SIZE as u64); + .memread(outbodyaddr, u64::from(DEFAULT_XATTR_SIZE)); // The first attr is "security.selinux" let (_attr1, next1) = read_cstring(attr_list.clone(), 0); // The next attrs are what we set by FUSE_SETXATTR. Check it. - let (attr2, _next2) = read_cstring(attr_list.clone(), next1); + let (attr2, _next2) = read_cstring(attr_list, next1); assert_eq!(attr2.unwrap(), testattr_name); // REMOVEXATTR @@ -1917,7 +1911,7 @@ fn fuse_batch_forget(fs: &VirtioFsTest, nodeid: u64, trim: usize) { ] .concat(); let (_, _) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), + fuse_in_head.as_bytes(), &data_bytes[0..data_bytes.len() - trim], &[0], &[0], @@ -2009,9 +2003,9 @@ fn virtio_fs_fuse_setlkw_test() { } let (outheaderaddr, _outbodyaddr) = fs.virtiofs_do_virtio_request( - &fuse_in_head.as_bytes(), - &fuse_lk_in_bytes, - &fuse_out_head.as_bytes(), + fuse_in_head.as_bytes(), + fuse_lk_in_bytes, + fuse_out_head.as_bytes(), &[0], ); diff --git a/tests/mod_test/tests/vnc_test.rs b/tests/mod_test/tests/vnc_test.rs index e4f4fdf16be1a0ad8bf9a6d413dee02141c86f1c..61868f9e08f06f6a3e44351e252efc680f8f0c30 100644 --- a/tests/mod_test/tests/vnc_test.rs +++ b/tests/mod_test/tests/vnc_test.rs @@ -77,7 +77,7 @@ fn test_set_area_dirty() { let pf = RfbPixelFormat::new(32, 8, 0_u8, 1_u8, 255, 255, 255, 16, 8, 0); assert!(vnc_client.test_set_pixel_format(pf).is_ok()); assert!(vnc_client - .test_update_request(UpdateState::Incremental, 0, 0, 640 as u16, 480 as u16,) + .test_update_request(UpdateState::Incremental, 0, 0, 640_u16, 480_u16,) .is_ok()); demo_gpu.borrow_mut().update_image_area(0, 0, 64, 64); demo_gpu.borrow_mut().set_area_dirty(0, 0, 64, 64); @@ -93,7 +93,7 @@ fn test_set_area_dirty() { demo_gpu.borrow_mut().update_image_area(0, 0, 64, 64); demo_gpu.borrow_mut().set_area_dirty(0, 0, 64, 64); assert!(vnc_client - .test_update_request(UpdateState::Incremental, 0, 0, 640 as u16, 480 as u16,) + .test_update_request(UpdateState::Incremental, 0, 0, 640_u16, 480_u16,) .is_ok()); let res = vnc_client.test_recv_server_data(pf); @@ -110,7 +110,7 @@ fn test_set_area_dirty() { let pf = RfbPixelFormat::new(32, 8, 0_u8, 1_u8, 255, 255, 255, 16, 8, 0); assert!(vnc_client.test_set_pixel_format(pf).is_ok()); assert!(vnc_client - .test_update_request(UpdateState::Incremental, 0, 0, 640 as u16, 480 as u16,) + .test_update_request(UpdateState::Incremental, 0, 0, 640_u16, 480_u16,) .is_ok()); demo_gpu.borrow_mut().update_image_area(0, 0, 64, 64); demo_gpu.borrow_mut().set_area_dirty(0, 0, 64, 64); @@ -171,7 +171,7 @@ fn test_set_multiple_area_dirty() { demo_gpu.borrow_mut().update_image_area(119, 120, 160, 160); demo_gpu.borrow_mut().set_area_dirty(0, 0, 640, 480); assert!(vnc_client - .test_update_request(UpdateState::NotIncremental, 0, 0, 640 as u16, 480 as u16,) + .test_update_request(UpdateState::NotIncremental, 0, 0, 640_u16, 480_u16,) .is_ok()); let res = vnc_client.test_recv_server_data(pf); @@ -193,7 +193,7 @@ fn test_set_multiple_area_dirty() { demo_gpu.borrow_mut().update_image_area(119, 120, 160, 160); demo_gpu.borrow_mut().set_area_dirty(0, 0, 640, 480); assert!(vnc_client - .test_update_request(UpdateState::NotIncremental, 0, 0, 640 as u16, 480 as u16,) + .test_update_request(UpdateState::NotIncremental, 0, 0, 640_u16, 480_u16,) .is_ok()); let res = vnc_client.test_recv_server_data(pf); @@ -418,7 +418,7 @@ fn test_set_pixel_format() { // Raw + bit_per_pixel=32 + true_color_flag=1. let pf = RfbPixelFormat::new(32, 8, 0_u8, 1_u8, 255, 255, 255, 16, 8, 0); - assert!(vnc_client.test_set_pixel_format(pf.clone()).is_ok()); + assert!(vnc_client.test_set_pixel_format(pf).is_ok()); assert!(vnc_client .test_update_request(UpdateState::NotIncremental, 0, 0, 2560, 2048) .is_ok()); @@ -430,7 +430,7 @@ fn test_set_pixel_format() { // Raw + bit_per_pixel=16 + true_color_flag=1. let pf = RfbPixelFormat::new(16, 8, 0_u8, 1_u8, 255, 255, 255, 16, 8, 0); - assert!(vnc_client.test_set_pixel_format(pf.clone()).is_ok()); + assert!(vnc_client.test_set_pixel_format(pf).is_ok()); assert!(vnc_client .test_update_request(UpdateState::NotIncremental, 0, 0, 2560, 2048) .is_ok()); @@ -443,12 +443,12 @@ fn test_set_pixel_format() { // Raw + bit_per_pixel=8 + true_color_flag=0. let pf = RfbPixelFormat::new(8, 8, 0_u8, 0_u8, 255, 255, 255, 16, 8, 0); - assert!(vnc_client.test_set_pixel_format(pf.clone()).is_ok()); + assert!(vnc_client.test_set_pixel_format(pf).is_ok()); assert!(vnc_client .test_update_request(UpdateState::NotIncremental, 0, 0, 2560, 2048) .is_ok()); - let res = vnc_client.test_recv_server_data(pf.clone()); + let res = vnc_client.test_recv_server_data(pf); assert!(res.is_ok()); let res = res.unwrap(); assert!(res.contains(&(RfbServerMsg::FramebufferUpdate, EncodingType::EncodingRaw))); @@ -463,7 +463,7 @@ fn test_set_pixel_format() { .is_ok()); assert!(vnc_client.stream_read_to_end().is_ok()); let pf = RfbPixelFormat::new(32, 8, 0_u8, 1_u8, 255, 255, 255, 16, 8, 0); - assert!(vnc_client.test_set_pixel_format(pf.clone()).is_ok()); + assert!(vnc_client.test_set_pixel_format(pf).is_ok()); assert!(vnc_client .test_update_request(UpdateState::NotIncremental, 0, 0, 2560, 2048) .is_ok()); @@ -481,7 +481,7 @@ fn test_set_pixel_format() { .is_ok()); assert!(vnc_client.stream_read_to_end().is_ok()); let pf = RfbPixelFormat::new(8, 8, 0_u8, 1_u8, 255, 255, 255, 16, 8, 0); - assert!(vnc_client.test_set_pixel_format(pf.clone()).is_ok()); + assert!(vnc_client.test_set_pixel_format(pf).is_ok()); assert!(vnc_client .test_update_request(UpdateState::NotIncremental, 0, 0, 2560, 2048) .is_ok()); @@ -498,7 +498,7 @@ fn test_set_pixel_format() { .is_ok()); assert!(vnc_client.stream_read_to_end().is_ok()); let pf = RfbPixelFormat::new(8, 8, 0_u8, 0_u8, 255, 255, 255, 16, 8, 0); - assert!(vnc_client.test_set_pixel_format(pf.clone()).is_ok()); + assert!(vnc_client.test_set_pixel_format(pf).is_ok()); assert!(vnc_client .test_update_request(UpdateState::NotIncremental, 0, 0, 2560, 2048) .is_ok()); @@ -554,12 +554,12 @@ fn test_vnc_kbd_mouse() { assert!(vnc_client.connect(TestAuthType::VncAuthNone).is_ok()); // Key event. for &(name, keysym, keycode) in KEYEVENTLIST.iter() { - assert!(vnc_client.test_key_event(0, keysym as u32).is_ok()); + assert!(vnc_client.test_key_event(0, u32::from(keysym)).is_ok()); let msg = input.borrow_mut().read_input_event(); println!("key {:?}: {:?}", name, msg); assert_eq!(msg.keycode, keycode); assert_eq!(msg.down, 0); - assert!(vnc_client.test_key_event(1, keysym as u32).is_ok()); + assert!(vnc_client.test_key_event(1, u32::from(keysym)).is_ok()); let msg = input.borrow_mut().read_input_event(); println!("key {:?}: {:?}", name, msg); @@ -715,9 +715,7 @@ fn test_rfb_version_abnormal(test_state: Rc>, port: u16) -> R assert!(vnc_client.read_msg(&mut buf, 12).is_ok()); assert_eq!(buf[..12].to_vec(), "RFB 003.008\n".as_bytes().to_vec()); println!("Client Rfb version: RFB 003.010"); - assert!(vnc_client - .write_msg(&"RFB 003.010\n".as_bytes().to_vec()) - .is_ok()); + assert!(vnc_client.write_msg("RFB 003.010\n".as_bytes()).is_ok()); buf.drain(..12); // VNC server closed connection. let res = vnc_client.epoll_wait(EventSet::READ_HANG_UP); @@ -738,9 +736,7 @@ fn test_unsupported_sec_type(test_state: Rc>, port: u16) -> R println!("Connect to server."); assert!(vnc_client.read_msg(&mut buf, 12).is_ok()); assert_eq!(buf[..12].to_vec(), "RFB 003.008\n".as_bytes().to_vec()); - assert!(vnc_client - .write_msg(&"RFB 003.008\n".as_bytes().to_vec()) - .is_ok()); + assert!(vnc_client.write_msg("RFB 003.008\n".as_bytes()).is_ok()); buf.drain(..12); // Step 2: Auth num is 1. @@ -751,7 +747,7 @@ fn test_unsupported_sec_type(test_state: Rc>, port: u16) -> R assert!(vnc_client.read_msg(&mut buf, auth_num as usize).is_ok()); buf.drain(..auth_num as usize); assert!(vnc_client - .write_msg(&(TestAuthType::Invalid as u8).to_be_bytes().to_vec()) + .write_msg((TestAuthType::Invalid as u8).to_be_bytes().as_ref()) .is_ok()); // VNC server close the connection. let res = vnc_client.epoll_wait(EventSet::READ_HANG_UP); @@ -770,7 +766,7 @@ fn test_set_pixel_format_abnormal(test_state: Rc>, port: u16) let mut vnc_client = create_new_client(test_state, port).unwrap(); assert!(vnc_client.connect(TestAuthType::VncAuthNone).is_ok()); let pf = RfbPixelFormat::new(17, 8, 0_u8, 1_u8, 255, 255, 255, 16, 8, 0); - assert!(vnc_client.test_set_pixel_format(pf.clone()).is_ok()); + assert!(vnc_client.test_set_pixel_format(pf).is_ok()); // VNC server close the connection. let res = vnc_client.epoll_wait(EventSet::READ_HANG_UP)?; @@ -789,7 +785,7 @@ fn test_set_encoding_abnormal(test_state: Rc>, port: u16) -> assert!(vnc_client.connect(TestAuthType::VncAuthNone).is_ok()); assert!(vnc_client.test_setup_encodings(Some(100), None).is_ok()); // Send a qmp to query vnc client state. - let value = qmp_query_vnc(test_state.clone()); + let value = qmp_query_vnc(test_state); let client_num = value["return"]["clients"].as_array().unwrap().len(); assert_eq!(client_num, 1); assert!(vnc_client.disconnect().is_ok()); @@ -809,7 +805,7 @@ fn test_client_cut_event(test_state: Rc>, port: u16) -> Resul }; assert!(vnc_client.test_send_client_cut(client_cut).is_ok()); // Send a qmp to query vnc client state. - let value = qmp_query_vnc(test_state.clone()); + let value = qmp_query_vnc(test_state); let client_num = value["return"]["clients"].as_array().unwrap().len(); assert_eq!(client_num, 1); assert!(vnc_client.disconnect().is_ok()); diff --git a/tests/mod_test/tests/x86_64/cpu_hotplug_test.rs b/tests/mod_test/tests/x86_64/cpu_hotplug_test.rs index b3b14123d43b64914491e60c7e28d63cff702bd4..6a279843d33dec4d5208778773ee93bf19865699 100644 --- a/tests/mod_test/tests/x86_64/cpu_hotplug_test.rs +++ b/tests/mod_test/tests/x86_64/cpu_hotplug_test.rs @@ -40,7 +40,7 @@ fn set_up(cpu: u8, max_cpus: Option) -> TestState { args = cpu_args[..].split(' ').collect(); extra_args.append(&mut args); - let mem_args = format!("-m 512"); + let mem_args = "-m 512".to_string(); args = mem_args[..].split(' ').collect(); extra_args.append(&mut args); @@ -48,11 +48,12 @@ fn set_up(cpu: u8, max_cpus: Option) -> TestState { extra_args.push("root=/dev/vda panic=1"); let uefi_drive = - format!("-drive file=/usr/share/edk2/ovmf/OVMF_CODE.fd,if=pflash,unit=0,readonly=true"); + "-drive file=/usr/share/edk2/ovmf/OVMF_CODE.fd,if=pflash,unit=0,readonly=true".to_string(); args = uefi_drive[..].split(' ').collect(); extra_args.append(&mut args); - let root_device = format!("-device pcie-root-port,port=0x0,addr=0x1.0x0,bus=pcie.0,id=pcie.1"); + let root_device = + "-device pcie-root-port,port=0x0,addr=0x1.0x0,bus=pcie.0,id=pcie.1".to_string(); args = root_device[..].split(' ').collect(); extra_args.append(&mut args); diff --git a/trace/Cargo.toml b/trace/Cargo.toml index 267a35661e967ede87fe07741b8e95cd71bc93e9..4fd31409de9fea6bf25f507922ef9a66d9c9d9c6 100644 --- a/trace/Cargo.toml +++ b/trace/Cargo.toml @@ -12,7 +12,8 @@ lazy_static = "1.4.0" regex = "1" anyhow = "1.0" trace_generator = { path = "trace_generator" } -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" +libloading = "0.7.4" [features] trace_to_logger = [] diff --git a/trace/src/hitrace.rs b/trace/src/hitrace.rs index 1cff1cd7cefed5fc767e05ef14b55dae7e3c8e44..625ea6209c3a1e48d6fc0e654f87d5fe0ed1174d 100644 --- a/trace/src/hitrace.rs +++ b/trace/src/hitrace.rs @@ -10,27 +10,84 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::ffi::OsStr; + +use anyhow::{Context, Result}; +use lazy_static::lazy_static; +use libloading::os::unix::Symbol; +use libloading::Library; +use log::error; + const HITRACE_TAG_VIRSE: u64 = 1u64 << 11; -#[link(name = "hitrace_meter")] -extern "C" { - fn StartTraceWrapper(label: u64, value: *const u8); - fn FinishTrace(label: u64); - fn StartAsyncTraceWrapper(label: u64, value: *const u8, taskId: i32); - fn FinishAsyncTraceWrapper(label: u64, value: *const u8, taskId: i32); +lazy_static! { + static ref HITRACE_FUNC_TABLE: HitraceFuncTable = + // SAFETY: The dynamic library should be always existing. + unsafe { + HitraceFuncTable::new(OsStr::new("libhitrace_meter.so")) + .map_err(|e| { + error!("failed to init HitraceFuncTable with error: {:?}", e); + e + }) + .unwrap() + }; +} + +macro_rules! get_libfn { + ( $lib: ident, $tname: ident, $fname: ident ) => { + $lib.get::<$tname>(stringify!($fname).as_bytes()) + .with_context(|| format!("failed to get function {}", stringify!($fname)))? + .into_raw() + }; +} + +type StartTraceWrapperFn = unsafe extern "C" fn(u64, *const u8); +type FinishTraceFn = unsafe extern "C" fn(u64); +type StartAsyncTraceWrapperFn = unsafe extern "C" fn(u64, *const u8, i32); +type FinishAsyncTraceWrapperFn = unsafe extern "C" fn(u64, *const u8, i32); + +struct HitraceFuncTable { + pub start_trace: Symbol, + pub finish_trace: Symbol, + pub start_trace_async: Symbol, + pub finish_trace_async: Symbol, +} + +impl HitraceFuncTable { + unsafe fn new(library_name: &OsStr) -> Result { + let library = + Library::new(library_name).with_context(|| "failed to load hitrace_meter library")?; + + Ok(Self { + start_trace: get_libfn!(library, StartTraceWrapperFn, StartTraceWrapper), + finish_trace: get_libfn!(library, FinishTraceFn, FinishTrace), + start_trace_async: get_libfn!( + library, + StartAsyncTraceWrapperFn, + StartAsyncTraceWrapper + ), + finish_trace_async: get_libfn!( + library, + FinishAsyncTraceWrapperFn, + FinishAsyncTraceWrapper + ), + }) + } } pub fn start_trace(value: &str) { if let Ok(value_ptr) = std::ffi::CString::new(value) { // SAFETY: All parameters have been checked. - unsafe { StartTraceWrapper(HITRACE_TAG_VIRSE, value_ptr.as_ptr() as *const u8) } + unsafe { + (HITRACE_FUNC_TABLE.start_trace)(HITRACE_TAG_VIRSE, value_ptr.as_ptr() as *const u8) + } } } pub fn finish_trace() { // SAFETY: All parameters have been checked. unsafe { - FinishTrace(HITRACE_TAG_VIRSE); + (HITRACE_FUNC_TABLE.finish_trace)(HITRACE_TAG_VIRSE); } } @@ -38,7 +95,11 @@ pub fn start_trace_async(value: &str, task_id: i32) { if let Ok(value_ptr) = std::ffi::CString::new(value) { // SAFETY: All parameters have been checked. unsafe { - StartAsyncTraceWrapper(HITRACE_TAG_VIRSE, value_ptr.as_ptr() as *const u8, task_id) + (HITRACE_FUNC_TABLE.start_trace_async)( + HITRACE_TAG_VIRSE, + value_ptr.as_ptr() as *const u8, + task_id, + ) } } } @@ -47,7 +108,11 @@ pub fn finish_trace_async(value: &str, task_id: i32) { if let Ok(value_ptr) = std::ffi::CString::new(value) { // SAFETY: All parameters have been checked. unsafe { - FinishAsyncTraceWrapper(HITRACE_TAG_VIRSE, value_ptr.as_ptr() as *const u8, task_id) + (HITRACE_FUNC_TABLE.finish_trace_async)( + HITRACE_TAG_VIRSE, + value_ptr.as_ptr() as *const u8, + task_id, + ) } } } diff --git a/trace/src/lib.rs b/trace/src/lib.rs index 1501a5028b69881841b7b3cf569b04c85c3322a2..ba8aa2ab1ef761ce1ce1e4a8d5ddf9a6aed14492 100644 --- a/trace/src/lib.rs +++ b/trace/src/lib.rs @@ -19,7 +19,7 @@ pub(crate) mod hitrace; feature = "trace_to_ftrace", all(target_env = "ohos", feature = "trace_to_hitrace") ))] -pub(crate) mod trace_scope; +pub mod trace_scope; use std::{ fmt, @@ -30,6 +30,7 @@ use std::{ use anyhow::{Ok, Result}; use lazy_static::lazy_static; +use log::warn; use regex::Regex; use vmm_sys_util::eventfd::EventFd; @@ -37,16 +38,33 @@ use trace_generator::{ add_trace_state_to, gen_trace_event_func, gen_trace_scope_func, gen_trace_state, }; +#[derive(PartialEq, Eq)] +pub enum TraceType { + Event, + Scope, + Unknown, +} + struct TraceState { name: String, + trace_type: TraceType, get_state: fn() -> bool, set_state: fn(bool), } impl TraceState { - fn new(name: String, get_state: fn() -> bool, set_state: fn(bool)) -> Self { + fn new(name: String, type_str: &str, get_state: fn() -> bool, set_state: fn(bool)) -> Self { + let trace_type = match type_str { + "event" => TraceType::Event, + "scope" => TraceType::Scope, + _ => { + warn!("The type of {} is Unknown: {}", name, type_str); + TraceType::Unknown + } + }; TraceState { name, + trace_type, get_state, set_state, } @@ -73,6 +91,15 @@ impl TraceStateSet { Ok(()) } + fn enable_state_by_type(&self, trace_type: TraceType) -> Result<()> { + for state in &self.state_list { + if state.trace_type == trace_type { + (state.set_state)(true); + } + } + Ok(()) + } + fn get_state_by_pattern(&self, pattern: String) -> Result> { let re = Regex::new(&pattern)?; let mut ret: Vec<(String, bool)> = Vec::new(); @@ -126,3 +153,7 @@ pub fn get_state_by_pattern(pattern: String) -> Result> { pub fn set_state_by_pattern(pattern: String, state: bool) -> Result<()> { TRACE_STATE_SET.set_state_by_pattern(pattern, state) } + +pub fn enable_state_by_type(trace_type: TraceType) -> Result<()> { + TRACE_STATE_SET.enable_state_by_type(trace_type) +} diff --git a/trace/src/trace_scope.rs b/trace/src/trace_scope.rs index 4a02101705ea8468eb9d087a935d618d5ff6ed46..a860ef8d7dfdda3be65263d201c7761db7ca150e 100644 --- a/trace/src/trace_scope.rs +++ b/trace/src/trace_scope.rs @@ -20,12 +20,14 @@ use crate::ftrace::write_trace_marker; static mut TRACE_SCOPE_COUNTER: AtomicI32 = AtomicI32::new(i32::MIN); +#[derive(Clone)] pub enum Scope { Common(TraceScope), Asyn(TraceScopeAsyn), None, } +#[derive(Clone)] pub struct TraceScope {} impl TraceScope { @@ -63,6 +65,7 @@ impl Drop for TraceScope { } } +#[derive(Clone)] pub struct TraceScopeAsyn { value: String, id: i32, diff --git a/trace/trace_generator/src/lib.rs b/trace/trace_generator/src/lib.rs index 008263cc0e72ac1568f7adb9ad34239a19e6ea23..ce6ce12618f9d13802ea93438e1f346a739ddc42 100644 --- a/trace/trace_generator/src/lib.rs +++ b/trace/trace_generator/src/lib.rs @@ -56,24 +56,24 @@ fn get_trace_desc() -> TraceConf { #[proc_macro] pub fn add_trace_state_to(input: TokenStream) -> TokenStream { let trace_conf = get_trace_desc(); - let mut state_name = Vec::new(); + let mut state = Vec::new(); for desc in trace_conf.events.unwrap_or_default() { if desc.enabled { - state_name.push(desc.name.trim().to_string()); + state.push((desc.name.trim().to_string(), "event")); } } for desc in trace_conf.scopes.unwrap_or_default() { if desc.enabled { - state_name.push(desc.name.trim().to_string()); + state.push((desc.name.trim().to_string(), "scope")); } } let set = parse_macro_input!(input as Ident); - let init_code = state_name.iter().map(|name| { + let init_code = state.iter().map(|(name, type_str)| { let get_func = parse_str::(format!("get_{}_state", name).as_str()).unwrap(); let set_func = parse_str::(format!("set_{}_state", name).as_str()).unwrap(); quote!( - #set.add_trace_state(TraceState::new(#name.to_string(), #get_func, #set_func)); + #set.add_trace_state(TraceState::new(#name.to_string(), #type_str, #get_func, #set_func)); ) }); TokenStream::from(quote! { #( #init_code )* }) @@ -216,6 +216,11 @@ pub fn gen_trace_scope_func(_input: TokenStream) -> TokenStream { } }; + let func_decl = match desc.enabled { + true => quote!(pub fn #func_name(asyn: bool, #func_args) -> trace_scope::Scope), + false => quote!(pub fn #func_name(asyn: bool, #func_args)), + }; + let message_args = match desc.args.is_empty() { true => quote!(), false => { @@ -258,7 +263,7 @@ pub fn gen_trace_scope_func(_input: TokenStream) -> TokenStream { all(target_env = "ohos", feature = "trace_to_hitrace") ))] #[inline(always)] - pub fn #func_name(asyn: bool, #func_args) -> trace_scope::Scope { + #func_decl { #func_body } diff --git a/trace/trace_info/acpi.toml b/trace/trace_info/acpi.toml new file mode 100644 index 0000000000000000000000000000000000000000..9b4352aaabca473eaf0addd6e761e49e05f18488 --- /dev/null +++ b/trace/trace_info/acpi.toml @@ -0,0 +1,23 @@ +[[events]] +name = "ged_inject_acpi_event" +args = "event: u32" +message = "acpi_sevent {}" +enabled = true + +[[events]] +name = "ged_read" +args = "event: u32" +message = "acpi_sevent {}" +enabled = true + +[[events]] +name = "power_read" +args = "reg_idx: u64, value: u32" +message = "reg_idx {} value {}" +enabled = true + +[[events]] +name = "power_status_read" +args = "regs: &dyn fmt::Debug" +message = "regs {:?}" +enabled = true diff --git a/trace/trace_info/camera.toml b/trace/trace_info/camera.toml index 4e0ee65deafe6a4be1234cd033f9782908931556..e1f044ffb8d32c898acd946cf8229a04eadbe94d 100644 --- a/trace/trace_info/camera.toml +++ b/trace/trace_info/camera.toml @@ -21,3 +21,15 @@ name = "camera_get_format_by_index" args = "format_index: u8, frame_index: u8, out: &dyn fmt::Debug" message = "V4l2 fmt {}, frm {}, info {:?}." enabled = true + +[[scopes]] +name = "ohcam_get_frame" +args = "offset: usize, len: usize" +message = "ohcam get frame offset {} len {}" +enabled = true + +[[scopes]] +name = "ohcam_next_frame" +args = "frame_id: u64" +message = "ohcam next frame {}" +enabled = true diff --git a/trace/trace_info/device_legacy.toml b/trace/trace_info/device_legacy.toml index 14ed94332e548cf7e203b8d0432e76871a40928d..42bbae08d9431af9cffeae7d494e42a8f4d0d216 100644 --- a/trace/trace_info/device_legacy.toml +++ b/trace/trace_info/device_legacy.toml @@ -201,3 +201,27 @@ name = "pflash_write_data" args = "offset: u64, size: usize, value: &[u8], counter: u32" message = "data offset: 0x{:04x}, size: {}, value: 0x{:x?}, counter: 0x{:04x}" enabled = true + +[[events]] +name = "fwcfg_select_entry" +args = "key: u16, key_name: &'static str, ret: i32" +message = "key_value {} key_name {:?} ret {}" +enabled = true + +[[events]] +name = "fwcfg_add_entry" +args = "key: u16, key_name: &'static str, data: Vec" +message = "key_value {} key_name {:?} data {:?}" +enabled = true + +[[events]] +name = "fwcfg_read_data" +args = "value: u64" +message = "value {}" +enabled = true + +[[events]] +name = "fwcfg_add_file" +args = "index: usize, filename: &str, data_len: usize" +message = "index {} filename {:?} data_len {}" +enabled = true diff --git a/trace/trace_info/memory.toml b/trace/trace_info/memory.toml new file mode 100644 index 0000000000000000000000000000000000000000..c7d8171c5d0ac5dd8d285908e42b7008c15cd03f --- /dev/null +++ b/trace/trace_info/memory.toml @@ -0,0 +1,29 @@ +[[events]] +name = "address_space_read" +args = "addr: &dyn fmt::Debug, count: u64" +message = "Memory: flatview_read addr {:?}, count {}" +enabled = true + +[[events]] +name = "address_space_write" +args = "addr: &dyn fmt::Debug, count: u64" +message = "Memory: flatview_write addr {:?}, count {}" +enabled = true + +[[scopes]] +name = "address_update_topology" +args = "" +message = "Memory: update opology" +enabled = true + +[[scopes]] +name = "pre_alloc" +args = "size: u64" +message = "Memory: pre_alloc ram size is {}" +enabled = true + +[[scopes]] +name = "init_memory" +args = "" +message = "Memory: init memory" +enabled = true diff --git a/trace/trace_info/misc.toml b/trace/trace_info/misc.toml index 78ac9d19168bfb1171525969ab0de01977feb889..28cf5132667509a234bec29f8968fbd17269c4ce 100644 --- a/trace/trace_info/misc.toml +++ b/trace/trace_info/misc.toml @@ -27,3 +27,63 @@ name = "scream_setup_alsa_hwp" args = "name: &str, hwp: &dyn fmt::Debug" message = "scream {} setup hardware parameters: {:?}" enabled = true + +[[events]] +name = "oh_scream_render_init" +args = "context: &dyn fmt::Debug" +message = "context: {:?}" +enabled = true + +[[events]] +name = "oh_scream_render_destroy" +args = "" +message = "" +enabled = true + +[[events]] +name = "oh_scream_capture_init" +args = "context: &dyn fmt::Debug" +message = "context: {:?}" +enabled = true + +[[events]] +name = "oh_scream_capture_destroy" +args = "" +message = "" +enabled = true + +[[events]] +name = "oh_scream_on_write_data_cb" +args = "len: usize" +message = "len: {}" +enabled = true + +[[events]] +name = "oh_scream_on_read_data_cb" +args = "len: usize" +message = "len: {}" +enabled = true + +[[scopes]] +name = "ohaudio_render_process" +args = "data: &dyn fmt::Debug" +message = "audio data {:?} to render" +enabled = true + +[[scopes]] +name = "ohaudio_capturer_process" +args = "data: &dyn fmt::Debug" +message = "audio data {:?} to capture" +enabled = true + +[[scopes]] +name = "ohaudio_write_cb" +args = "to_copy: usize" +message = "OH audio expect audio data {} bytes" +enabled = true + +[[scopes]] +name = "ohaudio_read_cb" +args = "len: i32" +message = "OH audio captured {} bytes" +enabled = true diff --git a/trace/trace_info/pci.toml b/trace/trace_info/pci.toml index dfd7642fab0c15d803eaae2039ffcbbe4ceb2788..be76599f3165de03b03852e405ee0f327c0b4801 100644 --- a/trace/trace_info/pci.toml +++ b/trace/trace_info/pci.toml @@ -15,15 +15,3 @@ name = "msix_write_config" args = "dev_id: u16, masked: bool, enabled: bool" message = "dev id: {} masked: {} enabled: {}" enabled = true - -[[events]] -name = "pci_update_mappings_add" -args = "bar_id: usize, addr: u64, size: u64" -message = "bar id: {} addr: 0x{:#X} size: {}" -enabled = true - -[[events]] -name = "pci_update_mappings_del" -args = "bar_id: usize, addr: u64, size: u64" -message = "bar id: {} addr: 0x{:#X} size: {}" -enabled = true diff --git a/trace/trace_info/ui.toml b/trace/trace_info/ui.toml index 091c05648a3a07515979c9f49f587cddbfc25328..1df64556aed17bfc6491229948f624b2daba92d1 100644 --- a/trace/trace_info/ui.toml +++ b/trace/trace_info/ui.toml @@ -219,3 +219,63 @@ name = "console_select" args = "con_id: &dyn fmt::Debug" message = "console id={:?}" enabled = true + +[[events]] +name = "oh_event_mouse_button" +args = "msg_btn: u32, action: u32" +message = "msg_btn={} action={}" +enabled = true + +[[events]] +name = "oh_event_mouse_motion" +args = "x: f64, y: f64" +message = "x={} y={}" +enabled = true + +[[events]] +name = "oh_event_keyboard" +args = "keycode: u16, key_action: u16" +message = "keycode={} key_action={}" +enabled = true + +[[events]] +name = "oh_event_windowinfo" +args = "width: u32, height: u32" +message = "width={} height={}" +enabled = true + +[[events]] +name = "oh_event_scroll" +args = "direction: u32" +message = "direction={}" +enabled = true + +[[events]] +name = "oh_event_ledstate" +args = "state: u32" +message = "state={}" +enabled = true + +[[events]] +name = "oh_event_focus" +args = "state: u32" +message = "state={}" +enabled = true + +[[events]] +name = "oh_event_greet" +args = "id: u64" +message = "token_id={}" +enabled = true + +[[events]] +name = "oh_event_unsupported_type" +args = "ty: &dyn fmt::Debug, size: u32" +message = "type={:?} body_size={}" +enabled = true + +[[scopes]] +name = "handle_msg" +args = "opcode: &dyn fmt::Debug" +message = "handle ohui {:?} message" +enabled = true diff --git a/trace/trace_info/usb.toml b/trace/trace_info/usb.toml index fa5b94542d70a899d40d908be5ec25a230e0608c..9defe6d287c5195b74d4f605c66bbfc5c4aecbf9 100644 --- a/trace/trace_info/usb.toml +++ b/trace/trace_info/usb.toml @@ -232,6 +232,36 @@ args = "str: &dyn fmt::Debug" message = "{:?}" enabled = true +[[events]] +name = "usb_xhci_set_state" +args = "ep_id: u32, new_state: u32" +message = "Endpoint {} set new state {}." +enabled = true + +[[events]] +name = "usb_xhci_update_dequeue" +args = "ep_id: u32, dequeue: u64, stream_id: u32" +message = "Endpoint {} update dequeue {} on Stream ID {}." +enabled = true + +[[events]] +name = "usb_xhci_reset_streams" +args = "ep_id: u32" +message = "Resetting streams on Endpoint {}." +enabled = true + +[[events]] +name = "usb_xhci_get_ring" +args = "ep_id: u32, stream_id: u32" +message = "Found Transfer ring on Endpoint {} Stream ID {}." +enabled = true + +[[events]] +name = "usb_xhci_get_stream" +args = "stream_id: u32, ep_id: u32" +message = "Found Stream Context {} for Endpoint {}." +enabled = true + [[events]] name = "usb_handle_control" args = "device: &str, req: &dyn fmt::Debug" @@ -465,3 +495,111 @@ name = "usb_host_req_complete" args = "bus_num: u8, addr: u8, packet: u64, status: &dyn fmt::Debug, actual_length: usize" message = "dev bus 0x{:X} addr 0x{:X}, packet 0x{:#X}, status {:?} actual length {}" enabled = true + +[[events]] +name = "usb_uas_handle_control" +args = "packet_id: u32, device_id: &str, req: &[u8]" +message = "USB {} packet received on UAS {} device, the request is {:?}." +enabled = true + +[[events]] +name = "usb_uas_handle_iu_command" +args = "device_id: &str, cdb: u8" +message = "UAS {} device handling IU with cdb[0] {}." +enabled = true + +[[events]] +name = "usb_uas_fill_sense" +args = "status: u8, iu_len: usize, sense_len: usize" +message = "UAS device is filling sense with status {:02} URB length {} sense length {}." +enabled = true + +[[events]] +name = "usb_uas_fill_fake_sense" +args = "status: u8, iu_len: usize, sense_len: usize" +message = "UAS device is filling fake sense with status {:02} URB length {} sense length {}." +enabled = true + +[[events]] +name = "usb_uas_fill_packet" +args = "iovec_size: usize" +message = "UAS device is filling USB packet with iovec of size {}." +enabled = true + +[[events]] +name = "usb_uas_try_start_next_transfer" +args = "device_id: &str, xfer_len: i64" +message = "UAS {} device is trying to start next transfer of length {}." +enabled = true + +[[events]] +name = "usb_uas_start_next_transfer" +args = "device_id: &str, stream: usize" +message = "UAS {} device starting a transfer on stream {}." +enabled = true + +[[events]] +name = "usb_uas_handle_data" +args = "device_id: &str, endpoint: u8, stream: usize" +message = "UAS {} device handling data on endpoint {} and stream {}." +enabled = true + +[[events]] +name = "usb_uas_command_received" +args = "packet_id: u32, device_id: &str" +message = "USB {} command packet received on UAS {} device." +enabled = true + +[[events]] +name = "usb_uas_command_completed" +args = "packet_id: u32, device_id: &str" +message = "USB {} command packet completed on UAS {} device." +enabled = true + +[[events]] +name = "usb_uas_status_received" +args = "packet_id: u32, device_id: &str" +message = "USB {} status packet received on UAS {} device." +enabled = true + +[[events]] +name = "usb_uas_status_completed" +args = "packet_id: u32, device_id: &str" +message = "USB {} status packet completed on UAS {} device." +enabled = true + +[[events]] +name = "usb_uas_status_queued_async" +args = "packet_id: u32, device_id: &str" +message = "USB {} status packet queued async on UAS {} device." +enabled = true + +[[events]] +name = "usb_uas_data_received" +args = "packet_id: u32, device_id: &str" +message = "USB {} data packet received on UAS {} device." +enabled = true + +[[events]] +name = "usb_uas_data_completed" +args = "packet_id: u32, device_id: &str" +message = "USB {} data packet completed on UAS {} device." +enabled = true + +[[events]] +name = "usb_uas_data_queued_async" +args = "packet_id: u32, device_id: &str" +message = "USB {} data packet queued async on UAS {} device." +enabled = true + +[[events]] +name = "usb_uas_handle_iu_task_management" +args = "device_id: &str, tmf: u8, tag: u16" +message = "UAS {} device handling TMF {} with tag {}." +enabled = true + +[[events]] +name = "usb_uas_tmf_abort_task" +args = "device_id: &str, task_tag: usize" +message = "UAS {} device aborting task with tag {}." +enabled = true diff --git a/trace/trace_info/virtio.toml b/trace/trace_info/virtio.toml index 7bff17c5e9aaa065a04a09dade2731bad878b195..5d3240b00c317b5c30b641a79e771e0136a7a5fc 100644 --- a/trace/trace_info/virtio.toml +++ b/trace/trace_info/virtio.toml @@ -297,3 +297,27 @@ name = "vhost_delete_mem_range_failed" args = "" message = "Vhost: deleting mem region failed: not matched." enabled = true + +[[events]] +name = "auto_msg_evt_handler" +args = "" +message = "Balloon: handle auto balloon message" +enabled = true + +[[events]] +name = "reporting_evt_handler" +args = "" +message = "Balloon: handle fpr message" +enabled = true + +[[events]] +name = "virtio_read_object_direct" +args = "host_addr: u64, count: usize" +message = "Memory: virtio_read_object_direct host_addr {}, count {}" +enabled = true + +[[events]] +name = "virtio_write_object_direct" +args = "host_addr: u64, count: usize" +message = "Memory: virtio_write_object_direct host_addr {}, count {}" +enabled = true diff --git a/ui/Cargo.toml b/ui/Cargo.toml index 56f2b6f67c27c9006f684b42907219773b822c9b..b7faa42406adbf0ba8776cfcc0d385c97f8d1860 100644 --- a/ui/Cargo.toml +++ b/ui/Cargo.toml @@ -12,7 +12,7 @@ anyhow = "1.0" libc = "0.2" log = "0.4" serde_json = "1.0" -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" once_cell = "1.18.0" sscanf = "0.4.1" bitintr = "0.3.0" diff --git a/ui/src/console.rs b/ui/src/console.rs index 97dc06e98e0ca4bb26240aa503aa3b04a4c69ebb..7d0e27bc0f599b8e02291424ce42906850e21a2c 100644 --- a/ui/src/console.rs +++ b/ui/src/console.rs @@ -55,6 +55,10 @@ pub const DISPLAY_UPDATE_INTERVAL_INC: u64 = 50; /// Maximum refresh interval in ms. pub const DISPLAY_UPDATE_INTERVAL_MAX: u64 = 3_000; +pub const DEFAULT_CURSOR_WIDTH: usize = 32; +pub const DEFAULT_CURSOR_HEIGHT: usize = 32; +pub const DEFAULT_CURSOR_BPP: usize = 4; + pub enum ConsoleType { Graphic, Text, @@ -777,8 +781,8 @@ pub fn console_select(con_id: Option) -> Result<()> { /// * `height` - height of image. /// * `msg` - test messages showed in display. pub fn create_msg_surface(width: i32, height: i32, msg: String) -> Option { - if !(0..MAX_WINDOW_WIDTH as i32).contains(&width) - || !(0..MAX_WINDOW_HEIGHT as i32).contains(&height) + if !(0..i32::from(MAX_WINDOW_WIDTH)).contains(&width) + || !(0..i32::from(MAX_WINDOW_HEIGHT)).contains(&height) { error!("The size of image is invalid!"); return None; @@ -844,23 +848,23 @@ mod tests { #[test] fn test_console_select() { let con_opts = Arc::new(HwOpts {}); - let dev_name0 = format!("test_device0"); + let dev_name0 = "test_device0".to_string(); let con_0 = console_init(dev_name0, ConsoleType::Graphic, con_opts.clone()); let clone_con = con_0.clone(); assert_eq!( clone_con.unwrap().upgrade().unwrap().lock().unwrap().con_id, 0 ); - let dev_name1 = format!("test_device1"); + let dev_name1 = "test_device1".to_string(); let con_1 = console_init(dev_name1, ConsoleType::Graphic, con_opts.clone()); assert_eq!(con_1.unwrap().upgrade().unwrap().lock().unwrap().con_id, 1); - let dev_name2 = format!("test_device2"); + let dev_name2 = "test_device2".to_string(); let con_2 = console_init(dev_name2, ConsoleType::Graphic, con_opts.clone()); assert_eq!(con_2.unwrap().upgrade().unwrap().lock().unwrap().con_id, 2); assert!(console_close(&con_0).is_ok()); assert_eq!(CONSOLES.lock().unwrap().activate_id, Some(1)); - let dev_name3 = format!("test_device3"); - let con_3 = console_init(dev_name3, ConsoleType::Graphic, con_opts.clone()); + let dev_name3 = "test_device3".to_string(); + let con_3 = console_init(dev_name3, ConsoleType::Graphic, con_opts); assert_eq!(con_3.unwrap().upgrade().unwrap().lock().unwrap().con_id, 3); assert!(console_select(Some(0)).is_ok()); assert_eq!(CONSOLES.lock().unwrap().activate_id, Some(0)); @@ -891,10 +895,7 @@ mod tests { None, dcl_opts.clone(), ))); - let dcl_3 = Arc::new(Mutex::new(DisplayChangeListener::new( - None, - dcl_opts.clone(), - ))); + let dcl_3 = Arc::new(Mutex::new(DisplayChangeListener::new(None, dcl_opts))); assert!(register_display(&dcl_0).is_ok()); assert_eq!(dcl_0.lock().unwrap().dcl_id, Some(0)); diff --git a/ui/src/gtk/draw.rs b/ui/src/gtk/draw.rs index 9d9dde3cd961a9e764e6ec570e8510c2f746003e..b9cbfc6eee87151c85c7f810c8631e1bdeb3ecb3 100644 --- a/ui/src/gtk/draw.rs +++ b/ui/src/gtk/draw.rs @@ -218,7 +218,7 @@ fn da_event_callback(gs: &Rc>, event: &gdk::Event) -> fn gd_cursor_move_event(gs: &Rc>, event: &gdk::Event) -> Result<()> { let mut borrowed_gs = gs.borrow_mut(); let (width, height) = match &borrowed_gs.cairo_image { - Some(image) => (image.width() as f64, image.height() as f64), + Some(image) => (f64::from(image.width()), f64::from(image.height())), None => return Ok(()), }; @@ -231,8 +231,8 @@ fn gd_cursor_move_event(gs: &Rc>, event: &gdk::Event) let standard_x = ((real_x * (ABS_MAX as f64)) / width) as u16; let standard_y = ((real_y * (ABS_MAX as f64)) / height) as u16; - input_move_abs(Axis::X, standard_x as u32)?; - input_move_abs(Axis::Y, standard_y as u32)?; + input_move_abs(Axis::X, u32::from(standard_x))?; + input_move_abs(Axis::Y, u32::from(standard_y))?; input_point_sync() } @@ -304,7 +304,7 @@ fn da_draw_callback(gs: &Rc>, cr: &cairo::Context) -> let mut borrowed_gs = gs.borrow_mut(); let scale_mode = borrowed_gs.scale_mode.clone(); let (mut surface_width, mut surface_height) = match &borrowed_gs.cairo_image { - Some(image) => (image.width() as f64, image.height() as f64), + Some(image) => (f64::from(image.width()), f64::from(image.height())), None => return Ok(()), }; diff --git a/ui/src/gtk/menu.rs b/ui/src/gtk/menu.rs index e89b7a3ea307cea8c8458da5d31d4c0f0a5038bc..c1e4b6b64d13223939c46dca48057624bab061af 100644 --- a/ui/src/gtk/menu.rs +++ b/ui/src/gtk/menu.rs @@ -42,7 +42,7 @@ pub(crate) struct GtkMenu { pub(crate) window: ApplicationWindow, container: gtk::Box, pub(crate) note_book: gtk::Notebook, - pub(crate) radio_group: Vec, + pub(crate) radio_group: Rc>>, accel_group: AccelGroup, menu_bar: MenuBar, machine_menu: Menu, @@ -64,7 +64,7 @@ impl GtkMenu { window, container: gtk::Box::new(Orientation::Vertical, 0), note_book: gtk::Notebook::default(), - radio_group: vec![], + radio_group: Rc::new(RefCell::new(vec![])), accel_group: AccelGroup::default(), menu_bar: MenuBar::new(), machine_menu: Menu::new(), @@ -258,6 +258,11 @@ impl GtkMenu { self.zoom_fit.activate(); } + if let Some(page_num) = self.note_book.current_page() { + let radio_item = &self.radio_group.borrow()[page_num as usize]; + radio_item.activate(); + } + self.menu_bar.hide(); } } diff --git a/ui/src/gtk/mod.rs b/ui/src/gtk/mod.rs index ced3792bdc3df32ffeb24945616fe26e7aec1cf7..9c3fdcaa8e94939f39491662b29f9fa4c59744b0 100644 --- a/ui/src/gtk/mod.rs +++ b/ui/src/gtk/mod.rs @@ -18,6 +18,7 @@ use std::{ cmp, collections::HashMap, env, fs, + os::unix::fs::OpenOptionsExt, path::Path, ptr, rc::Rc, @@ -304,14 +305,17 @@ impl GtkDisplay { })); self.gtk_menu.view_menu.append(&gs_show_menu); - if !self.gtk_menu.radio_group.is_empty() { - let first_radio = &self.gtk_menu.radio_group[0]; + if !self.gtk_menu.radio_group.borrow().is_empty() { + let first_radio = &self.gtk_menu.radio_group.borrow()[0]; gs_show_menu.join_group(Some(first_radio)); } else { note_book.set_current_page(Some(page_num)); } - self.gtk_menu.radio_group.push(gs_show_menu.clone()); + self.gtk_menu + .radio_group + .borrow_mut() + .push(gs_show_menu.clone()); gs.borrow_mut().show_menu = gs_show_menu; gs.borrow_mut().draw_area = draw_area; @@ -426,8 +430,8 @@ impl GtkDisplayScreen { fn get_window_size(&self) -> Option<(f64, f64)> { if let Some(win) = self.draw_area.window() { - let w_width = win.width() as f64; - let w_height = win.height() as f64; + let w_width = f64::from(win.width()); + let w_height = f64::from(win.height()); if w_width.ne(&0.0) && w_height.ne(&0.0) { return Some((w_width, w_height)); @@ -449,8 +453,8 @@ impl GtkDisplayScreen { None => bail!("No display image."), }; let (scale_width, scale_height) = ( - (surface_width as f64) * self.scale_x, - (surface_height as f64) * self.scale_y, + f64::from(surface_width) * self.scale_x, + f64::from(surface_height) * self.scale_y, ); let (mut window_width, mut window_height) = (0.0, 0.0); @@ -458,7 +462,7 @@ impl GtkDisplayScreen { (window_width, window_height) = (w, h); }; let scale_factor = match self.draw_area.window() { - Some(window) => window.scale_factor() as f64, + Some(window) => f64::from(window.scale_factor()), None => bail!("No display window."), }; @@ -813,13 +817,13 @@ fn do_update_event(gs: &Rc>, event: DisplayChangeEvent drop(locked_con); // Image scalling. - let x1 = ((x as f64) * borrowed_gs.scale_x).floor(); - let y1 = ((y as f64) * borrowed_gs.scale_y).floor(); - let x2 = ((x as f64) * borrowed_gs.scale_x + (w as f64) * borrowed_gs.scale_x).ceil(); - let y2 = ((y as f64) * borrowed_gs.scale_y + (h as f64) * borrowed_gs.scale_y).ceil(); + let x1 = (f64::from(x) * borrowed_gs.scale_x).floor(); + let y1 = (f64::from(y) * borrowed_gs.scale_y).floor(); + let x2 = (f64::from(x) * borrowed_gs.scale_x + f64::from(w) * borrowed_gs.scale_x).ceil(); + let y2 = (f64::from(y) * borrowed_gs.scale_y + f64::from(h) * borrowed_gs.scale_y).ceil(); - let scale_width = (surface_width as f64) * borrowed_gs.scale_x; - let scale_height = (surface_height as f64) * borrowed_gs.scale_y; + let scale_width = f64::from(surface_width) * borrowed_gs.scale_x; + let scale_height = f64::from(surface_height) * borrowed_gs.scale_y; let (window_width, window_height); match borrowed_gs.get_window_size() { Some((w, h)) => (window_width, window_height) = (w, h), @@ -961,8 +965,8 @@ fn do_switch_event(gs: &Rc>) -> Result<()> { None => return Ok(()), }; if scale_mode.borrow().is_full_screen() || scale_mode.borrow().is_free_scale() { - borrowed_gs.scale_x = window_width / surface_width as f64; - borrowed_gs.scale_y = window_height / surface_height as f64; + borrowed_gs.scale_x = window_width / f64::from(surface_width); + borrowed_gs.scale_y = window_height / f64::from(surface_height); } // Vm desktop manage its own cursor, gtk cursor need to be trsp firstly. @@ -998,7 +1002,7 @@ pub(crate) fn update_window_size(gs: &Rc>) -> Result<( let borrowed_gs = gs.borrow(); let scale_mode = borrowed_gs.scale_mode.borrow().clone(); let (width, height) = match &borrowed_gs.cairo_image { - Some(image) => (image.width() as f64, image.height() as f64), + Some(image) => (f64::from(image.width()), f64::from(image.height())), None => (0.0, 0.0), }; let (mut scale_width, mut scale_height) = if scale_mode.is_free_scale() { @@ -1006,8 +1010,8 @@ pub(crate) fn update_window_size(gs: &Rc>) -> Result<( } else { (width * borrowed_gs.scale_x, height * borrowed_gs.scale_y) }; - scale_width = scale_width.max(DEFAULT_SURFACE_WIDTH as f64); - scale_height = scale_height.max(DEFAULT_SURFACE_HEIGHT as f64); + scale_width = scale_width.max(f64::from(DEFAULT_SURFACE_WIDTH)); + scale_height = scale_height.max(f64::from(DEFAULT_SURFACE_HEIGHT)); let geo: Geometry = Geometry::new( scale_width as i32, @@ -1101,6 +1105,10 @@ fn create_file(gpu_info: &mut GpuInfo, dev_name: &String) -> Result { gpu_info.fileDir = file_dir.clone(); let nsec = gettime()?.1; let file_name = file_dir + "/stratovirt-display-" + dev_name + "-" + &nsec.to_string() + ".png"; - let file = fs::File::create(file_name)?; + let file = fs::OpenOptions::new() + .create_new(true) + .write(true) + .mode(0o600) + .open(file_name)?; Ok(file) } diff --git a/ui/src/input.rs b/ui/src/input.rs index d540c998e18d0960eb9f22df4e13fb3cefe7540e..c59926be70370fd4366dbaf130380a34be4adf24 100644 --- a/ui/src/input.rs +++ b/ui/src/input.rs @@ -34,6 +34,7 @@ pub const INPUT_BUTTON_WHEEL_DOWN: u32 = 0x40; pub const INPUT_BUTTON_WHEEL_LEFT: u32 = 0x80; pub const INPUT_BUTTON_WHEEL_RIGHT: u32 = 0x100; pub const INPUT_BUTTON_MASK: u32 = 0x1f; +pub const INPUT_POINT_MAX: u32 = INPUT_POINT_FORWARD; // ASCII value. pub const ASCII_A: i32 = 65; @@ -252,23 +253,6 @@ impl KeyBoardState { #[derive(Default)] struct LedState { kbd_led: u8, - sync: Option>, -} - -pub trait SyncLedstate: Send + Sync { - fn sync_to_host(&self, state: u8) { - debug!("ledstate in guest is {}", state); - } -} - -impl LedState { - fn register_led_sync(&mut self, sync: Arc) { - self.sync = Some(sync); - } - - fn unregister_led_sync(&mut self) { - self.sync = None; - } } #[derive(Default)] @@ -278,6 +262,7 @@ struct Inputs { tablet_ids: Vec, tablet_lists: HashMap>>, keyboard_state: KeyBoardState, + btn_state: u32, } impl Inputs { @@ -343,14 +328,14 @@ impl Inputs { } Ok(()) } -} - -pub fn register_led_sync(sync: Arc) { - LED_STATE.lock().unwrap().register_led_sync(sync); -} -pub fn unregister_led_sync() { - LED_STATE.lock().unwrap().unregister_led_sync(); + fn update_button_state(&mut self, btn: u32, press: bool) { + if press { + self.btn_state |= btn; + } else { + self.btn_state &= !btn; + } + } } pub fn register_keyboard(device: &str, kbd: Arc>) { @@ -390,11 +375,26 @@ pub fn input_button(button: u32, down: bool) -> Result<()> { let mouse = INPUTS.lock().unwrap().get_active_mouse(); if let Some(m) = mouse { m.lock().unwrap().update_point_state(input_event)?; + INPUTS.lock().unwrap().update_button_state(button, down); } Ok(()) } +/// Release all pressed button. +pub fn release_all_btn() -> Result<()> { + let state = INPUTS.lock().unwrap().btn_state; + let mut checked_btn = INPUT_POINT_LEFT; + while checked_btn <= INPUT_POINT_MAX { + if (state & checked_btn) != 0 { + input_button(checked_btn, false)?; + input_point_sync()?; + } + checked_btn <<= 1; + } + Ok(()) +} + pub fn input_point_sync() -> Result<()> { let mouse = INPUTS.lock().unwrap().get_active_mouse(); if let Some(m) = mouse { @@ -490,9 +490,6 @@ pub fn get_kbd_led_state() -> u8 { pub fn set_kbd_led_state(state: u8) { LED_STATE.lock().unwrap().kbd_led = state; - if let Some(sync_cb) = LED_STATE.lock().unwrap().sync.as_ref() { - sync_cb.sync_to_host(state); - } } pub fn keyboard_modifier_get(key_mod: KeyboardModifier) -> bool { @@ -626,7 +623,7 @@ mod tests { assert!(key_event(12, true).is_ok()); assert_eq!(test_kdb.lock().unwrap().keycode, 12); - assert_eq!(test_kdb.lock().unwrap().down, true); + assert!(test_kdb.lock().unwrap().down); // Test point event. assert_eq!(test_mouse.lock().unwrap().button, 0); @@ -670,14 +667,14 @@ mod tests { let keysym_lists: Vec = vec![0x07D0, 0x07E1, 0x0802]; let keycode_lists: Vec = keysym_lists .iter() - .map(|x| *keysym2qkeycode.get(&x).unwrap()) + .map(|x| *keysym2qkeycode.get(x).unwrap()) .collect(); for idx in 0..keysym_lists.len() { let keysym = keycode_lists[idx]; let keycode = keycode_lists[idx]; - assert!(do_key_event(true, keysym as i32, keycode).is_ok()); + assert!(do_key_event(true, i32::from(keysym), keycode).is_ok()); assert_eq!(test_kdb.lock().unwrap().keycode, keycode); - assert_eq!(test_kdb.lock().unwrap().down, true); + assert!(test_kdb.lock().unwrap().down); } let locked_input = INPUTS.lock().unwrap(); diff --git a/ui/src/keycode.rs b/ui/src/keycode.rs index 18bbaa7c567892db355ed2ee110c336bdbed9b8a..16db2dbd431fa95fb2293a96c4d2c3c69ef9d73a 100644 --- a/ui/src/keycode.rs +++ b/ui/src/keycode.rs @@ -458,7 +458,7 @@ const KEY_CODE_OH: [(KeyCode, u16); 105] = [ (KeyCode::Home, 0x0821), (KeyCode::SysReq, 0x081F), (KeyCode::Right, 0x07DF), - (KeyCode::Menu, 0x09A2), + (KeyCode::Menu, 0x0813), (KeyCode::Prior, 0x0814), (KeyCode::Insert, 0x0823), (KeyCode::NumLock, 0x0836), diff --git a/ui/src/ohui_srv/channel.rs b/ui/src/ohui_srv/channel.rs index 861a16342042c0819bf77c92c7d3a686cfce5540..e55279c99852accc9da5cc8f0d69cffbfa07495f 100755 --- a/ui/src/ohui_srv/channel.rs +++ b/ui/src/ohui_srv/channel.rs @@ -10,119 +10,98 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::os::raw::c_void; +use std::io::{ErrorKind, Read, Write}; +use std::os::fd::AsRawFd; use std::os::unix::io::RawFd; -use std::sync::RwLock; -use anyhow::Result; -use libc::iovec; +use anyhow::{bail, Result}; use log::error; use util::byte_code::ByteCode; -use util::unix::UnixSock; +use util::socket::{SocketListener, SocketStream}; +use util::unix::limit_permission; pub struct OhUiChannel { - pub sock: RwLock, - pub path: String, + listener: SocketListener, + stream: Option, } impl OhUiChannel { - pub fn new(path: &str) -> Self { - OhUiChannel { - sock: RwLock::new(UnixSock::new(path)), - path: String::from(path), - } - } + pub fn new(path: &str) -> Result { + let listener = match SocketListener::bind_by_uds(path) { + Ok(l) => l, + Err(e) => bail!("Failed to create listener with path {}, {:?}", path, e), + }; + limit_permission(path).unwrap_or_else(|e| { + error!( + "Failed to limit permission for ohui-sock {}, err: {:?}", + path, e + ); + }); - pub fn bind(&self) -> Result<()> { - self.sock.write().unwrap().bind(true) + Ok(OhUiChannel { + listener, + stream: None, + }) } pub fn get_listener_raw_fd(&self) -> RawFd { - self.sock.read().unwrap().get_listener_raw_fd() + self.listener.as_raw_fd() } - pub fn get_stream_raw_fd(&self) -> RawFd { - self.sock.read().unwrap().get_stream_raw_fd() + pub fn get_stream_raw_fd(&self) -> Option { + self.stream.as_ref().map(|s| s.as_raw_fd()) } - pub fn set_nonblocking(&self, nb: bool) -> Result<()> { - self.sock.read().unwrap().set_nonblocking(nb) - } - - pub fn set_listener_nonblocking(&self, nb: bool) -> Result<()> { - self.sock.read().unwrap().listen_set_nonblocking(nb) - } - - pub fn accept(&self) -> Result<()> { - self.sock.write().unwrap().accept() + pub fn accept(&mut self) -> Result<()> { + self.stream = Some(self.listener.accept()?); + Ok(()) } - pub fn send(&self, data: *const u8, len: usize) -> Result { - let mut iovs = Vec::with_capacity(1); - iovs.push(iovec { - iov_base: data as *mut c_void, - iov_len: len, - }); - let ret = self.sock.read().unwrap().send_msg(&mut iovs, &[])?; - Ok(ret) + pub fn disconnect(&mut self) { + self.stream = None; } +} - pub fn send_by_obj(&self, obj: &T) -> Result<()> { - let slice = obj.as_bytes(); - let mut left = slice.len(); - let mut count = 0_usize; +pub fn recv_slice(stream: &mut dyn Read, data: &mut [u8]) -> Result { + let len = data.len(); + let mut ret = 0_usize; - while left > 0 { - let buf = &slice[count..]; - match self.send(buf.as_ptr(), left) { - Ok(n) => { - left -= n; - count += n; - } - Err(e) => { - if std::io::Error::last_os_error().raw_os_error().unwrap() == libc::EAGAIN { - continue; - } - return Err(e); + while ret < len { + match stream.read(&mut data[ret..len]) { + Ok(0) => break, + Ok(n) => ret += n, + Err(e) => { + let ek = e.kind(); + if ek != ErrorKind::WouldBlock && ek != ErrorKind::Interrupted { + bail!("recv_slice: error occurred: {:?}", e); } + break; } } - Ok(()) } + Ok(ret) +} - pub fn recv_slice(&self, data: &mut [u8]) -> Result { - let len = data.len(); - if len == 0 { - return Ok(0); - } - let ret = self.recv(data.as_mut_ptr(), len); - match ret { - Ok(n) => Ok(n), +pub fn send_obj(stream: &mut dyn Write, obj: &T) -> Result<()> { + let slice = obj.as_bytes(); + let mut left = slice.len(); + let mut count = 0_usize; + + while left > 0 { + match stream.write(&slice[count..]) { + Ok(n) => { + left -= n; + count += n; + } Err(e) => { - if std::io::Error::last_os_error() - .raw_os_error() - .unwrap_or(libc::EIO) - != libc::EAGAIN - { - error!("recv_slice(): error occurred: {}", e); + let ek = e.kind(); + if ek == ErrorKind::WouldBlock || ek == ErrorKind::Interrupted { + continue; } - Ok(0) + bail!(e); } } } - - pub fn recv(&self, data: *mut u8, len: usize) -> Result { - let mut iovs = Vec::with_capacity(1); - iovs.push(iovec { - iov_base: data as *mut c_void, - iov_len: len, - }); - - let ret = self.sock.read().unwrap().recv_msg(&mut iovs, &mut []); - match ret { - Ok((n, _)) => Ok(n), - Err(e) => Err(e.into()), - } - } + Ok(()) } diff --git a/ui/src/ohui_srv/mod.rs b/ui/src/ohui_srv/mod.rs index b5689b5759cc70e4e23375acc48a96ed4122182c..7fc8b631330a3950a626eeb503ae71b67b1a8edb 100755 --- a/ui/src/ohui_srv/mod.rs +++ b/ui/src/ohui_srv/mod.rs @@ -14,7 +14,6 @@ pub mod channel; pub mod msg; pub mod msg_handle; -use std::mem::size_of; use std::os::unix::io::RawFd; use std::path::Path; use std::ptr; @@ -24,7 +23,7 @@ use std::sync::{ Arc, Mutex, RwLock, }; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Context, Result}; use log::{error, info}; use once_cell::sync::OnceCell; use vmm_sys_util::epoll::EventSet; @@ -35,7 +34,6 @@ use crate::{ DisplayChangeListenerOperations, DisplayMouse, DisplaySurface, DISPLAY_UPDATE_INTERVAL_DEFAULT, }, - input::{register_led_sync, unregister_led_sync}, pixman::{bytes_per_pixel, get_image_data, ref_pixman_image, unref_pixman_image}, }; use address_space::FileBackend; @@ -52,7 +50,7 @@ use util::{ NotifierOperation, }, pixman::{pixman_format_code_t, pixman_image_t}, - unix::do_mmap, + unix::{do_mmap, limit_permission}, }; #[derive(Debug, Clone)] @@ -93,9 +91,9 @@ pub struct OhUiServer { // guest surface for framebuffer surface: RwLock, // transfer channel via unix sock - channel: Arc, + channel: Arc>, // message handler - msg_handler: Arc, + msg_handler: OhUiMsgHandler, // connected or not connected: AtomicBool, // iothread processing unix socket @@ -111,13 +109,13 @@ pub struct OhUiServer { } impl OhUiServer { - fn init_channel(path: &String) -> Result> { + fn init_channel(path: &String) -> Result>> { let file_path = Path::new(path.as_str()).join("ohui.sock"); let sock_file = file_path .to_str() .ok_or_else(|| anyhow!("init_channel: Failed to get str from {}", path))?; TempCleaner::add_path(sock_file.to_string()); - Ok(Arc::new(OhUiChannel::new(sock_file))) + Ok(Arc::new(Mutex::new(OhUiChannel::new(sock_file)?))) } fn init_fb_file(path: &String) -> Result<(Option, u64)> { @@ -127,6 +125,12 @@ impl OhUiServer { .ok_or_else(|| anyhow!("init_fb_file: Failed to get str from {}", path))?; let fb_backend = FileBackend::new_mem(fb_file, VIRTIO_GPU_ENABLE_BAR0_SIZE)?; TempCleaner::add_path(fb_file.to_string()); + limit_permission(fb_file).unwrap_or_else(|e| { + error!( + "Failed to limit permission for ohui-fb {}, err: {:?}", + fb_file, e + ); + }); let host_addr = do_mmap( &Some(fb_backend.file.as_ref()), @@ -147,6 +151,12 @@ impl OhUiServer { .ok_or_else(|| anyhow!("init_cursor_file: Failed to get str from {}", path))?; let cursor_backend = FileBackend::new_mem(cursor_file, CURSOR_SIZE)?; TempCleaner::add_path(cursor_file.to_string()); + limit_permission(cursor_file).unwrap_or_else(|e| { + error!( + "Failed to limit permission for ohui-cursor {}, err: {:?}", + cursor_file, e + ); + }); let cursorbuffer = do_mmap( &Some(cursor_backend.file.as_ref()), @@ -160,16 +170,16 @@ impl OhUiServer { Ok(cursorbuffer) } - pub fn new(path: String) -> Result { - let channel = Self::init_channel(&path)?; - let (fb_file, framebuffer) = Self::init_fb_file(&path)?; - let cursorbuffer = Self::init_cursor_file(&path)?; + pub fn new(ui_path: String, sock_path: String) -> Result { + let channel = Self::init_channel(&sock_path)?; + let (fb_file, framebuffer) = Self::init_fb_file(&ui_path)?; + let cursorbuffer = Self::init_cursor_file(&ui_path)?; Ok(OhUiServer { passthru: OnceCell::new(), surface: RwLock::new(GuestSurface::new()), - channel: channel.clone(), - msg_handler: Arc::new(OhUiMsgHandler::new(channel)), + channel, + msg_handler: OhUiMsgHandler::new(), connected: AtomicBool::new(false), iothread: OnceCell::new(), cursorbuffer, @@ -186,8 +196,8 @@ impl OhUiServer { } #[inline(always)] - fn get_channel(&self) -> &OhUiChannel { - self.channel.as_ref() + fn get_channel(&self) -> Arc> { + self.channel.clone() } #[inline(always)] @@ -202,17 +212,22 @@ impl OhUiServer { self.msg_handler.handle_msg(self.token_id.clone()) } - fn raw_update_dirty_area( + // check dirty area data before call it. + unsafe fn raw_update_dirty_area( &self, surface_data: *mut u32, stride: i32, pos: (i32, i32), size: (i32, i32), + force_copy: bool, ) { let (x, y) = pos; let (w, h) = size; - if self.framebuffer == 0 || *self.passthru.get_or_init(|| false) { + if self.framebuffer == 0 + || surface_data.is_null() + || (!force_copy && *self.passthru.get_or_init(|| false)) + { return; } @@ -254,9 +269,10 @@ impl OhUiServer { fn set_connect(&self, conn: bool) { self.connected.store(conn, Ordering::Relaxed); if conn { - register_led_sync(self.msg_handler.clone()); + self.msg_handler.update_sock(self.channel.clone()); } else { - unregister_led_sync(); + self.channel.lock().unwrap().disconnect(); + self.msg_handler.reset(); } } @@ -269,6 +285,15 @@ impl OhUiServer { impl DisplayChangeListenerOperations for OhUiServer { fn dpy_switch(&self, surface: &DisplaySurface) -> Result<()> { + let height = surface.height() as u64; + let stride = surface.stride() as u64; + if self.framebuffer != 0 && height * stride > VIRTIO_GPU_ENABLE_BAR0_SIZE { + bail!( + "surface size is larger than ohui buffer size {}", + VIRTIO_GPU_ENABLE_BAR0_SIZE + ); + } + let mut locked_surface = self.surface.write().unwrap(); unref_pixman_image(locked_surface.guest_image); @@ -280,12 +305,16 @@ impl DisplayChangeListenerOperations for OhUiServer { locked_surface.height = surface.height(); drop(locked_surface); let locked_surface = self.surface.read().unwrap(); - self.raw_update_dirty_area( - get_image_data(locked_surface.guest_image), - locked_surface.stride, - (0, 0), - (locked_surface.width, locked_surface.height), - ); + // SAFETY: Dirty area does not exceed surface buffer. + unsafe { + self.raw_update_dirty_area( + get_image_data(locked_surface.guest_image), + locked_surface.stride, + (0, 0), + (locked_surface.width, locked_surface.height), + true, + ) + }; if !self.connected() { return Ok(()); @@ -311,12 +340,24 @@ impl DisplayChangeListenerOperations for OhUiServer { return Ok(()); } - self.raw_update_dirty_area( - get_image_data(locked_surface.guest_image), - locked_surface.stride, - (x, y), - (w, h), - ); + if locked_surface.width < x + || locked_surface.height < y + || locked_surface.width < x.saturating_add(w) + || locked_surface.height < y.saturating_add(h) + { + bail!("dpy_image_update: invalid dirty area"); + } + + // SAFETY: We checked dirty area data before. + unsafe { + self.raw_update_dirty_area( + get_image_data(locked_surface.guest_image), + locked_surface.stride, + (x, y), + (w, h), + false, + ) + }; self.msg_handler .handle_dirty_area(x as u32, y as u32, w as u32, h as u32); @@ -330,14 +371,19 @@ impl DisplayChangeListenerOperations for OhUiServer { return Ok(()); } - let len = cursor.width * cursor.height * size_of::() as u32; - if len > CURSOR_SIZE as u32 { + let len = cursor + .width + .checked_mul(cursor.height) + .with_context(|| "Invalid cursor width * height")? + .checked_mul(bytes_per_pixel() as u32) + .with_context(|| "Invalid cursor size")?; + if len > CURSOR_SIZE as u32 || len > cursor.data.len().try_into()? { error!("Too large cursor length {}.", len); // No need to return Err for this situation is not fatal return Ok(()); } - // SAFETY: len is checked before copying,it's safe to do this. + // SAFETY: len is checked before copying, it's safe to do this. unsafe { ptr::copy_nonoverlapping( cursor.data.as_ptr(), @@ -351,7 +397,7 @@ impl DisplayChangeListenerOperations for OhUiServer { cursor.height, cursor.hot_x, cursor.hot_y, - size_of::() as u32, + bytes_per_pixel() as u32, ); Ok(()) } @@ -359,7 +405,7 @@ impl DisplayChangeListenerOperations for OhUiServer { pub fn ohui_init(ohui_srv: Arc, cfg: &DisplayConfig) -> Result<()> { // set iothread - ohui_srv.set_iothread(cfg.ohui_config.iothread.clone()); + ohui_srv.set_iothread(cfg.iothread.clone()); // Register ohui interface let dcl = Arc::new(Mutex::new(DisplayChangeListener::new( None, @@ -392,7 +438,12 @@ impl OhUiTrans { } fn get_fd(&self) -> RawFd { - self.server.get_channel().get_stream_raw_fd() + self.server + .get_channel() + .lock() + .unwrap() + .get_stream_raw_fd() + .unwrap() } } @@ -437,8 +488,6 @@ impl OhUiListener { } fn handle_connection(&self) -> Result<()> { - // Set stream sock with nonblocking - self.server.get_channel().set_nonblocking(true)?; // Register OhUiTrans read notifier ohui_register_event(OhUiTrans::new(self.server.clone()), self.server.clone())?; self.server.set_connect(true); @@ -448,11 +497,15 @@ impl OhUiListener { } fn accept(&self) -> Result<()> { - self.server.get_channel().accept() + self.server.get_channel().lock().unwrap().accept() } fn get_fd(&self) -> RawFd { - self.server.get_channel().get_listener_raw_fd() + self.server + .get_channel() + .lock() + .unwrap() + .get_listener_raw_fd() } } @@ -499,11 +552,7 @@ fn ohui_register_event(e: T, srv: Arc) -> Re } fn ohui_start_listener(server: Arc) -> Result<()> { - // Bind and set listener nonblocking - let channel = server.get_channel(); - channel.bind()?; - channel.set_listener_nonblocking(true)?; - ohui_register_event(OhUiListener::new(server.clone()), server.clone())?; + ohui_register_event(OhUiListener::new(server.clone()), server)?; info!("Successfully start listener."); Ok(()) } diff --git a/ui/src/ohui_srv/msg.rs b/ui/src/ohui_srv/msg.rs index c359d71f55b1593ee0ccbaa192011cb14075b933..1ce2d4465610e4509ee1edbc2f4c2b57249dabc7 100755 --- a/ui/src/ohui_srv/msg.rs +++ b/ui/src/ohui_srv/msg.rs @@ -110,6 +110,8 @@ impl ByteCode for MouseMotionEvent {} pub struct KeyboardEvent { pub key_action: u16, pub keycode: u16, + pub led_state: u8, + pad: [u8; 3], } impl ByteCode for KeyboardEvent {} @@ -132,12 +134,6 @@ pub struct LedstateEvent { impl ByteCode for LedstateEvent {} -impl LedstateEvent { - pub fn new(state: u32) -> Self { - LedstateEvent { state } - } -} - #[repr(C, packed)] #[derive(Debug, Default, Copy, Clone)] pub struct GreetEvent { diff --git a/ui/src/ohui_srv/msg_handle.rs b/ui/src/ohui_srv/msg_handle.rs index 9f063f1fe5144c1bd25dc4726c93c555abcd4cc6..47d2939f39edd4aa0af5629cf8ba6184aef44998 100755 --- a/ui/src/ohui_srv/msg_handle.rs +++ b/ui/src/ohui_srv/msg_handle.rs @@ -11,18 +11,23 @@ // See the Mulan PSL v2 for more details. use std::collections::HashMap; +use std::os::fd::{FromRawFd, RawFd}; +use std::os::unix::net::UnixStream; use std::sync::{Arc, Mutex, RwLock}; -use anyhow::{anyhow, bail, Result}; -use log::{error, warn}; +use anyhow::{anyhow, bail, Context, Result}; +use log::error; use util::byte_code::ByteCode; -use super::{channel::OhUiChannel, msg::*}; +use super::{ + channel::{recv_slice, send_obj, OhUiChannel}, + msg::*, +}; use crate::{ console::{get_active_console, graphic_hardware_ui_info}, input::{ self, get_kbd_led_state, input_button, input_move_abs, input_point_sync, keyboard_update, - release_all_key, trigger_key, Axis, SyncLedstate, ABS_MAX, CAPS_LOCK_LED, + release_all_btn, release_all_key, trigger_key, Axis, ABS_MAX, CAPS_LOCK_LED, INPUT_BUTTON_WHEEL_DOWN, INPUT_BUTTON_WHEEL_LEFT, INPUT_BUTTON_WHEEL_RIGHT, INPUT_BUTTON_WHEEL_UP, INPUT_POINT_BACK, INPUT_POINT_FORWARD, INPUT_POINT_LEFT, INPUT_POINT_MIDDLE, INPUT_POINT_RIGHT, KEYCODE_CAPS_LOCK, KEYCODE_NUM_LOCK, @@ -49,11 +54,20 @@ fn trans_mouse_pos(x: f64, y: f64, w: f64, h: f64) -> (u32, u32) { ) } +#[derive(Clone, Default)] +struct CursorState { + w: u32, + h: u32, + hot_x: u32, + hot_y: u32, + size_per_pixel: u32, +} + #[derive(Default)] struct WindowState { width: u32, height: u32, - led_state: Option, + cursor: CursorState, } impl WindowState { @@ -86,25 +100,16 @@ impl WindowState { } fn move_pointer(&mut self, x: f64, y: f64) -> Result<()> { - let (pos_x, pos_y) = trans_mouse_pos(x, y, self.width as f64, self.height as f64); + let (pos_x, pos_y) = trans_mouse_pos(x, y, f64::from(self.width), f64::from(self.height)); input_move_abs(Axis::X, pos_x)?; input_move_abs(Axis::Y, pos_y)?; input_point_sync() } - fn update_host_ledstate(&mut self, led: u8) { - self.led_state = Some(led); - } - - fn sync_kbd_led_state(&mut self) -> Result<()> { - if self.led_state.is_none() { - return Ok(()); - } - - let host_stat = self.led_state.unwrap(); + fn sync_kbd_led_state(&mut self, led: u8) -> Result<()> { let guest_stat = get_kbd_led_state(); - if host_stat != guest_stat { - let sync_bits = host_stat ^ guest_stat; + if led != guest_stat { + let sync_bits = led ^ guest_stat; if (sync_bits & CAPS_LOCK_LED) != 0 { trigger_key(KEYCODE_CAPS_LOCK)?; } @@ -115,61 +120,47 @@ impl WindowState { trigger_key(KEYCODE_SCR_LOCK)?; } } - self.led_state = None; Ok(()) } } +#[derive(Default)] pub struct OhUiMsgHandler { state: Mutex, hmcode2svcode: HashMap, - reader: Mutex, - writer: Mutex, -} - -impl SyncLedstate for OhUiMsgHandler { - fn sync_to_host(&self, state: u8) { - let body = LedstateEvent::new(state as u32); - if let Err(e) = self - .writer - .lock() - .unwrap() - .send_message(EventType::Ledstate, &body) - { - error!("sync_to_host: failed to send message with error {e}"); - } - } + reader: Mutex>, + writer: Mutex>, } impl OhUiMsgHandler { - pub fn new(channel: Arc) -> Self { + pub fn new() -> Self { OhUiMsgHandler { state: Mutex::new(WindowState::default()), hmcode2svcode: KeyCode::keysym_to_qkeycode(DpyMod::Ohui), - reader: Mutex::new(MsgReader::new(channel.clone())), - writer: Mutex::new(MsgWriter::new(channel)), + reader: Mutex::new(None), + writer: Mutex::new(None), } } + pub fn update_sock(&self, channel: Arc>) { + let fd = channel.lock().unwrap().get_stream_raw_fd().unwrap(); + *self.reader.lock().unwrap() = Some(MsgReader::new(fd)); + *self.writer.lock().unwrap() = Some(MsgWriter::new(fd)); + } + pub fn handle_msg(&self, token_id: Arc>) -> Result<()> { - let mut reader = self.reader.lock().unwrap(); + let mut locked_reader = self.reader.lock().unwrap(); + let reader = locked_reader + .as_mut() + .with_context(|| "handle_msg: no connection established")?; if !reader.recv()? { return Ok(()); } let hdr = &reader.header; - let body_size = hdr.size as usize; let event_type = hdr.event_type; - if body_size != event_msg_data_len(hdr.event_type) { - warn!( - "{:?} data len is wrong, we want {}, but receive {}", - event_type, - event_msg_data_len(hdr.event_type), - body_size - ); - reader.clear(); - return Ok(()); - } + let body_size = hdr.size as usize; + trace::trace_scope_start!(handle_msg, args = (&event_type)); let body_bytes = reader.body.as_ref().unwrap(); if let Err(e) = match event_type { @@ -196,20 +187,27 @@ impl OhUiMsgHandler { } EventType::Focus => { let body = FocusEvent::from_bytes(&body_bytes[..]).unwrap(); + trace::oh_event_focus(body.state); if body.state == CLIENT_FOCUSOUT_EVENT { reader.clear(); release_all_key()?; + release_all_btn()?; } Ok(()) } - EventType::Ledstate => { - let body = LedstateEvent::from_bytes(&body_bytes[..]).unwrap(); - self.handle_ledstate(body); - Ok(()) - } + EventType::Ledstate => Ok(()), EventType::Greet => { let body = GreetEvent::from_bytes(&body_bytes[..]).unwrap(); + trace::oh_event_greet(body.token_id); *token_id.write().unwrap() = body.token_id; + let cursor = self.state.lock().unwrap().cursor.clone(); + self.handle_cursor_define( + cursor.w, + cursor.h, + cursor.hot_x, + cursor.hot_y, + cursor.size_per_pixel, + ); Ok(()) } _ => { @@ -217,6 +215,7 @@ impl OhUiMsgHandler { "unsupported type {:?} and body size {}", event_type, body_size ); + trace::oh_event_unsupported_type(&event_type, body_size.try_into().unwrap()); Ok(()) } } { @@ -228,6 +227,7 @@ impl OhUiMsgHandler { fn handle_mouse_button(&self, mb: &MouseButtonEvent) -> Result<()> { let (msg_btn, action) = (mb.button, mb.btn_action); + trace::oh_event_mouse_button(msg_btn, action); let btn = match msg_btn { CLIENT_MOUSE_BUTTON_LEFT => INPUT_POINT_LEFT, CLIENT_MOUSE_BUTTON_RIGHT => INPUT_POINT_RIGHT, @@ -251,26 +251,33 @@ impl OhUiMsgHandler { hot_y: u32, size_per_pixel: u32, ) { - let body = HWCursorEvent::new(w, h, hot_x, hot_y, size_per_pixel); - if let Err(e) = self - .writer - .lock() - .unwrap() - .send_message(EventType::CursorDefine, &body) - { - error!("handle_cursor_define: failed to send message with error {e}"); + self.state.lock().unwrap().cursor = CursorState { + w, + h, + hot_x, + hot_y, + size_per_pixel, + }; + + if let Some(writer) = self.writer.lock().unwrap().as_mut() { + let body = HWCursorEvent::new(w, h, hot_x, hot_y, size_per_pixel); + if let Err(e) = writer.send_message(EventType::CursorDefine, &body) { + error!("handle_cursor_define: failed to send message with error {e}"); + } } } // NOTE: we only support absolute position info now, that means usb-mouse does not work. fn handle_mouse_motion(&self, mm: &MouseMotionEvent) -> Result<()> { + trace::oh_event_mouse_motion(mm.x, mm.y); self.state.lock().unwrap().move_pointer(mm.x, mm.y) } fn handle_keyboard(&self, ke: &KeyboardEvent) -> Result<()> { - if self.state.lock().unwrap().led_state.is_some() { - self.state.lock().unwrap().sync_kbd_led_state()?; - } + self.state + .lock() + .unwrap() + .sync_kbd_led_state(ke.led_state)?; let hmkey = ke.keycode; let keycode = match self.hmcode2svcode.get(&hmkey) { Some(k) => *k, @@ -278,6 +285,7 @@ impl OhUiMsgHandler { bail!("not supported keycode {}", hmkey); } }; + trace::oh_event_keyboard(keycode, ke.key_action); self.state .lock() .unwrap() @@ -295,6 +303,7 @@ impl OhUiMsgHandler { }; self.state.lock().unwrap().press_btn(dir)?; self.state.lock().unwrap().release_btn(dir)?; + trace::oh_event_scroll(dir); Ok(()) } @@ -310,85 +319,95 @@ impl OhUiMsgHandler { error!("handle_windowinfo failed with error {e}"); } } - } - - fn handle_ledstate(&self, led: &LedstateEvent) { - self.state - .lock() - .unwrap() - .update_host_ledstate(led.state as u8); + trace::oh_event_windowinfo(wi.width, wi.height); } pub fn send_windowinfo(&self, w: u32, h: u32) { self.state.lock().unwrap().update_window_info(w, h); - let body = WindowInfoEvent::new(w, h); - if let Err(e) = self - .writer - .lock() - .unwrap() - .send_message(EventType::WindowInfo, &body) - { - error!("send_windowinfo: failed to send message with error {e}"); + if let Some(writer) = self.writer.lock().unwrap().as_mut() { + let body = WindowInfoEvent::new(w, h); + if let Err(e) = writer.send_message(EventType::WindowInfo, &body) { + error!("send_windowinfo: failed to send message with error {e}"); + } } } pub fn handle_dirty_area(&self, x: u32, y: u32, w: u32, h: u32) { - let body = FrameBufferDirtyEvent::new(x, y, w, h); - if let Err(e) = self - .writer - .lock() - .unwrap() - .send_message(EventType::FrameBufferDirty, &body) - { - error!("handle_dirty_area: failed to send message with error {e}"); + if let Some(writer) = self.writer.lock().unwrap().as_mut() { + let body = FrameBufferDirtyEvent::new(x, y, w, h); + if let Err(e) = writer.send_message(EventType::FrameBufferDirty, &body) { + error!("handle_dirty_area: failed to send message with error {e}"); + } } } + + pub fn reset(&self) { + *self.reader.lock().unwrap() = None; + *self.writer.lock().unwrap() = None; + } } struct MsgReader { - /// socket to read - channel: Arc, /// cache for header - pub header: EventMsgHdr, + header: EventMsgHdr, /// received byte size of header - pub header_ready: usize, + header_ready: usize, /// cache of body - pub body: Option>, + body: Option>, /// received byte size of body - pub body_ready: usize, + body_ready: usize, + /// UnixStream to read + sock: UnixStream, } impl MsgReader { - pub fn new(channel: Arc) -> Self { + pub fn new(fd: RawFd) -> Self { MsgReader { - channel, header: EventMsgHdr::default(), header_ready: 0, body: None, body_ready: 0, + // SAFETY: The fd is valid only when the new connection has been established + // and MsgReader instance would be destroyed when disconnected. + sock: unsafe { UnixStream::from_raw_fd(fd) }, } } pub fn recv(&mut self) -> Result { if self.recv_header()? { + self.check_header()?; return self.recv_body(); } Ok(false) } - pub fn clear(&mut self) { + fn clear(&mut self) { self.header_ready = 0; self.body_ready = 0; self.body = None; } + fn check_header(&mut self) -> Result<()> { + let expected_size = event_msg_data_len(self.header.event_type); + if expected_size != self.header.size as usize { + self.clear(); + bail!( + "{:?} data len is wrong, we want {}, but receive {}", + self.header.event_type as EventType, + expected_size, + self.header.size as usize, + ); + } + Ok(()) + } + fn recv_header(&mut self) -> Result { if self.header_ready == EVENT_MSG_HDR_SIZE as usize { return Ok(true); } let buf = self.header.as_mut_bytes(); - self.header_ready += self.channel.recv_slice(&mut buf[self.header_ready..])?; + self.header_ready += recv_slice(&mut self.sock, &mut buf[self.header_ready..])?; Ok(self.header_ready == EVENT_MSG_HDR_SIZE as usize) } @@ -410,22 +429,32 @@ impl MsgReader { unsafe { buf.set_len(body_size); } - self.body_ready += self.channel.recv_slice(&mut buf[self.body_ready..])?; + self.body_ready += recv_slice(&mut self.sock, &mut buf[self.body_ready..])?; Ok(self.body_ready == body_size) } } -struct MsgWriter(Arc); +struct MsgWriter { + sock: UnixStream, +} impl MsgWriter { - fn new(channel: Arc) -> Self { - MsgWriter(channel) + fn new(fd: RawFd) -> Self { + Self { + // SAFETY: The fd is valid only when the new connection has been established + // and MsgWriter instance would be destroyed when disconnected. + sock: unsafe { UnixStream::from_raw_fd(fd) }, + } } - fn send_message(&self, t: EventType, body: &T) -> Result<()> { + fn send_message( + &mut self, + t: EventType, + body: &T, + ) -> Result<()> { let hdr = EventMsgHdr::new(t); - self.0.send_by_obj(&hdr)?; - self.0.send_by_obj(body) + send_obj(&mut self.sock, &hdr)?; + send_obj(&mut self.sock, body) } } diff --git a/ui/src/pixman.rs b/ui/src/pixman.rs index 8c5a2008761d4d49a0577e53129b51c0bb8b0d10..2163210659e497e4f07030548a0c09a0df7ff268 100644 --- a/ui/src/pixman.rs +++ b/ui/src/pixman.rs @@ -36,7 +36,7 @@ pub struct ColorInfo { impl ColorInfo { pub fn set_color_info(&mut self, shift: u8, max: u16) { - self.mask = (max as u32) << (shift as u32); + self.mask = u32::from(max) << u32::from(shift); self.shift = shift; self.max = if max == 0 { 0xFF } else { max as u8 }; self.bits = max.popcnt() as u8; @@ -84,10 +84,14 @@ impl PixelFormat { self.green.max = ((1 << self.green.bits) - 1) as u8; self.blue.max = ((1 << self.blue.bits) - 1) as u8; - self.alpha_chl.mask = self.alpha_chl.max.wrapping_shl(self.alpha_chl.shift as u32) as u32; - self.red.mask = self.red.max.wrapping_shl(self.red.shift as u32) as u32; - self.green.mask = self.green.max.wrapping_shl(self.green.shift as u32) as u32; - self.blue.mask = self.blue.max.wrapping_shl(self.blue.shift as u32) as u32; + self.alpha_chl.mask = u32::from( + self.alpha_chl + .max + .wrapping_shl(u32::from(self.alpha_chl.shift)), + ); + self.red.mask = u32::from(self.red.max.wrapping_shl(u32::from(self.red.shift))); + self.green.mask = u32::from(self.green.max.wrapping_shl(u32::from(self.green.shift))); + self.blue.mask = u32::from(self.blue.max.wrapping_shl(u32::from(self.blue.shift))); } pub fn is_default_pixel_format(&self) -> bool { diff --git a/ui/src/utils.rs b/ui/src/utils.rs index 7bb7844d31a7cf98bd6a90628373e0ca6591a674..c61ae8552d74edbe50b90f5de815653c945d5222 100644 --- a/ui/src/utils.rs +++ b/ui/src/utils.rs @@ -182,12 +182,12 @@ mod tests { fn test_buffpool_base() { let mut buffpool = BuffPool::new(); buffpool.set_limit(Some(7)); - buffpool.append_limit((0x12345678 as u32).to_be_bytes().to_vec()); - buffpool.append_limit((0x12 as u8).to_be_bytes().to_vec()); - buffpool.append_limit((0x1234 as u16).to_be_bytes().to_vec()); - assert!(buffpool.len() == 7 as usize); + buffpool.append_limit(0x12345678_u32.to_be_bytes().to_vec()); + buffpool.append_limit(0x12_u8.to_be_bytes().to_vec()); + buffpool.append_limit(0x1234_u16.to_be_bytes().to_vec()); + assert!(buffpool.len() == 7_usize); buffpool.remove_front(1); - assert!(buffpool.len() == 6 as usize); + assert!(buffpool.len() == 6_usize); let mut buf: Vec = vec![0_u8; 4]; buffpool.read_front(&mut buf, 4); assert!(buf == vec![52, 86, 120, 18]); diff --git a/ui/src/vnc/client_io.rs b/ui/src/vnc/client_io.rs index 347e7d58022f5bcc367a15984f1cd0f9d526cf83..7ffa8996833d8a66ef0bc3d6a399519df943ab91 100644 --- a/ui/src/vnc/client_io.rs +++ b/ui/src/vnc/client_io.rs @@ -47,8 +47,8 @@ use crate::{ use util::{ bitmap::Bitmap, loop_context::{ - gen_delete_notifiers, read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, - NotifierOperation, + create_new_eventfd, gen_delete_notifiers, read_fd, EventNotifier, EventNotifierHelper, + NotifierCallback, NotifierOperation, }, }; @@ -392,15 +392,15 @@ impl ClientState { pub fn new(addr: String) -> Self { ClientState { addr, - disconn_evt: Arc::new(Mutex::new(EventFd::new(libc::EFD_NONBLOCK).unwrap())), - write_fd: Arc::new(Mutex::new(EventFd::new(libc::EFD_NONBLOCK).unwrap())), + disconn_evt: Arc::new(Mutex::new(create_new_eventfd().unwrap())), + write_fd: Arc::new(Mutex::new(create_new_eventfd().unwrap())), in_buffer: Arc::new(Mutex::new(BuffPool::new())), out_buffer: Arc::new(Mutex::new(BuffPool::new())), client_dpm: Arc::new(Mutex::new(DisplayMode::default())), conn_state: Arc::new(Mutex::new(ConnState::default())), dirty_bitmap: Arc::new(Mutex::new(Bitmap::::new( MAX_WINDOW_HEIGHT as usize - * round_up_div(DIRTY_WIDTH_BITS as u64, u64::BITS as u64) as usize, + * round_up_div(u64::from(DIRTY_WIDTH_BITS), u64::from(u64::BITS)) as usize, ))), } } @@ -728,9 +728,9 @@ impl ClientIoHandler { let pf = self.client.client_dpm.lock().unwrap().pf.clone(); for i in 0..NUM_OF_COLORMAP { - let r = ((i >> pf.red.shift) & pf.red.max as u16) << (16 - pf.red.bits); - let g = ((i >> pf.green.shift) & pf.green.max as u16) << (16 - pf.green.bits); - let b = ((i >> pf.blue.shift) & pf.blue.max as u16) << (16 - pf.blue.bits); + let r = ((i >> pf.red.shift) & u16::from(pf.red.max)) << (16 - pf.red.bits); + let g = ((i >> pf.green.shift) & u16::from(pf.green.max)) << (16 - pf.green.bits); + let b = ((i >> pf.blue.shift) & u16::from(pf.blue.max)) << (16 - pf.blue.bits); buf.append(&mut r.to_be_bytes().to_vec()); buf.append(&mut g.to_be_bytes().to_vec()); buf.append(&mut b.to_be_bytes().to_vec()); @@ -922,10 +922,10 @@ impl ClientIoHandler { } } else { locked_state.update_state = UpdateState::Force; - let x = u16::from_be_bytes([buf[2], buf[3]]) as i32; - let y = u16::from_be_bytes([buf[4], buf[5]]) as i32; - let w = u16::from_be_bytes([buf[6], buf[7]]) as i32; - let h = u16::from_be_bytes([buf[8], buf[9]]) as i32; + let x = i32::from(u16::from_be_bytes([buf[2], buf[3]])); + let y = i32::from(u16::from_be_bytes([buf[4], buf[5]])); + let w = i32::from(u16::from_be_bytes([buf[6], buf[7]])); + let h = i32::from(u16::from_be_bytes([buf[8], buf[9]])); set_area_dirty( &mut client.dirty_bitmap.lock().unwrap(), x, @@ -992,8 +992,8 @@ impl ClientIoHandler { } let buf = self.read_incoming_msg(); - let mut x = ((buf[2] as u16) << 8) + buf[3] as u16; - let mut y = ((buf[4] as u16) << 8) + buf[5] as u16; + let mut x = (u16::from(buf[2]) << 8) + u16::from(buf[3]); + let mut y = (u16::from(buf[4]) << 8) + u16::from(buf[5]); trace::vnc_client_point_event(&buf[1], &x, &y); // Window size alignment. @@ -1001,8 +1001,8 @@ impl ClientIoHandler { let width = get_image_width(locked_surface.server_image); let height = get_image_height(locked_surface.server_image); drop(locked_surface); - x = ((x as u64 * ABS_MAX) / width as u64) as u16; - y = ((y as u64 * ABS_MAX) / height as u64) as u16; + x = ((u64::from(x) * ABS_MAX) / width as u64) as u16; + y = ((u64::from(y) * ABS_MAX) / height as u64) as u16; // ASCII -> HidCode. let new_button = buf[1]; @@ -1023,15 +1023,15 @@ impl ClientIoHandler { VNC_INPUT_BUTTON_WHEEL_RIGHT => INPUT_BUTTON_WHEEL_RIGHT, VNC_INPUT_BUTTON_WHEEL_LEFT => INPUT_BUTTON_WHEEL_LEFT, VNC_INPUT_BUTTON_BACK => INPUT_POINT_BACK, - _ => button_mask as u32, + _ => u32::from(button_mask), }; input_button(button, new_button & button_mask != 0)?; } self.client.client_dpm.lock().unwrap().last_button = new_button; } - input_move_abs(Axis::X, x as u32)?; - input_move_abs(Axis::Y, y as u32)?; + input_move_abs(Axis::X, u32::from(x))?; + input_move_abs(Axis::Y, u32::from(y))?; input_point_sync()?; self.update_event_handler(1, ClientIoHandler::handle_protocol_msg); @@ -1061,7 +1061,7 @@ impl ClientIoHandler { fn auth_failed(&mut self, msg: &str) { let auth_rej: u8 = 1; let mut buf: Vec = vec![1u8]; - buf.append(&mut (auth_rej as u32).to_be_bytes().to_vec()); + buf.append(&mut u32::from(auth_rej).to_be_bytes().to_vec()); // If the RFB protocol version is above 3.8, an error reason will be returned. if self.client.conn_state.lock().unwrap().version.minor >= 8 { let err_msg = msg; @@ -1250,17 +1250,17 @@ pub fn get_rects(client: &Arc, server: &Arc, dirty_num: } h = i - y; - x2 = cmp::min(x2, width / DIRTY_PIXELS_NUM as u64); + x2 = cmp::min(x2, width / u64::from(DIRTY_PIXELS_NUM)); if x2 > x { rects.push(Rectangle::new( - (x * DIRTY_PIXELS_NUM as u64) as i32, + (x * u64::from(DIRTY_PIXELS_NUM)) as i32, y as i32, - ((x2 - x) * DIRTY_PIXELS_NUM as u64) as i32, + ((x2 - x) * u64::from(DIRTY_PIXELS_NUM)) as i32, h as i32, )); } - if x == 0 && x2 == width / DIRTY_PIXELS_NUM as u64 { + if x == 0 && x2 == width / u64::from(DIRTY_PIXELS_NUM) { y += h; if y == height { break; @@ -1289,9 +1289,9 @@ fn pixel_format_message(client: &Arc, buf: &mut Vec) { buf.append(&mut locked_dpm.pf.depth.to_be_bytes().to_vec()); // Depth. buf.append(&mut big_endian.to_be_bytes().to_vec()); // Big-endian flag. buf.append(&mut (1_u8).to_be_bytes().to_vec()); // True-color flag. - buf.append(&mut (locked_dpm.pf.red.max as u16).to_be_bytes().to_vec()); // Red max. - buf.append(&mut (locked_dpm.pf.green.max as u16).to_be_bytes().to_vec()); // Green max. - buf.append(&mut (locked_dpm.pf.blue.max as u16).to_be_bytes().to_vec()); // Blue max. + buf.append(&mut u16::from(locked_dpm.pf.red.max).to_be_bytes().to_vec()); // Red max. + buf.append(&mut u16::from(locked_dpm.pf.green.max).to_be_bytes().to_vec()); // Green max. + buf.append(&mut u16::from(locked_dpm.pf.blue.max).to_be_bytes().to_vec()); // Blue max. buf.append(&mut locked_dpm.pf.red.shift.to_be_bytes().to_vec()); // Red shift. buf.append(&mut locked_dpm.pf.green.shift.to_be_bytes().to_vec()); // Green shift. buf.append(&mut locked_dpm.pf.blue.shift.to_be_bytes().to_vec()); // Blue shift. @@ -1416,7 +1416,7 @@ pub fn display_cursor_define( buf, ); let dpm = client.client_dpm.lock().unwrap().clone(); - let data_size = cursor.width * cursor.height * dpm.pf.pixel_bytes as u32; + let data_size = cursor.width * cursor.height * u32::from(dpm.pf.pixel_bytes); let data_ptr = cursor.data.as_ptr() as *mut u8; write_pixel(data_ptr, data_size as usize, &dpm, buf); buf.append(&mut mask); @@ -1442,7 +1442,7 @@ pub fn vnc_update_output_throttle(client: &Arc) { let width = locked_dpm.client_width; let height = locked_dpm.client_height; let bytes_per_pixel = locked_dpm.pf.pixel_bytes; - let mut offset = width * height * (bytes_per_pixel as i32) * OUTPUT_THROTTLE_SCALE; + let mut offset = width * height * i32::from(bytes_per_pixel) * OUTPUT_THROTTLE_SCALE; drop(locked_dpm); offset = cmp::max(offset, MIN_OUTPUT_LIMIT); diff --git a/ui/src/vnc/encoding/enc_hextile.rs b/ui/src/vnc/encoding/enc_hextile.rs index 0bada3d581158be48737c493faf9a9f8630955bb..f41a4d8c1c273870d97edcce0b63d134dd5bf062 100644 --- a/ui/src/vnc/encoding/enc_hextile.rs +++ b/ui/src/vnc/encoding/enc_hextile.rs @@ -127,7 +127,8 @@ fn compress_each_tile<'a>( &mut tmp_buf, ); // If the length becomes longer after compression, give up compression. - if tmp_buf.len() > (sub_rect.h * sub_rect.w * client_dpm.pf.pixel_bytes as i32) as usize + if tmp_buf.len() + > (sub_rect.h * sub_rect.w * i32::from(client_dpm.pf.pixel_bytes)) as usize { flag = RAW; *last_bg = None; @@ -396,8 +397,8 @@ mod tests { let image = create_pixman_image( pixman_format_code_t::PIXMAN_x8r8g8b8, - image_width as i32, - image_height as i32, + image_width, + image_height, image_data.as_ptr() as *mut u32, image_stride, ); @@ -427,8 +428,8 @@ mod tests { let image = create_pixman_image( pixman_format_code_t::PIXMAN_x8r8g8b8, - image_width as i32, - image_height as i32, + image_width, + image_height, image_data.as_ptr() as *mut u32, image_stride, ); @@ -458,8 +459,8 @@ mod tests { let image = create_pixman_image( pixman_format_code_t::PIXMAN_x8r8g8b8, - image_width as i32, - image_height as i32, + image_width, + image_height, image_data.as_ptr() as *mut u32, image_stride, ); diff --git a/ui/src/vnc/mod.rs b/ui/src/vnc/mod.rs index 9f0a7e93d5b704788329fe6be38c28a7bd6f1671..a760ae3b331efd8fb69a28c9a0efb0f1af549071 100644 --- a/ui/src/vnc/mod.rs +++ b/ui/src/vnc/mod.rs @@ -234,10 +234,10 @@ impl DisplayChangeListenerOperations for VncInterface { return Ok(()); } let server = VNC_SERVERS.lock().unwrap()[0].clone(); - let width = cursor.width as u64; - let height = cursor.height as u64; + let width = u64::from(cursor.width); + let height = u64::from(cursor.height); trace::vnc_dpy_cursor_update(&width, &height); - let bpl = round_up_div(width, BIT_PER_BYTE as u64); + let bpl = round_up_div(width, u64::from(BIT_PER_BYTE)); // Set the bit for mask. let bit_mask: u8 = 0x80; @@ -254,7 +254,7 @@ impl DisplayChangeListenerOperations for VncInterface { let idx = ((i + j * width) as usize) * bytes_per_pixel() + first_bit; if let Some(n) = cursor.data.get(idx) { if *n == 0xff { - mask[(j * bpl + i / BIT_PER_BYTE as u64) as usize] |= bit; + mask[(j * bpl + i / u64::from(BIT_PER_BYTE)) as usize] |= bit; } } bit >>= 1; @@ -290,7 +290,7 @@ pub fn vnc_init(vnc: &Option, object: &ObjectConfig) -> Result<()> { None => return Ok(()), }; - let addr = format!("{}:{}", vnc_cfg.ip, vnc_cfg.port); + let addr = format!("{}:{}", vnc_cfg.addr.0, vnc_cfg.addr.1); let listener: TcpListener = match TcpListener::bind(addr.as_str()) { Ok(l) => l, Err(e) => { @@ -426,16 +426,16 @@ pub fn set_area_dirty( let width: i32 = vnc_width(g_w); let height: i32 = vnc_height(g_h); - w += x % DIRTY_PIXELS_NUM as i32; - x -= x % DIRTY_PIXELS_NUM as i32; + w += x % i32::from(DIRTY_PIXELS_NUM); + x -= x % i32::from(DIRTY_PIXELS_NUM); x = cmp::min(x, width); y = cmp::min(y, height); w = cmp::min(x + w, width) - x; h = cmp::min(y + h, height); while y < h { - let pos = (y * VNC_BITMAP_WIDTH as i32 + x / DIRTY_PIXELS_NUM as i32) as usize; - let len = round_up_div(w as u64, DIRTY_PIXELS_NUM as u64) as usize; + let pos = (y * VNC_BITMAP_WIDTH as i32 + x / i32::from(DIRTY_PIXELS_NUM)) as usize; + let len = round_up_div(w as u64, u64::from(DIRTY_PIXELS_NUM)) as usize; dirty.set_range(pos, len)?; y += 1; } @@ -445,14 +445,14 @@ pub fn set_area_dirty( /// Get the width of image. fn vnc_width(width: i32) -> i32 { cmp::min( - MAX_WINDOW_WIDTH as i32, - round_up(width as u64, DIRTY_PIXELS_NUM as u64) as i32, + i32::from(MAX_WINDOW_WIDTH), + round_up(width as u64, u64::from(DIRTY_PIXELS_NUM)) as i32, ) } /// Get the height of image. fn vnc_height(height: i32) -> i32 { - cmp::min(MAX_WINDOW_HEIGHT as i32, height) + cmp::min(i32::from(MAX_WINDOW_HEIGHT), height) } /// Update server image diff --git a/ui/src/vnc/server_io.rs b/ui/src/vnc/server_io.rs index 76af9e8d434b1292e8e83ecc75c9cdd37d559e4a..2b29035593a4b765fe01e5309e9473df26bfbe8f 100644 --- a/ui/src/vnc/server_io.rs +++ b/ui/src/vnc/server_io.rs @@ -173,7 +173,7 @@ struct ImageInfo { impl ImageInfo { fn new(image: *mut pixman_image_t) -> Self { let bpp = pixman_format_bpp(get_image_format(image) as u32); - let length = get_image_width(image) * round_up_div(bpp as u64, 8) as i32; + let length = get_image_width(image) * round_up_div(u64::from(bpp), 8) as i32; ImageInfo { data: get_image_data(image) as *mut u8, stride: get_image_stride(image), @@ -221,7 +221,7 @@ impl SecurityType { // Tls configuration. if let Some(tls_cred) = object.tls_object.get(&vnc_cfg.tls_creds) { let tlscred = TlsCreds { - cred_type: tls_cred.cred_type.clone(), + cred_type: "x509".to_string(), dir: tls_cred.dir.clone(), endpoint: tls_cred.endpoint.clone(), verifypeer: tls_cred.verifypeer, @@ -310,8 +310,8 @@ impl VncSurface { guest_dirty_bitmap: Bitmap::::new( MAX_WINDOW_HEIGHT as usize * round_up_div( - (MAX_WINDOW_WIDTH / DIRTY_PIXELS_NUM) as u64, - u64::BITS as u64, + u64::from(MAX_WINDOW_WIDTH / DIRTY_PIXELS_NUM), + u64::from(u64::BITS), ) as usize, ), server_image: ptr::null_mut(), @@ -424,7 +424,7 @@ impl VncSurface { let width = self.get_min_width(); let line_bytes = cmp::min(s_info.stride, g_info.length); - while x < round_up_div(width as u64, DIRTY_PIXELS_NUM as u64) as usize { + while x < round_up_div(width as u64, u64::from(DIRTY_PIXELS_NUM)) as usize { if !self .guest_dirty_bitmap .contain(x + y * VNC_BITMAP_WIDTH as usize) diff --git a/util/Cargo.toml b/util/Cargo.toml index 95aefcdac2967e6400493391c47ca1f51921d8c3..806fe935e14fe015275892b2b9982bbd9479cf5f 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -11,12 +11,12 @@ license = "Mulan PSL v2" arc-swap = "1.6.0" thiserror = "1.0" anyhow = "1.0" -kvm-bindings = { version = "0.6.0", features = ["fam-wrappers"] } +kvm-bindings = { version = "0.7.0", features = ["fam-wrappers"] } nix = { version = "0.26.2", default-features = false, features = ["poll", "term", "time", "signal", "fs", "feature"] } libc = "0.2" libloading = "0.7.4" log = { version = "0.4", features = ["std"]} -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" byteorder = "1.4.3" once_cell = "1.18.0" io-uring = "0.6.0" @@ -28,5 +28,6 @@ trace = {path = "../trace"} default = [] usb_camera_v4l2 = ["dep:v4l2-sys-mit"] usb_camera_oh = [] +usb_host = [] scream_ohaudio = [] pixman = [] diff --git a/util/src/aio/mod.rs b/util/src/aio/mod.rs index d8d733da3b3a67989c7c757976ea62ca093fe937..7c3410fb8eab6d2b2007a4dd6844a6d4e6ea0080 100644 --- a/util/src/aio/mod.rs +++ b/util/src/aio/mod.rs @@ -18,6 +18,7 @@ mod uring; pub use raw::*; use std::clone::Clone; +use std::fmt::Display; use std::io::Write; use std::os::unix::io::RawFd; use std::sync::atomic::{AtomicI64, AtomicU32, AtomicU64, Ordering}; @@ -32,6 +33,7 @@ use uring::IoUringContext; use vmm_sys_util::eventfd::EventFd; use super::link_list::{List, Node}; +use crate::loop_context::create_new_eventfd; use crate::num_ops::{round_down, round_up}; use crate::thread_pool::ThreadPool; use crate::unix::host_page_size; @@ -66,7 +68,7 @@ pub enum AioEngine { } impl FromStr for AioEngine { - type Err = (); + type Err = anyhow::Error; fn from_str(s: &str) -> std::result::Result { match s { @@ -74,24 +76,25 @@ impl FromStr for AioEngine { AIO_NATIVE => Ok(AioEngine::Native), AIO_IOURING => Ok(AioEngine::IoUring), AIO_THREADS => Ok(AioEngine::Threads), - _ => Err(()), + _ => Err(anyhow!("Unknown aio type")), } } } -impl ToString for AioEngine { - fn to_string(&self) -> String { - match *self { - AioEngine::Off => "off".to_string(), - AioEngine::Native => "native".to_string(), - AioEngine::IoUring => "io_uring".to_string(), - AioEngine::Threads => "threads".to_string(), +impl Display for AioEngine { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AioEngine::Off => write!(f, "off"), + AioEngine::Native => write!(f, "native"), + AioEngine::IoUring => write!(f, "io_uring"), + AioEngine::Threads => write!(f, "threads"), } } } -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub enum WriteZeroesState { + #[default] Off, On, Unmap, @@ -130,7 +133,7 @@ impl Iovec { } pub fn get_iov_size(iovecs: &[Iovec]) -> u64 { - let mut sum = 0; + let mut sum: u64 = 0; for iov in iovecs { sum += iov.iov_len; } @@ -210,8 +213,10 @@ impl AioCb { pub fn rw_sync(&self) -> i32 { let mut ret = match self.opcode { - OpCode::Preadv => raw_readv(self.file_fd, &self.iovec, self.offset), - OpCode::Pwritev => raw_writev(self.file_fd, &self.iovec, self.offset), + // SAFETY: iovec of aiocb is valid. + OpCode::Preadv => unsafe { raw_readv(self.file_fd, &self.iovec, self.offset) }, + // SAFETY: iovec of aiocb is valid. + OpCode::Pwritev => unsafe { raw_writev(self.file_fd, &self.iovec, self.offset) }, _ => -1, }; if ret < 0 { @@ -263,9 +268,10 @@ impl AioCb { // If the buffer is full with zero and the operation is Pwritev, // It's equal to write zero operation. fn try_convert_to_write_zero(&mut self) { - if self.opcode == OpCode::Pwritev - && self.write_zeroes != WriteZeroesState::Off - && iovec_is_zero(&self.iovec) + if self.opcode == OpCode::Pwritev && + self.write_zeroes != WriteZeroesState::Off && + // SAFETY: iovec is generated by address_space. + unsafe { iovec_is_zero(&self.iovec) } { self.opcode = OpCode::WriteZeroes; if self.write_zeroes == WriteZeroesState::Unmap && self.discard { @@ -276,14 +282,14 @@ impl AioCb { pub fn is_misaligned(&self) -> bool { if self.direct && (self.opcode == OpCode::Preadv || self.opcode == OpCode::Pwritev) { - if (self.offset as u64) & (self.req_align as u64 - 1) != 0 { + if (self.offset as u64) & (u64::from(self.req_align) - 1) != 0 { return true; } for iov in self.iovec.iter() { - if iov.iov_base & (self.buf_align as u64 - 1) != 0 { + if iov.iov_base & (u64::from(self.buf_align) - 1) != 0 { return true; } - if iov.iov_len & (self.req_align as u64 - 1) != 0 { + if iov.iov_len & (u64::from(self.req_align) - 1) != 0 { return true; } } @@ -293,8 +299,8 @@ impl AioCb { pub fn handle_misaligned(&mut self) -> Result { let max_len = round_down( - self.nbytes + self.req_align as u64 * 2, - self.req_align as u64, + self.nbytes + u64::from(self.req_align) * 2, + u64::from(self.req_align), ) .with_context(|| "Failed to round down request length.")?; // Set upper limit of buffer length to avoid OOM. @@ -325,10 +331,10 @@ impl AioCb { bounce_buffer: *mut c_void, buffer_len: u64, ) -> Result<()> { - let offset_align = round_down(self.offset as u64, self.req_align as u64) + let offset_align = round_down(self.offset as u64, u64::from(self.req_align)) .with_context(|| "Failed to round down request offset.")?; let high = self.offset as u64 + self.nbytes; - let high_align = round_up(high, self.req_align as u64) + let high_align = round_up(high, u64::from(self.req_align)) .with_context(|| "Failed to round up request high edge.")?; match self.opcode { @@ -338,12 +344,15 @@ impl AioCb { loop { // Step1: Read file to bounce buffer. let nbytes = cmp::min(high_align - offset, buffer_len); - let len = raw_read( - self.file_fd, - bounce_buffer as u64, - nbytes as usize, - offset as usize, - ); + // SAFETY: bounce_buffer is valid and large enough. + let len = unsafe { + raw_read( + self.file_fd, + bounce_buffer as u64, + nbytes as usize, + offset as usize, + ) + }; if len < 0 { bail!("Failed to do raw read for misaligned read."); } @@ -367,7 +376,8 @@ impl AioCb { }; // Step2: Copy bounce buffer to iovec. - iov_from_buf_direct(iovecs, src).and_then(|v| { + // SAFETY: iovecs is generated by address_space. + unsafe { iov_from_buf_direct(iovecs, src) }.and_then(|v| { if v == real_nbytes as usize { Ok(()) } else { @@ -389,19 +399,22 @@ impl AioCb { // Load the head from file before fill iovec to buffer. let mut head_loaded = false; if self.offset as u64 > offset_align { - let len = raw_read( - self.file_fd, - bounce_buffer as u64, - self.req_align as usize, - offset_align as usize, - ); + // SAFETY: bounce_buffer is valid and large enough. + let len = unsafe { + raw_read( + self.file_fd, + bounce_buffer as u64, + self.req_align as usize, + offset_align as usize, + ) + }; if len < 0 || len as u32 != self.req_align { bail!("Failed to load head for misaligned write."); } head_loaded = true; } // Is head and tail in the same alignment section? - let same_section = (offset_align + self.req_align as u64) >= high; + let same_section = (offset_align + u64::from(self.req_align)) >= high; let need_tail = !(same_section && head_loaded) && (high_align > high); let mut offset = offset_align; @@ -415,12 +428,15 @@ impl AioCb { let real_nbytes = real_high - real_offset; if real_high == high && need_tail { - let len = raw_read( - self.file_fd, - bounce_buffer as u64 + nbytes - self.req_align as u64, - self.req_align as usize, - (offset + nbytes) as usize - self.req_align as usize, - ); + // SAFETY: bounce_buffer is valid and large enough. + let len = unsafe { + raw_read( + self.file_fd, + bounce_buffer as u64 + nbytes - u64::from(self.req_align), + self.req_align as usize, + (offset + nbytes) as usize - self.req_align as usize, + ) + }; if len < 0 || len as u32 != self.req_align { bail!("Failed to load tail for misaligned write."); } @@ -433,7 +449,8 @@ impl AioCb { real_nbytes as usize, ) }; - iov_to_buf_direct(iovecs, 0, dst).and_then(|v| { + // SAFETY: iovecs is generated by address_space. + unsafe { iov_to_buf_direct(iovecs, 0, dst) }.and_then(|v| { if v == real_nbytes as usize { Ok(()) } else { @@ -442,12 +459,15 @@ impl AioCb { })?; // Step2: Write bounce buffer to file. - let len = raw_write( - self.file_fd, - bounce_buffer as u64, - nbytes as usize, - offset as usize, - ); + // SAFETY: bounce_buffer is valid and large enough. + let len = unsafe { + raw_write( + self.file_fd, + bounce_buffer as u64, + nbytes as usize, + offset as usize, + ) + }; if len < 0 || len as u64 != nbytes { bail!("Failed to do raw write for misaligned write."); } @@ -503,7 +523,7 @@ impl Aio { thread_pool: Option>, ) -> Result { let max_events: usize = 128; - let fd = EventFd::new(libc::EFD_NONBLOCK)?; + let fd = create_new_eventfd()?; let ctx: Option>> = if let Some(pool) = thread_pool { let threads_aio_ctx = ThreadsAioContext::new(max_events as u32, &fd, pool); match engine { @@ -544,7 +564,14 @@ impl Aio { pub fn submit_request(&mut self, mut cb: AioCb) -> Result<()> { trace::aio_submit_request(cb.file_fd, &cb.opcode, cb.offset, cb.nbytes); - if self.ctx.is_none() { + if self.ctx.is_none() + || [ + OpCode::Discard, + OpCode::WriteZeroes, + OpCode::WriteZeroesUnmap, + ] + .contains(&cb.opcode) + { return self.handle_sync_request(cb); } @@ -572,7 +599,7 @@ impl Aio { -1 } }; - return (self.complete_func)(&cb, ret as i64); + return (self.complete_func)(&cb, i64::from(ret)); } cb.try_convert_to_write_zero(); @@ -584,7 +611,7 @@ impl Aio { OpCode::WriteZeroes | OpCode::WriteZeroesUnmap => cb.write_zeroes_sync(), OpCode::Noop => return Err(anyhow!("Aio opcode is not specified.")), }; - (self.complete_func)(&cb, ret as i64) + (self.complete_func)(&cb, i64::from(ret)) } pub fn flush_request(&mut self) -> Result<()> { @@ -610,8 +637,11 @@ impl Aio { evt.res } else { error!( - "Async IO request failed, status {} res {}", - evt.status, evt.res + "Async IO request failed, opcode {:?} status {} res {} expect {}", + (*node).value.opcode, + evt.status, + evt.res, + (*node).value.nbytes ); -1 }; @@ -633,10 +663,10 @@ impl Aio { warn!("Can not process aio list with invalid ctx."); return Ok(()); } - while self.aio_in_queue.len > 0 && self.aio_in_flight.len < self.max_events { + while !self.aio_in_queue.is_empty() && self.aio_in_flight.len() < self.max_events { let mut iocbs = Vec::new(); - for _ in self.aio_in_flight.len..self.max_events { + for _ in self.aio_in_flight.len()..self.max_events { match self.aio_in_queue.pop_tail() { Some(node) => { iocbs.push(&node.value as *const AioCb); @@ -699,7 +729,7 @@ impl Aio { self.aio_in_queue.add_head(node); self.incomplete_cnt.fetch_add(1, Ordering::SeqCst); - if self.aio_in_queue.len + self.aio_in_flight.len >= self.max_events { + if self.aio_in_queue.len() + self.aio_in_flight.len() >= self.max_events { self.process_list()?; } @@ -707,22 +737,28 @@ impl Aio { } } -pub fn mem_from_buf(buf: &[u8], hva: u64) -> Result<()> { - // SAFETY: all callers have valid hva address. - let mut slice = unsafe { std::slice::from_raw_parts_mut(hva as *mut u8, buf.len()) }; - (&mut slice) - .write(buf) +/// # Safety +/// +/// Caller should has valid hva address. +pub unsafe fn mem_from_buf(buf: &[u8], hva: u64) -> Result<()> { + let mut slice = std::slice::from_raw_parts_mut(hva as *mut u8, buf.len()); + slice + .write_all(buf) .with_context(|| format!("Failed to write buf to hva:{})", hva))?; Ok(()) } /// Write buf to iovec and return the written number of bytes. -pub fn iov_from_buf_direct(iovec: &[Iovec], buf: &[u8]) -> Result { +/// # Safety +/// +/// Caller should has valid iovec. +pub unsafe fn iov_from_buf_direct(iovec: &[Iovec], buf: &[u8]) -> Result { let mut start: usize = 0; let mut end: usize = 0; for iov in iovec.iter() { end = cmp::min(start + iov.iov_len as usize, buf.len()); + // iov len is not less than buf's. mem_from_buf(&buf[start..end], iov.iov_base)?; if end >= buf.len() { break; @@ -732,16 +768,21 @@ pub fn iov_from_buf_direct(iovec: &[Iovec], buf: &[u8]) -> Result { Ok(end) } -pub fn mem_to_buf(mut buf: &mut [u8], hva: u64) -> Result<()> { - // SAFETY: all callers have valid hva address. - let slice = unsafe { std::slice::from_raw_parts(hva as *const u8, buf.len()) }; - buf.write(slice) +/// # Safety +/// +/// Caller should has valid hva address. +pub unsafe fn mem_to_buf(mut buf: &mut [u8], hva: u64) -> Result<()> { + let slice = std::slice::from_raw_parts(hva as *const u8, buf.len()); + buf.write_all(slice) .with_context(|| format!("Failed to read buf from hva:{})", hva))?; Ok(()) } /// Read iovec to buf and return the read number of bytes. -pub fn iov_to_buf_direct(iovec: &[Iovec], offset: u64, buf: &mut [u8]) -> Result { +/// # Safety +/// +/// Caller should has valid iovec. +pub unsafe fn iov_to_buf_direct(iovec: &[Iovec], offset: u64, buf: &mut [u8]) -> Result { let mut iovec2: Option<&[Iovec]> = None; let mut start: usize = 0; let mut end: usize = 0; @@ -753,6 +794,7 @@ pub fn iov_to_buf_direct(iovec: &[Iovec], offset: u64, buf: &mut [u8]) -> Result for (index, iov) in iovec.iter().enumerate() { if iov.iov_len > offset { end = cmp::min((iov.iov_len - offset) as usize, buf.len()); + // iov len is not less than buf's. mem_to_buf(&mut buf[..end], iov.iov_base + offset)?; if end >= buf.len() || index >= (iovec.len() - 1) { return Ok(end); @@ -770,6 +812,7 @@ pub fn iov_to_buf_direct(iovec: &[Iovec], offset: u64, buf: &mut [u8]) -> Result for iov in iovec2.unwrap() { end = cmp::min(start + iov.iov_len as usize, buf.len()); + // iov len is not less than buf's. mem_to_buf(&mut buf[start..end], iov.iov_base)?; if end >= buf.len() { break; @@ -792,28 +835,32 @@ pub fn iov_discard_front_direct(iovec: &mut [Iovec], mut size: u64) -> Option<&m None } -fn iovec_is_zero(iovecs: &[Iovec]) -> bool { - let size = std::mem::size_of::() as u64; +/// # Safety +/// +/// Caller should have valid buffer base/len. +pub unsafe fn buffer_is_zero(base: u64, len: u64) -> bool { + let slice = std::slice::from_raw_parts(base as *const u8, len as usize); + let (prefix, aligned, suffix) = slice.align_to::(); + prefix.iter().all(|&x| x == 0) + && aligned.iter().all(|&x| x == 0) + && suffix.iter().all(|&x| x == 0) +} + +// Caller should have valid hva iovec. +unsafe fn iovec_is_zero(iovecs: &[Iovec]) -> bool { for iov in iovecs { - if iov.iov_len % size != 0 { - return false; - } // SAFETY: iov_base and iov_len has been checked in pop_avail(). - let slice = unsafe { - std::slice::from_raw_parts(iov.iov_base as *const u64, (iov.iov_len / size) as usize) - }; - for val in slice.iter() { - if *val != 0 { - return false; - } + if !buffer_is_zero(iov.iov_base, iov.iov_len) { + return false; } } true } pub fn iovecs_split(iovecs: Vec, mut size: u64) -> (Vec, Vec) { - let mut begin = Vec::new(); - let mut end = Vec::new(); + let len = iovecs.len(); + let mut begin: Vec = Vec::with_capacity(len); + let mut end: Vec = Vec::with_capacity(len); for iov in iovecs { if size == 0 { end.push(iov); @@ -831,12 +878,12 @@ pub fn iovecs_split(iovecs: Vec, mut size: u64) -> (Vec, Vec = vec![0; 6]; + let result1 = unsafe { buffer_is_zero(buf1.as_ptr() as u64, buf1.len() as u64) }; + assert_eq!(result1, true); + + let buf2: Vec = vec![0; 128]; + let result2 = unsafe { buffer_is_zero(buf2.as_ptr() as u64, buf2.len() as u64) }; + assert_eq!(result2, true); + + let buf3: Vec = vec![0; 513]; + let result3 = unsafe { buffer_is_zero(buf3.as_ptr() as u64, buf3.len() as u64) }; + assert_eq!(result3, true); + + let buf4: Vec = Vec::new(); + let result4 = unsafe { buffer_is_zero(buf4.as_ptr() as u64, buf4.len() as u64) }; + assert_eq!(result4, true); + + let buf5: Vec = vec![0, 1, 0]; + let result5 = unsafe { buffer_is_zero(buf5.as_ptr() as u64, buf5.len() as u64) }; + assert_eq!(result5, false); + + let buf6: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 0]; + let result6 = unsafe { buffer_is_zero(buf6.as_ptr() as u64, buf6.len() as u64) }; + assert_eq!(result6, false); + + let mut buf7: Vec = vec![0; 1025]; + buf7[700] = 1; + let result7 = unsafe { buffer_is_zero(buf7.as_ptr() as u64, buf7.len() as u64) }; + assert_eq!(result7, false); + } + + #[test] + fn test_iovec_is_zero() { + let buf1: Vec = vec![0; 5]; + let buf2: Vec = vec![0; 16]; + let iovecs1 = vec![ + Iovec::new(buf1.as_ptr() as u64, buf1.len() as u64), + Iovec::new(buf2.as_ptr() as u64, buf2.len() as u64), + ]; + + let result1 = unsafe { iovec_is_zero(&iovecs1) }; + assert_eq!(result1, true); + + let buf3: Vec = vec![0, 1, 0]; + let buf4: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 0]; + let iovecs2 = vec![ + Iovec::new(buf3.as_ptr() as u64, buf3.len() as u64), + Iovec::new(buf4.as_ptr() as u64, buf4.len() as u64), + ]; + + let result2 = unsafe { iovec_is_zero(&iovecs2) }; + assert_eq!(result2, false); + } } diff --git a/util/src/aio/raw.rs b/util/src/aio/raw.rs index 4654a7c07acebde6a98af8aa147c2742df04888e..159740009137e5a41409ae7dc9a928d237955d6a 100644 --- a/util/src/aio/raw.rs +++ b/util/src/aio/raw.rs @@ -18,18 +18,18 @@ use vmm_sys_util::fallocate::{fallocate, FallocateMode}; use super::Iovec; -pub fn raw_read(fd: RawFd, buf: u64, size: usize, offset: usize) -> i64 { +/// # Safety +/// +/// Caller should has valid buf. +pub unsafe fn raw_read(fd: RawFd, buf: u64, size: usize, offset: usize) -> i64 { let mut ret; loop { - // SAFETY: fd and buf is valid. - ret = unsafe { - pread( - fd as c_int, - buf as *mut c_void, - size as size_t, - offset as off_t, - ) as i64 - }; + ret = pread( + fd as c_int, + buf as *mut c_void, + size as size_t, + offset as off_t, + ) as i64; if !(ret < 0 && (nix::errno::errno() == libc::EINTR || nix::errno::errno() == libc::EAGAIN)) { break; @@ -48,18 +48,18 @@ pub fn raw_read(fd: RawFd, buf: u64, size: usize, offset: usize) -> i64 { ret } -pub fn raw_readv(fd: RawFd, iovec: &[Iovec], offset: usize) -> i64 { +/// # Safety +/// +/// Caller should has valid iovec. +pub unsafe fn raw_readv(fd: RawFd, iovec: &[Iovec], offset: usize) -> i64 { let mut ret; loop { - // SAFETY: fd and buf is valid. - ret = unsafe { - preadv( - fd as c_int, - iovec.as_ptr() as *const iovec, - iovec.len() as c_int, - offset as off_t, - ) as i64 - }; + ret = preadv( + fd as c_int, + iovec.as_ptr() as *const iovec, + iovec.len() as c_int, + offset as off_t, + ) as i64; if !(ret < 0 && (nix::errno::errno() == libc::EINTR || nix::errno::errno() == libc::EAGAIN)) { break; @@ -76,18 +76,18 @@ pub fn raw_readv(fd: RawFd, iovec: &[Iovec], offset: usize) -> i64 { ret } -pub fn raw_write(fd: RawFd, buf: u64, size: usize, offset: usize) -> i64 { +/// # Safety +/// +/// Caller should has valid buf. +pub unsafe fn raw_write(fd: RawFd, buf: u64, size: usize, offset: usize) -> i64 { let mut ret; loop { - // SAFETY: fd and buf is valid. - ret = unsafe { - pwrite( - fd as c_int, - buf as *mut c_void, - size as size_t, - offset as off_t, - ) as i64 - }; + ret = pwrite( + fd as c_int, + buf as *mut c_void, + size as size_t, + offset as off_t, + ) as i64; if !(ret < 0 && (nix::errno::errno() == libc::EINTR || nix::errno::errno() == libc::EAGAIN)) { break; @@ -106,18 +106,19 @@ pub fn raw_write(fd: RawFd, buf: u64, size: usize, offset: usize) -> i64 { ret } -pub fn raw_writev(fd: RawFd, iovec: &[Iovec], offset: usize) -> i64 { +/// # Safety +/// +/// Caller should has valid iovec. +pub unsafe fn raw_writev(fd: RawFd, iovec: &[Iovec], offset: usize) -> i64 { let mut ret; loop { - // SAFETY: fd and buf is valid. - ret = unsafe { - pwritev( - fd as c_int, - iovec.as_ptr() as *const iovec, - iovec.len() as c_int, - offset as off_t, - ) as i64 - }; + // Caller should has valid iovec. + ret = pwritev( + fd as c_int, + iovec.as_ptr() as *const iovec, + iovec.len() as c_int, + offset as off_t, + ) as i64; if !(ret < 0 && (nix::errno::errno() == libc::EINTR || nix::errno::errno() == libc::EAGAIN)) { break; @@ -171,7 +172,7 @@ fn do_fallocate( offset: u64, size: u64, ) -> i32 { - let mut ret = 0; + let mut ret: i32 = 0; loop { let mode = match &fallocate_mode { FallocateMode::PunchHole => FallocateMode::PunchHole, diff --git a/util/src/aio/threads.rs b/util/src/aio/threads.rs index cab4fe08b93772e9cbf0680d64aafc7be38156a2..1aecf948ea676c711ec6f50f4fc3a2e725e06dcf 100644 --- a/util/src/aio/threads.rs +++ b/util/src/aio/threads.rs @@ -63,7 +63,7 @@ impl ThreadsTasks { let aio_event = AioEvent { user_data: task.user_data, status: 0, - res: res as i64, + res: i64::from(res), }; self.complete_lists.lock().unwrap().push(aio_event); self.notify_event diff --git a/util/src/aio/uring.rs b/util/src/aio/uring.rs index 8f95832e92146d0105c392db5b587f7842d8ec58..f1d373a65f7d6fb793bb60acbf059b3d9a590ad4 100644 --- a/util/src/aio/uring.rs +++ b/util/src/aio/uring.rs @@ -13,7 +13,7 @@ use std::os::unix::io::AsRawFd; use anyhow::{bail, Context}; -use io_uring::{opcode, squeue, types, IoUring}; +use io_uring::{opcode, types, IoUring}; use libc; use vmm_sys_util::eventfd::EventFd; @@ -70,17 +70,12 @@ impl AioContext for IoUringContext { OpCode::Preadv => opcode::Readv::new(fd, iovs as *const libc::iovec, len as u32) .offset(offset) .build() - .flags(squeue::Flags::ASYNC) .user_data(data), OpCode::Pwritev => opcode::Writev::new(fd, iovs as *const libc::iovec, len as u32) .offset(offset) .build() - .flags(squeue::Flags::ASYNC) - .user_data(data), - OpCode::Fdsync => opcode::Fsync::new(fd) - .build() - .flags(squeue::Flags::ASYNC) .user_data(data), + OpCode::Fdsync => opcode::Fsync::new(fd).build().user_data(data), _ => { bail!("Invalid entry code"); } @@ -110,7 +105,7 @@ impl AioContext for IoUringContext { self.events.push(AioEvent { user_data: cqe.user_data(), status: 0, - res: cqe.result() as i64, + res: i64::from(cqe.result()), }); } &self.events diff --git a/util/src/arg_parser.rs b/util/src/arg_parser.rs index ad7d1fcb11118e32ec88310133c447be219ffe54..1d6ddf91018c72d1dfdcb6919f8939948360bdfb 100644 --- a/util/src/arg_parser.rs +++ b/util/src/arg_parser.rs @@ -608,6 +608,9 @@ impl<'a> ArgMatches<'a> { fn split_arg(args: &[String]) -> (&[String], &[String]) { if let Some(index) = args.iter().position(|arg| arg == ARG_SEPARATOR) { + if index == args.len() - 1 { + return (&args[..index], &[]); + } return (&args[..index], &args[index + 1..]); } (args, &[]) @@ -626,8 +629,8 @@ fn parse_cmdline( let mut arg_map: BTreeMap> = BTreeMap::new(); let mut multi_vec: Vec = Vec::new(); - let mut i = (0, ""); - let mut j = 1; + let mut i: (usize, &str) = (0, ""); + let mut j: usize = 1; for cmd_arg in &cmd_args[1..] { if !allow_list.contains(cmd_arg) && cmd_arg.starts_with(PREFIX_CHARS_SHORT) @@ -800,7 +803,7 @@ mod tests { arg_parser.output_help(&mut buffer.inner); let help_str = buffer.get_msg_vec(); - let help_msg = help_str.split("\n").collect::>(); + let help_msg = help_str.split('\n').collect::>(); assert_eq!(help_msg[0], "StratoVirt 1.0.0"); assert_eq!(help_msg[1], "Huawei Technologies Co., Ltd"); assert_eq!(help_msg[2], "A light kvm-based hypervisor."); @@ -832,10 +835,10 @@ mod tests { arg.possible_values.as_ref().unwrap(), &vec!["vm1", "vm2", "vm3"] ); - assert_eq!(arg.required, false); - assert_eq!(arg.presented, true); - assert_eq!(arg.hiddable, false); - assert_eq!(arg.can_no_value, false); + assert!(!arg.required); + assert!(arg.presented); + assert!(!arg.hiddable); + assert!(!arg.can_no_value); assert_eq!(arg.value.as_ref().unwrap(), "vm1"); let (help_msg, help_type) = arg.help_message(); diff --git a/util/src/bitmap.rs b/util/src/bitmap.rs index fed47e0fb96aedb6c708ba574f23a8596e96702f..3d8111c676451b57ff46aaaf95dfb7be8ff5e4f2 100644 --- a/util/src/bitmap.rs +++ b/util/src/bitmap.rs @@ -230,7 +230,7 @@ impl Bitmap { /// /// * `num` - the input number. pub fn contain(&self, num: usize) -> Result { - if num > self.vol() { + if num >= self.vol() { return Err(anyhow!(UtilError::OutOfBound( num as u64, self.size() as u64 * T::len() as u64, @@ -255,7 +255,7 @@ impl Bitmap { self.size() as u64 ))); } - let mut num = 0; + let mut num: usize = 0; for i in 0..self.bit_index(offset) + 1 { if i == self.bit_index(offset) { for j in i * T::len()..offset { @@ -434,12 +434,12 @@ mod tests { let mut bitmap = Bitmap::::new(1); assert!(bitmap.set(15).is_ok()); assert!(bitmap.set(16).is_err()); - assert_eq!(bitmap.contain(15).unwrap(), true); + assert!(bitmap.contain(15).unwrap()); assert_eq!(bitmap.count_front_bits(16).unwrap(), 1); assert_eq!(bitmap.count_front_bits(15).unwrap(), 0); assert!(bitmap.change(15).is_ok()); assert!(bitmap.change(16).is_err()); - assert_eq!(bitmap.contain(15).unwrap(), false); + assert!(!bitmap.contain(15).unwrap()); } #[test] @@ -451,25 +451,25 @@ mod tests { bitmap.clear_all(); assert!(bitmap.set_range(65, 10).is_ok()); - assert_eq!(bitmap.contain(64).unwrap(), false); - assert_eq!(bitmap.contain(65).unwrap(), true); - assert_eq!(bitmap.contain(70).unwrap(), true); - assert_eq!(bitmap.contain(74).unwrap(), true); - assert_eq!(bitmap.contain(75).unwrap(), false); + assert!(!bitmap.contain(64).unwrap()); + assert!(bitmap.contain(65).unwrap()); + assert!(bitmap.contain(70).unwrap()); + assert!(bitmap.contain(74).unwrap()); + assert!(!bitmap.contain(75).unwrap()); bitmap.clear_all(); assert!(bitmap.set_range(63, 1).is_ok()); - assert_eq!(bitmap.contain(62).unwrap(), false); - assert_eq!(bitmap.contain(63).unwrap(), true); - assert_eq!(bitmap.contain(64).unwrap(), false); + assert!(!bitmap.contain(62).unwrap()); + assert!(bitmap.contain(63).unwrap()); + assert!(!bitmap.contain(64).unwrap()); bitmap.clear_all(); assert!(bitmap.set_range(63, 66).is_ok()); - assert_eq!(bitmap.contain(62).unwrap(), false); - assert_eq!(bitmap.contain(63).unwrap(), true); - assert_eq!(bitmap.contain(67).unwrap(), true); - assert_eq!(bitmap.contain(128).unwrap(), true); - assert_eq!(bitmap.contain(129).unwrap(), false); + assert!(!bitmap.contain(62).unwrap()); + assert!(bitmap.contain(63).unwrap()); + assert!(bitmap.contain(67).unwrap()); + assert!(bitmap.contain(128).unwrap()); + assert!(!bitmap.contain(129).unwrap()); bitmap.clear_all(); } @@ -483,25 +483,25 @@ mod tests { assert!(bitmap.set_range(0, 256).is_ok()); assert!(bitmap.clear_range(65, 10).is_ok()); - assert_eq!(bitmap.contain(64).unwrap(), true); - assert_eq!(bitmap.contain(65).unwrap(), false); - assert_eq!(bitmap.contain(70).unwrap(), false); - assert_eq!(bitmap.contain(74).unwrap(), false); - assert_eq!(bitmap.contain(75).unwrap(), true); + assert!(bitmap.contain(64).unwrap()); + assert!(!bitmap.contain(65).unwrap()); + assert!(!bitmap.contain(70).unwrap()); + assert!(!bitmap.contain(74).unwrap()); + assert!(bitmap.contain(75).unwrap()); assert!(bitmap.set_range(0, 256).is_ok()); assert!(bitmap.clear_range(63, 1).is_ok()); - assert_eq!(bitmap.contain(62).unwrap(), true); - assert_eq!(bitmap.contain(63).unwrap(), false); - assert_eq!(bitmap.contain(64).unwrap(), true); + assert!(bitmap.contain(62).unwrap()); + assert!(!bitmap.contain(63).unwrap()); + assert!(bitmap.contain(64).unwrap()); assert!(bitmap.set_range(0, 256).is_ok()); assert!(bitmap.clear_range(63, 66).is_ok()); - assert_eq!(bitmap.contain(62).unwrap(), true); - assert_eq!(bitmap.contain(63).unwrap(), false); - assert_eq!(bitmap.contain(67).unwrap(), false); - assert_eq!(bitmap.contain(128).unwrap(), false); - assert_eq!(bitmap.contain(129).unwrap(), true); + assert!(bitmap.contain(62).unwrap()); + assert!(!bitmap.contain(63).unwrap()); + assert!(!bitmap.contain(67).unwrap()); + assert!(!bitmap.contain(128).unwrap()); + assert!(bitmap.contain(129).unwrap()); assert!(bitmap.clear_range(0, 256).is_ok()); } @@ -515,7 +515,7 @@ mod tests { assert!(bitmap.clear(64).is_ok()); assert!(bitmap.clear(128).is_ok()); - let mut offset = 0; + let mut offset = 0_usize; offset = bitmap.find_next_zero(offset).unwrap(); assert_eq!(offset, 0); offset = bitmap.find_next_zero(offset + 1).unwrap(); @@ -537,7 +537,7 @@ mod tests { assert!(bitmap.set(64).is_ok()); assert!(bitmap.set(128).is_ok()); - let mut offset = 0; + let mut offset = 0_usize; offset = bitmap.find_next_bit(offset).unwrap(); assert_eq!(offset, 0); offset = bitmap.find_next_bit(offset + 1).unwrap(); diff --git a/util/src/byte_code.rs b/util/src/byte_code.rs index 64fe6c9ac7be18bea3ec3d026fba201b68ebcc80..29d82c68ae508117dc68a1d7c45f63b01faaeb9d 100644 --- a/util/src/byte_code.rs +++ b/util/src/byte_code.rs @@ -15,7 +15,7 @@ use std::slice::{from_raw_parts, from_raw_parts_mut}; /// A trait bound defined for types which are safe to convert to a byte slice and /// to create from a byte slice. -pub trait ByteCode: Default + Copy + Send + Sync { +pub trait ByteCode: Clone + Default + Send + Sync { /// Return the contents of an object (impl trait `ByteCode`) as a slice of bytes. /// the inverse of this function is "from_bytes" fn as_bytes(&self) -> &[u8] { diff --git a/util/src/daemonize.rs b/util/src/daemonize.rs index d62440650027fa4e79e389a9e175de68eb21e098..f693abb4e612d457816e500a4b1db942ce574d32 100644 --- a/util/src/daemonize.rs +++ b/util/src/daemonize.rs @@ -50,6 +50,7 @@ fn create_pid_file(path: &str) -> Result<()> { let mut pid_file: File = OpenOptions::new() .write(true) .create(true) + .truncate(true) .mode(0o600) .open(path)?; write!(pid_file, "{}", pid)?; diff --git a/util/src/device_tree.rs b/util/src/device_tree.rs index ebb50d0e28729c0dc879c4d4111aa38d411e5d36..db590271de9f7a650a4e32c8a12e6cf3ff964322 100644 --- a/util/src/device_tree.rs +++ b/util/src/device_tree.rs @@ -10,7 +10,7 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use byteorder::{BigEndian, ByteOrder}; use crate::UtilError; @@ -142,6 +142,9 @@ impl FdtBuilder { let off_dt_strings = FDT_HEADER_SIZE + self.mem_reserve.len() + self.structure_blk.len(); let off_mem_rsvmap = FDT_HEADER_SIZE; + if self.fdt_header.len() < FDT_HEADER_SIZE { + bail!("fdt header size too small"); + } BigEndian::write_u32(&mut self.fdt_header[0..4], FDT_MAGIC); BigEndian::write_u32(&mut self.fdt_header[4..8], total_size as u32); BigEndian::write_u32(&mut self.fdt_header[8..12], off_dt_struct as u32); diff --git a/util/src/edid.rs b/util/src/edid.rs index 17080d70b5225769bb410237f3a7c20fe88bc883..f124d6d68fd98043fb83d27642064324bcca9476 100644 --- a/util/src/edid.rs +++ b/util/src/edid.rs @@ -495,7 +495,7 @@ impl EdidInfo { fn fullfill_checksum(&mut self, edid_array: &mut [u8]) { let mut sum: u32 = 0; for elem in edid_array.iter() { - sum += *elem as u32; + sum += u32::from(*elem); } sum &= 0xff; if sum != 0 { diff --git a/util/src/evdev.rs b/util/src/evdev.rs new file mode 100644 index 0000000000000000000000000000000000000000..76573f0aca2a3c016f6ed9def45b58c52a5c5af7 --- /dev/null +++ b/util/src/evdev.rs @@ -0,0 +1,181 @@ +// Copyright (c) 2025 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::collections::BTreeMap; +use std::fs::File; + +use anyhow::{bail, Result}; +use libc::c_int; +use log::error; +use vmm_sys_util::{ioctl::ioctl_with_mut_ref, ioctl_ioc_nr, ioctl_ior_nr, ioctl_iow_nr}; + +use crate::byte_code::ByteCode; + +/// Event Code type, used for autorepeating devices. +pub const EV_REP: u8 = 0x14; +/// Max event type. +pub const EV_MAX: u8 = 0x1F; +/// Max ABS_* event type. +pub const ABS_MAX: u8 = 0x3F; + +/// Sync event type. +pub const EV_SYN: u16 = 0x00; +/// Synchronization event. +pub const SYN_REPORT: u16 = 0x00; + +/// The payload(union) size of the virtio_input_config. +pub const VIRTIO_INPUT_CFG_PAYLOAD_SIZE: usize = 128; + +#[derive(Copy, Clone)] +pub struct EvdevBuf { + pub buf: [u8; VIRTIO_INPUT_CFG_PAYLOAD_SIZE], + pub len: usize, +} + +impl EvdevBuf { + pub fn new() -> Self { + Self { + buf: [0_u8; VIRTIO_INPUT_CFG_PAYLOAD_SIZE], + len: 0, + } + } + + pub fn get_bit(&self, bit: usize) -> bool { + if (bit + 7) / 8 > self.len { + return false; + } + let idx = bit / 8; + let offset = bit % 8; + self.buf[idx] & (1u8 << offset) != 0 + } + + pub fn to_vec(self) -> Vec { + self.buf[0..self.len].to_vec() + } +} + +impl Default for EvdevBuf { + fn default() -> Self { + Self::new() + } +} + +impl ByteCode for EvdevBuf {} + +#[derive(Copy, Clone, Default)] +#[repr(C)] +pub struct EvdevId { + pub bustype: u16, + pub vendor: u16, + pub product: u16, + pub version: u16, +} + +impl EvdevId { + pub fn from_buf(buf: EvdevBuf) -> Self { + *Self::from_bytes(buf.to_vec().as_slice()).unwrap() + } +} + +impl ByteCode for EvdevId {} + +#[derive(Copy, Clone, Default)] +#[repr(C)] +pub struct InputAbsInfo { + pub value: u32, + pub minimum: u32, + pub maximum: u32, + pub fuzz: u32, + pub flat: u32, + pub resolution: u32, +} + +const EVDEV: u32 = 69; // 'E' +ioctl_ior_nr!(EVIOCGVERSION, EVDEV, 0x01, c_int); +ioctl_ior_nr!(EVIOCGID, EVDEV, 0x02, EvdevId); +ioctl_ior_nr!(EVIOCGNAME, EVDEV, 0x06, EvdevBuf); +ioctl_ior_nr!(EVIOCGUNIQ, EVDEV, 0x08, EvdevBuf); +ioctl_ior_nr!(EVIOCGPROP, EVDEV, 0x09, EvdevBuf); +ioctl_ior_nr!(EVIOCGBIT, EVDEV, 0x20 + evt, EvdevBuf, evt); +ioctl_ior_nr!(EVIOCGABS, EVDEV, 0x40 + abs, InputAbsInfo, abs); +ioctl_iow_nr!(EVIOCGRAB, EVDEV, 0x90, c_int); + +pub fn evdev_ioctl(fd: &File, req: u64, len: usize) -> EvdevBuf { + let mut evbuf = EvdevBuf::new(); + // SAFETY: file is `evdev` fd, and we check the return. + let ret = unsafe { ioctl_with_mut_ref(fd, req, &mut evbuf.buf) }; + if ret < 0 { + error!( + "Ioctl {} failed, error is {}.", + req, + std::io::Error::last_os_error() + ); + evbuf.len = 0; + return evbuf; + } + + evbuf.len = len; + if evbuf.len == 0 { + if ret != 0 { + evbuf.len = ret as usize; + } else { + evbuf.len = VIRTIO_INPUT_CFG_PAYLOAD_SIZE; + } + } + + evbuf +} + +pub fn evdev_evt_supported(fd: &File) -> Result> { + let mut evts: BTreeMap = BTreeMap::new(); + let evt_type = evdev_ioctl(fd, EVIOCGBIT(0), 0); + if evt_type.len == 0 { + bail!(format!( + "Failed to get bit 0, error {}", + std::io::Error::last_os_error() + )) + } + for ev in 1..EV_MAX { + if ev == EV_REP || !evt_type.get_bit(ev as usize) { + // Not supported event + continue; + } + evts.insert(ev, evdev_ioctl(fd, EVIOCGBIT(ev as u32), 0)); + } + + Ok(evts) +} + +pub fn evdev_abs(fd: &File) -> Result> { + let mut absinfo_db: BTreeMap = BTreeMap::new(); + for abs in 0..ABS_MAX { + let mut absinfo = InputAbsInfo::default(); + // SAFETY: file is `evdev` fd, and we check the return. + let len = unsafe { ioctl_with_mut_ref(fd, EVIOCGABS(abs as u32), &mut absinfo) }; + if len == 0 { + absinfo_db.insert(abs, absinfo); + } + } + + Ok(absinfo_db) +} + +#[repr(C)] +#[derive(Clone, Copy, Default)] +pub struct InputEvent { + pub timestamp: [u64; 2], + pub ev_type: u16, + pub code: u16, + pub value: i32, +} + +impl ByteCode for InputEvent {} diff --git a/util/src/file.rs b/util/src/file.rs index 0fd944508d041f7711620e4a4be6bd1347d1d13f..c1efa02b21d1fef5142a48364f720cd3694cc5f4 100644 --- a/util/src/file.rs +++ b/util/src/file.rs @@ -67,8 +67,8 @@ pub fn get_file_alignment(file: &File, direct: bool) -> (u32, u32) { return (1, 1); } - let mut req_align = 0; - let mut buf_align = 0; + let mut req_align: u32 = 0; + let mut buf_align: u32 = 0; // SAFETY: we allocate aligned memory and free it later. let aligned_buffer = unsafe { libc::memalign( @@ -76,6 +76,10 @@ pub fn get_file_alignment(file: &File, direct: bool) -> (u32, u32) { (MAX_FILE_ALIGN * 2) as libc::size_t, ) }; + if aligned_buffer.is_null() { + log::warn!("OOM occurs when get file alignment, assume max alignment"); + return (MAX_FILE_ALIGN, MAX_FILE_ALIGN); + } // Guess alignment requirement of request. let mut align = MIN_FILE_ALIGN; @@ -92,7 +96,7 @@ pub fn get_file_alignment(file: &File, direct: bool) -> (u32, u32) { while align <= MAX_FILE_ALIGN { if is_io_aligned( file, - aligned_buffer as u64 + align as u64, + aligned_buffer as u64 + u64::from(align), MAX_FILE_ALIGN as usize, ) { buf_align = align; diff --git a/util/src/leak_bucket.rs b/util/src/leak_bucket.rs index 5dd65f2f0fc7fa9e686ac68f7d60e69703b3c4a4..23a1d9109784a3ef1b9ef6bf478acfab966419b0 100644 --- a/util/src/leak_bucket.rs +++ b/util/src/leak_bucket.rs @@ -15,12 +15,12 @@ use std::os::unix::io::{AsRawFd, RawFd}; use std::sync::Arc; use std::time::{Duration, Instant}; -use anyhow::Result; +use anyhow::{Context, Result}; use log::error; use vmm_sys_util::eventfd::EventFd; use crate::clock::get_current_time; -use crate::loop_context::EventLoopContext; +use crate::loop_context::{create_new_eventfd, EventLoopContext}; use crate::time::NANOSECONDS_PER_SECOND; /// Used to improve the accuracy of bucket level. @@ -49,11 +49,13 @@ impl LeakBucket { /// * `units_ps` - units per second. pub fn new(units_ps: u64) -> Result { Ok(LeakBucket { - capacity: units_ps * ACCURACY_SCALE, + capacity: units_ps + .checked_mul(ACCURACY_SCALE) + .with_context(|| "capacity overflow")?, level: 0, prev_time: get_current_time(), timer_started: false, - timer_wakeup: Arc::new(EventFd::new(libc::EFD_NONBLOCK)?), + timer_wakeup: Arc::new(create_new_eventfd()?), }) } @@ -63,7 +65,7 @@ impl LeakBucket { /// # Arguments /// /// * `loop_context` - used for delay function call. - pub fn throttled(&mut self, loop_context: &mut EventLoopContext, need_units: u64) -> bool { + pub fn throttled(&mut self, loop_context: &mut EventLoopContext, need_units: u32) -> bool { // capacity value is zero, indicating that there is no need to limit if self.capacity == 0 { return false; @@ -75,10 +77,13 @@ impl LeakBucket { // update the water level let now = get_current_time(); let nanos = (now - self.prev_time).as_nanos(); - if nanos > (self.level * NANOSECONDS_PER_SECOND / self.capacity) as u128 { + let throttle_timeout = + u128::from(self.level) * u128::from(NANOSECONDS_PER_SECOND) / u128::from(self.capacity); + if nanos > throttle_timeout { self.level = 0; } else { - self.level -= nanos as u64 * self.capacity / NANOSECONDS_PER_SECOND; + self.level -= + (nanos * u128::from(self.capacity) / u128::from(NANOSECONDS_PER_SECOND)) as u64; } self.prev_time = now; @@ -92,19 +97,17 @@ impl LeakBucket { .unwrap_or_else(|e| error!("LeakBucket send event to device failed {:?}", e)); }); - loop_context.timer_add( - func, - Duration::from_nanos( - (self.level - self.capacity) * NANOSECONDS_PER_SECOND / self.capacity, - ), - ); + let timeout = + (self.level - self.capacity).saturating_mul(NANOSECONDS_PER_SECOND) / self.capacity; + loop_context.timer_add(func, Duration::from_nanos(timeout)); self.timer_started = true; return true; } - self.level += need_units * ACCURACY_SCALE; + let scaled_need = u64::from(need_units) * ACCURACY_SCALE; + self.level = self.level.saturating_add(scaled_need); false } diff --git a/util/src/lib.rs b/util/src/lib.rs index f861d6f6a1cca1d01970d52d8b25957f4015df1d..33d7769094e16f93aedd3b3efc868ce0ee295426 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -21,6 +21,7 @@ pub mod daemonize; pub mod device_tree; pub mod edid; pub mod error; +pub mod evdev; pub mod file; pub mod leak_bucket; pub mod link_list; @@ -34,7 +35,6 @@ pub mod ohos_binding; pub mod pixman; pub mod seccomp; pub mod socket; -pub mod syscall; pub mod tap; pub mod test_helper; pub mod thread_pool; @@ -95,6 +95,60 @@ pub fn set_termi_canon_mode() -> std::io::Result<()> { Ok(()) } +/// Macro: Generate base getting function. +/// +/// # Arguments +/// +/// * `get_func` - Name of getting `&base` function. +/// * `get_mut_func` - Name of getting `&mut base` function. +/// * `base_type` - Type of `base`. +/// * `base` - `base` in self. +/// +/// # Examples +/// +/// ```rust +/// use util::gen_base_func; +/// struct TestBase(u8); +/// struct Test { +/// base: TestBase, +/// } +/// +/// impl Test { +/// gen_base_func!(test_base, test_base_mut, TestBase, base); +/// } +/// ``` +/// +/// This is equivalent to: +/// +/// ```rust +/// struct TestBase(u8); +/// struct Test { +/// base: TestBase, +/// } +/// +/// impl Test { +/// fn test_base(&self) -> &TestBase { +/// &self.base +/// } +/// +/// fn test_base_mut(&mut self) -> &mut TestBase { +/// &mut self.base +/// } +/// } +/// ``` +#[macro_export] +macro_rules! gen_base_func { + ($get_func: ident, $get_mut_func: ident, $base_type: ty, $($base: tt).*) => { + fn $get_func(&self) -> &$base_type { + &self.$($base).* + } + + fn $get_mut_func(&mut self) -> &mut $base_type { + &mut self.$($base).* + } + }; +} + /// This trait is to cast trait object to struct. pub trait AsAny { fn as_any(&self) -> &dyn Any; diff --git a/util/src/link_list.rs b/util/src/link_list.rs index 264b5d6fd1b7afb00cf6a5554822accb9db4065d..a779bce17724c480e46d003c0ca66c28b4916b6e 100644 --- a/util/src/link_list.rs +++ b/util/src/link_list.rs @@ -23,7 +23,7 @@ pub struct Node { pub struct List { head: Option>>, tail: Option>>, - pub len: usize, + len: usize, marker: PhantomData>>, } @@ -139,4 +139,78 @@ impl List { node }) } + + #[inline(always)] + pub fn len(&self) -> usize { + self.len + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline(always)] + pub fn iter(&'_ self) -> impl Iterator { + Iter::new(self) + } + + #[inline(always)] + pub fn iter_mut(&'_ mut self) -> impl Iterator { + IterMut::new(self) + } +} + +struct Iter<'a, T> { + curr: Option>>, + list: PhantomData<&'a List>, +} + +impl<'a, T> Iter<'a, T> { + fn new(list: &'a List) -> Self { + Self { + curr: list.head, + list: PhantomData, + } + } +} + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + self.curr.map(|node| { + // SAFETY: node is guaranteed not to be null. + let node = unsafe { node.as_ref() }; + self.curr = node.next; + &node.value + }) + } +} + +struct IterMut<'a, T> { + curr: Option>>, + list: PhantomData<&'a mut List>, +} + +impl<'a, T> IterMut<'a, T> { + fn new(list: &'a mut List) -> Self { + Self { + curr: list.head, + list: PhantomData, + } + } +} + +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option { + self.curr.map(|mut node| { + // SAFETY: node is guaranteed not to be null. + let node = unsafe { node.as_mut() }; + self.curr = node.next; + &mut node.value + }) + } } diff --git a/util/src/logger.rs b/util/src/logger.rs index 889e788f91e7ed0c3caeb2ff0ead4d72d49fecf0..de35d83f6dcad5e8b49b9754279c7f3b37877dd8 100644 --- a/util/src/logger.rs +++ b/util/src/logger.rs @@ -34,7 +34,7 @@ fn format_now() -> String { println!("{:?}", e); (0, 0) }); - let format_time = get_format_time(sec as i64); + let format_time = get_format_time(sec); format!( "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:09}", @@ -63,7 +63,7 @@ impl FileRotate { self.current_size += Wrapping(size_inc); let sec = gettime()?.0; - let today = get_format_time(sec as i64)[2]; + let today = get_format_time(sec)[2]; if self.current_size < Wrapping(LOG_ROTATE_SIZE_MAX) && self.create_day == today { return Ok(()); } @@ -159,7 +159,7 @@ fn init_vm_logger( current_size = Wrapping(metadata.len() as usize); let mod_time = metadata.modified()?; let sec = mod_time.duration_since(UNIX_EPOCH)?.as_secs(); - create_day = get_format_time(sec as i64)[2]; + create_day = get_format_time(i64::try_from(sec)?)[2]; }; let rotate = Mutex::new(FileRotate { handler: logfile, @@ -193,7 +193,6 @@ fn init_logger_with_env(logfile: Box, logfile_path: String) -> fn open_log_file(path: &str) -> Result { std::fs::OpenOptions::new() .read(false) - .write(true) .append(true) .create(true) .mode(0o640) diff --git a/util/src/loop_context.rs b/util/src/loop_context.rs index 10e4cf7068856ba70b39754b24fee40589d8b9b4..d1a3f9bf4ff45a5e66cb1f1f30f85780c675ed00 100644 --- a/util/src/loop_context.rs +++ b/util/src/loop_context.rs @@ -11,16 +11,17 @@ // See the Mulan PSL v2 for more details. use std::collections::BTreeMap; -use std::fmt; use std::fmt::Debug; +use std::io::Error; use std::os::unix::io::{AsRawFd, RawFd}; use std::rc::Rc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, RwLock}; +use std::sync::{Arc, Barrier, Mutex, RwLock}; use std::time::{Duration, Instant}; +use std::{fmt, i32}; use anyhow::{anyhow, Context, Result}; -use libc::{c_void, read, EFD_NONBLOCK}; +use libc::{c_void, read, EFD_CLOEXEC, EFD_NONBLOCK}; use log::{error, warn}; use nix::errno::Errno; use nix::{ @@ -54,6 +55,10 @@ pub enum NotifierOperation { Park = 16, /// Resume a file descriptor from the event table Resume = 32, + /// Add events to current event table for a file descriptor + AddEvents = 64, + /// Delete events from current event table for a file descriptor + DeleteEvents = 128, } #[derive(Debug, PartialEq)] @@ -154,6 +159,10 @@ pub fn gen_delete_notifiers(fds: &[RawFd]) -> Vec { notifiers } +pub fn create_new_eventfd() -> Result { + EventFd::new(EFD_NONBLOCK | EFD_CLOEXEC) +} + /// EventLoop manager, advise continue running or stop running pub trait EventLoopManager: Send + Sync { fn loop_should_exit(&self) -> bool; @@ -215,6 +224,8 @@ pub struct EventLoopContext { pub thread_pool: Arc, /// Record VM clock state. pub clock_state: Arc>, + /// The io thread barrier. + pub thread_exit_barrier: Arc, } impl Drop for EventLoopContext { @@ -231,11 +242,11 @@ unsafe impl Send for EventLoopContext {} impl EventLoopContext { /// Constructs a new `EventLoopContext`. - pub fn new() -> Self { + pub fn new(thread_exit_barrier: Arc) -> Self { let mut ctx = EventLoopContext { epoll: Epoll::new().unwrap(), manager: None, - kick_event: EventFd::new(EFD_NONBLOCK).unwrap(), + kick_event: create_new_eventfd().unwrap(), kick_me: AtomicBool::new(false), kicked: AtomicBool::new(false), events: Arc::new(RwLock::new(BTreeMap::new())), @@ -245,6 +256,7 @@ impl EventLoopContext { timer_next_id: AtomicU64::new(0), thread_pool: Arc::new(ThreadPool::default()), clock_state: Arc::new(Mutex::new(ClockState::default())), + thread_exit_barrier, }; ctx.init_kick(); ctx @@ -284,7 +296,7 @@ impl EventLoopContext { fn clear_gc(&mut self) { let max_cnt = self.gc.write().unwrap().len(); - let mut pop_cnt = 0; + let mut pop_cnt: usize = 0; loop { // Loop to avoid hold lock for long time. @@ -466,6 +478,35 @@ impl EventLoopContext { Ok(()) } + fn update_events_for_fd(&mut self, event: &EventNotifier, add: bool) -> Result<()> { + let mut events_map = self.events.write().unwrap(); + match events_map.get_mut(&event.raw_fd) { + Some(notifier) => { + let new_events = if add { + event.event | notifier.event + } else { + !event.event & notifier.event + }; + if new_events != notifier.event { + self.epoll + .ctl( + ControlOperation::Modify, + notifier.raw_fd, + EpollEvent::new(new_events, &**notifier as *const _ as u64), + ) + .with_context(|| { + format!("Failed to add events, event fd: {}", notifier.raw_fd) + })?; + notifier.event = new_events; + } + } + _ => { + return Err(anyhow!(UtilError::NoRegisterFd(event.raw_fd))); + } + } + Ok(()) + } + /// update fds registered to `EventLoop` according to the operation type. /// /// # Arguments @@ -490,6 +531,12 @@ impl EventLoopContext { NotifierOperation::Resume => { self.resume_event(&en)?; } + NotifierOperation::AddEvents => { + self.update_events_for_fd(&en, true)?; + } + NotifierOperation::DeleteEvents => { + self.update_events_for_fd(&en, false)?; + } } } self.kick(); @@ -506,17 +553,11 @@ impl EventLoopContext { } } - self.epoll_wait_manager(self.timers_min_duration()) + self.epoll_wait_manager(self.timers_min_duration())?; + Ok(true) } - pub fn iothread_run(&mut self) -> Result { - if let Some(manager) = &self.manager { - if manager.lock().unwrap().loop_should_exit() { - manager.lock().unwrap().loop_cleanup()?; - return Ok(false); - } - } - + pub fn iothread_run(&mut self) -> Result<()> { let min_timeout_ns = self.timers_min_duration(); if min_timeout_ns.is_none() { for _i in 0..AIO_PRFETCH_CYCLE_TIME { @@ -532,7 +573,8 @@ impl EventLoopContext { } } } - self.epoll_wait_manager(min_timeout_ns) + self.epoll_wait_manager(min_timeout_ns)?; + Ok(()) } /// Call the function given by `func` after `delay` time. @@ -593,7 +635,7 @@ impl EventLoopContext { /// Call function of the timers which have already expired. pub fn run_timers(&mut self) { let now = get_current_time(); - let mut expired_nr = 0; + let mut expired_nr: usize = 0; let mut timers = self.timers.lock().unwrap(); for timer in timers.iter() { @@ -611,7 +653,7 @@ impl EventLoopContext { } } - fn epoll_wait_manager(&mut self, mut time_out: Option) -> Result { + fn epoll_wait_manager(&mut self, mut time_out: Option) -> Result<()> { let need_kick = !(time_out.is_some() && *time_out.as_ref().unwrap() == Duration::ZERO); if need_kick { self.kick_me.store(true, Ordering::SeqCst); @@ -628,13 +670,13 @@ impl EventLoopContext { match ppoll(&mut pollfds, time_out_spec, None) { Ok(_) => time_out = Some(Duration::ZERO), - Err(e) if e == Errno::EINTR => time_out = Some(Duration::ZERO), + Err(Errno::EINTR) => time_out = Some(Duration::ZERO), Err(e) => return Err(anyhow!(UtilError::EpollWait(e.into()))), }; } let time_out_ms = match time_out { - Some(t) => t.as_millis() as i32, + Some(t) => i32::try_from(t.as_millis()).unwrap_or(i32::MAX), None => -1, }; let ev_count = match self.epoll.wait(time_out_ms, &mut self.ready_events[..]) { @@ -655,8 +697,7 @@ impl EventLoopContext { let mut notifiers = Vec::new(); let status_locked = event.status.lock().unwrap(); if *status_locked == EventStatus::Alive { - for j in 0..event.handlers.len() { - let handler = &event.handlers[j]; + for handler in event.handlers.iter() { match handler(self.ready_events[i].event_set(), event.raw_fd) { None => {} Some(mut notifier) => { @@ -673,30 +714,25 @@ impl EventLoopContext { self.run_timers(); self.clear_gc(); - Ok(true) + Ok(()) } -} -impl Default for EventLoopContext { - fn default() -> Self { - Self::new() + pub fn clean_event_loop(&mut self) -> Result<()> { + if let Some(manager) = &self.manager { + manager.lock().unwrap().loop_cleanup()?; + } + Ok(()) } } pub fn read_fd(fd: RawFd) -> u64 { let mut value: u64 = 0; - // SAFETY: this is called by notifier handler and notifier handler - // is executed with fd is is valid. The value is defined above thus - // valid too. - let ret = unsafe { - read( - fd, - &mut value as *mut u64 as *mut c_void, - std::mem::size_of::(), - ) - }; + let buf = &mut value as *mut u64 as *mut c_void; + let count = std::mem::size_of::(); + // SAFETY: The buf refers to local value and count equals to value size. + let ret = unsafe { read(fd, buf, count) }; if ret == -1 { warn!("Failed to read fd"); } @@ -707,6 +743,7 @@ pub fn read_fd(fd: RawFd) -> u64 { #[cfg(test)] mod test { use std::os::unix::io::{AsRawFd, RawFd}; + use std::sync::Barrier; use vmm_sys_util::{epoll::EventSet, eventfd::EventFd}; @@ -715,12 +752,9 @@ mod test { impl EventLoopContext { fn check_existence(&self, fd: RawFd) -> Option { let events_map = self.events.read().unwrap(); - match events_map.get(&fd) { - None => { - return None; - } - Some(notifier) => Some(*notifier.status.lock().unwrap() == EventStatus::Alive), - } + events_map + .get(&fd) + .map(|notifier| *notifier.status.lock().unwrap() == EventStatus::Alive) } fn create_event(&mut self) -> i32 { @@ -755,14 +789,13 @@ mod test { #[test] fn basic_test() { - let mut mainloop = EventLoopContext::new(); + let mut mainloop = EventLoopContext::new(Arc::new(Barrier::new(1))); let mut notifiers = Vec::new(); let fd1 = EventFd::new(EFD_NONBLOCK).unwrap(); let fd1_related = EventFd::new(EFD_NONBLOCK).unwrap(); let handler1 = generate_handler(fd1_related.as_raw_fd()); - let mut handlers = Vec::new(); - handlers.push(handler1); + let handlers = vec![handler1]; let event1 = EventNotifier::new( NotifierOperation::AddShared, fd1.as_raw_fd(), @@ -783,7 +816,7 @@ mod test { #[test] fn parked_event_test() { - let mut mainloop = EventLoopContext::new(); + let mut mainloop = EventLoopContext::new(Arc::new(Barrier::new(1))); let mut notifiers = Vec::new(); let fd1 = EventFd::new(EFD_NONBLOCK).unwrap(); let fd2 = EventFd::new(EFD_NONBLOCK).unwrap(); @@ -830,7 +863,7 @@ mod test { #[test] fn event_handler_test() { - let mut mainloop = EventLoopContext::new(); + let mut mainloop = EventLoopContext::new(Arc::new(Barrier::new(1))); let mut notifiers = Vec::new(); let fd1 = EventFd::new(EFD_NONBLOCK).unwrap(); let fd1_related = EventFd::new(EFD_NONBLOCK).unwrap(); @@ -869,7 +902,7 @@ mod test { #[test] fn error_operation_test() { - let mut mainloop = EventLoopContext::new(); + let mut mainloop = EventLoopContext::new(Arc::new(Barrier::new(1))); let fd1 = EventFd::new(EFD_NONBLOCK).unwrap(); let leisure_fd = EventFd::new(EFD_NONBLOCK).unwrap(); @@ -906,7 +939,7 @@ mod test { #[test] fn error_parked_operation_test() { - let mut mainloop = EventLoopContext::new(); + let mut mainloop = EventLoopContext::new(Arc::new(Barrier::new(1))); let fd1 = EventFd::new(EFD_NONBLOCK).unwrap(); let fd2 = EventFd::new(EFD_NONBLOCK).unwrap(); @@ -941,7 +974,7 @@ mod test { #[test] fn fd_released_test() { - let mut mainloop = EventLoopContext::new(); + let mut mainloop = EventLoopContext::new(Arc::new(Barrier::new(1))); let fd = mainloop.create_event(); // In this case, fd is already closed. But program was wrote to ignore the error. diff --git a/util/src/num_ops.rs b/util/src/num_ops.rs index 2be535ae5669329e3d17b6773520908d81343fb0..f5ea59be220b5323a490857dc203d3d8ac00fc76 100644 --- a/util/src/num_ops.rs +++ b/util/src/num_ops.rs @@ -35,7 +35,7 @@ use log::error; /// assert!(value == Some(1004)); /// ``` pub fn round_up(origin: u64, align: u64) -> Option { - match origin % align { + match origin.checked_rem(align)? { 0 => Some(origin), diff => origin.checked_add(align - diff), } @@ -58,7 +58,7 @@ pub fn round_up(origin: u64, align: u64) -> Option { /// assert!(value == Some(1000)); /// ``` pub fn round_down(origin: u64, align: u64) -> Option { - match origin % align { + match origin.checked_rem(align)? { 0 => Some(origin), diff => origin.checked_sub(diff), } @@ -81,14 +81,12 @@ pub fn round_down(origin: u64, align: u64) -> Option { /// assert!(value == Some(3)); /// ``` pub fn div_round_up(dividend: u64, divisor: u64) -> Option { - if let Some(res) = dividend.checked_div(divisor) { - if dividend % divisor == 0 { - return Some(res); - } else { - return Some(res + 1); - } + let res = dividend.checked_div(divisor)?; + if dividend.checked_rem(divisor)? == 0 { + Some(res) + } else { + Some(res + 1) } - None } /// Get the first half or second half of u64. @@ -203,7 +201,7 @@ pub fn write_u64_high(origin: u64, value: u32) -> u64 { /// assert!(value == 0xfa); /// ``` pub fn extract_u32(value: u32, start: u32, length: u32) -> Option { - if length > 32 - start { + if length > 32_u32.checked_sub(start)? { error!( "extract_u32: ( start {} length {} ) is out of range", start, length @@ -235,7 +233,7 @@ pub fn extract_u32(value: u32, start: u32, length: u32) -> Option { /// assert!(value == 0xffff); /// ``` pub fn extract_u64(value: u64, start: u32, length: u32) -> Option { - if length > 64 - start { + if length > 64_u32.checked_sub(start)? { error!( "extract_u64: ( start {} length {} ) is out of range", start, length @@ -243,7 +241,7 @@ pub fn extract_u64(value: u64, start: u32, length: u32) -> Option { return None; } - Some((value >> start as u64) & (!(0_u64) >> (64 - length) as u64)) + Some((value >> u64::from(start)) & (!(0_u64) >> u64::from(64 - length))) } /// Deposit @fieldval into the 32 bit @value at the bit field specified @@ -271,7 +269,7 @@ pub fn extract_u64(value: u64, start: u32, length: u32) -> Option { /// assert!(value == 0xffba); /// ``` pub fn deposit_u32(value: u32, start: u32, length: u32, fieldval: u32) -> Option { - if length > 32 - start { + if length > 32_u32.checked_sub(start)? { error!( "deposit_u32: ( start {} length {} ) is out of range", start, length @@ -371,8 +369,8 @@ pub fn write_data_u32(data: &mut [u8], value: u32) -> bool { /// ``` pub fn read_data_u32(data: &[u8], value: &mut u32) -> bool { *value = match data.len() { - 1 => data[0] as u32, - 2 => LittleEndian::read_u16(data) as u32, + 1 => u32::from(data[0]), + 2 => u32::from(LittleEndian::read_u16(data)), 4 => LittleEndian::read_u32(data), _ => { error!("Invalid data length: data len {}", data.len()); @@ -401,7 +399,7 @@ pub fn read_data_u32(data: &[u8], value: &mut u32) -> bool { /// ``` pub fn read_data_u16(data: &[u8], value: &mut u16) -> bool { *value = match data.len() { - 1 => data[0] as u16, + 1 => u16::from(data[0]), 2 => LittleEndian::read_u16(data), _ => { error!("Invalid data length: data len {}", data.len()); @@ -451,7 +449,7 @@ int_trait_impl!(Num for u8 u16 usize); /// assert!(value == 17); /// ``` pub fn str_to_num(s: &str) -> Result { - let mut base = 10; + let mut base: u32 = 10; if s.starts_with("0x") || s.starts_with("0X") { base = 16; } @@ -496,13 +494,13 @@ mod test { #[test] fn round_up_test() { - let result = round_up(10001 as u64, 100 as u64); + let result = round_up(10001_u64, 100_u64); assert_eq!(result, Some(10100)); } #[test] fn round_down_test() { - let result = round_down(10001 as u64, 100 as u64); + let result = round_down(10001_u64, 100_u64); assert_eq!(result, Some(10000)); } diff --git a/util/src/offsetof.rs b/util/src/offsetof.rs index 2cc01699b3c6fabcd0cfbc4301777c898e5b9e2f..055b24e26fb40dca6233ce48b879f3e96e35e7f2 100644 --- a/util/src/offsetof.rs +++ b/util/src/offsetof.rs @@ -10,6 +10,8 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +// Note: This can be replaced with std::mem::offset_of! within higher rust version. + /// Macro: Calculate offset of specified field in a type. #[macro_export] macro_rules! __offset_of { diff --git a/util/src/ohos_binding/audio/mod.rs b/util/src/ohos_binding/audio/mod.rs index 9590dad7ef1fc5e0798ba68194ba787616a0d397..7c5585c1ac7abf4ec14e879a7ee8f73339e576df 100755 --- a/util/src/ohos_binding/audio/mod.rs +++ b/util/src/ohos_binding/audio/mod.rs @@ -10,17 +10,21 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -mod sys; +pub mod sys; use std::os::raw::c_void; use std::ptr; +use std::sync::{Arc, RwLock}; use log::error; +use once_cell::sync::Lazy; -use sys as capi; +use super::hwf_adapter::{hwf_adapter_volume_api, volume::VolumeFuncTable}; +pub use sys as capi; const AUDIO_SAMPLE_RATE_44KHZ: u32 = 44100; const AUDIO_SAMPLE_RATE_48KHZ: u32 = 48000; +const RENDER_CB_FREQUENCY: i32 = 50; macro_rules! call_capi { ( $f: ident ( $($x: expr),* ) ) => { @@ -184,6 +188,14 @@ pub enum AudioProcessCb { length: i32, ) -> i32, >, + Option< + extern "C" fn( + capturer: *mut capi::OhAudioCapturer, + userData: *mut c_void, + source_type: capi::OHAudioInterruptSourceType, + hint: capi::OHAudioInterruptHint, + ) -> i32, + >, ), RendererCb( Option< @@ -194,9 +206,18 @@ pub enum AudioProcessCb { length: i32, ) -> i32, >, + Option< + extern "C" fn( + capturer: *mut capi::OhAudioRenderer, + userData: *mut c_void, + source_type: capi::OHAudioInterruptSourceType, + hint: capi::OHAudioInterruptHint, + ) -> i32, + >, ), } +#[derive(Debug)] pub struct AudioContext { stream_type: AudioStreamType, spec: AudioSpec, @@ -253,10 +274,18 @@ impl AudioContext { )) } + pub fn set_frame_size(&self, size: i32) -> Result<(), OAErr> { + call_capi!(OH_AudioStreamBuilder_SetFrameSizeInCallback( + self.builder, + size + )) + } + fn create_renderer(&mut self, cb: AudioProcessCb) -> Result<(), OAErr> { let mut cbs = capi::OhAudioRendererCallbacks::default(); - if let AudioProcessCb::RendererCb(f) = cb { - cbs.oh_audio_renderer_on_write_data = f; + if let AudioProcessCb::RendererCb(data_cb, interrupt_cb) = cb { + cbs.oh_audio_renderer_on_write_data = data_cb; + cbs.oh_audio_renderer_on_interrupt_event = interrupt_cb; } call_capi!(OH_AudioStreamBuilder_SetRendererCallback( self.builder, @@ -271,14 +300,19 @@ impl AudioContext { fn create_capturer(&mut self, cb: AudioProcessCb) -> Result<(), OAErr> { let mut cbs = capi::OhAudioCapturerCallbacks::default(); - if let AudioProcessCb::CapturerCb(v) = cb { - cbs.oh_audio_capturer_on_read_data = v; + if let AudioProcessCb::CapturerCb(data_cb, interrupt_cb) = cb { + cbs.oh_audio_capturer_on_read_data = data_cb; + cbs.oh_audio_capturer_on_interrupt_event = interrupt_cb; } call_capi!(OH_AudioStreamBuilder_SetCapturerCallback( self.builder, cbs, self.userdata ))?; + call_capi!(OH_AudioStreamBuilder_SetCapturerInfo( + self.builder, + capi::OH_AUDIO_STREAM_SOURCE_TYPE_AUDIOSTREAM_SOURCE_TYPE_VOICE_COMMUNICATION + ))?; call_capi!(OH_AudioStreamBuilder_GenerateCapturer( self.builder, &mut self.capturer @@ -328,6 +362,9 @@ impl AudioContext { self.set_fmt(size, rate, channels)?; self.set_sample_rate()?; self.set_sample_format()?; + if capi::OH_AUDIO_STREAM_TYPE_AUDIOSTREAM_TYPE_RERNDERER == self.stream_type.into() { + self.set_frame_size(rate as i32 / RENDER_CB_FREQUENCY)?; + } self.create_processor(cb) } @@ -365,6 +402,94 @@ impl AudioContext { } } +// From here, the code is related to ohaudio volume. +static OH_VOLUME_ADAPTER: Lazy> = Lazy::new(|| RwLock::new(OhVolume::new())); + +pub trait GuestVolumeNotifier: Send + Sync { + fn notify(&self, vol: u32); +} + +struct OhVolume { + capi: Arc, + notifiers: Vec>, +} + +impl OhVolume { + fn new() -> Self { + let capi = hwf_adapter_volume_api(); + // SAFETY: We call related API sequentially for specified ctx. + unsafe { (*capi.register_volume_change)(on_ohos_volume_changed) }; + Self { + capi, + notifiers: Vec::new(), + } + } + + fn get_ohos_volume(&self) -> u32 { + // SAFETY: We call related API sequentially for specified ctx. + unsafe { (self.capi.get_volume)() as u32 } + } + + fn get_max_volume(&self) -> u32 { + // SAFETY: We call related API sequentially for specified ctx. + unsafe { (self.capi.get_max_volume)() as u32 } + } + + fn get_min_volume(&self) -> u32 { + // SAFETY: We call related API sequentially for specified ctx. + unsafe { (self.capi.get_min_volume)() as u32 } + } + + fn set_ohos_volume(&self, volume: i32) { + // SAFETY: We call related API sequentially for specified ctx. + unsafe { (self.capi.set_volume)(volume) }; + } + + fn notify_volume_change(&self, volume: i32) { + for notifier in self.notifiers.iter() { + notifier.notify(volume as u32); + } + } + + fn register_guest_notifier(&mut self, notifier: Arc) { + self.notifiers.push(notifier); + } +} + +// SAFETY: use RW lock to ensure the security of resources. +unsafe extern "C" fn on_ohos_volume_changed(volume: i32) { + OH_VOLUME_ADAPTER + .read() + .unwrap() + .notify_volume_change(volume); +} + +pub fn register_guest_volume_notifier(notifier: Arc) { + OH_VOLUME_ADAPTER + .write() + .unwrap() + .register_guest_notifier(notifier); +} + +pub fn get_ohos_volume_max() -> u32 { + OH_VOLUME_ADAPTER.read().unwrap().get_max_volume() +} + +pub fn get_ohos_volume_min() -> u32 { + OH_VOLUME_ADAPTER.read().unwrap().get_min_volume() +} + +pub fn get_ohos_volume() -> u32 { + OH_VOLUME_ADAPTER.read().unwrap().get_ohos_volume() +} + +pub fn set_ohos_volume(vol: u32) { + OH_VOLUME_ADAPTER + .read() + .unwrap() + .set_ohos_volume(vol as i32); +} + #[cfg(test)] mod tests { use crate::ohos_binding::audio::sys as capi; diff --git a/util/src/ohos_binding/audio/sys.rs b/util/src/ohos_binding/audio/sys.rs index 289becbe377a117b3b33ae35ced560504f975142..66440c38fb0210160ecb9b2ba80c947afde4d443 100755 --- a/util/src/ohos_binding/audio/sys.rs +++ b/util/src/ohos_binding/audio/sys.rs @@ -131,6 +131,31 @@ pub const OH_AUDIO_STREAM_SOURCE_TYPE_AUDIOSTREAM_SOURCE_TYPE_VOICE_COMMUNICATIO /// @since 10 pub type OHAudioStreamSourceType = ::std::os::raw::c_int; +#[allow(unused)] +pub const AUDIOSTREAM_INTERRUPT_FORCE: OHAudioInterruptSourceType = 0; +#[allow(unused)] +pub const AUDIOSTREAM_INTERRUPT_SHARE: OHAudioInterruptSourceType = 1; + +/// Defines the audio interrupt source type. +/// +/// @since 10 +pub type OHAudioInterruptSourceType = ::std::os::raw::c_int; + +#[allow(unused)] +pub const AUDIOSTREAM_INTERRUPT_HINT_RESUME: OHAudioInterruptHint = 1; +pub const AUDIOSTREAM_INTERRUPT_HINT_PAUSE: OHAudioInterruptHint = 2; +#[allow(unused)] +pub const AUDIOSTREAM_INTERRUPT_HINT_STOP: OHAudioInterruptHint = 3; +#[allow(unused)] +pub const AUDIOSTREAM_INTERRUPT_HINT_DUCK: OHAudioInterruptHint = 4; +#[allow(unused)] +pub const AUDIOSTREAM_INTERRUPT_HINT_UNDUCK: OHAudioInterruptHint = 5; + +/// Defines the audio interrupt hint type. +/// +/// @since 10 +pub type OHAudioInterruptHint = ::std::os::raw::c_int; + #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct OH_AudioStreamBuilderStruct { @@ -183,7 +208,14 @@ pub struct OhAudioRendererCallbacks { ) -> i32, >, pub oh_audio_renderer_on_stream_event: PlaceHolderFn, - pub oh_audio_renderer_on_interrpt_event: PlaceHolderFn, + pub oh_audio_renderer_on_interrupt_event: ::std::option::Option< + extern "C" fn( + renderer: *mut OhAudioRenderer, + userData: *mut ::std::os::raw::c_void, + source_type: OHAudioInterruptSourceType, + hint: OHAudioInterruptHint, + ) -> i32, + >, pub oh_audio_renderer_on_error: PlaceHolderFn, } @@ -204,7 +236,14 @@ pub struct OhAudioCapturerCallbacks { ) -> i32, >, pub oh_audio_capturer_on_stream_event: PlaceHolderFn, - pub oh_audio_capturer_on_interrpt_event: PlaceHolderFn, + pub oh_audio_capturer_on_interrupt_event: ::std::option::Option< + extern "C" fn( + capturer: *mut OhAudioCapturer, + userData: *mut ::std::os::raw::c_void, + source_type: OHAudioInterruptSourceType, + hint: OHAudioInterruptHint, + ) -> i32, + >, pub oh_audio_capturer_on_error: PlaceHolderFn, } @@ -249,6 +288,10 @@ extern "C" { renderer: *mut OhAudioRenderer, encodingType: *mut OhAudioStreamEncodingType, ) -> OhAudioStreamResult; + pub fn OH_AudioStreamBuilder_SetFrameSizeInCallback( + builder: *mut OhAudioStreamBuilder, + size: i32, + ) -> OhAudioStreamResult; /// Create a streamBuilder can be used to open a renderer or capturer client. /// /// OH_AudioStreamBuilder_Destroy() must be called when you are done using the builder. diff --git a/util/src/ohos_binding/camera.rs b/util/src/ohos_binding/camera.rs index c481403ebb6b7c24b7fb262da27febeb0b9fb270..4121fcf328635c2b9941ca119a0f3b835880df59 100644 --- a/util/src/ohos_binding/camera.rs +++ b/util/src/ohos_binding/camera.rs @@ -10,11 +10,12 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::ffi::CString; use std::os::raw::{c_int, c_void}; use std::ptr; use std::sync::Arc; -use anyhow::{bail, Result}; +use anyhow::{bail, Context, Result}; use super::hwf_adapter::camera::{ BrokenProcessFn, BufferProcessFn, CamFuncTable, OhCameraCtx, ProfileRecorder, @@ -45,42 +46,33 @@ impl Drop for OhCamera { } impl OhCamera { - pub fn new(idx: i32) -> Result { + pub fn new(id: String) -> Result<(OhCamera, i32)> { let capi = hwf_adapter_camera_api(); // SAFETY: We call related API sequentially for specified ctx. let mut ctx = unsafe { (capi.create_ctx)() }; if ctx.is_null() { bail!("OH Camera: failed to create camera ctx"); } + let id_c = CString::new(id).with_context(|| "failed to create CString id")?; + let fmt_cnt; // SAFETY: We call related API sequentially for specified ctx. unsafe { - let n = (capi.init_cameras)(ctx); + let n = (capi.init_camera)(ctx, id_c.as_ptr()); if n < 0 { (capi.destroy_ctx)(ptr::addr_of_mut!(ctx)); bail!("OH Camera: failed to init cameras"); - } else if idx >= n { - (capi.destroy_ctx)(ptr::addr_of_mut!(ctx)); - bail!( - "OH Camera: invalid idx {}, valid num is less than {}", - idx, - n - ); } - if (capi.init_profiles)(ctx) < 0 { + + fmt_cnt = (capi.init_profiles)(ctx); + if fmt_cnt < 0 { (capi.destroy_ctx)(ptr::addr_of_mut!(ctx)); bail!("OH Camera: failed to init profiles"); } } - Ok(Self { ctx, capi }) - } - - pub fn get_fmt_nums(&self, idx: i32) -> Result { - // SAFETY: We call related API sequentially for specified ctx. - let ret = unsafe { (self.capi.get_profile_size)(self.ctx, idx as c_int) }; - if ret < 0 { - bail!("OH Camera: invalid camera idx {}", idx); + if fmt_cnt > i32::from(u8::MAX) { + bail!("Invalid format counts: {fmt_cnt}"); } - Ok(ret) + Ok((Self { ctx, capi }, fmt_cnt)) } pub fn release_camera(&self) { @@ -93,10 +85,10 @@ impl OhCamera { unsafe { (self.capi.destroy_ctx)(ptr::addr_of_mut!(self.ctx)) } } - pub fn set_fmt(&self, cam_idx: i32, profile_idx: i32) -> Result<()> { + pub fn set_fmt(&self, profile_idx: i32) -> Result<()> { let ret = // SAFETY: We call related API sequentially for specified ctx. - unsafe { (self.capi.set_profile)(self.ctx, cam_idx as c_int, profile_idx as c_int) }; + unsafe { (self.capi.set_profile)(self.ctx, profile_idx as c_int) }; if ret < 0 { bail!("OH Camera: failed to get camera profile"); } @@ -110,6 +102,9 @@ impl OhCamera { ) -> Result<()> { // SAFETY: We call related API sequentially for specified ctx. unsafe { + if (self.capi.create_session)(self.ctx) != 0 { + bail!("OH Camera: failed to create session"); + } if (self.capi.pre_start)(self.ctx, buffer_proc, broken_proc) != 0 { bail!("OH Camera: failed to prestart camera stream"); } @@ -120,13 +115,14 @@ impl OhCamera { Ok(()) } - pub fn reset_camera(&self) { + pub fn reset_camera(&self, id: String) -> Result<()> { + let id_cstr = CString::new(id).with_context(|| "failed to create CString id")?; // SAFETY: We call related API sequentially for specified ctx. unsafe { - (self.capi.create_session)(self.ctx); - (self.capi.init_cameras)(self.ctx); + (self.capi.init_camera)(self.ctx, id_cstr.as_ptr()); (self.capi.init_profiles)(self.ctx); } + Ok(()) } pub fn stop_stream(&self) { @@ -137,22 +133,17 @@ impl OhCamera { } } - pub fn get_profile(&self, cam_idx: i32, profile_idx: i32) -> Result<(i32, i32, i32, i32)> { + pub fn get_profile(&self, profile_idx: i32) -> Result<(i32, i32, i32, i32)> { let pr = ProfileRecorder::default(); // SAFETY: We call related API sequentially for specified ctx. unsafe { let ret = (self.capi.get_profile)( self.ctx, - cam_idx as c_int, profile_idx as c_int, ptr::addr_of!(pr) as *mut c_void, ); if ret < 0 { - bail!( - "OH Camera: failed to get camera {} profile {}", - cam_idx, - profile_idx - ); + bail!("OH Camera: failed to get profile {}", profile_idx); } } Ok((pr.fmt, pr.width, pr.height, pr.fps)) diff --git a/util/src/ohos_binding/hwf_adapter/camera.rs b/util/src/ohos_binding/hwf_adapter/camera.rs index 0fdc9ee74a214b02752f2a56a8d6a34b1cd85abc..bbb2074279b2a0e774b4c79c07f91af0efaccbe4 100644 --- a/util/src/ohos_binding/hwf_adapter/camera.rs +++ b/util/src/ohos_binding/hwf_adapter/camera.rs @@ -10,7 +10,7 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. -use std::os::raw::{c_int, c_void}; +use std::os::raw::{c_char, c_int, c_void}; use anyhow::{Context, Result}; use libloading::os::unix::Symbol as RawSymbol; @@ -33,17 +33,16 @@ pub struct ProfileRecorder { pub fps: i32, } -pub type BufferProcessFn = unsafe extern "C" fn(src_buffer: u64, length: i32); -pub type BrokenProcessFn = unsafe extern "C" fn(); +pub type BufferProcessFn = unsafe extern "C" fn(src_buffer: u64, length: i32, camid: *const c_char); +pub type BrokenProcessFn = unsafe extern "C" fn(camid: *const c_char); type OhcamCreateCtxFn = unsafe extern "C" fn() -> *mut OhCameraCtx; type OhcamCreateSessionFn = unsafe extern "C" fn(*mut OhCameraCtx) -> c_int; type OhcamReleaseSessionFn = unsafe extern "C" fn(*mut OhCameraCtx); -type OhcamInitCamerasFn = unsafe extern "C" fn(*mut OhCameraCtx) -> c_int; +type OhcamInitCameraFn = unsafe extern "C" fn(*mut OhCameraCtx, *const c_char) -> c_int; type OhcamInitProfilesFn = unsafe extern "C" fn(*mut OhCameraCtx) -> c_int; -type OhcamGetProfileSizeFn = unsafe extern "C" fn(*mut OhCameraCtx, c_int) -> c_int; -type OhcamGetProfileFn = unsafe extern "C" fn(*mut OhCameraCtx, c_int, c_int, *mut c_void) -> c_int; -type OhcamSetProfileFn = unsafe extern "C" fn(*mut OhCameraCtx, c_int, c_int) -> c_int; +type OhcamGetProfileFn = unsafe extern "C" fn(*mut OhCameraCtx, c_int, *mut c_void) -> c_int; +type OhcamSetProfileFn = unsafe extern "C" fn(*mut OhCameraCtx, c_int) -> c_int; type OhcamPreStartFn = unsafe extern "C" fn(*mut OhCameraCtx, BufferProcessFn, BrokenProcessFn) -> c_int; type OhcamStartFn = unsafe extern "C" fn(*mut OhCameraCtx) -> c_int; @@ -56,9 +55,8 @@ pub struct CamFuncTable { pub create_ctx: RawSymbol, pub create_session: RawSymbol, pub release_session: RawSymbol, - pub init_cameras: RawSymbol, + pub init_camera: RawSymbol, pub init_profiles: RawSymbol, - pub get_profile_size: RawSymbol, pub get_profile: RawSymbol, pub set_profile: RawSymbol, pub pre_start: RawSymbol, @@ -75,9 +73,8 @@ impl CamFuncTable { create_ctx: get_libfn!(library, OhcamCreateCtxFn, OhcamCreateCtx), create_session: get_libfn!(library, OhcamCreateSessionFn, OhcamCreateSession), release_session: get_libfn!(library, OhcamReleaseSessionFn, OhcamReleaseSession), - init_cameras: get_libfn!(library, OhcamInitCamerasFn, OhcamInitCameras), + init_camera: get_libfn!(library, OhcamInitCameraFn, OhcamInitCamera), init_profiles: get_libfn!(library, OhcamInitProfilesFn, OhcamInitProfiles), - get_profile_size: get_libfn!(library, OhcamGetProfileSizeFn, OhcamGetProfileSize), get_profile: get_libfn!(library, OhcamGetProfileFn, OhcamGetProfile), set_profile: get_libfn!(library, OhcamSetProfileFn, OhcamSetProfile), pre_start: get_libfn!(library, OhcamPreStartFn, OhcamPreStart), diff --git a/util/src/ohos_binding/hwf_adapter/mod.rs b/util/src/ohos_binding/hwf_adapter/mod.rs index ffa11457fb2e44c87fc64c283a42d29c75ed1064..14e4e0c24657c12779b823239c6f914eab259b55 100644 --- a/util/src/ohos_binding/hwf_adapter/mod.rs +++ b/util/src/ohos_binding/hwf_adapter/mod.rs @@ -12,6 +12,11 @@ #[cfg(feature = "usb_camera_oh")] pub mod camera; +#[cfg(feature = "usb_host")] +pub mod usb; + +#[cfg(feature = "scream_ohaudio")] +pub mod volume; use std::ffi::OsStr; use std::sync::Arc; @@ -23,6 +28,10 @@ use once_cell::sync::Lazy; #[cfg(feature = "usb_camera_oh")] use camera::CamFuncTable; +#[cfg(feature = "usb_host")] +use usb::UsbFuncTable; +#[cfg(feature = "scream_ohaudio")] +use volume::VolumeFuncTable; static LIB_HWF_ADAPTER: Lazy = Lazy::new(|| // SAFETY: The dynamic library should be always existing. @@ -40,6 +49,10 @@ struct LibHwfAdapter { library: Library, #[cfg(feature = "usb_camera_oh")] camera: Arc, + #[cfg(feature = "usb_host")] + usb: Arc, + #[cfg(feature = "scream_ohaudio")] + volume: Arc, } impl LibHwfAdapter { @@ -52,10 +65,25 @@ impl LibHwfAdapter { CamFuncTable::new(&library).with_context(|| "failed to init camera function table")?, ); + #[cfg(feature = "usb_host")] + let usb = Arc::new( + UsbFuncTable::new(&library).with_context(|| "failed to init usb function table")?, + ); + + #[cfg(feature = "scream_ohaudio")] + let volume = Arc::new( + VolumeFuncTable::new(&library) + .with_context(|| "failed to init volume function table")?, + ); + Ok(Self { library, #[cfg(feature = "usb_camera_oh")] camera, + #[cfg(feature = "usb_host")] + usb, + #[cfg(feature = "scream_ohaudio")] + volume, }) } @@ -63,9 +91,29 @@ impl LibHwfAdapter { fn get_camera_api(&self) -> Arc { self.camera.clone() } + + #[cfg(feature = "usb_host")] + fn get_usb_api(&self) -> Arc { + self.usb.clone() + } + + #[cfg(feature = "scream_ohaudio")] + fn get_volume_api(&self) -> Arc { + self.volume.clone() + } } #[cfg(feature = "usb_camera_oh")] pub fn hwf_adapter_camera_api() -> Arc { LIB_HWF_ADAPTER.get_camera_api() } + +#[cfg(feature = "usb_host")] +pub fn hwf_adapter_usb_api() -> Arc { + LIB_HWF_ADAPTER.get_usb_api() +} + +#[cfg(feature = "scream_ohaudio")] +pub fn hwf_adapter_volume_api() -> Arc { + LIB_HWF_ADAPTER.get_volume_api() +} diff --git a/util/src/ohos_binding/hwf_adapter/usb.rs b/util/src/ohos_binding/hwf_adapter/usb.rs new file mode 100644 index 0000000000000000000000000000000000000000..abb3cc7479b914d571d00935c2bee457dfa7380c --- /dev/null +++ b/util/src/ohos_binding/hwf_adapter/usb.rs @@ -0,0 +1,45 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::os::raw::c_int; + +use anyhow::{Context, Result}; +use libloading::os::unix::Symbol as RawSymbol; +use libloading::Library; + +use crate::get_libfn; + +#[allow(non_snake_case)] +#[repr(C)] +#[derive(Eq, PartialEq, Clone, Copy, Debug)] +pub struct OhusbDevice { + pub busNum: u8, + pub devAddr: u8, + pub fd: c_int, +} + +type OhusbOpenDeviceFn = unsafe extern "C" fn(*mut OhusbDevice) -> c_int; +type OhusbCloseDeviceFn = unsafe extern "C" fn(*mut OhusbDevice) -> c_int; + +pub struct UsbFuncTable { + pub open_device: RawSymbol, + pub close_device: RawSymbol, +} + +impl UsbFuncTable { + pub unsafe fn new(library: &Library) -> Result { + Ok(Self { + open_device: get_libfn!(library, OhusbOpenDeviceFn, OhusbOpenDevice), + close_device: get_libfn!(library, OhusbCloseDeviceFn, OhusbCloseDevice), + }) + } +} diff --git a/util/src/ohos_binding/hwf_adapter/volume.rs b/util/src/ohos_binding/hwf_adapter/volume.rs new file mode 100644 index 0000000000000000000000000000000000000000..b730e143308c5721b93b0b02e1293559f76e712a --- /dev/null +++ b/util/src/ohos_binding/hwf_adapter/volume.rs @@ -0,0 +1,51 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::os::raw::c_int; + +use anyhow::{Context, Result}; +use libloading::os::unix::Symbol as RawSymbol; +use libloading::Library; + +use crate::get_libfn; + +pub type VolumeChangedCallBack = unsafe extern "C" fn(c_int); + +type OhSysAudioGetVolumeFn = unsafe extern "C" fn() -> c_int; +type OhSysAudioGetMaxVolumeFn = unsafe extern "C" fn() -> c_int; +type OhSysAudioGetMinVolumeFn = unsafe extern "C" fn() -> c_int; +type OhSysAudioSetVolumeFn = unsafe extern "C" fn(c_int); +type OhSysAudioRegisterVolumeChangeFn = unsafe extern "C" fn(VolumeChangedCallBack) -> c_int; + +pub struct VolumeFuncTable { + pub get_volume: RawSymbol, + pub get_max_volume: RawSymbol, + pub get_min_volume: RawSymbol, + pub set_volume: RawSymbol, + pub register_volume_change: RawSymbol, +} + +impl VolumeFuncTable { + pub unsafe fn new(library: &Library) -> Result { + Ok(Self { + get_volume: get_libfn!(library, OhSysAudioGetVolumeFn, OhSysAudioGetVolume), + get_max_volume: get_libfn!(library, OhSysAudioGetMaxVolumeFn, OhSysAudioGetMaxVolume), + get_min_volume: get_libfn!(library, OhSysAudioGetMinVolumeFn, OhSysAudioGetMinVolume), + set_volume: get_libfn!(library, OhSysAudioSetVolumeFn, OhSysAudioSetVolume), + register_volume_change: get_libfn!( + library, + OhSysAudioRegisterVolumeChangeFn, + OhSysAudioRegisterVolumeChange + ), + }) + } +} diff --git a/util/src/ohos_binding/misc.rs b/util/src/ohos_binding/misc.rs index 1d9e31a57c95758617cb92e642c7f88a840819d3..27b564557e8f96715d1c2e5e78fd6459f14122b4 100644 --- a/util/src/ohos_binding/misc.rs +++ b/util/src/ohos_binding/misc.rs @@ -34,7 +34,7 @@ ioctl_ior_nr!( ::std::os::raw::c_ulonglong ); -pub fn set_firstcaller_tokenid(id: u64) -> Result<()> { +fn set_firstcaller_tokenid(id: u64) -> Result<()> { let fd = OpenOptions::new() .read(true) .write(true) @@ -56,7 +56,7 @@ pub fn set_firstcaller_tokenid(id: u64) -> Result<()> { Ok(()) } -pub fn get_firstcaller_tokenid() -> Result { +fn get_firstcaller_tokenid() -> Result { let fd = OpenOptions::new() .read(true) .write(true) @@ -78,3 +78,12 @@ pub fn get_firstcaller_tokenid() -> Result { } Ok(id) } + +pub fn bound_tokenid(token_id: u64) -> Result<()> { + if token_id == 0 { + bail!("UI token ID not passed."); + } else if token_id != get_firstcaller_tokenid()? { + set_firstcaller_tokenid(token_id)?; + } + Ok(()) +} diff --git a/util/src/ohos_binding/mod.rs b/util/src/ohos_binding/mod.rs index 5f876ba544f5d1559704a66d9ff93e42a6224ff3..2e6a3cfcc775e570537c87e7a805151998c465ee 100644 --- a/util/src/ohos_binding/mod.rs +++ b/util/src/ohos_binding/mod.rs @@ -15,6 +15,8 @@ pub mod audio; #[cfg(feature = "usb_camera_oh")] pub mod camera; pub mod misc; +#[cfg(feature = "usb_host")] +pub mod usb; #[cfg(feature = "usb_camera_oh")] mod hwf_adapter; diff --git a/util/src/ohos_binding/usb.rs b/util/src/ohos_binding/usb.rs new file mode 100644 index 0000000000000000000000000000000000000000..a8d227bd316b3b9ff6f546f5efc8e26fdd4ef6a7 --- /dev/null +++ b/util/src/ohos_binding/usb.rs @@ -0,0 +1,50 @@ +// Copyright (c) 2024 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +pub use super::hwf_adapter::usb::OhusbDevice; + +use std::sync::Arc; + +use anyhow::{bail, Result}; + +use super::hwf_adapter::hwf_adapter_usb_api; +use super::hwf_adapter::usb::UsbFuncTable; + +#[derive(Clone)] +pub struct OhUsb { + capi: Arc, +} + +impl OhUsb { + pub fn new() -> Result { + let capi = hwf_adapter_usb_api(); + Ok(Self { capi }) + } + + pub fn open_device(&self, dev_handle: *mut OhusbDevice) -> Result { + // SAFETY: We call related API sequentially for specified ctx. + let ret = unsafe { (self.capi.open_device)(dev_handle) }; + if ret < 0 { + bail!("OH USB: open device failed."); + } + Ok(ret) + } + + pub fn close_device(&self, dev_handle: *mut OhusbDevice) -> Result { + // SAFETY: We call related API sequentially for specified ctx. + let ret = unsafe { (self.capi.close_device)(dev_handle) }; + if ret < 0 { + bail!("OH USB: close device failed."); + } + Ok(ret) + } +} diff --git a/util/src/pixman.rs b/util/src/pixman.rs index 50fc5d28f1658986d4a337056a3a3f1375708d0b..2cd2e27bdc3348c4990e9fba53f11951700d9d76 100644 --- a/util/src/pixman.rs +++ b/util/src/pixman.rs @@ -184,21 +184,24 @@ pub enum pixman_op_t { PIXMAN_OP_HSL_LUMINOSITY = 62, } -pub type pixman_image_destroy_func_t = ::std::option::Option< - unsafe extern "C" fn(image: *mut pixman_image_t, data: *mut libc::c_void), ->; +pub type pixman_image_destroy_func_t = + Option; -pub extern "C" fn virtio_gpu_unref_resource_callback( +/// # Safety +/// +/// Caller should has valid image and data. +pub unsafe extern "C" fn virtio_gpu_unref_resource_callback( _image: *mut pixman_image_t, data: *mut libc::c_void, ) { - // SAFETY: The safety of this function is guaranteed by caller. - unsafe { pixman_image_unref(data.cast()) }; + // The safety of this function is guaranteed by caller. + pixman_image_unref(data.cast()); } fn pixman_format_reshift(val: u32, ofs: u32, num: u32) -> u32 { ((val >> (ofs)) & ((1 << (num)) - 1)) << ((val >> 22) & 3) } + pub fn pixman_format_bpp(val: u32) -> u8 { pixman_format_reshift(val, 24, 8) as u8 } @@ -206,17 +209,24 @@ pub fn pixman_format_bpp(val: u32) -> u8 { pub fn pixman_format_a(val: u32) -> u8 { pixman_format_reshift(val, 12, 4) as u8 } + pub fn pixman_format_r(val: u32) -> u8 { pixman_format_reshift(val, 8, 4) as u8 } + pub fn pixman_format_g(val: u32) -> u8 { pixman_format_reshift(val, 4, 4) as u8 } + pub fn pixman_format_b(val: u32) -> u8 { pixman_format_reshift(val, 0, 4) as u8 } + pub fn pixman_format_depth(val: u32) -> u8 { - pixman_format_a(val) + pixman_format_r(val) + pixman_format_g(val) + pixman_format_b(val) + pixman_format_a(val) + .wrapping_add(pixman_format_r(val)) + .wrapping_add(pixman_format_g(val)) + .wrapping_add(pixman_format_b(val)) } extern "C" { diff --git a/util/src/seccomp.rs b/util/src/seccomp.rs index 41206a3bce317608b84695fd1e9af87920b61a70..cce0facd2b52eda3d8b8e2aa648fcd849f52763b 100644 --- a/util/src/seccomp.rs +++ b/util/src/seccomp.rs @@ -213,10 +213,10 @@ impl SeccompData { offset_of!(SeccompData, arch) as u32 } - fn args(num: u32) -> u32 { + fn args(num: u8) -> u32 { let offset_of_u64 = offset_of!(SeccompData, args) - offset_of!(SeccompData, instruction_pointer); - offset_of!(SeccompData, args) as u32 + num * offset_of_u64 as u32 + offset_of!(SeccompData, args) as u32 + u32::from(num) * offset_of_u64 as u32 } } @@ -292,7 +292,7 @@ pub struct BpfRule { /// The first bpf_filter to compare syscall number. header_rule: SockFilter, /// The last args index. - args_idx_last: Option, + args_idx_last: Option, /// The inner rules to limit the arguments of syscall. inner_rules: Vec, /// The last bpf_filter to allow syscall. @@ -321,7 +321,7 @@ impl BpfRule { /// * `args_idx` - The index number of system call's arguments. /// * `args_value` - The value of args_num you want to limit. This value used with `cmp` /// together. - pub fn add_constraint(mut self, cmp: SeccompCmpOpt, args_idx: u32, args_value: u32) -> BpfRule { + pub fn add_constraint(mut self, cmp: SeccompCmpOpt, args_idx: u8, args_value: u32) -> BpfRule { if self.inner_rules.is_empty() { self.tail_rule = bpf_stmt(BPF_LD + BPF_W + BPF_ABS, SeccompData::nr()); } @@ -348,7 +348,7 @@ impl BpfRule { inner_append.push(constraint_filter); inner_append.push(bpf_stmt(BPF_RET + BPF_K, SECCOMP_RET_ALLOW)); - if !self.append(&mut inner_append) { + if !self.append_to_inner(&mut inner_append) { self.start_new_session(); self.add_constraint(cmp, args_idx, args_value) } else { @@ -379,7 +379,8 @@ impl BpfRule { } /// Add bpf_filters to `inner_rules`. - fn append(&mut self, bpf_filters: &mut Vec) -> bool { + fn append_to_inner(&mut self, bpf_filters: &mut Vec) -> bool { + // bpf_filters len is less than u8::MAX. let offset = bpf_filters.len() as u8; if let Some(jf_added) = self.header_rule.jf.checked_add(offset) { @@ -446,7 +447,7 @@ impl SyscallFilter { } let prog = SockFProg { - len: sock_bpf_vec.len() as u16, + len: u16::try_from(sock_bpf_vec.len())?, sock_filter: sock_bpf_vec.as_ptr(), }; let bpf_prog_ptr = &prog as *const SockFProg; diff --git a/util/src/socket.rs b/util/src/socket.rs index 3b1290bbe866cb3a98cc199062f4545df118146b..573b7f4430e8ee7d585fde12b0030aa652166475 100644 --- a/util/src/socket.rs +++ b/util/src/socket.rs @@ -41,6 +41,13 @@ impl SocketStream { } => link_description.clone(), } } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> IoResult<()> { + match self { + SocketStream::Tcp { stream, .. } => stream.set_nonblocking(nonblocking), + SocketStream::Unix { stream, .. } => stream.set_nonblocking(nonblocking), + } + } } impl AsRawFd for SocketStream { @@ -132,6 +139,7 @@ impl SocketListener { match self { SocketListener::Tcp { listener, address } => { let (stream, sock_addr) = listener.accept()?; + stream.set_nonblocking(true)?; let peer_address = sock_addr.to_string(); let link_description = format!( "{{ protocol: tcp, address: {}, peer: {} }}", @@ -144,6 +152,7 @@ impl SocketListener { } SocketListener::Unix { listener, address } => { let (stream, _) = listener.accept()?; + stream.set_nonblocking(true)?; let link_description = format!("{{ protocol: unix, address: {} }}", address); Ok(SocketStream::Unix { link_description, diff --git a/util/src/tap.rs b/util/src/tap.rs index d59e20a22a950ccf26d47da5de4daada5ca603c4..4752cd6796d6389faeb82e0c5474a739ecb10206 100644 --- a/util/src/tap.rs +++ b/util/src/tap.rs @@ -11,10 +11,13 @@ // See the Mulan PSL v2 for more details. use std::fs::{File, OpenOptions}; -use std::io::{Read, Result as IoResult, Write}; +use std::io::{ErrorKind, Read, Result as IoResult, Write}; use std::os::unix::fs::OpenOptionsExt; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; use anyhow::{anyhow, bail, Context, Result}; use log::error; @@ -22,6 +25,8 @@ use nix::fcntl::{fcntl, FcntlArg, OFlag}; use vmm_sys_util::ioctl::{ioctl_with_mut_ref, ioctl_with_ref, ioctl_with_val}; use vmm_sys_util::{ioctl_ioc_nr, ioctl_ior_nr, ioctl_iow_nr}; +use crate::aio::Iovec; + const IFF_ATTACH_QUEUE: u16 = 0x0200; const IFF_DETACH_QUEUE: u16 = 0x0400; @@ -55,6 +60,8 @@ pub struct IfReq { pub struct Tap { pub file: Arc, pub enabled: bool, + pub upload_stats: Arc, + pub download_stats: Arc, } impl Tap { @@ -108,7 +115,7 @@ impl Tap { )); } - let mut features = 0; + let mut features: u16 = 0; // SAFETY: The parameter of file can be guaranteed to be legal, and other parameters are constant. let ret = unsafe { ioctl_with_mut_ref(&file, TUNGETFEATURES(), &mut features) }; if ret < 0 { @@ -125,13 +132,15 @@ impl Tap { Ok(Tap { file: Arc::new(file), enabled: true, + upload_stats: Arc::new(AtomicU64::new(0)), + download_stats: Arc::new(AtomicU64::new(0)), }) } pub fn set_offload(&self, flags: u32) -> Result<()> { let ret = // SAFETY: The parameter of file can be guaranteed to be legal, and other parameters are constant. - unsafe { ioctl_with_val(self.file.as_ref(), TUNSETOFFLOAD(), flags as libc::c_ulong) }; + unsafe { ioctl_with_val(self.file.as_ref(), TUNSETOFFLOAD(), u64::from(flags)) }; if ret < 0 { return Err(anyhow!("ioctl TUNSETOFFLOAD failed.".to_string())); } @@ -153,7 +162,7 @@ impl Tap { let flags = TUN_F_CSUM | TUN_F_UFO; ( // SAFETY: The parameter of file can be guaranteed to be legal, and other parameters are constant. - unsafe { ioctl_with_val(self.file.as_ref(), TUNSETOFFLOAD(), flags as libc::c_ulong) } + unsafe { ioctl_with_val(self.file.as_ref(), TUNSETOFFLOAD(), u64::from(flags)) } ) >= 0 } @@ -185,11 +194,72 @@ impl Tap { ret } - pub fn read(&mut self, buf: &mut [u8]) -> IoResult { + pub fn receive_packets(&self, iovecs: &[Iovec]) -> isize { + // SAFETY: the arguments of readv has been checked and is correct. + let size = unsafe { + libc::readv( + self.as_raw_fd() as libc::c_int, + iovecs.as_ptr() as *const libc::iovec, + iovecs.len() as libc::c_int, + ) + }; + if size < 0 { + let e = std::io::Error::last_os_error(); + if e.kind() == std::io::ErrorKind::WouldBlock { + return size; + } + + // If the backend tap device is removed, readv returns less than 0. + // At this time, the content in the tap needs to be cleaned up. + // Here, read is called to process, otherwise handle_rx may be triggered all the time. + let mut buf = [0; 1024]; + match self.read(&mut buf) { + Ok(cnt) => error!("Failed to call readv but tap read is ok: cnt {}", cnt), + Err(e) => { + // When the backend tap device is abnormally removed, read return EBADFD. + error!("Failed to read tap: {:?}", e); + } + } + error!("Failed to call readv for net handle_rx: {:?}", e); + } else { + self.download_stats.fetch_add(size as u64, Ordering::SeqCst); + } + + size + } + + pub fn send_packets(&self, iovecs: &[Iovec]) -> i8 { + loop { + // SAFETY: the arguments of writev has been checked and is correct. + let size = unsafe { + libc::writev( + self.as_raw_fd(), + iovecs.as_ptr() as *const libc::iovec, + iovecs.len() as libc::c_int, + ) + }; + if size < 0 { + let e = std::io::Error::last_os_error(); + match e.kind() { + ErrorKind::Interrupted => continue, + ErrorKind::WouldBlock => return -1_i8, + // Ignore other errors which can not be handled. + _ => error!("Failed to call writev for net handle_tx: {:?}", e), + } + } else { + self.upload_stats.fetch_add(size as u64, Ordering::SeqCst); + } + + break; + } + 0_i8 + } + + pub fn read(&self, buf: &mut [u8]) -> IoResult { self.file.as_ref().read(buf) } - pub fn write(&mut self, buf: &[u8]) -> IoResult { + pub fn write(&self, buf: &[u8]) -> IoResult { self.file.as_ref().write(buf) } diff --git a/util/src/thread_pool.rs b/util/src/thread_pool.rs index db059a80845535a8447718db276eb5c1c56372d1..48ac9af95f5066a1ebd0f24b492dad278195653d 100644 --- a/util/src/thread_pool.rs +++ b/util/src/thread_pool.rs @@ -10,6 +10,7 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +use std::collections::LinkedList; use std::sync::{Arc, Condvar, Mutex}; use std::thread; use std::time::Duration; @@ -17,10 +18,8 @@ use std::time::Duration; use anyhow::{bail, Context, Result}; use log::error; -use crate::link_list::{List, Node}; - const MIN_THREADS: u64 = 1; -const MAX_THREADS: u64 = 64; +const MAX_THREADS: u64 = 10; type PoolTask = Box; pub trait TaskOperation: Sync + Send { @@ -45,7 +44,7 @@ struct PoolState { /// The maximum number of threads that thread pool can create. max_threads: u64, /// List of pending tasks in the thread pool. - req_lists: List, + req_lists: LinkedList, } /// SAFETY: All the operations on req_lists are protected by the mutex, @@ -61,7 +60,7 @@ impl PoolState { pending_threads: 0, min_threads: MIN_THREADS, max_threads: MAX_THREADS, - req_lists: List::new(), + req_lists: LinkedList::new(), } } @@ -131,7 +130,7 @@ impl ThreadPool { if locked_state.spawn_thread_needed() { locked_state.spawn_thread(pool.clone())? } - locked_state.req_lists.add_tail(Box::new(Node::new(task))); + locked_state.req_lists.push_back(task); drop(locked_state); pool.request_cond.notify_one(); @@ -167,7 +166,7 @@ fn worker_thread(pool: Arc) { while locked_state.is_running() { let result; - if locked_state.req_lists.len == 0 { + if locked_state.req_lists.is_empty() { locked_state.blocked_threads += 1; match pool .request_cond @@ -186,7 +185,7 @@ fn worker_thread(pool: Arc) { locked_state.blocked_threads -= 1; if result.timed_out() - && locked_state.req_lists.len == 0 + && locked_state.req_lists.is_empty() && locked_state.total_threads > locked_state.min_threads { // If wait time_out and no pending task and current total number @@ -197,15 +196,15 @@ fn worker_thread(pool: Arc) { continue; } - let mut req = locked_state.req_lists.pop_head().unwrap(); + let mut req = locked_state.req_lists.pop_front().unwrap(); drop(locked_state); - (*req.value).run(); + req.run(); locked_state = pool.pool_state.lock().unwrap(); } locked_state.total_threads -= 1; - trace::thread_pool_exit_thread(&locked_state.total_threads, &locked_state.req_lists.len); + trace::thread_pool_exit_thread(&locked_state.total_threads, &locked_state.req_lists.len()); pool.stop_cond.notify_one(); pool.request_cond.notify_one(); @@ -243,7 +242,7 @@ mod test { } // Waiting for creating. - while pool.pool_state.lock().unwrap().req_lists.len != 0 { + while !pool.pool_state.lock().unwrap().req_lists.is_empty() { thread::sleep(time::Duration::from_millis(10)); let now = time::SystemTime::now(); diff --git a/util/src/time.rs b/util/src/time.rs index e8fd8a568e193a2556350aa50e030993057d2a5b..ed9f4f3383487891fc2949a78108ca9823beb805 100644 --- a/util/src/time.rs +++ b/util/src/time.rs @@ -33,9 +33,9 @@ pub fn mktime64(year: u64, mon: u64, day: u64, hour: u64, min: u64, sec: u64) -> } /// Get wall time. -pub fn gettime() -> Result<(u32, u32)> { +pub fn gettime() -> Result<(i64, i64)> { match clock_gettime(ClockId::CLOCK_REALTIME) { - Ok(ts) => Ok((ts.tv_sec() as u32, ts.tv_nsec() as u32)), + Ok(ts) => Ok((ts.tv_sec(), ts.tv_nsec())), Err(e) => bail!("clock_gettime failed: {:?}", e), } } @@ -50,8 +50,8 @@ pub fn get_format_time(sec: i64) -> [i32; 6] { } [ - ti.tm_year + 1900, - ti.tm_mon + 1, + ti.tm_year.saturating_add(1900), + ti.tm_mon.saturating_add(1), ti.tm_mday, ti.tm_hour, ti.tm_min, diff --git a/util/src/unix.rs b/util/src/unix.rs index d71e3c826e50b190998f5606de1938acab322f24..2fcebce41ab4e28c7cbed39411ec00eb9b340391 100644 --- a/util/src/unix.rs +++ b/util/src/unix.rs @@ -19,8 +19,8 @@ use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; use anyhow::{anyhow, bail, Context, Result}; use libc::{ - c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, CMSG_LEN, CMSG_SPACE, MSG_NOSIGNAL, - MSG_WAITALL, SCM_RIGHTS, SOL_SOCKET, + c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, syscall, SYS_mbind, CMSG_LEN, CMSG_SPACE, + MSG_NOSIGNAL, MSG_WAITALL, SCM_RIGHTS, SOL_SOCKET, }; use log::error; use nix::unistd::{sysconf, SysconfVar}; @@ -113,7 +113,7 @@ pub fn do_mmap( // SAFETY: The return value is checked. let hva = unsafe { libc::mmap( - std::ptr::null_mut() as *mut libc::c_void, + std::ptr::null_mut(), len as libc::size_t, prot, flags, @@ -125,15 +125,15 @@ pub fn do_mmap( return Err(std::io::Error::last_os_error()).with_context(|| "Mmap failed."); } if !dump_guest_core { - set_memory_undumpable(hva, len); + // SAFETY: The hva and len are mmap-ed above and are verified. + unsafe { set_memory_undumpable(hva, len) }; } Ok(hva as u64) } -fn set_memory_undumpable(host_addr: *mut libc::c_void, size: u64) { - // SAFETY: host_addr and size are valid and return value is checked. - let ret = unsafe { libc::madvise(host_addr, size as libc::size_t, libc::MADV_DONTDUMP) }; +unsafe fn set_memory_undumpable(host_addr: *mut libc::c_void, size: u64) { + let ret = libc::madvise(host_addr, size as libc::size_t, libc::MADV_DONTDUMP); if ret < 0 { error!( "Syscall madvise(with MADV_DONTDUMP) failed, OS error is {:?}", @@ -142,6 +142,47 @@ fn set_memory_undumpable(host_addr: *mut libc::c_void, size: u64) { } } +/// This function set memory policy for host NUMA node memory range. +/// +/// * Arguments +/// +/// * `addr` - The memory range starting with addr. +/// * `len` - Length of the memory range. +/// * `mode` - Memory policy mode. +/// * `node_mask` - node_mask specifies physical node ID. +/// * `max_node` - The max node. +/// * `flags` - Mode flags. +/// +/// # Safety +/// +/// Caller should has valid params. +pub unsafe fn mbind( + addr: u64, + len: u64, + mode: u32, + node_mask: Vec, + max_node: u64, + flags: u32, +) -> Result<()> { + let res = syscall( + SYS_mbind, + addr as *mut c_void, + len, + mode, + node_mask.as_ptr(), + max_node + 1, + flags, + ); + if res < 0 { + bail!( + "Failed to apply host numa node policy, error is {}", + std::io::Error::last_os_error() + ); + } + + Ok(()) +} + /// Unix socket is a data communication endpoint for exchanging data /// between processes executing on the same host OS. pub struct UnixSock { @@ -187,10 +228,11 @@ impl UnixSock { /// The listener accepts incoming client connections. pub fn accept(&mut self) -> Result<()> { - let (sock, _addr) = self + let listener = self .listener .as_ref() - .unwrap() + .with_context(|| "UnixSock is not bound")?; + let (sock, _addr) = listener .accept() .with_context(|| format!("Failed to accept the socket {}", self.path))?; self.sock = Some(sock); @@ -203,8 +245,12 @@ impl UnixSock { } pub fn server_connection_refuse(&mut self) -> Result<()> { + let listener = self + .listener + .as_ref() + .with_context(|| "UnixSock is not bound")?; // Refuse connection by finishing life cycle of stream fd from listener fd. - self.listener.as_ref().unwrap().accept().with_context(|| { + listener.accept().with_context(|| { format!( "Failed to accept the socket for refused connection {}", self.path @@ -224,18 +270,21 @@ impl UnixSock { } pub fn listen_set_nonblocking(&self, nonblocking: bool) -> Result<()> { - self.listener + let listener = self + .listener .as_ref() - .unwrap() + .with_context(|| "UnixSock is not bound")?; + listener .set_nonblocking(nonblocking) .with_context(|| "couldn't set nonblocking for unix sock listener") } pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> { - self.sock + let sock = self + .sock .as_ref() - .unwrap() - .set_nonblocking(nonblocking) + .with_context(|| "UnixSock is not connected")?; + sock.set_nonblocking(nonblocking) .with_context(|| "couldn't set nonblocking") } @@ -270,7 +319,9 @@ impl UnixSock { let nex_cmsg_pos = (next_cmsg as *mut u8).wrapping_sub(msghdr.msg_control as usize) as u64; // SAFETY: Parameter is constant. - if nex_cmsg_pos.wrapping_add(unsafe { CMSG_LEN(0) } as u64) > msghdr.msg_controllen as u64 { + if nex_cmsg_pos.wrapping_add(u64::from(unsafe { CMSG_LEN(0) })) + > msghdr.msg_controllen as u64 + { null_mut() } else { next_cmsg @@ -287,13 +338,13 @@ impl UnixSock { /// # Errors /// /// The socket file descriptor is broken. - pub fn send_msg(&self, iovecs: &mut [iovec], out_fds: &[RawFd]) -> std::io::Result { + pub fn send_msg(&self, iovecs: &mut [iovec], out_fds: &[RawFd]) -> Result { // SAFETY: We checked the iovecs lens before. let iovecs_len = iovecs.len(); // SAFETY: We checked the out_fds lens before. - let cmsg_len = unsafe { CMSG_LEN((std::mem::size_of_val(out_fds)) as u32) }; + let cmsg_len = unsafe { CMSG_LEN(u32::try_from(std::mem::size_of_val(out_fds))?) }; // SAFETY: We checked the out_fds lens before. - let cmsg_capacity = unsafe { CMSG_SPACE((std::mem::size_of_val(out_fds)) as u32) }; + let cmsg_capacity = unsafe { CMSG_SPACE(u32::try_from(std::mem::size_of_val(out_fds))?) }; let mut cmsg_buffer = vec![0_u64; cmsg_capacity as usize]; // In `musl` toolchain, msghdr has private member `__pad0` and `__pad1`, it can't be @@ -331,17 +382,17 @@ impl UnixSock { msg.msg_controllen = cmsg_capacity as _; } - let write_count = - // SAFETY: msg parameters are valid. - unsafe { sendmsg(self.sock.as_ref().unwrap().as_raw_fd(), &msg, MSG_NOSIGNAL) }; - - if write_count == -1 { - Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!( - "Failed to send msg, err: {}", - std::io::Error::last_os_error() - ), + let sock = self + .sock + .as_ref() + .with_context(|| "UnixSock is not connected")?; + // SAFETY: msg parameters are valid. + let write_count = unsafe { sendmsg(sock.as_raw_fd(), &msg, MSG_NOSIGNAL) }; + + if write_count < 0 { + Err(anyhow!( + "Failed to send msg, err: {}", + std::io::Error::last_os_error() )) } else { Ok(write_count as usize) @@ -358,15 +409,11 @@ impl UnixSock { /// # Errors /// /// The socket file descriptor is broken. - pub fn recv_msg( - &self, - iovecs: &mut [iovec], - in_fds: &mut [RawFd], - ) -> std::io::Result<(usize, usize)> { + pub fn recv_msg(&self, iovecs: &mut [iovec], in_fds: &mut [RawFd]) -> Result<(usize, usize)> { // SAFETY: We check the iovecs lens before. let iovecs_len = iovecs.len(); // SAFETY: We check the in_fds lens before. - let cmsg_capacity = unsafe { CMSG_SPACE((std::mem::size_of_val(in_fds)) as u32) }; + let cmsg_capacity = unsafe { CMSG_SPACE(u32::try_from(std::mem::size_of_val(in_fds))?) }; let mut cmsg_buffer = vec![0_u64; cmsg_capacity as usize]; // In `musl` toolchain, msghdr has private member `__pad0` and `__pad1`, it can't be @@ -386,33 +433,25 @@ impl UnixSock { msg.msg_controllen = cmsg_capacity as _; } + let sock = self + .sock + .as_ref() + .with_context(|| "UnixSock is not connected")?; // SAFETY: msg parameters are valid. - let total_read = unsafe { - recvmsg( - self.sock.as_ref().unwrap().as_raw_fd(), - &mut msg, - MSG_WAITALL, - ) - }; - - if total_read == -1 { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!( - "Failed to recv msg, err: {}", - std::io::Error::last_os_error() - ), - )); + let total_read = unsafe { recvmsg(sock.as_raw_fd(), &mut msg, MSG_WAITALL) }; + + if total_read < 0 { + bail!( + "Failed to recv msg, err: {}", + std::io::Error::last_os_error() + ); } if total_read == 0 && (msg.msg_controllen as u64) < size_of::() as u64 { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!( - "The length of control message is invalid, {} {}", - msg.msg_controllen, - size_of::() - ), - )); + bail!( + "The length of control message is invalid, {} {}", + msg.msg_controllen, + size_of::() + ); } let mut cmsg_ptr = msg.msg_control as *mut cmsghdr; @@ -420,23 +459,29 @@ impl UnixSock { while !cmsg_ptr.is_null() { // SAFETY: The pointer of cmsg_ptr was created in this function and // can be guaranteed not be null. - let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() }; + let cmsg = unsafe { cmsg_ptr.read_unaligned() }; if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { // SAFETY: Input parameter is constant. - let fd_count = (cmsg.cmsg_len as u64 - unsafe { CMSG_LEN(0) } as u64) as usize + let fd_count = (cmsg.cmsg_len as u64 - u64::from(unsafe { CMSG_LEN(0) })) as usize / size_of::(); + let new_in_fds_count = in_fds_count + .checked_add(fd_count) + .with_context(|| "fds count overflow")?; + if new_in_fds_count > in_fds.len() { + bail!("in_fds is too small"); + } // SAFETY: // 1. the pointer of cmsg_ptr was created in this function and can be guaranteed not be null. // 2. the parameter of in_fds has been checked before. unsafe { copy_nonoverlapping( self.cmsg_data(cmsg_ptr), - in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(), + in_fds[in_fds_count..new_in_fds_count].as_mut_ptr(), fd_count, ); } - in_fds_count += fd_count; + in_fds_count = new_in_fds_count; } cmsg_ptr = self.get_next_cmsg(&msg, &cmsg, cmsg_ptr); @@ -490,7 +535,7 @@ mod tests { assert_ne!(stream.get_stream_raw_fd(), 0); assert!(listener.accept().is_ok()); - assert_eq!(listener.is_accepted(), true); + assert!(listener.is_accepted()); if sock_path.exists() { fs::remove_file("./test_socket1.sock").unwrap(); diff --git a/util/src/v4l2.rs b/util/src/v4l2.rs index 545e264ea34ec2b74b8e58d5909cf0e3618c9693..9abcfd22a19d1d0fef56314152edd0a40ea7c774 100644 --- a/util/src/v4l2.rs +++ b/util/src/v4l2.rs @@ -136,7 +136,7 @@ impl V4l2Backend { // 2. buf can be guaranteed not be null. let ret = unsafe { libc::mmap( - std::ptr::null_mut() as *mut libc::c_void, + std::ptr::null_mut(), buf.length as libc::size_t, libc::PROT_WRITE | libc::PROT_READ, libc::MAP_SHARED, @@ -152,7 +152,7 @@ impl V4l2Backend { ); } locked_buf[i as usize].iov_base = ret as u64; - locked_buf[i as usize].iov_len = buf.length as u64; + locked_buf[i as usize].iov_len = u64::from(buf.length); // Queue buffer to get data. self.queue_buffer(&buf)?; } diff --git a/vfio/Cargo.toml b/vfio/Cargo.toml index ca4a130f78c66348dba05de63334336da01b2130..6b3e135ad6130c2dcae4bf8214f8bed5f9b948ad 100644 --- a/vfio/Cargo.toml +++ b/vfio/Cargo.toml @@ -10,11 +10,11 @@ description = "Virtual function I/O" byteorder = "1.4.3" thiserror = "1.0" anyhow = "1.0" -kvm-bindings = { version = "0.6.0", features = ["fam-wrappers"] } -kvm-ioctls = "0.15.0" +kvm-bindings = { version = "0.7.0", features = ["fam-wrappers"] } +kvm-ioctls = "0.16.0" libc = "0.2" log = "0.4" -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" vfio-bindings = "0.3" once_cell = "1.18.0" address_space = { path = "../address_space" } @@ -22,3 +22,4 @@ hypervisor = { path = "../hypervisor"} machine_manager = { path = "../machine_manager" } util = { path = "../util" } devices = { path = "../devices" } +clap = { version = "=4.1.4", default-features = false, features = ["std", "derive"] } diff --git a/vfio/src/lib.rs b/vfio/src/lib.rs index 49c77af172c8dff6350a0ae3c869add1aa01c260..3d2705a399a1ef56a06ba927bf255e0fbe6cf2c9 100644 --- a/vfio/src/lib.rs +++ b/vfio/src/lib.rs @@ -22,15 +22,17 @@ pub use vfio_dev::{ VFIO_GROUP_GET_DEVICE_FD, VFIO_GROUP_GET_STATUS, VFIO_GROUP_SET_CONTAINER, VFIO_IOMMU_MAP_DMA, VFIO_IOMMU_UNMAP_DMA, VFIO_SET_IOMMU, }; -pub use vfio_pci::VfioPciDevice; +pub use vfio_pci::{VfioConfig, VfioPciDevice}; use std::collections::HashMap; use std::os::unix::io::RawFd; use std::sync::{Arc, Mutex}; +use anyhow::Result; use kvm_ioctls::DeviceFd; use once_cell::sync::Lazy; +use devices::pci::register_pcidevops_type; use vfio_dev::VfioGroup; pub static KVM_DEVICE_FD: Lazy>> = Lazy::new(|| Mutex::new(None)); @@ -38,3 +40,7 @@ pub static CONTAINERS: Lazy>>>> = Lazy::new(|| Mutex::new(HashMap::new())); pub static GROUPS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); + +pub fn vfio_register_pcidevops_type() -> Result<()> { + register_pcidevops_type::() +} diff --git a/vfio/src/vfio_dev.rs b/vfio/src/vfio_dev.rs index cdba51c7f52e042fa95ac52d03b1db9f57536a82..5a6fa44478af8048c3b51d5b59900106ba700a31 100644 --- a/vfio/src/vfio_dev.rs +++ b/vfio/src/vfio_dev.rs @@ -32,7 +32,9 @@ use vmm_sys_util::{ioctl_io_nr, ioctl_ioc_nr}; use super::{CONTAINERS, GROUPS, KVM_DEVICE_FD}; use crate::VfioError; -use address_space::{AddressSpace, FlatRange, Listener, ListenerReqType, RegionIoEventFd}; +use address_space::{ + AddressAttr, AddressSpace, FlatRange, Listener, ListenerReqType, RegionIoEventFd, +}; /// Refer to VFIO in https://github.com/torvalds/linux/blob/master/include/uapi/linux/vfio.h const IOMMU_GROUP: &str = "iommu_group"; @@ -227,7 +229,8 @@ impl VfioContainer { let guest_phys_addr = fr.addr_range.base.raw_value(); let memory_size = fr.addr_range.size; - let hva = match fr.owner.get_host_address() { + // SAFETY: memory_size is range's size, so we make sure [hva, hva+size] is in ram range. + let hva = match unsafe { fr.owner.get_host_address(AddressAttr::Ram) } { Some(addr) => addr, None => bail!("Failed to get host address"), }; diff --git a/vfio/src/vfio_pci.rs b/vfio/src/vfio_pci.rs index d354098ee04976b22224e2648c1ecb6f4d51f04f..6cf8d75c5e63d79feb7da72ddbdcdd89bcb55b51 100644 --- a/vfio/src/vfio_pci.rs +++ b/vfio/src/vfio_pci.rs @@ -12,11 +12,12 @@ use std::mem::size_of; use std::os::unix::io::{AsRawFd, RawFd}; -use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; use std::sync::{Arc, Mutex, Weak}; use anyhow::{anyhow, bail, Context, Result}; use byteorder::{ByteOrder, LittleEndian}; +use clap::{ArgAction, Parser}; use log::error; use vfio_bindings::bindings::vfio; use vmm_sys_util::eventfd::EventFd; @@ -40,15 +41,38 @@ use devices::pci::msix::{ }; use devices::pci::{ init_multifunction, le_read_u16, le_read_u32, le_write_u16, le_write_u32, pci_ext_cap_id, - pci_ext_cap_next, pci_ext_cap_ver, PciBus, PciDevBase, PciDevOps, + pci_ext_cap_next, pci_ext_cap_ver, MsiVector, PciBus, PciDevBase, PciDevOps, }; -use devices::{pci::MsiVector, Device, DeviceBase}; +use devices::{convert_bus_ref, Bus, Device, DeviceBase, PCI_BUS}; +use machine_manager::config::{get_pci_df, parse_bool, valid_id}; +use util::gen_base_func; +use util::loop_context::create_new_eventfd; use util::num_ops::ranges_overlap; use util::unix::host_page_size; const PCI_NUM_BARS: u8 = 6; const PCI_ROM_SLOT: u8 = 6; +#[derive(Parser, Default, Debug)] +#[command(no_binary_name(true))] +#[clap(group = clap::ArgGroup::new("path").args(&["host", "sysfsdev"]).multiple(false).required(true))] +pub struct VfioConfig { + #[arg(long, value_parser = ["vfio-pci"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long, value_parser = valid_id)] + pub host: Option, + #[arg(long)] + pub bus: String, + #[arg(long)] + pub sysfsdev: Option, + #[arg(long, value_parser = get_pci_df)] + pub addr: (u8, u8), + #[arg(long, value_parser = parse_bool, action = ArgAction::Append)] + pub multifunction: Option, +} + struct MsixTable { table_bar: u8, table_offset: u64, @@ -101,17 +125,17 @@ impl VfioPciDevice { vfio_device: Arc>, devfn: u8, name: String, - parent_bus: Weak>, + parent_bus: Weak>, multi_func: bool, mem_as: Arc, ) -> Self { Self { // Unknown PCI or PCIe type here, allocate enough space to match the two types. base: PciDevBase { - base: DeviceBase::new(name, true), - config: PciConfig::new(PCIE_CONFIG_SPACE_SIZE, PCI_NUM_BARS), + base: DeviceBase::new(name, true, Some(parent_bus)), + config: PciConfig::new(devfn, PCIE_CONFIG_SPACE_SIZE, PCI_NUM_BARS), devfn, - parent_bus, + bme: Arc::new(AtomicBool::new(false)), }, config_size: 0, config_offset: 0, @@ -162,7 +186,7 @@ impl VfioPciDevice { // Cache the pci config space to avoid overwriting the original config space. Because we // will parse the chain of extended caps in cache config and insert them into original // config space. - let mut config = PciConfig::new(PCIE_CONFIG_SPACE_SIZE, PCI_NUM_BARS); + let mut config = PciConfig::new(self.base.devfn, PCIE_CONFIG_SPACE_SIZE, PCI_NUM_BARS); config.config = config_data; let mut next = PCI_CONFIG_SPACE_SIZE; while (PCI_CONFIG_SPACE_SIZE..PCIE_CONFIG_SPACE_SIZE).contains(&next) { @@ -207,13 +231,13 @@ impl VfioPciDevice { self.vfio_device.lock().unwrap().write_region( data.as_slice(), self.config_offset, - COMMAND as u64, + u64::from(COMMAND), )?; for i in 0..PCI_ROM_SLOT { let offset = BAR_0 as usize + REG_SIZE * i as usize; let v = le_read_u32(&self.base.config.config, offset)?; - if v & BAR_IO_SPACE as u32 != 0 { + if v & u32::from(BAR_IO_SPACE) != 0 { le_write_u32(&mut self.base.config.config, offset, v & !IO_BASE_ADDR_MASK)?; } else { le_write_u32( @@ -251,8 +275,8 @@ impl VfioPciDevice { Ok(VfioMsixInfo { table: MsixTable { table_bar: (table as u16 & MSIX_TABLE_BIR) as u8, - table_offset: (table & MSIX_TABLE_OFFSET) as u64, - table_size: (entries * MSIX_TABLE_ENTRY_SIZE) as u64, + table_offset: u64::from(table & MSIX_TABLE_OFFSET), + table_size: u64::from(entries * MSIX_TABLE_ENTRY_SIZE), }, entries, }) @@ -272,13 +296,13 @@ impl VfioPciDevice { locked_dev.read_region( data.as_mut_slice(), self.config_offset, - (BAR_0 + (REG_SIZE as u8) * i) as u64, + u64::from(BAR_0 + (REG_SIZE as u8) * i), )?; let mut region_type = RegionType::Mem32Bit; let pci_bar = LittleEndian::read_u32(&data); - if pci_bar & BAR_IO_SPACE as u32 != 0 { + if pci_bar & u32::from(BAR_IO_SPACE) != 0 { region_type = RegionType::Io; - } else if pci_bar & BAR_MEM_64BIT as u32 != 0 { + } else if pci_bar & u32::from(BAR_MEM_64BIT) != 0 { region_type = RegionType::Mem64Bit; } let vfio_region = infos.remove(0); @@ -429,7 +453,7 @@ impl VfioPciDevice { } fn unregister_bars(&mut self) -> Result<()> { - let bus = self.base.parent_bus.upgrade().unwrap(); + let bus = self.parent_bus().unwrap().upgrade().unwrap(); self.base.config.unregister_bars(&bus)?; Ok(()) } @@ -450,9 +474,9 @@ impl VfioPciDevice { MSIX_CAP_FUNC_MASK | MSIX_CAP_ENABLE, )?; - let msi_irq_manager = if let Some(pci_bus) = self.base.parent_bus.upgrade() { - let locked_pci_bus = pci_bus.lock().unwrap(); - locked_pci_bus.get_msi_irq_manager() + let msi_irq_manager = if let Some(bus) = self.parent_bus().unwrap().upgrade() { + PCI_BUS!(bus, locked_bus, pci_bus); + pci_bus.get_msi_irq_manager() } else { None }; @@ -484,7 +508,7 @@ impl VfioPciDevice { let cloned_dev = self.vfio_device.clone(); let cloned_gsi_routes = self.gsi_msi_routes.clone(); - let parent_bus = self.base.parent_bus.clone(); + let parent_bus = self.parent_bus().unwrap().clone(); let dev_id = self.dev_id.clone(); let devfn = self.base.devfn; let cloned_msix = msix.clone(); @@ -492,27 +516,28 @@ impl VfioPciDevice { let mut locked_msix = msix.lock().unwrap(); locked_msix.table[offset as usize..(offset as usize + data.len())] .copy_from_slice(data); - let vector = offset / MSIX_TABLE_ENTRY_SIZE as u64; + let vector = offset / u64::from(MSIX_TABLE_ENTRY_SIZE); if locked_msix.is_vector_masked(vector as u16) { return true; } let entry = locked_msix.get_message(vector as u16); - let parent_bus = parent_bus.upgrade().unwrap(); - parent_bus.lock().unwrap().update_dev_id(devfn, &dev_id); + let bus = parent_bus.upgrade().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); + pci_bus.update_dev_id(devfn, &dev_id); let msix_vector = MsiVector { msg_addr_lo: entry.address_lo, msg_addr_hi: entry.address_hi, msg_data: entry.data, masked: false, #[cfg(target_arch = "aarch64")] - dev_id: dev_id.load(Ordering::Acquire) as u32, + dev_id: u32::from(dev_id.load(Ordering::Acquire)), }; let mut locked_gsi_routes = cloned_gsi_routes.lock().unwrap(); let gsi_route = locked_gsi_routes.get_mut(vector as usize).unwrap(); if gsi_route.irq_fd.is_none() { - let irq_fd = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + let irq_fd = create_new_eventfd().unwrap(); gsi_route.irq_fd = Some(Arc::new(irq_fd)); } let irq_fd = gsi_route.irq_fd.clone(); @@ -699,8 +724,8 @@ impl VfioPciDevice { fn vfio_enable_msix(&mut self) -> Result<()> { let mut gsi_routes = self.gsi_msi_routes.lock().unwrap(); - if gsi_routes.len() == 0 { - let irq_fd = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + if gsi_routes.is_empty() { + let irq_fd = create_new_eventfd().unwrap(); let gsi_route = GsiMsiRoute { irq_fd: Some(Arc::new(irq_fd)), gsi: -1, @@ -713,7 +738,7 @@ impl VfioPciDevice { let gsi_route = GsiMsiRoute { irq_fd: None, gsi: -1, - nr: i as u32, + nr: u32::from(i), }; gsi_routes.push(gsi_route); } @@ -790,25 +815,16 @@ impl VfioPciDevice { } impl Device for VfioPciDevice { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} - -impl PciDevOps for VfioPciDevice { - fn pci_base(&self) -> &PciDevBase { - &self.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + Result::with_context(self.vfio_device.lock().unwrap().reset(), || { + "Fail to reset vfio dev" + }) } - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { + let parent_bus = self.parent_bus().unwrap(); self.init_write_mask(false)?; self.init_write_clear_mask(false)?; Result::with_context(self.vfio_device.lock().unwrap().reset(), || { @@ -826,21 +842,17 @@ impl PciDevOps for VfioPciDevice { self.multi_func, &mut self.base.config.config, self.base.devfn, - self.base.parent_bus.clone(), + parent_bus.clone(), ), || "Failed to init vfio device multifunction.", )?; #[cfg(target_arch = "aarch64")] { - let bus_num = self - .base - .parent_bus - .upgrade() - .unwrap() - .lock() - .unwrap() - .number(SECONDARY_BUS_NUM as usize); + let bus = parent_bus.upgrade().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); + let bus_num = pci_bus.number(SECONDARY_BUS_NUM as usize); + drop(locked_bus); self.dev_id = Arc::new(AtomicU16::new(self.set_dev_id(bus_num, self.base.devfn))); } @@ -853,22 +865,13 @@ impl PciDevOps for VfioPciDevice { )?)); Result::with_context(self.register_bars(), || "Failed to register bars")?; - let devfn = self.base.devfn; + let devfn = u64::from(self.base.devfn); let dev = Arc::new(Mutex::new(self)); - let pci_bus = dev.lock().unwrap().base.parent_bus.upgrade().unwrap(); - let mut locked_pci_bus = pci_bus.lock().unwrap(); - let pci_device = locked_pci_bus.devices.get(&devfn); - if pci_device.is_none() { - locked_pci_bus.devices.insert(devfn, dev); - } else { - bail!( - "Devfn {:?} has been used by {:?}", - &devfn, - pci_device.unwrap().lock().unwrap().name() - ); - } + let parent_bus = dev.lock().unwrap().parent_bus().unwrap().upgrade().unwrap(); + let mut locked_bus = parent_bus.lock().unwrap(); + locked_bus.attach_child(devfn, dev.clone())?; - Ok(()) + Ok(dev) } fn unrealize(&mut self) -> Result<()> { @@ -878,6 +881,10 @@ impl PciDevOps for VfioPciDevice { } Ok(()) } +} + +impl PciDevOps for VfioPciDevice { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); /// Read pci data from pci config if it emulate, otherwise read from vfio device. fn read_config(&mut self, offset: usize, data: &mut [u8]) { @@ -953,19 +960,20 @@ impl PciDevOps for VfioPciDevice { let was_enable = self.base.config.msix.as_ref().map_or(false, |m| { m.lock().unwrap().is_enabled(&self.base.config.config) }); - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let locked_parent_bus = parent_bus.lock().unwrap(); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(parent_bus, locked_bus, pci_bus); self.base.config.write( offset, data, self.dev_id.load(Ordering::Acquire), #[cfg(target_arch = "x86_64")] - Some(&locked_parent_bus.io_region), - Some(&locked_parent_bus.mem_region), + Some(&pci_bus.io_region), + Some(&pci_bus.mem_region), ); if ranges_overlap(offset, size, COMMAND as usize, REG_SIZE).unwrap() { - if le_read_u32(&self.base.config.config, offset).unwrap() & COMMAND_MEMORY_SPACE as u32 + if le_read_u32(&self.base.config.config, offset).unwrap() + & u32::from(COMMAND_MEMORY_SPACE) != 0 { if let Err(e) = self.setup_bars_mmap() { @@ -988,12 +996,6 @@ impl PciDevOps for VfioPciDevice { } } } - - fn reset(&mut self, _reset_child_device: bool) -> Result<()> { - Result::with_context(self.vfio_device.lock().unwrap().reset(), || { - "Fail to reset vfio dev" - }) - } } fn get_irq_rawfds(gsi_msi_routes: &[GsiMsiRoute], start: u32, count: u32) -> Vec { @@ -1009,3 +1011,38 @@ fn get_irq_rawfds(gsi_msi_routes: &[GsiMsiRoute], start: u32, count: u32) -> Vec } rawfds } + +#[cfg(test)] +mod tests { + use super::*; + use machine_manager::config::str_slip_to_clap; + + #[test] + fn test_vfio_config_cmdline_parser() { + // Test1: right. + let vfio_cmd1 = "vfio-pci,host=0000:1a:00.3,id=net,bus=pcie.0,addr=0x5,multifunction=on"; + let result = VfioConfig::try_parse_from(str_slip_to_clap(vfio_cmd1, true, false)); + assert!(result.is_ok()); + let vfio_config = result.unwrap(); + assert_eq!(vfio_config.host, Some("0000:1a:00.3".to_string())); + assert_eq!(vfio_config.id, "net"); + assert_eq!(vfio_config.bus, "pcie.0"); + assert_eq!(vfio_config.addr, (5, 0)); + assert_eq!(vfio_config.multifunction, Some(true)); + + // Test2: Missing bus/addr. + let vfio_cmd2 = "vfio-pci,host=0000:1a:00.3,id=net"; + let result = VfioConfig::try_parse_from(str_slip_to_clap(vfio_cmd2, true, false)); + assert!(result.is_err()); + + // Test3: `host` conflicts with `sysfsdev`. + let vfio_cmd3 = "vfio-pci,host=0000:1a:00.3,sysfsdev=/sys/bus/pci/devices/0000:00:02.0,id=net,bus=pcie.0,addr=0x5"; + let result = VfioConfig::try_parse_from(str_slip_to_clap(vfio_cmd3, true, false)); + assert!(result.is_err()); + + // Test4: Missing host/sysfsdev. + let vfio_cmd4 = "vfio-pci,id=net,bus=pcie.0,addr=0x1.0x2"; + let result = VfioConfig::try_parse_from(str_slip_to_clap(vfio_cmd4, true, false)); + assert!(result.is_err()); + } +} diff --git a/virtio/Cargo.toml b/virtio/Cargo.toml index d2890ed09111805ef769f464103b46f0381921fa..b8692b39491280351478e519422301dedbe962d7 100644 --- a/virtio/Cargo.toml +++ b/virtio/Cargo.toml @@ -13,7 +13,7 @@ anyhow = "1.0" libc = "0.2" log = "0.4" serde_json = "1.0" -vmm-sys-util = "0.11.1" +vmm-sys-util = "0.12.1" once_cell = "1.18.0" address_space = { path = "../address_space" } machine_manager = { path = "../machine_manager" } @@ -31,4 +31,10 @@ clap = { version = "=4.1.4", default-features = false, features = ["std", "deriv [features] default = [] virtio_gpu = ["ui", "machine_manager/virtio_gpu", "util/pixman"] +virtio_rng = [] +virtio_scsi = [] ohui_srv = [] +vhost_vsock =[] +vhostuser_block = [] +vhostuser_net = [] +vhost_net = [] diff --git a/virtio/src/device/balloon.rs b/virtio/src/device/balloon.rs index fd6347a283ac77c6cc1f3c1629232e9d1d7b459b..e6ab32cdeef9e4db9e6172fa9c2284bde4f76eb0 100644 --- a/virtio/src/device/balloon.rs +++ b/virtio/src/device/balloon.rs @@ -31,7 +31,8 @@ use crate::{ VIRTIO_TYPE_BALLOON, }; use address_space::{ - AddressSpace, FlatRange, GuestAddress, Listener, ListenerReqType, RegionIoEventFd, RegionType, + AddressAttr, AddressSpace, FlatRange, GuestAddress, Listener, ListenerReqType, RegionIoEventFd, + RegionType, }; use machine_manager::{ config::{get_pci_df, parse_bool, DEFAULT_VIRTQUEUE_SIZE}, @@ -41,17 +42,15 @@ use machine_manager::{ qmp::qmp_channel::QmpChannel, qmp::qmp_schema::BalloonInfo, }; -use util::{ - bitmap::Bitmap, - byte_code::ByteCode, - loop_context::{ - read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, - }, - num_ops::round_down, - offset_of, - seccomp::BpfRule, - unix::host_page_size, +use util::bitmap::Bitmap; +use util::byte_code::ByteCode; +use util::loop_context::{ + read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, }; +use util::num_ops::round_down; +use util::seccomp::BpfRule; +use util::unix::host_page_size; +use util::{gen_base_func, offset_of}; const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; const VIRTIO_BALLOON_F_REPORTING: u32 = 5; @@ -162,7 +161,11 @@ fn iov_to_buf( return None; } - match address_space.read_object::(GuestAddress(iov.iov_base.raw_value() + offset)) { + // GPAChecked: the iov has been checked in pop_avail(). + match address_space.read_object::( + GuestAddress(iov.iov_base.raw_value() + offset), + AddressAttr::Ram, + ) { Ok(dat) => Some(dat), Err(ref e) => { error!("Read virtioqueue failed: {:?}", e); @@ -171,9 +174,8 @@ fn iov_to_buf( } } -fn memory_advise(addr: *mut libc::c_void, len: libc::size_t, advice: libc::c_int) { - // SAFETY: The memory to be freed is allocated by guest. - if unsafe { libc::madvise(addr, len, advice) } != 0 { +unsafe fn memory_advise(addr: *mut libc::c_void, len: libc::size_t, advice: libc::c_int) { + if libc::madvise(addr, len, advice) != 0 { let evt_type = match advice { libc::MADV_DONTNEED => "DONTNEED".to_string(), libc::MADV_REMOVE => "REMOVE".to_string(), @@ -220,7 +222,7 @@ impl Request { for elem_iov in iovec { request.iovec.push(GuestIovec { iov_base: elem_iov.addr, - iov_len: elem_iov.len as u64, + iov_len: u64::from(elem_iov.len), }); request.elem_cnt += elem_iov.len; } @@ -239,11 +241,14 @@ impl Request { } else if hva == last_addr + BALLOON_PAGE_SIZE { free_len += 1; } else { - memory_advise( - start_addr as *const libc::c_void as *mut _, - (free_len * BALLOON_PAGE_SIZE) as usize, - libc::MADV_WILLNEED, - ); + // SAFETY: The memory to be freed is allocated by guest. + unsafe { + memory_advise( + start_addr as *const libc::c_void as *mut _, + (free_len * BALLOON_PAGE_SIZE) as usize, + libc::MADV_WILLNEED, + ) + }; free_len = 1; start_addr = hva; } @@ -252,11 +257,14 @@ impl Request { } if free_len != 0 { - memory_advise( - start_addr as *const libc::c_void as *mut _, - (free_len * BALLOON_PAGE_SIZE) as usize, - libc::MADV_WILLNEED, - ); + // SAFETY: The memory to be freed is allocated by guest. + unsafe { + memory_advise( + start_addr as *const libc::c_void as *mut _, + (free_len * BALLOON_PAGE_SIZE) as usize, + libc::MADV_WILLNEED, + ) + }; } } /// Mark balloon page with `MADV_DONTNEED` or `MADV_WILLNEED`. @@ -278,11 +286,11 @@ impl Request { let mut hvaset = Vec::new(); for iov in self.iovec.iter() { - let mut offset = 0; + let mut offset = 0_u64; while let Some(pfn) = iov_to_buf::(address_space, iov, offset) { offset += std::mem::size_of::() as u64; - let gpa: GuestAddress = GuestAddress((pfn as u64) << VIRTIO_BALLOON_PFN_SHIFT); + let gpa: GuestAddress = GuestAddress(u64::from(pfn) << VIRTIO_BALLOON_PFN_SHIFT); let (hva, shared) = match mem.lock().unwrap().get_host_address(gpa) { Some((addr, mem_share)) => (addr, mem_share), None => { @@ -301,7 +309,7 @@ impl Request { } let host_page_size = host_page_size(); - let mut advice = 0; + let mut advice = 0_i32; // If host_page_size equals BALLOON_PAGE_SIZE and have the same share properties, // we can directly call the madvise function without any problem. And if the advice is // MADV_WILLNEED, we just hint the whole host page it lives on, since we can't do @@ -320,11 +328,14 @@ impl Request { } else if hva == last_addr + BALLOON_PAGE_SIZE && last_share == share { free_len += 1; } else { - memory_advise( - start_addr as *const libc::c_void as *mut _, - (free_len * BALLOON_PAGE_SIZE) as usize, - advice, - ); + // SAFETY: The memory to be freed is allocated by guest. + unsafe { + memory_advise( + start_addr as *const libc::c_void as *mut _, + (free_len * BALLOON_PAGE_SIZE) as usize, + advice, + ) + }; free_len = 1; start_addr = hva; last_share = share; @@ -338,11 +349,14 @@ impl Request { last_addr = hva; } if free_len != 0 { - memory_advise( - start_addr as *const libc::c_void as *mut _, - (free_len * BALLOON_PAGE_SIZE) as usize, - advice, - ); + // SAFETY: The memory to be freed is allocated by guest. + unsafe { + memory_advise( + start_addr as *const libc::c_void as *mut _, + (free_len * BALLOON_PAGE_SIZE) as usize, + advice, + ) + }; } } else { let mut host_page_bitmap = BalloonedPageBitmap::new(host_page_size / BALLOON_PAGE_SIZE); @@ -376,11 +390,14 @@ impl Request { } else { advice = libc::MADV_DONTNEED; } - memory_advise( - host_page_bitmap.base_address as *const libc::c_void as *mut _, - host_page_size as usize, - advice, - ); + // SAFETY: The memory to be freed is allocated by guest. + unsafe { + memory_advise( + host_page_bitmap.base_address as *const libc::c_void as *mut _, + host_page_size as usize, + advice, + ) + }; host_page_bitmap = BalloonedPageBitmap::new(host_page_size / BALLOON_PAGE_SIZE); } } @@ -402,11 +419,14 @@ impl Request { } else { libc::MADV_DONTNEED }; - memory_advise( - hva as *const libc::c_void as *mut _, - iov.iov_len as usize, - advice, - ); + // SAFETY: The memory to be freed is allocated by guest. + unsafe { + memory_advise( + hva as *const libc::c_void as *mut _, + iov.iov_len as usize, + advice, + ) + }; } } } @@ -442,17 +462,17 @@ impl BlnMemInfo { fn get_host_address(&self, addr: GuestAddress) -> Option<(u64, bool)> { let all_regions = self.regions.lock().unwrap(); - for i in 0..all_regions.len() { - if addr.raw_value() < all_regions[i].guest_phys_addr + all_regions[i].memory_size - && addr.raw_value() >= all_regions[i].guest_phys_addr + for region in all_regions.iter() { + if addr.raw_value() < region.guest_phys_addr + region.memory_size + && addr.raw_value() >= region.guest_phys_addr { return Some(( - all_regions[i].userspace_addr + addr.raw_value() - - all_regions[i].guest_phys_addr, - all_regions[i].mem_share, + region.userspace_addr + addr.raw_value() - region.guest_phys_addr, + region.mem_share, )); } } + None } @@ -471,7 +491,8 @@ impl BlnMemInfo { fn add_mem_range(&self, fr: &FlatRange) { let guest_phys_addr = fr.addr_range.base.raw_value(); let memory_size = fr.addr_range.size; - if let Some(host_addr) = fr.owner.get_host_address() { + // SAFETY: memory_size is range's size, so we make sure [hva, hva+size] is in ram range. + if let Some(host_addr) = unsafe { fr.owner.get_host_address(AddressAttr::Ram) } { let userspace_addr = host_addr + fr.offset_in_region; let reg_page_size = fr.owner.get_region_page_size(); self.regions.lock().unwrap().push(BlnMemoryRegion { @@ -489,7 +510,8 @@ impl BlnMemInfo { fn delete_mem_range(&self, fr: &FlatRange) { let mut mem_regions = self.regions.lock().unwrap(); - if let Some(host_addr) = fr.owner.get_host_address() { + // SAFETY: memory_size is range's size, so we make sure [hva, hva+size] is in ram range. + if let Some(host_addr) = unsafe { fr.owner.get_host_address(AddressAttr::Ram) } { let reg_page_size = fr.owner.get_region_page_size(); let target = BlnMemoryRegion { guest_phys_addr: fr.addr_range.base.raw_value(), @@ -616,7 +638,7 @@ impl BalloonIoHandler { trace::virtio_receive_request("Balloon".to_string(), "to inflate".to_string()); &self.inf_queue } else { - trace::virtio_receive_request("Balloon".to_string(), "to inflate".to_string()); + trace::virtio_receive_request("Balloon".to_string(), "to deflate".to_string()); &self.def_queue }; let mut locked_queue = queue.lock().unwrap(); @@ -636,7 +658,7 @@ impl BalloonIoHandler { } locked_queue .vring - .add_used(&self.mem_space, req.desc_index, req.elem_cnt) + .add_used(req.desc_index, req.elem_cnt) .with_context(|| "Failed to add balloon response into used queue")?; (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&locked_queue), false) .with_context(|| { @@ -648,6 +670,7 @@ impl BalloonIoHandler { } fn reporting_evt_handler(&mut self) -> Result<()> { + trace::reporting_evt_handler(); let queue = self .report_queue .as_ref() @@ -670,7 +693,7 @@ impl BalloonIoHandler { } locked_queue .vring - .add_used(&self.mem_space, req.desc_index, req.elem_cnt) + .add_used(req.desc_index, req.elem_cnt) .with_context(|| "Failed to add balloon response into used queue")?; (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&locked_queue), false) .with_context(|| { @@ -682,6 +705,7 @@ impl BalloonIoHandler { } fn auto_msg_evt_handler(&mut self) -> Result<()> { + trace::auto_msg_evt_handler(); let queue = self .msg_queue .as_ref() @@ -701,14 +725,13 @@ impl BalloonIoHandler { .with_context(|| "Fail to parse available descriptor chain")?; // SAFETY: There is no confliction when writing global variable BALLOON_DEV, in other // words, this function will not be called simultaneously. - if let Some(dev) = unsafe { &BALLOON_DEV } { + if let Some(dev) = unsafe { BALLOON_DEV.as_ref() } { let mut balloon_dev = dev.lock().unwrap(); for iov in req.iovec.iter() { if let Some(stat) = iov_to_buf::(&self.mem_space, iov, 0) { - let ram_size = (balloon_dev.mem_info.lock().unwrap().get_ram_size() - >> VIRTIO_BALLOON_PFN_SHIFT) - as u32; - balloon_dev.set_num_pages(cmp::min(stat._val as u32, ram_size)); + let ram_size = balloon_dev.mem_info.lock().unwrap().get_ram_size() + >> VIRTIO_BALLOON_PFN_SHIFT; + balloon_dev.set_num_pages(cmp::min(stat._val, ram_size) as u32); } } balloon_dev @@ -718,7 +741,7 @@ impl BalloonIoHandler { locked_queue .vring - .add_used(&self.mem_space, req.desc_index, req.elem_cnt) + .add_used(req.desc_index, req.elem_cnt) .with_context(|| "Failed to add balloon response into used queue")?; (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&locked_queue), false) .with_context(|| { @@ -741,7 +764,7 @@ impl BalloonIoHandler { /// Get the memory size of balloon. fn get_balloon_memory_size(&self) -> u64 { - (self.balloon_actual.load(Ordering::Acquire) as u64) << VIRTIO_BALLOON_PFN_SHIFT + u64::from(self.balloon_actual.load(Ordering::Acquire)) << VIRTIO_BALLOON_PFN_SHIFT } } @@ -882,8 +905,10 @@ impl EventNotifierHelper for BalloonIoHandler { } #[derive(Parser, Debug, Clone, Default)] -#[command(name = "balloon")] +#[command(no_binary_name(true))] pub struct BalloonConfig { + #[arg(long)] + pub classtype: String, #[arg(long, value_parser = valid_id)] pub id: String, #[arg(long)] @@ -914,9 +939,9 @@ impl ConfigCheck for BalloonConfig { { return Err(anyhow!(ConfigError::IllegalValue( "balloon membuf-percent".to_string(), - MEM_BUFFER_PERCENT_MIN as u64, + u64::from(MEM_BUFFER_PERCENT_MIN), false, - MEM_BUFFER_PERCENT_MAX as u64, + u64::from(MEM_BUFFER_PERCENT_MAX), false, ))); } @@ -925,9 +950,9 @@ impl ConfigCheck for BalloonConfig { { return Err(anyhow!(ConfigError::IllegalValue( "balloon monitor-interval".to_string(), - MONITOR_INTERVAL_SECOND_MIN as u64, + u64::from(MONITOR_INTERVAL_SECOND_MIN), false, - MONITOR_INTERVAL_SECOND_MAX as u64, + u64::from(MONITOR_INTERVAL_SECOND_MAX), false, ))); } @@ -1017,11 +1042,11 @@ impl Balloon { if host_page_size > BALLOON_PAGE_SIZE && !self.mem_info.lock().unwrap().has_huge_page() { warn!("Balloon used with backing page size > 4kiB, this may not be reliable"); } - let target = (size >> VIRTIO_BALLOON_PFN_SHIFT) as u32; + let target = size >> VIRTIO_BALLOON_PFN_SHIFT; let address_space_ram_size = - (self.mem_info.lock().unwrap().get_ram_size() >> VIRTIO_BALLOON_PFN_SHIFT) as u32; + self.mem_info.lock().unwrap().get_ram_size() >> VIRTIO_BALLOON_PFN_SHIFT; let vm_target = cmp::min(target, address_space_ram_size); - self.num_pages = address_space_ram_size - vm_target; + self.num_pages = (address_space_ram_size - vm_target) as u32; self.signal_config_change().with_context(|| { "Failed to notify about configuration change after setting balloon memory" })?; @@ -1034,7 +1059,7 @@ impl Balloon { /// Get the size of memory that reclaimed by balloon. fn get_balloon_memory_size(&self) -> u64 { - (self.actual.load(Ordering::Acquire) as u64) << VIRTIO_BALLOON_PFN_SHIFT + u64::from(self.actual.load(Ordering::Acquire)) << VIRTIO_BALLOON_PFN_SHIFT } /// Get the actual memory size of guest. @@ -1048,13 +1073,7 @@ impl Balloon { } impl VirtioDevice for Balloon { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { self.bln_cfg.check()?; @@ -1204,7 +1223,7 @@ impl VirtioDevice for Balloon { pub fn qmp_balloon(target: u64) -> bool { // SAFETY: there is no confliction when writing global variable BALLOON_DEV, in other // words, this function will not be called simultaneously. - if let Some(dev) = unsafe { &BALLOON_DEV } { + if let Some(dev) = unsafe { BALLOON_DEV.as_ref() } { match dev.lock().unwrap().set_guest_memory_size(target) { Ok(()) => { return true; @@ -1222,7 +1241,7 @@ pub fn qmp_balloon(target: u64) -> bool { pub fn qmp_query_balloon() -> Option { // SAFETY: There is no confliction when writing global variable BALLOON_DEV, in other // words, this function will not be called simultaneously. - if let Some(dev) = unsafe { &BALLOON_DEV } { + if let Some(dev) = unsafe { BALLOON_DEV.as_ref() } { let unlocked_dev = dev.lock().unwrap(); return Some(unlocked_dev.get_guest_memory_size()); } @@ -1240,39 +1259,14 @@ pub fn balloon_allow_list(syscall_allow_list: &mut Vec) { #[cfg(test)] mod tests { - pub use super::*; - pub use crate::*; - - use address_space::{AddressRange, HostMemMapping, Region}; + use super::*; + use crate::tests::{address_space_init, MEMORY_SIZE}; + use crate::*; + use address_space::{AddressAttr, AddressRange, HostMemMapping, Region}; + use machine_manager::event_loop::EventLoop; - const MEMORY_SIZE: u64 = 1024 * 1024; const QUEUE_SIZE: u16 = 256; - fn address_space_init() -> Arc { - let root = Region::init_container_region(1 << 36, "space"); - let sys_space = AddressSpace::new(root, "space", None).unwrap(); - let host_mmap = Arc::new( - HostMemMapping::new( - GuestAddress(0), - None, - MEMORY_SIZE, - None, - false, - false, - false, - ) - .unwrap(), - ); - sys_space - .root() - .add_subregion( - Region::init_ram_region(host_mmap.clone(), "space"), - host_mmap.start_address().raw_value(), - ) - .unwrap(); - sys_space - } - fn create_flat_range(addr: u64, size: u64, offset_in_region: u64) -> FlatRange { let mem_mapping = Arc::new( HostMemMapping::new(GuestAddress(addr), None, size, None, false, false, false).unwrap(), @@ -1320,13 +1314,13 @@ mod tests { bln.base.device_features = 1 | 1 << 32; bln.set_driver_features(0, 1); assert_eq!(bln.base.driver_features, 1); - assert_eq!(bln.base.driver_features, bln.driver_features(0) as u64); + assert_eq!(bln.base.driver_features, u64::from(bln.driver_features(0))); bln.base.driver_features = 1 << 32; bln.set_driver_features(1, 1); assert_eq!(bln.base.driver_features, 1 << 32); assert_eq!( bln.base.driver_features, - (bln.driver_features(1) as u64) << 32 + u64::from(bln.driver_features(1)) << 32 ); // Test methods of balloon. @@ -1444,33 +1438,45 @@ mod tests { let mut queue_config_inf = QueueConfig::new(QUEUE_SIZE); queue_config_inf.desc_table = GuestAddress(0x100); - queue_config_inf.addr_cache.desc_table_host = mem_space - .get_host_address(queue_config_inf.desc_table) - .unwrap(); + queue_config_inf.addr_cache.desc_table_host = unsafe { + mem_space + .get_host_address(queue_config_inf.desc_table, AddressAttr::Ram) + .unwrap() + }; queue_config_inf.avail_ring = GuestAddress(0x300); - queue_config_inf.addr_cache.avail_ring_host = mem_space - .get_host_address(queue_config_inf.avail_ring) - .unwrap(); + queue_config_inf.addr_cache.avail_ring_host = unsafe { + mem_space + .get_host_address(queue_config_inf.avail_ring, AddressAttr::Ram) + .unwrap() + }; queue_config_inf.used_ring = GuestAddress(0x600); - queue_config_inf.addr_cache.used_ring_host = mem_space - .get_host_address(queue_config_inf.used_ring) - .unwrap(); + queue_config_inf.addr_cache.used_ring_host = unsafe { + mem_space + .get_host_address(queue_config_inf.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config_inf.ready = true; queue_config_inf.size = QUEUE_SIZE; let mut queue_config_def = QueueConfig::new(QUEUE_SIZE); queue_config_def.desc_table = GuestAddress(0x1100); - queue_config_def.addr_cache.desc_table_host = mem_space - .get_host_address(queue_config_def.desc_table) - .unwrap(); + queue_config_def.addr_cache.desc_table_host = unsafe { + mem_space + .get_host_address(queue_config_def.desc_table, AddressAttr::Ram) + .unwrap() + }; queue_config_def.avail_ring = GuestAddress(0x1300); - queue_config_def.addr_cache.avail_ring_host = mem_space - .get_host_address(queue_config_def.avail_ring) - .unwrap(); + queue_config_def.addr_cache.avail_ring_host = unsafe { + mem_space + .get_host_address(queue_config_def.avail_ring, AddressAttr::Ram) + .unwrap() + }; queue_config_def.used_ring = GuestAddress(0x1600); - queue_config_def.addr_cache.used_ring_host = mem_space - .get_host_address(queue_config_def.used_ring) - .unwrap(); + queue_config_def.addr_cache.used_ring_host = unsafe { + mem_space + .get_host_address(queue_config_def.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config_def.ready = true; queue_config_def.size = QUEUE_SIZE; @@ -1484,7 +1490,7 @@ mod tests { driver_features: bln.base.driver_features, mem_space: mem_space.clone(), inf_queue: queue1, - inf_evt: event_inf.clone(), + inf_evt: event_inf, def_queue: queue2, def_evt: event_def, report_queue: None, @@ -1514,7 +1520,11 @@ mod tests { // Set desc table. mem_space - .write_object::(&desc, GuestAddress(queue_config_inf.desc_table.0)) + .write_object::( + &desc, + GuestAddress(queue_config_inf.desc_table.0), + AddressAttr::Ram, + ) .unwrap(); let ele = GuestIovec { @@ -1522,13 +1532,21 @@ mod tests { iov_len: std::mem::size_of::() as u64, }; mem_space - .write_object::(&ele, GuestAddress(0x2000)) + .write_object::(&ele, GuestAddress(0x2000), AddressAttr::Ram) .unwrap(); mem_space - .write_object::(&0, GuestAddress(queue_config_inf.avail_ring.0 + 4 as u64)) + .write_object::( + &0, + GuestAddress(queue_config_inf.avail_ring.0 + 4_u64), + AddressAttr::Ram, + ) .unwrap(); mem_space - .write_object::(&1, GuestAddress(queue_config_inf.avail_ring.0 + 2 as u64)) + .write_object::( + &1, + GuestAddress(queue_config_inf.avail_ring.0 + 2_u64), + AddressAttr::Ram, + ) .unwrap(); assert!(handler.process_balloon_queue(BALLOON_INFLATE_EVENT).is_ok()); @@ -1544,17 +1562,29 @@ mod tests { }; mem_space - .write_object::(&desc, GuestAddress(queue_config_def.desc_table.0)) + .write_object::( + &desc, + GuestAddress(queue_config_def.desc_table.0), + AddressAttr::Ram, + ) .unwrap(); mem_space - .write_object::(&ele, GuestAddress(0x3000)) + .write_object::(&ele, GuestAddress(0x3000), AddressAttr::Ram) .unwrap(); mem_space - .write_object::(&0, GuestAddress(queue_config_def.avail_ring.0 + 4 as u64)) + .write_object::( + &0, + GuestAddress(queue_config_def.avail_ring.0 + 4_u64), + AddressAttr::Ram, + ) .unwrap(); mem_space - .write_object::(&1, GuestAddress(queue_config_def.avail_ring.0 + 2 as u64)) + .write_object::( + &1, + GuestAddress(queue_config_def.avail_ring.0 + 2_u64), + AddressAttr::Ram, + ) .unwrap(); assert!(handler.process_balloon_queue(BALLOON_DEFLATE_EVENT).is_ok()); @@ -1562,6 +1592,8 @@ mod tests { #[test] fn test_balloon_activate() { + EventLoop::object_init(&None).unwrap(); + let mem_space = address_space_init(); let interrupt_evt = EventFd::new(libc::EFD_NONBLOCK).unwrap(); let interrupt_status = Arc::new(AtomicU32::new(0)); @@ -1578,18 +1610,20 @@ mod tests { }, ) as VirtioInterrupt); - let mut queue_config_inf = QueueConfig::new(QUEUE_SIZE); - queue_config_inf.desc_table = GuestAddress(0); - queue_config_inf.avail_ring = GuestAddress(4096); - queue_config_inf.used_ring = GuestAddress(8192); - queue_config_inf.ready = true; - queue_config_inf.size = QUEUE_SIZE; - let mut queues: Vec>> = Vec::new(); - let queue1 = Arc::new(Mutex::new(Queue::new(queue_config_inf, 1).unwrap())); - queues.push(queue1); - let event_inf = Arc::new(EventFd::new(libc::EFD_NONBLOCK).unwrap()); - let queue_evts: Vec> = vec![event_inf.clone()]; + let mut queue_evts: Vec> = Vec::new(); + for i in 0..QUEUE_NUM_BALLOON as u64 { + let mut queue_config_inf = QueueConfig::new(QUEUE_SIZE); + queue_config_inf.desc_table = GuestAddress(12288 * i); + queue_config_inf.avail_ring = GuestAddress(12288 * i + 4096); + queue_config_inf.used_ring = GuestAddress(12288 * i + 8192); + queue_config_inf.ready = true; + queue_config_inf.size = QUEUE_SIZE; + let queue = Arc::new(Mutex::new(Queue::new(queue_config_inf, 1).unwrap())); + queues.push(queue); + let event_inf = Arc::new(EventFd::new(libc::EFD_NONBLOCK).unwrap()); + queue_evts.push(event_inf); + } let bln_cfg = BalloonConfig { id: "bln".to_string(), @@ -1598,7 +1632,9 @@ mod tests { }; let mut bln = Balloon::new(bln_cfg, mem_space.clone()); bln.base.queues = queues; - assert!(bln.activate(mem_space, interrupt_cb, queue_evts).is_err()); + assert!(bln.activate(mem_space, interrupt_cb, queue_evts).is_ok()); + + EventLoop::loop_clean(); } #[test] diff --git a/virtio/src/device/block.rs b/virtio/src/device/block.rs index a7719cf35ada6aba6227e5bc3ce79681b83ac075..25b18036a0b05f4746a6e379ad03b29a0097da02 100644 --- a/virtio/src/device/block.rs +++ b/virtio/src/device/block.rs @@ -21,11 +21,12 @@ use std::sync::{Arc, Mutex}; use anyhow::{anyhow, bail, Context, Result}; use byteorder::{ByteOrder, LittleEndian}; +use clap::Parser; use log::{error, warn}; use vmm_sys_util::{epoll::EventSet, eventfd::EventFd}; use crate::{ - check_config_space_rw, gpa_hva_iovec_map, iov_discard_back, iov_discard_front, iov_to_buf, + check_config_space_rw, gpa_hva_iovec_map, iov_discard_back, iov_discard_front, iov_read_object, read_config_default, report_virtio_error, virtio_has_feature, Element, Queue, VirtioBase, VirtioDevice, VirtioError, VirtioInterrupt, VirtioInterruptType, VIRTIO_BLK_F_DISCARD, VIRTIO_BLK_F_FLUSH, VIRTIO_BLK_F_MQ, VIRTIO_BLK_F_RO, VIRTIO_BLK_F_SEG_MAX, @@ -35,12 +36,15 @@ use crate::{ VIRTIO_BLK_WRITE_ZEROES_FLAG_UNMAP, VIRTIO_F_RING_EVENT_IDX, VIRTIO_F_RING_INDIRECT_DESC, VIRTIO_F_VERSION_1, VIRTIO_TYPE_BLOCK, }; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress, RegionCache}; use block_backend::{ create_block_backend, remove_block_backend, BlockDriverOps, BlockIoErrorCallback, BlockProperty, BlockStatus, }; -use machine_manager::config::{BlkDevConfig, ConfigCheck, DriveFile, VmConfig}; +use machine_manager::config::{ + get_pci_df, parse_bool, valid_block_device_virtqueue_size, valid_id, ConfigCheck, ConfigError, + DriveConfig, DriveFile, VmConfig, DEFAULT_VIRTQUEUE_SIZE, MAX_VIRTIO_QUEUE, +}; use machine_manager::event_loop::{register_event_helper, unregister_event_helper, EventLoop}; use migration::{ migration::Migratable, DeviceStateDesc, FieldDesc, MigrationHook, MigrationManager, @@ -54,9 +58,10 @@ use util::aio::{ use util::byte_code::ByteCode; use util::leak_bucket::LeakBucket; use util::loop_context::{ - read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, + create_new_eventfd, read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, + NotifierOperation, }; -use util::offset_of; +use util::{gen_base_func, offset_of}; /// Number of virtqueues. const QUEUE_NUM_BLK: usize = 1; @@ -76,6 +81,8 @@ const MAX_NUM_MERGE_BYTES: u64 = i32::MAX as u64; const MAX_ITERATION_PROCESS_QUEUE: u16 = 10; /// Max number sectors of per request. const MAX_REQUEST_SECTORS: u32 = u32::MAX >> SECTOR_SHIFT; +/// Max length of serial number. +const MAX_SERIAL_NUM_LEN: usize = 20; type SenderConfig = ( Option>>>, @@ -86,6 +93,70 @@ type SenderConfig = ( bool, ); +fn valid_serial(s: &str) -> Result { + if s.len() > MAX_SERIAL_NUM_LEN { + return Err(anyhow!(ConfigError::StringLengthTooLong( + "device serial number".to_string(), + MAX_SERIAL_NUM_LEN, + ))); + } + Ok(s.to_string()) +} + +#[derive(Parser, Debug, Clone)] +#[command(no_binary_name(true))] +pub struct VirtioBlkDevConfig { + #[arg(long, value_parser = ["virtio-blk-pci", "virtio-blk-device"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: Option, + #[arg(long, value_parser = get_pci_df)] + pub addr: Option<(u8, u8)>, + #[arg(long, value_parser = parse_bool)] + pub multifunction: Option, + #[arg(long)] + pub drive: String, + #[arg(long)] + pub bootindex: Option, + #[arg(long, alias = "num-queues", value_parser = clap::value_parser!(u16).range(1..=MAX_VIRTIO_QUEUE as i64))] + pub num_queues: Option, + #[arg(long)] + pub iothread: Option, + #[arg(long, alias = "queue-size", default_value = "256", value_parser = valid_block_device_virtqueue_size)] + pub queue_size: u16, + #[arg(long, value_parser = valid_serial)] + pub serial: Option, +} + +impl Default for VirtioBlkDevConfig { + fn default() -> Self { + Self { + classtype: "".to_string(), + id: "".to_string(), + bus: None, + addr: None, + multifunction: None, + drive: "".to_string(), + num_queues: Some(1), + bootindex: None, + iothread: None, + queue_size: DEFAULT_VIRTQUEUE_SIZE, + serial: None, + } + } +} + +impl ConfigCheck for VirtioBlkDevConfig { + fn check(&self) -> Result<()> { + if self.serial.is_some() { + valid_serial(&self.serial.clone().unwrap())?; + } + Ok(()) + } +} + fn get_serial_num_config(serial_num: &str) -> Vec { let mut id_bytes = vec![0; VIRTIO_BLK_ID_BYTES as usize]; let bytes_to_copy = cmp::min(serial_num.len(), VIRTIO_BLK_ID_BYTES as usize); @@ -157,14 +228,17 @@ impl AioCompleteCb { } fn complete_one_request(&self, req: &Request, status: u8) -> Result<()> { - if let Err(ref e) = self.mem_space.write_object(&status, req.in_header) { + if let Err(ref e) = self + .mem_space + .write_object(&status, req.in_header, AddressAttr::Ram) + { bail!("Failed to write the status (blk io completion) {:?}", e); } let mut queue_lock = self.queue.lock().unwrap(); queue_lock .vring - .add_used(&self.mem_space, req.desc_index, req.in_len) + .add_used(req.desc_index, req.in_len) .with_context(|| { format!( "Failed to add used ring(blk io completion), index {}, len {}", @@ -173,10 +247,7 @@ impl AioCompleteCb { })?; trace::virtio_blk_complete_one_request(req.desc_index, req.in_len); - if queue_lock - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue_lock.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue_lock), false) .with_context(|| { VirtioError::InterruptTrigger("blk io completion", VirtioInterruptType::Vring) @@ -200,28 +271,25 @@ struct Request { } impl Request { - fn new(handler: &BlockIoHandler, elem: &mut Element, status: &mut u8) -> Result { + fn new( + handler: &BlockIoHandler, + cache: &Option, + elem: &mut Element, + status: &mut u8, + devid: &str, + ) -> Result { if elem.out_iovec.is_empty() || elem.in_iovec.is_empty() { bail!( - "Missed header for block request: out {} in {} desc num {}", + "Missed header for block {} request: out {} in {} desc num {}", + devid, elem.out_iovec.len(), elem.in_iovec.len(), elem.desc_num ); } - let mut out_header = RequestOutHeader::default(); - iov_to_buf( - &handler.mem_space, - &elem.out_iovec, - out_header.as_mut_bytes(), - ) - .and_then(|size| { - if size < size_of::() { - bail!("Invalid out header for block request: length {}", size); - } - Ok(()) - })?; + let mut out_header = + iov_read_object::(&handler.mem_space, &elem.out_iovec, cache)?; out_header.request_type = LittleEndian::read_u32(out_header.request_type.as_bytes()); out_header.sector = LittleEndian::read_u64(out_header.sector.as_bytes()); @@ -233,7 +301,7 @@ impl Request { ); } // Note: addr plus len has been checked not overflow in virtqueue. - let in_header = GuestAddress(in_iov_elem.addr.0 + in_iov_elem.len as u64 - 1); + let in_header = GuestAddress(in_iov_elem.addr.0 + u64::from(in_iov_elem.len) - 1); let mut request = Request { desc_index: elem.index, @@ -266,9 +334,9 @@ impl Request { // Otherwise discard the last "status" byte. _ => iov_discard_back(&mut elem.in_iovec, 1), } - .with_context(|| "Empty data for block request")?; + .with_context(|| format!("Empty data for block {} request", devid))?; - let (data_len, iovec) = gpa_hva_iovec_map(data_iovec, &handler.mem_space)?; + let (data_len, iovec) = gpa_hva_iovec_map(data_iovec, &handler.mem_space, cache)?; request.data_len = data_len; request.iovec = iovec; } @@ -279,7 +347,7 @@ impl Request { } } - if !request.io_range_valid(handler.disk_sectors) { + if !request.io_range_valid(handler.disk_sectors, devid) { *status = VIRTIO_BLK_S_IOERR; } @@ -323,24 +391,41 @@ impl Request { VIRTIO_BLK_T_IN => { locked_backend .read_vectored(iovecs, offset, aiocompletecb) - .with_context(|| "Failed to process block request for reading")?; + .with_context(|| { + format!( + "Failed to process block {} request for reading", + iohandler.devid + ) + })?; } VIRTIO_BLK_T_OUT => { locked_backend .write_vectored(iovecs, offset, aiocompletecb) - .with_context(|| "Failed to process block request for writing")?; + .with_context(|| { + format!( + "Failed to process block {} request for writing", + iohandler.devid + ) + })?; } VIRTIO_BLK_T_FLUSH => { - locked_backend - .datasync(aiocompletecb) - .with_context(|| "Failed to process block request for flushing")?; + locked_backend.datasync(aiocompletecb).with_context(|| { + format!( + "Failed to process block {} request for flushing", + iohandler.devid + ) + })?; } VIRTIO_BLK_T_GET_ID => { let serial = serial_num.clone().unwrap_or_else(|| String::from("")); let serial_vec = get_serial_num_config(&serial); - let status = iov_from_buf_direct(&self.iovec, &serial_vec).map_or_else( + // SAFETY: iovec is generated by address_space. + let status = unsafe { iov_from_buf_direct(&self.iovec, &serial_vec) }.map_or_else( |e| { - error!("Failed to process block request for getting id, {:?}", e); + error!( + "Failed to process block {} request for getting id, {:?}", + iohandler.devid, e + ); VIRTIO_BLK_S_IOERR }, |_| VIRTIO_BLK_S_OK, @@ -388,7 +473,8 @@ impl Request { // Get and check the discard segment. let mut segment = DiscardWriteZeroesSeg::default(); - iov_to_buf_direct(&self.iovec, 0, segment.as_mut_bytes()).and_then(|v| { + // SAFETY: iovec is generated by address_space. + unsafe { iov_to_buf_direct(&self.iovec, 0, segment.as_mut_bytes()) }.and_then(|v| { if v as u64 == size { Ok(()) } else { @@ -398,7 +484,7 @@ impl Request { let sector = LittleEndian::read_u64(segment.sector.as_bytes()); let num_sectors = LittleEndian::read_u32(segment.num_sectors.as_bytes()); if sector - .checked_add(num_sectors as u64) + .checked_add(u64::from(num_sectors)) .filter(|&off| off <= iohandler.disk_sectors) .is_none() || num_sectors > MAX_REQUEST_SECTORS @@ -419,7 +505,7 @@ impl Request { let block_backend = iohandler.block_backend.as_ref().unwrap(); let mut locked_backend = block_backend.lock().unwrap(); let offset = (sector as usize) << SECTOR_SHIFT; - let nbytes = (num_sectors as u64) << SECTOR_SHIFT; + let nbytes = u64::from(num_sectors) << SECTOR_SHIFT; trace::virtio_blk_handle_discard_write_zeroes_req(&opcode, flags, offset, nbytes); if opcode == OpCode::Discard { if flags == VIRTIO_BLK_WRITE_ZEROES_FLAG_UNMAP { @@ -438,7 +524,7 @@ impl Request { Ok(()) } - fn io_range_valid(&self, disk_sectors: u64) -> bool { + fn io_range_valid(&self, disk_sectors: u64, devid: &str) -> bool { match self.out_header.request_type { VIRTIO_BLK_T_IN | VIRTIO_BLK_T_OUT => { if self.data_len % SECTOR_SIZE != 0 { @@ -452,8 +538,8 @@ impl Request { .is_none() { error!( - "offset {} invalid, disk sector {}", - self.out_header.sector, disk_sectors + "devid {} offset {} invalid, disk sector {}", + devid, self.out_header.sector, disk_sectors ); return false; } @@ -470,6 +556,8 @@ impl Request { /// Control block of Block IO. struct BlockIoHandler { + /// Device id of this block device. + devid: String, /// The virtqueue. queue: Arc>, /// Eventfd of the virtqueue for IO event. @@ -514,9 +602,9 @@ impl BlockIoHandler { let mut merge_req_queue = Vec::::new(); let mut last_req: Option<&mut Request> = None; - let mut merged_reqs = 0; - let mut merged_iovs = 0; - let mut merged_bytes = 0; + let mut merged_reqs: u16 = 0; + let mut merged_iovs: usize = 0; + let mut merged_bytes: u64 = 0; for req in req_queue { let req_iovs = req.iovec.len(); @@ -572,7 +660,7 @@ impl BlockIoHandler { // limit io operations if iops is configured if let Some(lb) = self.leak_bucket.as_mut() { if let Some(ctx) = EventLoop::get_ctx(self.iothread.as_ref()) { - if lb.throttled(ctx, 1_u64) { + if lb.throttled(ctx, 1_u32) { queue.vring.push_back(); break; } @@ -581,7 +669,8 @@ impl BlockIoHandler { // Init and put valid request into request queue. let mut status = VIRTIO_BLK_S_OK; - let req = Request::new(self, &mut elem, &mut status)?; + let cache = queue.vring.get_cache(); + let req = Request::new(self, cache, &mut elem, &mut status, &self.devid)?; if status != VIRTIO_BLK_S_OK { let aiocompletecb = AioCompleteCb::new( self.queue.clone(), @@ -620,7 +709,10 @@ impl BlockIoHandler { if let Some(block_backend) = self.block_backend.as_ref() { req_rc.execute(self, block_backend.clone(), aiocompletecb)?; } else { - warn!("Failed to execute block request, block_backend not specified"); + warn!( + "Failed to execute block {} request, block_backend not specified", + &self.devid + ); aiocompletecb.complete_request(VIRTIO_BLK_S_IOERR)?; } } @@ -637,12 +729,7 @@ impl BlockIoHandler { // Do not unlock or drop the locked_status in this function. let status; let mut locked_status; - let len = self - .queue - .lock() - .unwrap() - .vring - .avail_ring_len(&self.mem_space)?; + let len = self.queue.lock().unwrap().vring.avail_ring_len()?; if len > 0 { if let Some(block_backend) = self.block_backend.as_ref() { status = block_backend.lock().unwrap().get_status(); @@ -653,16 +740,9 @@ impl BlockIoHandler { trace::virtio_blk_process_queue_suppress_notify(len); let mut done = false; - let mut iteration = 0; + let mut iteration: u16 = 0; - while self - .queue - .lock() - .unwrap() - .vring - .avail_ring_len(&self.mem_space)? - != 0 - { + while self.queue.lock().unwrap().vring.avail_ring_len()? != 0 { // Do not stuck IO thread. iteration += 1; if iteration > MAX_ITERATION_PROCESS_QUEUE { @@ -671,24 +751,24 @@ impl BlockIoHandler { break; } - self.queue.lock().unwrap().vring.suppress_queue_notify( - &self.mem_space, - self.driver_features, - true, - )?; + self.queue + .lock() + .unwrap() + .vring + .suppress_queue_notify(self.driver_features, true)?; done = self.process_queue_internal()?; - self.queue.lock().unwrap().vring.suppress_queue_notify( - &self.mem_space, - self.driver_features, - false, - )?; + self.queue + .lock() + .unwrap() + .vring + .suppress_queue_notify(self.driver_features, false)?; // See whether we have been throttled. if let Some(lb) = self.leak_bucket.as_mut() { if let Some(ctx) = EventLoop::get_ctx(self.iothread.as_ref()) { - if lb.throttled(ctx, 0) { + if lb.throttled(ctx, 0_u32) { break; } } @@ -955,7 +1035,9 @@ pub struct Block { /// Virtio device base property. base: VirtioBase, /// Configuration of the block device. - blk_cfg: BlkDevConfig, + blk_cfg: VirtioBlkDevConfig, + /// Configuration of the block device's drive. + drive_cfg: DriveConfig, /// Config space of the block device. config_space: VirtioBlkConfig, /// BLock backend opened by the block device. @@ -978,14 +1060,16 @@ pub struct Block { impl Block { pub fn new( - blk_cfg: BlkDevConfig, + blk_cfg: VirtioBlkDevConfig, + drive_cfg: DriveConfig, drive_files: Arc>>, ) -> Block { - let queue_num = blk_cfg.queues as usize; + let queue_num = blk_cfg.num_queues.unwrap_or(1) as usize; let queue_size = blk_cfg.queue_size; Self { base: VirtioBase::new(VIRTIO_TYPE_BLOCK, queue_num, queue_size), blk_cfg, + drive_cfg, req_align: 1, buf_align: 1, drive_files, @@ -997,13 +1081,13 @@ impl Block { // capacity: 64bits self.config_space.capacity = self.disk_sectors; // seg_max = queue_size - 2: 32bits - self.config_space.seg_max = self.queue_size_max() as u32 - 2; + self.config_space.seg_max = u32::from(self.queue_size_max()) - 2; - if self.blk_cfg.queues > 1 { - self.config_space.num_queues = self.blk_cfg.queues; + if self.blk_cfg.num_queues.unwrap_or(1) > 1 { + self.config_space.num_queues = self.blk_cfg.num_queues.unwrap_or(1); } - if self.blk_cfg.discard { + if self.drive_cfg.discard { // Just support one segment per request. self.config_space.max_discard_seg = 1; // The default discard alignment is 1 sector. @@ -1011,7 +1095,7 @@ impl Block { self.config_space.max_discard_sectors = MAX_REQUEST_SECTORS; } - if self.blk_cfg.write_zeroes != WriteZeroesState::Off { + if self.drive_cfg.write_zeroes != WriteZeroesState::Off { // Just support one segment per request. self.config_space.max_write_zeroes_seg = 1; self.config_space.max_write_zeroes_sectors = MAX_REQUEST_SECTORS; @@ -1039,13 +1123,7 @@ impl Block { } impl VirtioDevice for Block { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { // if iothread not found, return err @@ -1058,35 +1136,36 @@ impl VirtioDevice for Block { ); } - if !self.blk_cfg.path_on_host.is_empty() { + if !self.drive_cfg.path_on_host.is_empty() { let drive_files = self.drive_files.lock().unwrap(); - let file = VmConfig::fetch_drive_file(&drive_files, &self.blk_cfg.path_on_host)?; - let alignments = VmConfig::fetch_drive_align(&drive_files, &self.blk_cfg.path_on_host)?; + let file = VmConfig::fetch_drive_file(&drive_files, &self.drive_cfg.path_on_host)?; + let alignments = + VmConfig::fetch_drive_align(&drive_files, &self.drive_cfg.path_on_host)?; self.req_align = alignments.0; self.buf_align = alignments.1; - let drive_id = VmConfig::get_drive_id(&drive_files, &self.blk_cfg.path_on_host)?; + let drive_id = VmConfig::get_drive_id(&drive_files, &self.drive_cfg.path_on_host)?; let mut thread_pool = None; - if self.blk_cfg.aio != AioEngine::Off { + if self.drive_cfg.aio != AioEngine::Off { thread_pool = Some(EventLoop::get_ctx(None).unwrap().thread_pool.clone()); } let aio = Aio::new( Arc::new(BlockIoHandler::complete_func), - self.blk_cfg.aio, + self.drive_cfg.aio, thread_pool, )?; let conf = BlockProperty { id: drive_id, - format: self.blk_cfg.format, + format: self.drive_cfg.format, iothread: self.blk_cfg.iothread.clone(), - direct: self.blk_cfg.direct, + direct: self.drive_cfg.direct, req_align: self.req_align, buf_align: self.buf_align, - discard: self.blk_cfg.discard, - write_zeroes: self.blk_cfg.write_zeroes, - l2_cache_size: self.blk_cfg.l2_cache_size, - refcount_cache_size: self.blk_cfg.refcount_cache_size, + discard: self.drive_cfg.discard, + write_zeroes: self.drive_cfg.write_zeroes, + l2_cache_size: self.drive_cfg.l2_cache_size, + refcount_cache_size: self.drive_cfg.refcount_cache_size, }; let backend = create_block_backend(file, aio, conf)?; let disk_size = backend.lock().unwrap().disk_size()?; @@ -1110,16 +1189,16 @@ impl VirtioDevice for Block { | 1_u64 << VIRTIO_F_RING_EVENT_IDX | 1_u64 << VIRTIO_BLK_F_FLUSH | 1_u64 << VIRTIO_BLK_F_SEG_MAX; - if self.blk_cfg.read_only { + if self.drive_cfg.readonly { self.base.device_features |= 1_u64 << VIRTIO_BLK_F_RO; }; - if self.blk_cfg.queues > 1 { + if self.blk_cfg.num_queues.unwrap_or(1) > 1 { self.base.device_features |= 1_u64 << VIRTIO_BLK_F_MQ; } - if self.blk_cfg.discard { + if self.drive_cfg.discard { self.base.device_features |= 1_u64 << VIRTIO_BLK_F_DISCARD; } - if self.blk_cfg.write_zeroes != WriteZeroesState::Off { + if self.drive_cfg.write_zeroes != WriteZeroesState::Off { self.base.device_features |= 1_u64 << VIRTIO_BLK_F_WRITE_ZEROES; } self.build_device_config_space(); @@ -1130,7 +1209,7 @@ impl VirtioDevice for Block { fn unrealize(&mut self) -> Result<()> { MigrationManager::unregister_device_instance(BlockState::descriptor(), &self.blk_cfg.id); let drive_files = self.drive_files.lock().unwrap(); - let drive_id = VmConfig::get_drive_id(&drive_files, &self.blk_cfg.path_on_host)?; + let drive_id = VmConfig::get_drive_id(&drive_files, &self.drive_cfg.path_on_host)?; remove_block_backend(&drive_id); Ok(()) } @@ -1166,9 +1245,10 @@ impl VirtioDevice for Block { continue; } let (sender, receiver) = channel(); - let update_evt = Arc::new(EventFd::new(libc::EFD_NONBLOCK)?); + let update_evt = Arc::new(create_new_eventfd()?); let driver_features = self.base.driver_features; let handler = BlockIoHandler { + devid: self.blk_cfg.id.clone(), queue: queue.clone(), queue_evt: queue_evts[index].clone(), mem_space: mem_space.clone(), @@ -1176,20 +1256,20 @@ impl VirtioDevice for Block { req_align: self.req_align, buf_align: self.buf_align, disk_sectors: self.disk_sectors, - direct: self.blk_cfg.direct, - serial_num: self.blk_cfg.serial_num.clone(), + direct: self.drive_cfg.direct, + serial_num: self.blk_cfg.serial.clone(), driver_features, receiver, update_evt: update_evt.clone(), device_broken: self.base.broken.clone(), interrupt_cb: interrupt_cb.clone(), iothread: self.blk_cfg.iothread.clone(), - leak_bucket: match self.blk_cfg.iops { + leak_bucket: match self.drive_cfg.iops { Some(iops) => Some(LeakBucket::new(iops)?), None => None, }, - discard: self.blk_cfg.discard, - write_zeroes: self.blk_cfg.write_zeroes, + discard: self.drive_cfg.discard, + write_zeroes: self.drive_cfg.write_zeroes, }; let notifiers = EventNotifierHelper::internal_notifiers(Arc::new(Mutex::new(handler))); @@ -1236,18 +1316,28 @@ impl VirtioDevice for Block { Ok(()) } - fn update_config(&mut self, dev_config: Option>) -> Result<()> { - let is_plug = dev_config.is_some(); - if let Some(conf) = dev_config { - self.blk_cfg = conf + // configs[0]: DriveConfig. configs[1]: VirtioBlkDevConfig. + fn update_config(&mut self, configs: Vec>) -> Result<()> { + let mut is_plug = false; + if configs.len() == 2 { + self.drive_cfg = configs[0] .as_any() - .downcast_ref::() + .downcast_ref::() + .unwrap() + .clone(); + self.blk_cfg = configs[1] + .as_any() + .downcast_ref::() .unwrap() .clone(); // microvm type block device don't support multiple queue. - self.blk_cfg.queues = QUEUE_NUM_BLK as u16; - } else { + self.blk_cfg.num_queues = Some(QUEUE_NUM_BLK as u16); + is_plug = true; + } else if configs.is_empty() { self.blk_cfg = Default::default(); + self.drive_cfg = Default::default(); + } else { + bail!("Invalid update configs."); } if !is_plug { @@ -1296,8 +1386,8 @@ impl VirtioDevice for Block { self.req_align, self.buf_align, self.disk_sectors, - self.blk_cfg.serial_num.clone(), - self.blk_cfg.direct, + self.blk_cfg.serial.clone(), + self.drive_cfg.direct, )) .with_context(|| VirtioError::ChannelSend("image fd".to_string()))?; } @@ -1352,49 +1442,51 @@ mod tests { use vmm_sys_util::tempfile::TempFile; use super::*; + use crate::tests::address_space_init; use crate::*; - use address_space::{AddressSpace, GuestAddress, HostMemMapping, Region}; - use machine_manager::config::{IothreadConfig, VmConfig, DEFAULT_VIRTQUEUE_SIZE}; + use address_space::{AddressAttr, GuestAddress}; + use machine_manager::config::{ + str_slip_to_clap, IothreadConfig, VmConfig, DEFAULT_VIRTQUEUE_SIZE, + }; + use machine_manager::temp_cleaner::TempCleaner; const QUEUE_NUM_BLK: usize = 1; const CONFIG_SPACE_SIZE: usize = 60; const VIRTQ_DESC_F_NEXT: u16 = 0x01; const VIRTQ_DESC_F_WRITE: u16 = 0x02; - const SYSTEM_SPACE_SIZE: u64 = (1024 * 1024) as u64; - - // build dummy address space of vm - fn address_space_init() -> Arc { - let root = Region::init_container_region(1 << 36, "sysmem"); - let sys_space = AddressSpace::new(root, "sysmem", None).unwrap(); - let host_mmap = Arc::new( - HostMemMapping::new( - GuestAddress(0), - None, - SYSTEM_SPACE_SIZE, - None, - false, - false, - false, - ) - .unwrap(), - ); - sys_space - .root() - .add_subregion( - Region::init_ram_region(host_mmap.clone(), "sysmem"), - host_mmap.start_address().raw_value(), - ) - .unwrap(); - sys_space - } fn init_default_block() -> Block { Block::new( - BlkDevConfig::default(), + VirtioBlkDevConfig::default(), + DriveConfig::default(), Arc::new(Mutex::new(HashMap::new())), ) } + #[test] + fn test_virtio_block_config_cmdline_parser() { + // Test1: Right. + let blk_cmd1 = "virtio-blk-pci,id=rootfs,bus=pcie.0,addr=0x1.0x2,drive=rootfs,serial=111111,num-queues=4"; + let blk_config = + VirtioBlkDevConfig::try_parse_from(str_slip_to_clap(blk_cmd1, true, false)).unwrap(); + assert_eq!(blk_config.id, "rootfs"); + assert_eq!(blk_config.bus.unwrap(), "pcie.0"); + assert_eq!(blk_config.addr.unwrap(), (1, 2)); + assert_eq!(blk_config.serial.unwrap(), "111111"); + assert_eq!(blk_config.num_queues.unwrap(), 4); + + // Test2: Default values. + assert_eq!(blk_config.queue_size, DEFAULT_VIRTQUEUE_SIZE); + + // Test3: Illegal values. + let blk_cmd3 = "virtio-blk-pci,id=rootfs,bus=pcie.0,addr=0x1.0x2,drive=rootfs,serial=111111,num-queues=33"; + let result = VirtioBlkDevConfig::try_parse_from(str_slip_to_clap(blk_cmd3, true, false)); + assert!(result.is_err()); + let blk_cmd3 = "virtio-blk-pci,id=rootfs,drive=rootfs,serial=111111111111111111111111111111111111111111111111111111111111111111111"; + let result = VirtioBlkDevConfig::try_parse_from(str_slip_to_clap(blk_cmd3, true, false)); + assert!(result.is_err()); + } + // Use different input parameters to verify block `new()` and `realize()` functionality. #[test] fn test_block_init() { @@ -1410,16 +1502,16 @@ mod tests { assert!(block.senders.is_empty()); // Realize block device: create TempFile as backing file. - block.blk_cfg.read_only = true; - block.blk_cfg.direct = false; + block.drive_cfg.readonly = true; + block.drive_cfg.direct = false; let f = TempFile::new().unwrap(); - block.blk_cfg.path_on_host = f.as_path().to_str().unwrap().to_string(); + block.drive_cfg.path_on_host = f.as_path().to_str().unwrap().to_string(); VmConfig::add_drive_file( &mut block.drive_files.lock().unwrap(), "", - &block.blk_cfg.path_on_host, - block.blk_cfg.read_only, - block.blk_cfg.direct, + &block.drive_cfg.path_on_host, + block.drive_cfg.readonly, + block.drive_cfg.direct, ) .unwrap(); assert!(block.realize().is_ok()); @@ -1470,14 +1562,14 @@ mod tests { let page = 0_u32; block.set_driver_features(page, driver_feature); assert_eq!(block.base.driver_features, 0_u64); - assert_eq!(block.driver_features(page) as u64, 0_u64); + assert_eq!(u64::from(block.driver_features(page)), 0_u64); assert_eq!(block.device_features(0_u32), 0_u32); let driver_feature: u32 = 0xFF; let page = 1_u32; block.set_driver_features(page, driver_feature); assert_eq!(block.base.driver_features, 0_u64); - assert_eq!(block.driver_features(page) as u64, 0_u64); + assert_eq!(u64::from(block.driver_features(page)), 0_u64); assert_eq!(block.device_features(1_u32), 0_u32); // If both the device feature bit and the front-end driver feature bit are @@ -1492,7 +1584,7 @@ mod tests { (1_u64 << VIRTIO_F_RING_INDIRECT_DESC) ); assert_eq!( - block.driver_features(page) as u64, + u64::from(block.driver_features(page)), (1_u64 << VIRTIO_F_RING_INDIRECT_DESC) ); assert_eq!( @@ -1517,18 +1609,18 @@ mod tests { fn test_serial_num_config() { let serial_num = "fldXlNNdCeqMvoIfEFogBxlL"; let serial_num_arr = serial_num.as_bytes(); - let id_bytes = get_serial_num_config(&serial_num); + let id_bytes = get_serial_num_config(serial_num); assert_eq!(id_bytes[..], serial_num_arr[..20]); assert_eq!(id_bytes.len(), 20); let serial_num = "7681194149"; let serial_num_arr = serial_num.as_bytes(); - let id_bytes = get_serial_num_config(&serial_num); + let id_bytes = get_serial_num_config(serial_num); assert_eq!(id_bytes[..10], serial_num_arr[..]); assert_eq!(id_bytes.len(), 20); let serial_num = ""; - let id_bytes_temp = get_serial_num_config(&serial_num); + let id_bytes_temp = get_serial_num_config(serial_num); assert_eq!(id_bytes_temp[..], [0; 20]); assert_eq!(id_bytes_temp.len(), 20); } @@ -1537,29 +1629,31 @@ mod tests { // io request will be handled by this thread. #[test] fn test_iothread() { + TempCleaner::object_init(); let thread_name = "io1".to_string(); // spawn io thread let io_conf = IothreadConfig { + classtype: "iothread".to_string(), id: thread_name.clone(), }; EventLoop::object_init(&Some(vec![io_conf])).unwrap(); let mut block = init_default_block(); let file = TempFile::new().unwrap(); - block.blk_cfg.path_on_host = file.as_path().to_str().unwrap().to_string(); - block.blk_cfg.direct = false; + block.drive_cfg.path_on_host = file.as_path().to_str().unwrap().to_string(); + block.drive_cfg.direct = false; // config iothread and iops block.blk_cfg.iothread = Some(thread_name); - block.blk_cfg.iops = Some(100); + block.drive_cfg.iops = Some(100); VmConfig::add_drive_file( &mut block.drive_files.lock().unwrap(), "", - &block.blk_cfg.path_on_host, - block.blk_cfg.read_only, - block.blk_cfg.direct, + &block.drive_cfg.path_on_host, + block.drive_cfg.readonly, + block.drive_cfg.direct, ) .unwrap(); @@ -1572,7 +1666,7 @@ mod tests { VirtioInterruptType::Config => VIRTIO_MMIO_INT_CONFIG, VirtioInterruptType::Vring => VIRTIO_MMIO_INT_VRING, }; - interrupt_status.fetch_or(status as u32, Ordering::SeqCst); + interrupt_status.fetch_or(status, Ordering::SeqCst); interrupt_evt .write(1) .with_context(|| VirtioError::EventFdWrite)?; @@ -1583,14 +1677,23 @@ mod tests { let mut queue_config = QueueConfig::new(DEFAULT_VIRTQUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - mem_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress(16 * DEFAULT_VIRTQUEUE_SIZE as u64); - queue_config.addr_cache.avail_ring_host = - mem_space.get_host_address(queue_config.avail_ring).unwrap(); - queue_config.used_ring = GuestAddress(32 * DEFAULT_VIRTQUEUE_SIZE as u64); - queue_config.addr_cache.used_ring_host = - mem_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + mem_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(16 * u64::from(DEFAULT_VIRTQUEUE_SIZE)); + queue_config.addr_cache.avail_ring_host = unsafe { + mem_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; + queue_config.used_ring = GuestAddress(32 * u64::from(DEFAULT_VIRTQUEUE_SIZE)); + queue_config.addr_cache.used_ring_host = unsafe { + mem_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.size = DEFAULT_VIRTQUEUE_SIZE; queue_config.ready = true; @@ -1610,7 +1713,11 @@ mod tests { next: 1, }; mem_space - .write_object::(&desc, GuestAddress(queue_config.desc_table.0)) + .write_object::( + &desc, + GuestAddress(queue_config.desc_table.0), + AddressAttr::Ram, + ) .unwrap(); // write RequestOutHeader to first desc @@ -1620,7 +1727,7 @@ mod tests { sector: 0, }; mem_space - .write_object::(&req_head, GuestAddress(0x100)) + .write_object::(&req_head, GuestAddress(0x100), AddressAttr::Ram) .unwrap(); // making the second descriptor entry to receive data from device @@ -1633,18 +1740,27 @@ mod tests { mem_space .write_object::( &desc, - GuestAddress(queue_config.desc_table.0 + 16 as u64), + GuestAddress(queue_config.desc_table.0 + 16_u64), + AddressAttr::Ram, ) .unwrap(); // write avail_ring idx mem_space - .write_object::(&0, GuestAddress(queue_config.avail_ring.0 + 4 as u64)) + .write_object::( + &0, + GuestAddress(queue_config.avail_ring.0 + 4_u64), + AddressAttr::Ram, + ) .unwrap(); // write avail_ring id mem_space - .write_object::(&1, GuestAddress(queue_config.avail_ring.0 + 2 as u64)) + .write_object::( + &1, + GuestAddress(queue_config.avail_ring.0 + 2_u64), + AddressAttr::Ram, + ) .unwrap(); // imitating guest OS to send notification. @@ -1662,11 +1778,16 @@ mod tests { // get used_ring data let idx = mem_space - .read_object::(GuestAddress(queue_config.used_ring.0 + 2 as u64)) + .read_object::( + GuestAddress(queue_config.used_ring.0 + 2_u64), + AddressAttr::Ram, + ) .unwrap(); if idx == 1 { break; } } + TempCleaner::clean(); + EventLoop::loop_clean(); } } diff --git a/virtio/src/device/gpu.rs b/virtio/src/device/gpu.rs index 97ca67c14f97d847f617a7496c250ab5ddbb5565..9cc884485ccfa521ac46e5483b3baecdfdebb1c5 100644 --- a/virtio/src/device/gpu.rs +++ b/virtio/src/device/gpu.rs @@ -18,15 +18,16 @@ use std::sync::{Arc, Mutex, Weak}; use std::{ptr, vec}; use anyhow::{anyhow, bail, Context, Result}; +use clap::{ArgAction, Parser}; use log::{error, info, warn}; use vmm_sys_util::{epoll::EventSet, eventfd::EventFd}; use crate::{ - check_config_space_rw, gpa_hva_iovec_map, iov_discard_front, iov_to_buf, read_config_default, - ElemIovec, Element, Queue, VirtioBase, VirtioDevice, VirtioDeviceQuirk, VirtioError, - VirtioInterrupt, VirtioInterruptType, VIRTIO_F_RING_EVENT_IDX, VIRTIO_F_RING_INDIRECT_DESC, - VIRTIO_F_VERSION_1, VIRTIO_GPU_CMD_GET_DISPLAY_INFO, VIRTIO_GPU_CMD_GET_EDID, - VIRTIO_GPU_CMD_MOVE_CURSOR, VIRTIO_GPU_CMD_RESOURCE_ATTACH_BACKING, + check_config_space_rw, gpa_hva_iovec_map, iov_discard_front, iov_read_object, + read_config_default, ElemIovec, Element, Queue, VirtioBase, VirtioDevice, VirtioDeviceQuirk, + VirtioError, VirtioInterrupt, VirtioInterruptType, VIRTIO_F_RING_EVENT_IDX, + VIRTIO_F_RING_INDIRECT_DESC, VIRTIO_F_VERSION_1, VIRTIO_GPU_CMD_GET_DISPLAY_INFO, + VIRTIO_GPU_CMD_GET_EDID, VIRTIO_GPU_CMD_MOVE_CURSOR, VIRTIO_GPU_CMD_RESOURCE_ATTACH_BACKING, VIRTIO_GPU_CMD_RESOURCE_CREATE_2D, VIRTIO_GPU_CMD_RESOURCE_DETACH_BACKING, VIRTIO_GPU_CMD_RESOURCE_FLUSH, VIRTIO_GPU_CMD_RESOURCE_UNREF, VIRTIO_GPU_CMD_SET_SCANOUT, VIRTIO_GPU_CMD_TRANSFER_TO_HOST_2D, VIRTIO_GPU_CMD_UPDATE_CURSOR, VIRTIO_GPU_FLAG_FENCE, @@ -36,13 +37,14 @@ use crate::{ VIRTIO_GPU_RESP_OK_EDID, VIRTIO_GPU_RESP_OK_NODATA, VIRTIO_TYPE_GPU, }; use address_space::{AddressSpace, FileBackend, GuestAddress}; -use machine_manager::config::{GpuDevConfig, DEFAULT_VIRTQUEUE_SIZE, VIRTIO_GPU_MAX_OUTPUTS}; +use machine_manager::config::{get_pci_df, valid_id, DEFAULT_VIRTQUEUE_SIZE}; use machine_manager::event_loop::{register_event_helper, unregister_event_helper}; use migration_derive::ByteCode; use ui::console::{ console_close, console_init, display_cursor_define, display_graphic_update, display_replace_surface, display_set_major_screen, get_run_stage, set_run_stage, ConsoleType, DisplayConsole, DisplayMouse, DisplaySurface, HardWareOperations, VmRunningStage, + DEFAULT_CURSOR_BPP, DEFAULT_CURSOR_HEIGHT, DEFAULT_CURSOR_WIDTH, }; use ui::pixman::{ create_pixman_image, get_image_data, get_image_format, get_image_height, get_image_stride, @@ -51,6 +53,7 @@ use ui::pixman::{ use util::aio::{iov_from_buf_direct, iov_to_buf_direct, Iovec}; use util::byte_code::ByteCode; use util::edid::EdidInfo; +use util::gen_base_func; use util::loop_context::{ read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, }; @@ -72,6 +75,49 @@ const VIRTIO_GPU_RES_WIN_FRAMEBUF: u32 = 0x80000000; const VIRTIO_GPU_RES_EFI_FRAMEBUF: u32 = 0x40000000; const VIRTIO_GPU_RES_FRAMEBUF: u32 = VIRTIO_GPU_RES_WIN_FRAMEBUF | VIRTIO_GPU_RES_EFI_FRAMEBUF; +/// The maximum number of outputs. +const VIRTIO_GPU_MAX_OUTPUTS: usize = 16; +/// The default maximum memory 256M. +const VIRTIO_GPU_DEFAULT_MAX_HOSTMEM: u64 = 0x10000000; + +#[derive(Parser, Clone, Debug, Default)] +#[command(no_binary_name(true))] +pub struct GpuDevConfig { + #[arg(long, value_parser = ["virtio-gpu-pci"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: String, + #[arg(long, value_parser = get_pci_df)] + pub addr: (u8, u8), + #[arg(long, alias = "max_outputs", default_value="1", value_parser = clap::value_parser!(u32).range(1..=VIRTIO_GPU_MAX_OUTPUTS as i64))] + pub max_outputs: u32, + #[arg(long, default_value="true", action = ArgAction::Append)] + pub edid: bool, + #[arg(long, default_value = "1024")] + pub xres: u32, + #[arg(long, default_value = "768")] + pub yres: u32, + // The default max_hostmem is 256M. + #[arg(long, alias = "max_hostmem", default_value="268435456", value_parser = clap::value_parser!(u64).range(1..))] + pub max_hostmem: u64, + #[arg(long, alias = "enable_bar0", default_value="false", action = ArgAction::Append)] + pub enable_bar0: bool, +} + +impl GpuDevConfig { + pub fn check(&self) { + if self.max_hostmem < VIRTIO_GPU_DEFAULT_MAX_HOSTMEM { + warn!( + "max_hostmem should >= {}, allocating less than it may cause \ + the GPU to fail to start or refresh.", + VIRTIO_GPU_DEFAULT_MAX_HOSTMEM + ); + } + } +} + #[derive(Debug)] struct GpuResource { resource_id: u32, @@ -373,13 +419,7 @@ impl VirtioGpuRequest { ); } - let mut header = VirtioGpuCtrlHdr::default(); - iov_to_buf(mem_space, &elem.out_iovec, header.as_mut_bytes()).and_then(|size| { - if size < size_of::() { - bail!("Invalid header for gpu request: len {}.", size) - } - Ok(()) - })?; + let header = iov_read_object::(mem_space, &elem.out_iovec, &None)?; // Size of out_iovec is no less than size of VirtioGpuCtrlHdr, so // it is possible to get none back. @@ -387,8 +427,8 @@ impl VirtioGpuRequest { iov_discard_front(&mut elem.out_iovec, size_of::() as u64) .unwrap_or_default(); - let (out_len, out_iovec) = gpa_hva_iovec_map(data_iovec, mem_space)?; - let (in_len, in_iovec) = gpa_hva_iovec_map(&elem.in_iovec, mem_space)?; + let (out_len, out_iovec) = gpa_hva_iovec_map(data_iovec, mem_space, &None)?; + let (in_len, in_iovec) = gpa_hva_iovec_map(&elem.in_iovec, mem_space, &None)?; // Note: in_iov and out_iov total len is no more than 1<<32, and // out_iov is more than 1, so in_len and out_len will not overflow. @@ -631,8 +671,8 @@ pub fn cal_image_hostmem(format: u32, width: u32, height: u32) -> (Option } }; let bpp = pixman_format_bpp(pixman_format as u32); - let stride = ((width as u64 * bpp as u64 + 0x1f) >> 5) * (size_of::() as u64); - match stride.checked_mul(height as u64) { + let stride = ((u64::from(width) * u64::from(bpp) + 0x1f) >> 5) * (size_of::() as u64); + match stride.checked_mul(u64::from(height)) { None => { error!( "stride * height is overflow: width {} height {} stride {} bpp {}", @@ -675,7 +715,8 @@ impl GpuIoHandler { } fn get_request(&mut self, header: &VirtioGpuRequest, req: &mut T) -> Result<()> { - iov_to_buf_direct(&header.out_iovec, 0, req.as_mut_bytes()).and_then(|size| { + // SAFETY: out_iovec is generated by address_space. + unsafe { iov_to_buf_direct(&header.out_iovec, 0, req.as_mut_bytes()) }.and_then(|size| { if size == size_of::() { Ok(()) } else { @@ -687,20 +728,14 @@ impl GpuIoHandler { fn complete_one_request(&mut self, index: u16, len: u32) -> Result<()> { let mut queue_lock = self.ctrl_queue.lock().unwrap(); - queue_lock - .vring - .add_used(&self.mem_space, index, len) - .with_context(|| { - format!( - "Failed to add used ring(gpu ctrl), index {}, len {}", - index, len, - ) - })?; + queue_lock.vring.add_used(index, len).with_context(|| { + format!( + "Failed to add used ring(gpu ctrl), index {}, len {}", + index, len, + ) + })?; - if queue_lock - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue_lock.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue_lock), false) .with_context(|| "Failed to trigger interrupt(gpu ctrl)")?; trace::virtqueue_send_interrupt("Gpu", &*queue_lock as *const _ as u64); @@ -721,7 +756,8 @@ impl GpuIoHandler { header.ctx_id = req.header.ctx_id; } - let len = iov_from_buf_direct(&req.in_iovec, resp.as_bytes())?; + // SAFETY: in_iovec is generated by address_space. + let len = unsafe { iov_from_buf_direct(&req.in_iovec, resp.as_bytes())? }; if len != size_of::() { error!( "GuestError: An incomplete response will be used instead of the expected: expected \ @@ -758,6 +794,17 @@ impl GpuIoHandler { let scanout = &mut self.scanouts[scanout_id]; display_replace_surface(&scanout.con, None) .unwrap_or_else(|e| error!("Error occurs during surface switching: {:?}", e)); + + let mouse = DisplayMouse { + height: DEFAULT_CURSOR_WIDTH as u32, + width: DEFAULT_CURSOR_HEIGHT as u32, + hot_x: 0, + hot_y: 0, + data: vec![0_u8; DEFAULT_CURSOR_WIDTH * DEFAULT_CURSOR_HEIGHT * DEFAULT_CURSOR_BPP], + }; + display_cursor_define(&scanout.con, &mouse) + .unwrap_or_else(|e| error!("Error occurs during display_cursor_define: {:?}", e)); + scanout.clear(); } @@ -1133,7 +1180,7 @@ impl GpuIoHandler { } let pixman_format = get_image_format(res.pixman_image); - let bpp = (pixman_format_bpp(pixman_format as u32) as u32 + 8 - 1) / 8; + let bpp = (u32::from(pixman_format_bpp(pixman_format as u32)) + 8 - 1) / 8; let pixman_stride = get_image_stride(res.pixman_image); let offset = info_set_scanout.rect.x_coord * bpp + info_set_scanout.rect.y_coord * pixman_stride as u32; @@ -1182,6 +1229,12 @@ impl GpuIoHandler { scanout.width = info_set_scanout.rect.width; scanout.height = info_set_scanout.rect.height; + if (self.driver_features & (1 << VIRTIO_GPU_F_EDID)) == 0 + && (info_set_scanout.resource_id & VIRTIO_GPU_RES_WIN_FRAMEBUF) != 0 + { + self.change_run_stage()?; + } + self.response_nodata(VIRTIO_GPU_RESP_OK_NODATA, req) } @@ -1256,10 +1309,10 @@ impl GpuIoHandler { let extents = pixman_region_extents(final_reg_ptr); display_graphic_update( &scanout.con, - (*extents).x1 as i32, - (*extents).y1 as i32, - ((*extents).x2 - (*extents).x1) as i32, - ((*extents).y2 - (*extents).y1) as i32, + i32::from((*extents).x1), + i32::from((*extents).y1), + i32::from((*extents).x2 - (*extents).x1), + i32::from((*extents).y2 - (*extents).y1), )?; pixman_region_fini(rect_reg_ptr); pixman_region_fini(final_reg_ptr); @@ -1314,12 +1367,13 @@ impl GpuIoHandler { let res = &mut self.resources_list[res_idx]; let pixman_format = get_image_format(res.pixman_image); let width = get_image_width(res.pixman_image) as u32; - let bpp = (pixman_format_bpp(pixman_format as u32) as u32 + 8 - 1) / 8; + let bpp = (u32::from(pixman_format_bpp(pixman_format as u32)) + 8 - 1) / 8; let stride = get_image_stride(res.pixman_image) as u32; let data = get_image_data(res.pixman_image).cast() as *mut u8; if res.format == VIRTIO_GPU_FORMAT_MONOCHROME { - let v = iov_to_buf_direct(&res.iov, 0, &mut res.monochrome_cursor)?; + // SAFETY: iov is generated by address_space. + let v = unsafe { iov_to_buf_direct(&res.iov, 0, &mut res.monochrome_cursor)? }; if v != res.monochrome_cursor.len() { error!("No enough data is copied for transfer_to_host_2d with monochrome"); } @@ -1332,7 +1386,8 @@ impl GpuIoHandler { let trans_size = (trans_info.rect.height * stride) as usize; // SAFETY: offset_dst and trans_size do not exceeds data size. let dst = unsafe { from_raw_parts_mut(data.add(offset_dst), trans_size) }; - iov_to_buf_direct(&res.iov, trans_info.offset, dst).map(|v| { + // SAFETY: iov is generated by address_space. + unsafe { iov_to_buf_direct(&res.iov, trans_info.offset, dst) }.map(|v| { if v < trans_size { warn!("No enough data is copied for transfer_to_host_2d"); } @@ -1349,7 +1404,8 @@ impl GpuIoHandler { for _ in 0..trans_info.rect.height { // SAFETY: offset_dst and line_size do not exceeds data size. let dst = unsafe { from_raw_parts_mut(data.add(offset_dst), line_size) }; - iov_to_buf_direct(&res.iov, offset_src as u64, dst).map(|v| { + // SAFETY: iov is generated by address_space. + unsafe { iov_to_buf_direct(&res.iov, offset_src as u64, dst) }.map(|v| { if v < line_size { warn!("No enough data is copied for transfer_to_host_2d"); } @@ -1407,9 +1463,9 @@ impl GpuIoHandler { } let entries = info_attach_backing.nr_entries; - let ents_size = size_of::() as u64 * entries as u64; + let ents_size = size_of::() as u64 * u64::from(entries); let head_size = size_of::() as u64; - if (req.out_len as u64) < (ents_size + head_size) { + if u64::from(req.out_len) < (ents_size + head_size) { error!( "GuestError: The nr_entries {} in resource attach backing request is larger than total len {}.", info_attach_backing.nr_entries, req.out_len, @@ -1424,7 +1480,8 @@ impl GpuIoHandler { let ents_buf = // SAFETY: ents is guaranteed not be null and the range of ents_size has been limited. unsafe { from_raw_parts_mut(ents.as_mut_ptr() as *mut u8, ents_size as usize) }; - let v = iov_to_buf_direct(&req.out_iovec, head_size, ents_buf)?; + // SAFETY: out_iovec is generated by address_space. + let v = unsafe { iov_to_buf_direct(&req.out_iovec, head_size, ents_buf)? }; if v as u64 != ents_size { error!( "Virtio-GPU: Load no enough ents buf when attach backing, {} vs {}", @@ -1440,7 +1497,7 @@ impl GpuIoHandler { len: ent.length, }); } - match gpa_hva_iovec_map(&elemiovec, &self.mem_space) { + match gpa_hva_iovec_map(&elemiovec, &self.mem_space, &None) { Ok((_, iov)) => { res.iov = iov; self.response_nodata(VIRTIO_GPU_RESP_OK_NODATA, req) @@ -1552,17 +1609,11 @@ impl GpuIoHandler { } }; - queue - .vring - .add_used(&self.mem_space, elem.index, 0) - .with_context(|| { - format!("Failed to add used ring(cursor), index {}", elem.index) - })?; + queue.vring.add_used(elem.index, 0).with_context(|| { + format!("Failed to add used ring(cursor), index {}", elem.index) + })?; - if queue - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue), false) .with_context(|| { VirtioError::InterruptTrigger("gpu cursor", VirtioInterruptType::Vring) @@ -1683,13 +1734,7 @@ impl Gpu { } impl VirtioDevice for Gpu { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn device_quirk(&self) -> Option { if self.cfg.enable_bar0 { @@ -1843,3 +1888,50 @@ impl VirtioDevice for Gpu { result } } + +#[cfg(test)] +mod tests { + use super::*; + use machine_manager::config::str_slip_to_clap; + + #[test] + fn test_parse_virtio_gpu_pci_cmdline() { + // Test1: Right. + let gpu_cmd = "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0,max_outputs=5,edid=false,\ + xres=2048,yres=800,enable_bar0=true,max_hostmem=268435457"; + let gpu_cfg = GpuDevConfig::try_parse_from(str_slip_to_clap(gpu_cmd, true, false)).unwrap(); + assert_eq!(gpu_cfg.id, "gpu_1"); + assert_eq!(gpu_cfg.bus, "pcie.0"); + assert_eq!(gpu_cfg.addr, (4, 0)); + assert_eq!(gpu_cfg.max_outputs, 5); + assert_eq!(gpu_cfg.xres, 2048); + assert_eq!(gpu_cfg.yres, 800); + assert!(!gpu_cfg.edid); + assert_eq!(gpu_cfg.max_hostmem, 268435457); + assert!(gpu_cfg.enable_bar0); + + // Test2: Default. + let gpu_cmd2 = "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0"; + let gpu_cfg = + GpuDevConfig::try_parse_from(str_slip_to_clap(gpu_cmd2, true, false)).unwrap(); + assert_eq!(gpu_cfg.max_outputs, 1); + assert_eq!(gpu_cfg.xres, 1024); + assert_eq!(gpu_cfg.yres, 768); + assert!(gpu_cfg.edid); + assert_eq!(gpu_cfg.max_hostmem, VIRTIO_GPU_DEFAULT_MAX_HOSTMEM); + assert!(!gpu_cfg.enable_bar0); + + // Test3/4: max_outputs is illegal. + let gpu_cmd3 = "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0,max_outputs=17"; + let result = GpuDevConfig::try_parse_from(str_slip_to_clap(gpu_cmd3, true, false)); + assert!(result.is_err()); + let gpu_cmd4 = "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0,max_outputs=0"; + let result = GpuDevConfig::try_parse_from(str_slip_to_clap(gpu_cmd4, true, false)); + assert!(result.is_err()); + + // Test5: max_hostmem is illegal. + let gpu_cmd5 = "virtio-gpu-pci,id=gpu_1,bus=pcie.0,addr=0x4.0x0,max_hostmem=0"; + let result = GpuDevConfig::try_parse_from(str_slip_to_clap(gpu_cmd5, true, false)); + assert!(result.is_err()); + } +} diff --git a/virtio/src/device/input.rs b/virtio/src/device/input.rs new file mode 100644 index 0000000000000000000000000000000000000000..721b4e3df9692e5c9a83cc65c97ec98086ff0c7d --- /dev/null +++ b/virtio/src/device/input.rs @@ -0,0 +1,634 @@ +// Copyright (c) 2025 Huawei Technologies Co.,Ltd. All rights reserved. +// +// StratoVirt is licensed under Mulan PSL v2. +// You can use this software according to the terms and conditions of the Mulan +// PSL v2. +// You may obtain a copy of Mulan PSL v2 at: +// http://license.coscl.org.cn/MulanPSL2 +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +// See the Mulan PSL v2 for more details. + +use std::collections::BTreeMap; +use std::fs::{File, OpenOptions}; +use std::io::{Read, Write}; +use std::mem::size_of; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::prelude::OpenOptionsExt; +use std::rc::Rc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use anyhow::{anyhow, bail, Context, Result}; +use clap::{ArgAction, Parser}; +use libc::c_int; +use log::{error, warn}; +use vmm_sys_util::{epoll::EventSet, eventfd::EventFd}; + +use crate::{ + check_config_space_rw, error::*, iov_read_object, read_config_default, report_virtio_error, + Queue, VirtioBase, VirtioDevice, VirtioInterrupt, VirtioInterruptType, VIRTIO_F_VERSION_1, + VIRTIO_TYPE_INPUT, +}; +use address_space::AddressSpace; +use machine_manager::{ + config::{get_pci_df, parse_bool, valid_id, DEFAULT_VIRTQUEUE_SIZE}, + event_loop::{register_event_helper, unregister_event_helper}, +}; +use util::byte_code::ByteCode; +use util::evdev::*; +use util::loop_context::{ + read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, +}; + +/// Unset select cfg. +const VIRTIO_INPUT_CFG_UNSET: u8 = 0x00; +/// Returns the name of the device +const VIRTIO_INPUT_CFG_ID_NAME: u8 = 0x01; +/// Returns the serial number of the device. +const VIRTIO_INPUT_CFG_ID_SERIAL: u8 = 0x02; +/// Returns ID information of the device. +const VIRTIO_INPUT_CFG_ID_DEVIDS: u8 = 0x03; +/// Returns input properties of the device. +const VIRTIO_INPUT_CFG_PROP_BITS: u8 = 0x10; +/// subsel specifies the event type using EV_* constants in the underlying evdev implementation. +const VIRTIO_INPUT_CFG_EV_BITS: u8 = 0x11; +/// subsel specifies the absolute axis using ABS_* constants in the underlying evdev implementation. +const VIRTIO_INPUT_CFG_ABS_INFO: u8 = 0x12; + +/// Number of virtqueues. +const QUEUE_NUM_INPUT: usize = 2; + +#[derive(Parser, Debug, Clone, Default)] +#[command(no_binary_name(true))] +pub struct InputConfig { + #[arg(long, value_parser = ["virtio-input-device", "virtio-input-pci"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: Option, + #[arg(long, value_parser=get_pci_df)] + pub addr: Option<(u8, u8)>, + #[arg(long, value_parser=parse_bool, action = ArgAction::Append)] + pub multifunction: Option, + #[arg(long)] + pub evdev: String, +} + +#[derive(Copy, Clone, Default)] +#[repr(C)] +struct virtio_input_device_ids { + bustype: [u8; size_of::()], + vendor: [u8; size_of::()], + product: [u8; size_of::()], + version: [u8; size_of::()], +} + +impl virtio_input_device_ids { + fn from_evdevid(evdev_id: EvdevId) -> Self { + Self { + bustype: evdev_id.bustype.to_le_bytes(), + vendor: evdev_id.vendor.to_le_bytes(), + product: evdev_id.product.to_le_bytes(), + version: evdev_id.version.to_le_bytes(), + } + } +} + +impl ByteCode for virtio_input_device_ids {} + +#[derive(Copy, Clone, Default)] +#[repr(C)] +struct VirtioInputAbsInfo { + min: [u8; size_of::()], + max: [u8; size_of::()], + fuzz: [u8; size_of::()], + flat: [u8; size_of::()], +} + +impl VirtioInputAbsInfo { + fn from_absinfo(absinfo: InputAbsInfo) -> Self { + Self { + min: absinfo.minimum.to_le_bytes(), + max: absinfo.maximum.to_le_bytes(), + fuzz: absinfo.fuzz.to_le_bytes(), + flat: absinfo.flat.to_le_bytes(), + } + } +} + +impl ByteCode for VirtioInputAbsInfo {} + +#[repr(C)] +#[derive(Copy, Clone)] +struct VirtioInputConfig { + select: u8, + subsel: u8, + size: u8, + reserved: [u8; 5], + payload: [u8; VIRTIO_INPUT_CFG_PAYLOAD_SIZE], +} + +impl VirtioInputConfig { + fn new() -> Self { + Self { + select: VIRTIO_INPUT_CFG_UNSET, + subsel: 0, + size: 0, + reserved: [0_u8; 5], + payload: [0_u8; VIRTIO_INPUT_CFG_PAYLOAD_SIZE], + } + } + + fn set_payload(&mut self, payload: &[u8]) { + let len = (&mut self.payload[..]).write(payload).unwrap(); + self.size = len as u8; + } +} + +impl Default for VirtioInputConfig { + fn default() -> Self { + Self::new() + } +} + +impl ByteCode for VirtioInputConfig {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct VirtioInputEvent { + ev_type: [u8; size_of::()], + code: [u8; size_of::()], + value: [u8; size_of::()], +} + +impl VirtioInputEvent { + fn to_evt(self) -> InputEvent { + use byteorder::{ByteOrder, LittleEndian}; + InputEvent { + timestamp: [0_u64, 2], + ev_type: LittleEndian::read_u16(&self.ev_type), + code: LittleEndian::read_u16(&self.code), + value: LittleEndian::read_i32(&self.value), + } + } + + fn from_evt(evt: &InputEvent) -> Self { + Self { + ev_type: evt.ev_type.to_le_bytes(), + code: evt.code.to_le_bytes(), + value: evt.value.to_le_bytes(), + } + } +} + +impl ByteCode for VirtioInputEvent {} + +struct EvdevConfig { + /// config select + select: u8, + /// config sub select + subsel: u8, + /// ID information of the device + device_ids: virtio_input_device_ids, + /// Name of the device + name: Vec, + /// Serial of the device + serial: Vec, + /// Properties of the device + properties: EvdevBuf, + /// Events supported of the device + event_supported: BTreeMap, + /// Axis information of the device + abs_info: BTreeMap, +} + +impl EvdevConfig { + fn new(fd: &File) -> Result { + if evdev_ioctl(fd, EVIOCGVERSION(), size_of::()).len == 0 { + bail!("It's not an evdev device"); + } + + let id = EvdevId::from_buf(evdev_ioctl(fd, EVIOCGID(), size_of::())); + Ok(Self { + select: VIRTIO_INPUT_CFG_UNSET, + subsel: 0, + device_ids: virtio_input_device_ids::from_evdevid(id), + name: evdev_ioctl(fd, EVIOCGNAME(), 0).to_vec(), + serial: evdev_ioctl(fd, EVIOCGUNIQ(), 0).to_vec(), + properties: evdev_ioctl(fd, EVIOCGPROP(), 0), + event_supported: evdev_evt_supported(fd)?, + abs_info: evdev_abs(fd)?, + }) + } + + fn get_device_config(&self) -> VirtioInputConfig { + let mut cfg = VirtioInputConfig { + select: self.select, + subsel: self.subsel, + ..Default::default() + }; + + match self.select { + VIRTIO_INPUT_CFG_ID_NAME => { + cfg.set_payload(self.name.as_slice()); + } + VIRTIO_INPUT_CFG_ID_SERIAL => { + cfg.set_payload(self.serial.as_slice()); + } + VIRTIO_INPUT_CFG_ID_DEVIDS => { + cfg.set_payload(self.device_ids.as_bytes()); + } + VIRTIO_INPUT_CFG_PROP_BITS => { + cfg.set_payload(self.properties.to_vec().as_slice()); + } + VIRTIO_INPUT_CFG_EV_BITS => { + if let Some(bitmap) = self.event_supported.get(&self.subsel) { + cfg.set_payload(bitmap.as_bytes()); + } + } + VIRTIO_INPUT_CFG_ABS_INFO => { + if let Some(absinfo) = self.abs_info.get(&self.subsel) { + cfg.set_payload(VirtioInputAbsInfo::from_absinfo(*absinfo).as_bytes()); + } + } + VIRTIO_INPUT_CFG_UNSET => {} + _ => { + log::warn!("select type {} is not supported", self.select); + } + } + cfg + } +} + +struct InputIoHandler { + /// The features of driver + driver_features: u64, + /// Address space + mem_space: Arc, + /// event queue. + event_queue: Arc>, + /// event queue EventFd + event_queue_evt: Arc, + /// status queue + status_queue: Arc>, + /// status queue EventFd + status_queue_evt: Arc, + /// Used to cache events + event_buf: Vec, + /// Device is broken or not + device_broken: Arc, + /// The interrupt call back function. + interrupt_cb: Arc, + /// fd of the evdev file + evdev_fd: Option>, +} + +impl InputIoHandler { + fn process_status_queue(&mut self) -> Result<()> { + let mut locked_status_queue = self.status_queue.lock().unwrap(); + loop { + let elem = locked_status_queue + .vring + .pop_avail(&self.mem_space, self.driver_features) + .with_context(|| "Failed to pop avail ring for process input status queue")?; + if elem.desc_num == 0 { + break; + } + let evt = iov_read_object::( + &self.mem_space, + &elem.out_iovec, + locked_status_queue.vring.get_cache(), + )? + .to_evt(); + match &self.evdev_fd.clone() { + Some(evdev_fd) => { + let _ = evdev_fd.as_ref().write(evt.as_bytes()); + } + None => {} + } + locked_status_queue + .vring + .add_used(elem.index, 0) + .with_context(|| { + format!( + "Failed to add input response into used status queue, index {}, len {}", + elem.index, 0 + ) + })?; + (self.interrupt_cb)( + &VirtioInterruptType::Vring, + Some(&locked_status_queue), + false, + ) + .with_context(|| VirtioError::InterruptTrigger("Input", VirtioInterruptType::Vring))?; + } + Ok(()) + } + + fn input_event_send(&mut self, evt: &InputEvent) -> Result<()> { + let mut locked_event_queue = self.event_queue.lock().unwrap(); + self.event_buf.push(VirtioInputEvent::from_evt(evt)); + if evt.ev_type != EV_SYN || evt.code != SYN_REPORT { + return Ok(()); + } + let mut event_index_list = Vec::new(); + for event in self.event_buf.iter() { + let elem = locked_event_queue + .vring + .pop_avail(&self.mem_space, self.driver_features) + .with_context(|| "Failed to pop avail ring for process input queue")?; + if elem.desc_num == 0 { + warn!("event queue buffer is full, drop current events"); + for _ in event_index_list.iter() { + locked_event_queue.vring.push_back(); + } + self.event_buf.clear(); + return Ok(()); + } + self.mem_space.write_object( + event, + elem.in_iovec[0].addr, + address_space::AddressAttr::Ram, + )?; + event_index_list.push(elem.index); + } + for index in event_index_list.iter() { + locked_event_queue + .vring + .add_used(*index, size_of::() as u32) + .with_context(|| "Failed to add input response into used queue")?; + } + (self.interrupt_cb)( + &VirtioInterruptType::Vring, + Some(&locked_event_queue), + false, + ) + .with_context(|| VirtioError::InterruptTrigger("input", VirtioInterruptType::Vring))?; + self.event_buf.clear(); + Ok(()) + } + + fn do_event(&mut self) { + let event_fd = &self.evdev_fd.clone().unwrap(); + loop { + let mut evt = InputEvent::default(); + match event_fd.as_ref().read(evt.as_mut_bytes()) { + Ok(sz) => { + if sz != size_of::() { + warn!("mismatch InputEvent length"); + return; + } + if let Err(e) = self.input_event_send(&evt) { + error!("Failed to send event: {:?}", e); + report_virtio_error( + self.interrupt_cb.clone(), + self.driver_features, + &self.device_broken, + ); + return; + } + } + Err(e) => { + error!("Failed to read event from evdev_fd: {:?}", e); + return; + } + } + } + } +} + +/// Create a new EventNotifier. +/// +/// # Arguments +/// +/// * `fd` - Raw file descriptor. +/// * `handler` - Handle function. +fn build_event_notifier(fd: RawFd, handler: Rc) -> EventNotifier { + EventNotifier::new( + NotifierOperation::AddShared, + fd, + None, + EventSet::IN, + vec![handler], + ) +} + +impl EventNotifierHelper for InputIoHandler { + fn internal_notifiers(input: Arc>) -> Vec { + let mut notifiers = Vec::new(); + let locked_input = input.lock().unwrap(); + // register event notifier for event queue. + let handler: Rc = Rc::new(move |_, fd: RawFd| { + read_fd(fd); + // Do nothing. + None + }); + notifiers.push(build_event_notifier( + locked_input.event_queue_evt.as_raw_fd(), + handler, + )); + // register event notifier for status queue. + let local_input = input.clone(); + let handler: Rc = Rc::new(move |_, fd: RawFd| { + read_fd(fd); + let mut locked_local_input = local_input.lock().unwrap(); + if locked_local_input.device_broken.load(Ordering::SeqCst) { + return None; + } + if locked_local_input.process_status_queue().is_err() { + report_virtio_error( + locked_local_input.interrupt_cb.clone(), + locked_local_input.driver_features, + &locked_local_input.device_broken, + ); + }; + None + }); + notifiers.push(build_event_notifier( + locked_input.status_queue_evt.as_raw_fd(), + handler, + )); + + // register evdev fd handler + if let Some(fd) = &locked_input.evdev_fd { + let local_input = input.clone(); + let handler: Rc = Rc::new(move |_, _| { + let mut locked_local_input = local_input.lock().unwrap(); + if locked_local_input.device_broken.load(Ordering::SeqCst) { + // The virtio-input device has broken, drop event + let event_fd = &locked_local_input.evdev_fd.clone().unwrap(); + let mut evt = InputEvent::default(); + let _ = event_fd.as_ref().read(evt.as_mut_bytes()); + return None; + } + locked_local_input.do_event(); + None + }); + notifiers.push(build_event_notifier(fd.as_raw_fd(), handler)); + }; + notifiers + } +} + +pub struct Input { + /// Virtio device base property. + base: VirtioBase, + /// Interrupt callback function. + interrupt_cb: Option>, + /// Input device config data. + evdev_cfg: EvdevConfig, + /// EventFd for device deactivate. + deactivate_evts: Vec, + /// Event file fd. + fd: Option>, +} + +impl Input { + pub fn new(option: InputConfig) -> Result { + let fd = OpenOptions::new() + .read(true) + .write(true) + .custom_flags(libc::O_NONBLOCK) + .open(option.evdev.clone()) + .with_context(|| { + format!( + "Open evdev {} failed({:?})", + option.evdev, + std::io::Error::last_os_error() + ) + })?; + let evdev_cfg = EvdevConfig::new(&fd)?; + Ok(Self { + base: VirtioBase::new(VIRTIO_TYPE_INPUT, QUEUE_NUM_INPUT, DEFAULT_VIRTQUEUE_SIZE), + interrupt_cb: None, + evdev_cfg, + deactivate_evts: Vec::new(), + fd: Some(Arc::new(fd)), + }) + } +} + +impl VirtioDevice for Input { + fn virtio_base(&self) -> &VirtioBase { + &self.base + } + + fn virtio_base_mut(&mut self) -> &mut VirtioBase { + &mut self.base + } + + fn realize(&mut self) -> Result<()> { + self.init_config_features() + } + + fn init_config_features(&mut self) -> Result<()> { + self.base.device_features = 1u64 << VIRTIO_F_VERSION_1; + Ok(()) + } + + fn read_config(&self, offset: u64, data: &mut [u8]) -> Result<()> { + let config = self.evdev_cfg.get_device_config(); + read_config_default(config.as_bytes(), offset, data) + } + + fn write_config(&mut self, offset: u64, data: &[u8]) -> Result<()> { + let mut config = self.evdev_cfg.get_device_config(); + let config_slice = config.as_mut_bytes(); + check_config_space_rw(config_slice, offset, data)?; + config_slice[(offset as usize)..(offset as usize + data.len())].copy_from_slice(data); + + self.evdev_cfg.select = config.select; + self.evdev_cfg.subsel = config.subsel; + Ok(()) + } + + fn activate( + &mut self, + mem_space: Arc, + interrupt_cb: Arc, + queue_evts: Vec>, + ) -> Result<()> { + let queues = &self.base.queues; + if queues.len() != self.queue_num() { + return Err(anyhow!(VirtioError::IncorrectQueueNum( + self.queue_num(), + queues.len() + ))); + } + + let event_queue = queues[0].clone(); + let event_queue_evt = queue_evts[0].clone(); + let status_queue = queues[1].clone(); + let status_queue_evt = queue_evts[1].clone(); + + self.interrupt_cb = Some(interrupt_cb.clone()); + let handler = InputIoHandler { + driver_features: self.base.driver_features, + mem_space, + event_queue, + event_queue_evt, + status_queue, + status_queue_evt, + event_buf: Vec::new(), + device_broken: self.base.broken.clone(), + interrupt_cb: interrupt_cb.clone(), + evdev_fd: self.fd.clone(), + }; + register_event_helper( + EventNotifierHelper::internal_notifiers(Arc::new(Mutex::new(handler))), + None, + &mut self.deactivate_evts, + ) + .with_context(|| "Failed to register input handler to Mainloop")?; + self.base.broken.store(false, Ordering::SeqCst); + Ok(()) + } + + fn deactivate(&mut self) -> Result<()> { + unregister_event_helper(None, &mut self.deactivate_evts) + } + + fn reset(&mut self) -> Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use machine_manager::config::str_slip_to_clap; + + #[test] + fn test_input_config_cmdline_parse() { + // Test1: virtio-input-device(mmio). + let input_cmd = "virtio-input-device,id=input0,evdev=/dev/input/event0"; + let input_config = + InputConfig::try_parse_from(str_slip_to_clap(input_cmd, true, false)).unwrap(); + assert_eq!(input_config.multifunction, None); + + // Test2: virtio-input-pci. + let input_cmd = "virtio-input-pci,bus=pcie.0,addr=0x1,id=input0,evdev=/dev/input/event0"; + let input_config = + InputConfig::try_parse_from(str_slip_to_clap(input_cmd, true, false)).unwrap(); + assert_eq!(input_config.bus.unwrap(), "pcie.0"); + assert_eq!(input_config.addr.unwrap(), (1, 0)); + assert_eq!(input_config.evdev, "/dev/input/event0"); + } + + #[test] + fn test_input_init() { + let input_config = InputConfig { + classtype: "virtio-input-pci".to_string(), + id: "input0".to_string(), + evdev: "/evdev/path".to_string(), + bus: Some("pcie.0".to_string()), + addr: Some((3, 0)), + ..Default::default() + }; + let input = Input::new(input_config); + assert!(input.is_err()); + } +} diff --git a/virtio/src/device/mod.rs b/virtio/src/device/mod.rs index 11ea8ed939819b24ac42233a3eaf260d3df21173..f8914b3a9974b28a0028be294763eb8e55f1e90a 100644 --- a/virtio/src/device/mod.rs +++ b/virtio/src/device/mod.rs @@ -14,7 +14,10 @@ pub mod balloon; pub mod block; #[cfg(feature = "virtio_gpu")] pub mod gpu; +pub mod input; pub mod net; +#[cfg(feature = "virtio_rng")] pub mod rng; +#[cfg(feature = "virtio_scsi")] pub mod scsi_cntlr; pub mod serial; diff --git a/virtio/src/device/net.rs b/virtio/src/device/net.rs index 4e605faa84bf55d21ef6e29f5f4a63a6a267618e..4d74508977aaeb3a09767d5956d76322eaa7a97c 100644 --- a/virtio/src/device/net.rs +++ b/virtio/src/device/net.rs @@ -11,26 +11,26 @@ // See the Mulan PSL v2 for more details. use std::collections::HashMap; -use std::io::ErrorKind; use std::os::unix::io::{AsRawFd, RawFd}; use std::path::Path; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::{channel, Receiver, Sender}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::{cmp, fs, mem}; use anyhow::{bail, Context, Result}; use byteorder::{ByteOrder, LittleEndian}; use log::{error, warn}; use once_cell::sync::Lazy; +use util::aio::Iovec; use vmm_sys_util::{epoll::EventSet, eventfd::EventFd}; use crate::{ - check_config_space_rw, iov_discard_front, iov_to_buf, mem_to_buf, read_config_default, - report_virtio_error, virtio_has_feature, ElemIovec, Element, Queue, VirtioBase, VirtioDevice, - VirtioError, VirtioInterrupt, VirtioInterruptType, VirtioNetHdr, VIRTIO_F_RING_EVENT_IDX, - VIRTIO_F_RING_INDIRECT_DESC, VIRTIO_F_VERSION_1, VIRTIO_NET_CTRL_MAC, + check_config_space_rw, gpa_hva_iovec_map, iov_discard_front, iov_to_buf, mem_to_buf, + read_config_default, report_virtio_error, virtio_has_feature, ElemIovec, Element, Queue, + VirtioBase, VirtioDevice, VirtioError, VirtioInterrupt, VirtioInterruptType, VirtioNetHdr, + VIRTIO_F_RING_EVENT_IDX, VIRTIO_F_RING_INDIRECT_DESC, VIRTIO_F_VERSION_1, VIRTIO_NET_CTRL_MAC, VIRTIO_NET_CTRL_MAC_ADDR_SET, VIRTIO_NET_CTRL_MAC_TABLE_SET, VIRTIO_NET_CTRL_MQ, VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MAX, VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MIN, VIRTIO_NET_CTRL_MQ_VQ_PAIRS_SET, VIRTIO_NET_CTRL_RX, VIRTIO_NET_CTRL_RX_ALLMULTI, @@ -43,11 +43,11 @@ use crate::{ VIRTIO_NET_F_HOST_TSO4, VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, VIRTIO_NET_F_MQ, VIRTIO_NET_OK, VIRTIO_TYPE_NET, }; -use address_space::{AddressSpace, RegionCache}; -use machine_manager::event_loop::{register_event_helper, unregister_event_helper}; -use machine_manager::{ - config::{ConfigCheck, NetworkInterfaceConfig}, - event_loop::EventLoop, +use address_space::{AddressAttr, AddressSpace}; +use machine_manager::config::{ConfigCheck, NetDevcfg, NetworkInterfaceConfig}; +use machine_manager::event_loop::{register_event_helper, unregister_event_helper, EventLoop}; +use machine_manager::state_query::{ + register_state_query_callback, unregister_state_query_callback, }; use migration::{ migration::Migratable, DeviceStateDesc, FieldDesc, MigrationHook, MigrationManager, @@ -55,9 +55,10 @@ use migration::{ }; use migration_derive::{ByteCode, Desc}; use util::byte_code::ByteCode; -use util::loop_context::gen_delete_notifiers; +use util::gen_base_func; use util::loop_context::{ - read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, + create_new_eventfd, read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, + NotifierOperation, }; use util::num_ops::str_to_num; use util::tap::{ @@ -222,7 +223,7 @@ impl CtrlInfo { data_iovec: &mut Vec, ) -> Result { let ack = VIRTIO_NET_OK; - let mut mac_table_len = 0; + let mut mac_table_len: usize = 0; // Default for unicast. let mut overflow = &mut self.mac_info.uni_mac_of; let mut mac_table = &mut self.mac_info.uni_mac_table; @@ -243,7 +244,7 @@ impl CtrlInfo { continue; } - let size = entries as u64 * MAC_ADDR_LEN as u64; + let size = u64::from(entries) * MAC_ADDR_LEN as u64; let res_len = Element::iovec_size(data_iovec); if size > res_len { bail!("Invalid request for setting mac table."); @@ -365,7 +366,7 @@ impl CtrlInfo { data_iovec: &mut Vec, ) -> u8 { let mut ack = VIRTIO_NET_OK; - if cmd as u16 == VIRTIO_NET_CTRL_MQ_VQ_PAIRS_SET { + if u16::from(cmd) == VIRTIO_NET_CTRL_MQ_VQ_PAIRS_SET { let mut queue_pairs: u16 = 0; *data_iovec = get_buf_and_discard(mem_space, data_iovec, queue_pairs.as_mut_bytes()) .unwrap_or_else(|e| { @@ -467,7 +468,7 @@ fn get_buf_and_discard( iovec: &mut [ElemIovec], buf: &mut [u8], ) -> Result> { - iov_to_buf(mem_space, iovec, buf).and_then(|size| { + iov_to_buf(mem_space, &None, iovec, buf).and_then(|size| { if size < buf.len() { error!("Invalid length {}, expected length {}", size, buf.len()); bail!("Invalid length {}, expected length {}", size, buf.len()); @@ -613,19 +614,17 @@ impl NetCtrlHandler { // Write result to the device writable iovec. let status = elem .in_iovec - .get(0) + .first() .with_context(|| "Failed to get device writable iovec")?; - self.mem_space.write_object::(&ack, status.addr)?; + self.mem_space + .write_object::(&ack, status.addr, AddressAttr::Ram)?; locked_queue .vring - .add_used(&self.mem_space, elem.index, mem::size_of_val(&ack) as u32) + .add_used(elem.index, mem::size_of_val(&ack) as u32) .with_context(|| format!("Failed to add used ring {}", elem.index))?; - if locked_queue - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if locked_queue.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&locked_queue), false) .with_context(|| { VirtioError::InterruptTrigger("ctrl", VirtioInterruptType::Vring) @@ -671,149 +670,75 @@ impl EventNotifierHelper for NetCtrlHandler { } } -struct TxVirtio { +struct RTxVirtio { queue: Arc>, queue_evt: Arc, } -impl TxVirtio { +impl RTxVirtio { fn new(queue: Arc>, queue_evt: Arc) -> Self { TxVirtio { queue, queue_evt } } } -struct RxVirtio { - queue_full: bool, - queue: Arc>, - queue_evt: Arc, - recv_evt: Arc, -} +type RxVirtio = RTxVirtio; +type TxVirtio = RTxVirtio; -impl RxVirtio { - fn new(queue: Arc>, queue_evt: Arc, recv_evt: Arc) -> Self { - RxVirtio { - queue_full: false, - queue, - queue_evt, - recv_evt, - } - } -} - -struct NetIoHandler { +struct NetIoQueue { rx: RxVirtio, tx: TxVirtio, - tap: Option, - tap_fd: RawFd, + ctrl_info: Arc>, mem_space: Arc, interrupt_cb: Arc, + listen_state: Arc>, driver_features: u64, - receiver: Receiver, - update_evt: Arc, - device_broken: Arc, - is_listening: bool, - ctrl_info: Arc>, queue_size: u16, } -impl NetIoHandler { - fn read_from_tap(iovecs: &[libc::iovec], tap: &mut Tap) -> i32 { - // SAFETY: the arguments of readv has been checked and is correct. - let size = unsafe { - libc::readv( - tap.as_raw_fd() as libc::c_int, - iovecs.as_ptr() as *const libc::iovec, - iovecs.len() as libc::c_int, - ) - } as i32; - if size < 0 { - let e = std::io::Error::last_os_error(); - if e.kind() == std::io::ErrorKind::WouldBlock { - return size; - } - - // If the backend tap device is removed, readv returns less than 0. - // At this time, the content in the tap needs to be cleaned up. - // Here, read is called to process, otherwise handle_rx may be triggered all the time. - let mut buf = [0; 1024]; - match tap.read(&mut buf) { - Ok(cnt) => error!("Failed to call readv but tap read is ok: cnt {}", cnt), - Err(e) => { - // When the backend tap device is abnormally removed, read return EBADFD. - error!("Failed to read tap: {:?}", e); - } - } - error!("Failed to call readv for net handle_rx: {:?}", e); - } - - size - } - - fn get_libc_iovecs( - mem_space: &Arc, - cache: &Option, - elem_iovecs: &[ElemIovec], - ) -> Vec { - let mut iovecs = Vec::new(); - for elem_iov in elem_iovecs.iter() { - // elem_iov.addr has been checked in pop_avail(). - let mut len = elem_iov.len; - let mut start = elem_iov.addr; - loop { - let io_vec = mem_space - .get_host_address_from_cache(start, cache) - .map(|(hva, fr_len)| libc::iovec { - iov_base: hva as *mut libc::c_void, - iov_len: std::cmp::min(elem_iov.len, fr_len as u32) as libc::size_t, - }) - .unwrap(); - start = start.unchecked_add(io_vec.iov_len as u64); - len -= io_vec.iov_len as u32; - iovecs.push(io_vec); - if len == 0 { - break; - } - } - } - iovecs - } - - fn handle_rx(&mut self) -> Result<()> { +impl NetIoQueue { + fn handle_rx(&self, tap: &Arc>>) -> Result<()> { trace::virtio_receive_request("Net".to_string(), "to rx".to_string()); - if self.tap.is_none() { + if tap.read().unwrap().is_none() { return Ok(()); } let mut queue = self.rx.queue.lock().unwrap(); - let mut rx_packets = 0; + let mut rx_packets: u16 = 0; loop { let elem = queue .vring .pop_avail(&self.mem_space, self.driver_features) .with_context(|| "Failed to pop avail ring for net rx")?; if elem.desc_num == 0 { - self.rx.queue_full = true; + queue + .vring + .suppress_queue_notify(self.driver_features, false) + .with_context(|| "Failed to enable rx queue notify")?; + self.listen_state.lock().unwrap().set_queue_avail(false); break; } else if elem.in_iovec.is_empty() { bail!("The length of in iovec is 0"); } - let iovecs = NetIoHandler::get_libc_iovecs( - &self.mem_space, - queue.vring.get_cache(), - &elem.in_iovec, - ); + let (_, iovecs) = + gpa_hva_iovec_map(&elem.in_iovec, &self.mem_space, queue.vring.get_cache())?; if MigrationManager::is_active() { // FIXME: mark dirty page needs to be managed by `AddressSpace` crate. for iov in iovecs.iter() { // Mark vmm dirty page manually if live migration is active. - MigrationManager::mark_dirty_log(iov.iov_base as u64, iov.iov_len as u64); + MigrationManager::mark_dirty_log(iov.iov_base, iov.iov_len); } } // Read the data from the tap device. - let size = NetIoHandler::read_from_tap(&iovecs, self.tap.as_mut().unwrap()); - if size < (NET_HDR_LENGTH + ETHERNET_HDR_LENGTH + VLAN_TAG_LENGTH) as i32 { + let locked_tap = tap.read().unwrap(); + let size = if locked_tap.is_some() { + locked_tap.as_ref().unwrap().receive_packets(&iovecs) + } else { + -1 + }; + drop(locked_tap); + if size < (NET_HDR_LENGTH + ETHERNET_HDR_LENGTH + VLAN_TAG_LENGTH) as isize { queue.vring.push_back(); break; } @@ -841,7 +766,7 @@ impl NetIoHandler { queue .vring - .add_used(&self.mem_space, elem.index, size as u32) + .add_used(elem.index, u32::try_from(size)?) .with_context(|| { format!( "Failed to add used ring for net rx, index: {}, len: {}", @@ -849,10 +774,7 @@ impl NetIoHandler { ) })?; - if queue - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue), false) .with_context(|| { VirtioError::InterruptTrigger("net", VirtioInterruptType::Vring) @@ -863,9 +785,9 @@ impl NetIoHandler { rx_packets += 1; if rx_packets >= self.queue_size { self.rx - .recv_evt + .queue_evt .write(1) - .with_context(|| "Failed to trigger tap queue event".to_string())?; + .with_context(|| "Failed to trigger rx queue event".to_string())?; break; } } @@ -873,35 +795,11 @@ impl NetIoHandler { Ok(()) } - fn send_packets(&self, tap_fd: libc::c_int, iovecs: &[libc::iovec]) -> i8 { - loop { - // SAFETY: the arguments of writev has been checked and is correct. - let size = unsafe { - libc::writev( - tap_fd, - iovecs.as_ptr() as *const libc::iovec, - iovecs.len() as libc::c_int, - ) - }; - if size < 0 { - let e = std::io::Error::last_os_error(); - match e.kind() { - ErrorKind::Interrupted => continue, - ErrorKind::WouldBlock => return -1_i8, - // Ignore other errors which can not be handled. - _ => error!("Failed to call writev for net handle_tx: {:?}", e), - } - } - break; - } - 0_i8 - } - - fn handle_tx(&mut self) -> Result<()> { + fn handle_tx(&self, tap: &Arc>>) -> Result<()> { trace::virtio_receive_request("Net".to_string(), "to tx".to_string()); let mut queue = self.tx.queue.lock().unwrap(); - let mut tx_packets = 0; + let mut tx_packets: u16 = 0; loop { let elem = queue .vring @@ -913,33 +811,26 @@ impl NetIoHandler { bail!("The length of out iovec is 0"); } - let iovecs = NetIoHandler::get_libc_iovecs( - &self.mem_space, - queue.vring.get_cache(), - &elem.out_iovec, - ); - let tap_fd = if let Some(tap) = self.tap.as_mut() { - tap.as_raw_fd() as libc::c_int - } else { - -1_i32 - }; - if tap_fd != -1 && self.send_packets(tap_fd, &iovecs) == -1 { + let (_, iovecs) = + gpa_hva_iovec_map(&elem.out_iovec, &self.mem_space, queue.vring.get_cache())?; + let locked_tap = tap.read().unwrap(); + if locked_tap.is_none() || locked_tap.as_ref().unwrap().send_packets(&iovecs) == -1 { queue.vring.push_back(); - self.tx.queue_evt.write(1).with_context(|| { - "Failed to trigger tx queue event when writev blocked".to_string() - })?; - return Ok(()); + queue + .vring + .suppress_queue_notify(self.driver_features, true) + .with_context(|| "Failed to suppress tx queue notify")?; + self.listen_state.lock().unwrap().set_tap_full(true); + break; } + drop(locked_tap); queue .vring - .add_used(&self.mem_space, elem.index, 0) + .add_used(elem.index, 0) .with_context(|| format!("Net tx: Failed to add used ring {}", elem.index))?; - if queue - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue), false) .with_context(|| { VirtioError::InterruptTrigger("net", VirtioInterruptType::Vring) @@ -958,49 +849,99 @@ impl NetIoHandler { Ok(()) } +} - fn update_evt_handler(net_io: &Arc>) -> Vec { - let mut locked_net_io = net_io.lock().unwrap(); - locked_net_io.tap = match locked_net_io.receiver.recv() { - Ok(tap) => tap, - Err(e) => { - error!("Failed to receive the tap {:?}", e); - None - } - }; - let old_tap_fd = locked_net_io.tap_fd; - locked_net_io.tap_fd = -1; - if let Some(tap) = locked_net_io.tap.as_ref() { - locked_net_io.tap_fd = tap.as_raw_fd(); +struct ListenState { + queue_avail: bool, + tap_full: bool, + is_listening: bool, + has_changed: bool, +} + +impl ListenState { + fn new() -> Self { + Self { + queue_avail: true, + tap_full: false, + is_listening: true, + has_changed: false, } + } - let mut notifiers_fds = vec![ - locked_net_io.update_evt.as_raw_fd(), - locked_net_io.rx.queue_evt.as_raw_fd(), - locked_net_io.rx.recv_evt.as_raw_fd(), - locked_net_io.tx.queue_evt.as_raw_fd(), - ]; - if old_tap_fd != -1 { - notifiers_fds.push(old_tap_fd); + fn set_tap_full(&mut self, value: bool) { + if self.tap_full == value { + return; + } + self.tap_full = value; + self.has_changed = true; + } + + fn set_queue_avail(&mut self, value: bool) { + if self.queue_avail == value { + return; + } + self.queue_avail = value; + self.has_changed = true; + } + + fn tap_fd_handler(&mut self, tap: &Tap) -> Vec { + let mut notifiers = Vec::new(); + + if !self.is_listening && (self.queue_avail || self.tap_full) { + notifiers.push(EventNotifier::new( + NotifierOperation::Resume, + tap.as_raw_fd(), + None, + EventSet::empty(), + Vec::new(), + )); + self.is_listening = true; + } + + if !self.is_listening { + return notifiers; } - let mut notifiers = gen_delete_notifiers(¬ifiers_fds); - drop(locked_net_io); - notifiers.append(&mut EventNotifierHelper::internal_notifiers(net_io.clone())); + // NOTE: We want to poll for OUT event when the tap is full, and for IN event when the + // virtio queue is available. + let tap_events = match (self.queue_avail, self.tap_full) { + (true, true) => EventSet::OUT | EventSet::IN | EventSet::EDGE_TRIGGERED, + (false, true) => EventSet::OUT | EventSet::EDGE_TRIGGERED, + (true, false) => EventSet::IN | EventSet::EDGE_TRIGGERED, + (false, false) => EventSet::empty(), + }; + + let tap_operation = if tap_events.is_empty() { + self.is_listening = false; + NotifierOperation::Park + } else { + NotifierOperation::Modify + }; + + notifiers.push(EventNotifier::new( + tap_operation, + tap.as_raw_fd(), + None, + tap_events, + Vec::new(), + )); notifiers } } -fn get_net_header(iovec: &[libc::iovec], buf: &mut [u8]) -> Result { +fn get_net_header(iovec: &[Iovec], buf: &mut [u8]) -> Result { let mut start: usize = 0; let mut end: usize = 0; for elem in iovec { end = start - .checked_add(elem.iov_len) + .checked_add(elem.iov_len as usize) .with_context(|| "Overflow when getting the net header")?; end = cmp::min(end, buf.len()); - mem_to_buf(&mut buf[start..end], elem.iov_base as u64)?; + // SAFETY: iovec is generated by address_space and len is not less than buf's. + unsafe { + mem_to_buf(&mut buf[start..end], elem.iov_base)?; + } if end >= buf.len() { break; } @@ -1022,147 +963,245 @@ fn build_event_notifier( EventNotifier::new(op, fd, None, event, handlers) } -impl EventNotifierHelper for NetIoHandler { - fn internal_notifiers(net_io: Arc>) -> Vec { - // Register event notifier for update_evt. - let locked_net_io = net_io.lock().unwrap(); - let cloned_net_io = net_io.clone(); +struct NetIoHandler { + /// The context name of iothread for tap and rx virtio queue. + /// Since we placed the handlers of RxVirtio, TxVirtio and tap_fd in different threads, + /// thread name is needed to change the monitoring status of tap_fd. + rx_iothread: Option, + /// Virtio queue used for net io. + net_queue: Arc, + /// The context of tap device. + tap: Arc>>, + /// Device is broken or not. + device_broken: Arc, + /// The receiver half of Rust's channel to recv tap information. + receiver: Receiver, + /// Eventfd for config space update. + update_evt: Arc, +} + +impl NetIoHandler { + fn update_evt_handler(&mut self) -> Result<()> { + let mut locked_tap = self.tap.write().unwrap(); + let old_tap_fd = if locked_tap.is_some() { + locked_tap.as_ref().unwrap().as_raw_fd() + } else { + -1 + }; + + *locked_tap = match self.receiver.recv() { + Ok(tap) => tap, + Err(e) => { + error!("Failed to receive the tap {:?}", e); + None + } + }; + drop(locked_tap); + + if old_tap_fd != -1 { + unregister_event_helper(self.rx_iothread.as_ref(), &mut vec![old_tap_fd])?; + } + if self.tap.read().unwrap().is_some() { + EventLoop::update_event(self.tap_notifier(), self.rx_iothread.as_ref())?; + } + Ok(()) + } + + /// Register event notifier for update_evt. + fn update_evt_notifier(&self, net_io: Arc>) -> Vec { + let device_broken = self.device_broken.clone(); let handler: Rc = Rc::new(move |_, fd: RawFd| { read_fd(fd); - if cloned_net_io - .lock() - .unwrap() - .device_broken - .load(Ordering::SeqCst) - { + + if device_broken.load(Ordering::SeqCst) { return None; } - Some(NetIoHandler::update_evt_handler(&cloned_net_io)) + + if let Err(e) = net_io.lock().unwrap().update_evt_handler() { + error!("Update net events failed: {:?}", e); + } + + None }); - let mut notifiers = vec![build_event_notifier( - locked_net_io.update_evt.as_raw_fd(), + let notifiers = vec![build_event_notifier( + self.update_evt.as_raw_fd(), Some(handler), NotifierOperation::AddShared, EventSet::IN, )]; + notifiers + } - // Register event notifier for rx. - let cloned_net_io = net_io.clone(); + /// Register event notifier for rx. + fn rx_virtio_notifier(&self) -> Vec { + let net_queue = self.net_queue.clone(); + let device_broken = self.device_broken.clone(); + let tap = self.tap.clone(); + let rx_iothread = self.rx_iothread.as_ref().cloned(); let handler: Rc = Rc::new(move |_, fd: RawFd| { read_fd(fd); - let mut locked_net_io = cloned_net_io.lock().unwrap(); - if locked_net_io.device_broken.load(Ordering::SeqCst) { + + if device_broken.load(Ordering::SeqCst) { return None; } - if let Err(ref e) = locked_net_io.rx.recv_evt.write(1) { - error!("Failed to trigger tap receive event, {:?}", e); + net_queue.listen_state.lock().unwrap().set_queue_avail(true); + let mut locked_queue = net_queue.rx.queue.lock().unwrap(); + + if let Err(ref err) = locked_queue + .vring + .suppress_queue_notify(net_queue.driver_features, true) + { + error!("Failed to suppress rx queue notify: {:?}", err); report_virtio_error( - locked_net_io.interrupt_cb.clone(), - locked_net_io.driver_features, - &locked_net_io.device_broken, + net_queue.interrupt_cb.clone(), + net_queue.driver_features, + &device_broken, + ); + return None; + }; + + drop(locked_queue); + + if let Err(ref err) = net_queue.handle_rx(&tap) { + error!("Failed to handle receive queue event: {:?}", err); + report_virtio_error( + net_queue.interrupt_cb.clone(), + net_queue.driver_features, + &device_broken, ); + return None; } - if let Some(tap) = locked_net_io.tap.as_ref() { - if !locked_net_io.is_listening { - let notifier = vec![EventNotifier::new( - NotifierOperation::Resume, - tap.as_raw_fd(), - None, - EventSet::IN | EventSet::EDGE_TRIGGERED, - Vec::new(), - )]; - locked_net_io.is_listening = true; - locked_net_io.rx.queue_full = false; - return Some(notifier); - } + let mut locked_listen = net_queue.listen_state.lock().unwrap(); + let locked_tap = tap.read().unwrap(); + if locked_tap.is_none() || !locked_listen.has_changed { + return None; + } + + let notifiers = locked_listen.tap_fd_handler(locked_tap.as_ref().unwrap()); + locked_listen.has_changed = false; + drop(locked_tap); + drop(locked_listen); + + if let Err(e) = EventLoop::update_event(notifiers, rx_iothread.as_ref()) { + error!("Update tap notifiers failed in handle rx: {:?}", e); } None }); - let rx_fd = locked_net_io.rx.queue_evt.as_raw_fd(); - notifiers.push(build_event_notifier( + let rx_fd = self.net_queue.rx.queue_evt.as_raw_fd(); + let notifiers = vec![build_event_notifier( rx_fd, Some(handler), NotifierOperation::AddShared, EventSet::IN, - )); + )]; + notifiers + } - // Register event notifier for tx. - let cloned_net_io = net_io.clone(); + /// Register event notifier for tx. + fn tx_virtio_notifier(&self) -> Vec { + let net_queue = self.net_queue.clone(); + let device_broken = self.device_broken.clone(); + let tap = self.tap.clone(); + let rx_iothread = self.rx_iothread.as_ref().cloned(); let handler: Rc = Rc::new(move |_, fd: RawFd| { read_fd(fd); - let mut locked_net_io = cloned_net_io.lock().unwrap(); - if locked_net_io.device_broken.load(Ordering::SeqCst) { + + if device_broken.load(Ordering::SeqCst) { return None; } - if let Err(ref e) = locked_net_io.handle_tx() { + + if let Err(ref e) = net_queue.handle_tx(&tap) { error!("Failed to handle tx(tx event) for net, {:?}", e); report_virtio_error( - locked_net_io.interrupt_cb.clone(), - locked_net_io.driver_features, - &locked_net_io.device_broken, + net_queue.interrupt_cb.clone(), + net_queue.driver_features, + &device_broken, ); } + + let mut locked_listen = net_queue.listen_state.lock().unwrap(); + let locked_tap = tap.read().unwrap(); + if locked_tap.is_none() || !locked_listen.has_changed { + return None; + } + + let notifiers = locked_listen.tap_fd_handler(locked_tap.as_ref().unwrap()); + locked_listen.has_changed = false; + drop(locked_tap); + drop(locked_listen); + + if let Err(e) = EventLoop::update_event(notifiers, rx_iothread.as_ref()) { + error!("Update tap notifiers failed in handle tx: {:?}", e); + } + None }); - let tx_fd = locked_net_io.tx.queue_evt.as_raw_fd(); - notifiers.push(build_event_notifier( + let tx_fd = self.net_queue.tx.queue_evt.as_raw_fd(); + let notifiers = vec![build_event_notifier( tx_fd, Some(handler), NotifierOperation::AddShared, EventSet::IN, - )); + )]; + notifiers + } - // Register event notifier for tap. - let cloned_net_io = net_io.clone(); - if let Some(tap) = locked_net_io.tap.as_ref() { - let handler: Rc = Rc::new(move |_, _| { - let mut locked_net_io = cloned_net_io.lock().unwrap(); - if locked_net_io.device_broken.load(Ordering::SeqCst) { - return None; - } + /// Register event notifier for tap. + fn tap_notifier(&self) -> Vec { + let tap = self.tap.clone(); + let net_queue = self.net_queue.clone(); + let device_broken = self.device_broken.clone(); + let locked_tap = self.tap.read().unwrap(); + if locked_tap.is_none() { + return vec![]; + } + let handler: Rc = Rc::new(move |events: EventSet, _| { + if device_broken.load(Ordering::SeqCst) { + return None; + } - if let Err(ref e) = locked_net_io.handle_rx() { - error!("Failed to handle rx(tap event), {:?}", e); + if events.contains(EventSet::OUT) { + net_queue.listen_state.lock().unwrap().set_tap_full(false); + net_queue + .tx + .queue_evt + .write(1) + .unwrap_or_else(|e| error!("Failed to notify tx thread: {:?}", e)); + } + + if events.contains(EventSet::IN) { + if let Err(ref err) = net_queue.handle_rx(&tap) { + error!("Failed to handle receive queue event: {:?}", err); report_virtio_error( - locked_net_io.interrupt_cb.clone(), - locked_net_io.driver_features, - &locked_net_io.device_broken, + net_queue.interrupt_cb.clone(), + net_queue.driver_features, + &device_broken, ); return None; } + } - if let Some(tap) = locked_net_io.tap.as_ref() { - if locked_net_io.rx.queue_full && locked_net_io.is_listening { - let notifier = vec![EventNotifier::new( - NotifierOperation::Park, - tap.as_raw_fd(), - None, - EventSet::IN | EventSet::EDGE_TRIGGERED, - Vec::new(), - )]; - locked_net_io.is_listening = false; - return Some(notifier); - } - } - None - }); - let tap_fd = tap.as_raw_fd(); - notifiers.push(build_event_notifier( - tap_fd, - Some(handler.clone()), - NotifierOperation::AddShared, - EventSet::IN | EventSet::EDGE_TRIGGERED, - )); - let recv_evt_fd = locked_net_io.rx.recv_evt.as_raw_fd(); - notifiers.push(build_event_notifier( - recv_evt_fd, - Some(handler), - NotifierOperation::AddShared, - EventSet::IN | EventSet::EDGE_TRIGGERED, - )); - } + let mut locked_listen = net_queue.listen_state.lock().unwrap(); + let locked_tap = tap.read().unwrap(); + if !locked_listen.has_changed || locked_tap.is_none() { + return None; + } + let tap_notifiers = locked_listen.tap_fd_handler(locked_tap.as_ref().unwrap()); + locked_listen.has_changed = false; + drop(locked_tap); + drop(locked_listen); + + Some(tap_notifiers) + }); + let tap_fd = locked_tap.as_ref().unwrap().as_raw_fd(); + let notifiers = vec![build_event_notifier( + tap_fd, + Some(handler), + NotifierOperation::AddShared, + EventSet::IN | EventSet::EDGE_TRIGGERED, + )]; notifiers } @@ -1190,6 +1229,8 @@ pub struct Net { base: VirtioBase, /// Configuration of the network device. net_cfg: NetworkInterfaceConfig, + /// Configuration of the network device. + netdev_cfg: NetDevcfg, /// Virtio net configurations. config_space: Arc>, /// Tap device opened. @@ -1200,12 +1241,16 @@ pub struct Net { update_evts: Vec>, /// The information about control command. ctrl_info: Option>>, + /// The deactivate events for receiving. + rx_deactivate_evts: Vec, + /// The deactivate events for transporting. + tx_deactivate_evts: Vec, } impl Net { - pub fn new(net_cfg: NetworkInterfaceConfig) -> Self { + pub fn new(net_cfg: NetworkInterfaceConfig, netdev_cfg: NetDevcfg) -> Self { let queue_num = if net_cfg.mq { - (net_cfg.queues + 1) as usize + (netdev_cfg.queues + 1) as usize } else { QUEUE_NUM_NET }; @@ -1214,6 +1259,7 @@ impl Net { Self { base: VirtioBase::new(VIRTIO_TYPE_NET, queue_num, queue_size), net_cfg, + netdev_cfg, ..Default::default() } } @@ -1378,13 +1424,7 @@ fn get_tap_offload_flags(features: u64) -> u32 { } impl VirtioDevice for Net { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { // if iothread not found, return err @@ -1397,11 +1437,11 @@ impl VirtioDevice for Net { ); } - let queue_pairs = self.net_cfg.queues / 2; - if !self.net_cfg.host_dev_name.is_empty() { - self.taps = create_tap(None, Some(&self.net_cfg.host_dev_name), queue_pairs) + let queue_pairs = self.netdev_cfg.queues / 2; + if !self.netdev_cfg.ifname.is_empty() { + self.taps = create_tap(None, Some(&self.netdev_cfg.ifname), queue_pairs) .with_context(|| "Failed to open tap with file path")?; - } else if let Some(fds) = self.net_cfg.tap_fds.as_mut() { + } else if let Some(fds) = self.netdev_cfg.tap_fds.as_mut() { let mut created_fds = 0; if let Some(taps) = &self.taps { for (index, tap) in taps.iter().enumerate() { @@ -1419,6 +1459,21 @@ impl VirtioDevice for Net { self.taps = None; } + if let Some(ref taps) = self.taps { + for (idx, tap) in taps.iter().enumerate() { + let upload_stats = tap.upload_stats.clone(); + let download_stats = tap.download_stats.clone(); + register_state_query_callback( + format!("tap-{}", idx), + Arc::new(move || { + let upload = upload_stats.load(Ordering::SeqCst); + let download = download_stats.load(Ordering::SeqCst); + format!("upload: {} download: {}", upload, download) + }), + ) + } + } + self.init_config_features()?; Ok(()) @@ -1444,7 +1499,7 @@ impl VirtioDevice for Net { let mut locked_config = self.config_space.lock().unwrap(); - let queue_pairs = self.net_cfg.queues / 2; + let queue_pairs = self.netdev_cfg.queues / 2; if self.net_cfg.mq && (VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MIN..=VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MAX) .contains(&queue_pairs) @@ -1478,6 +1533,11 @@ impl VirtioDevice for Net { } fn unrealize(&mut self) -> Result<()> { + if let Some(ref taps) = self.taps { + for (idx, _) in taps.iter().enumerate() { + unregister_state_query_callback(&format!("tap-{}", idx)); + } + } mark_mac_table(&self.config_space.lock().unwrap().mac, false); MigrationManager::unregister_device_instance( VirtioNetState::descriptor(), @@ -1543,7 +1603,7 @@ impl VirtioDevice for Net { // The features about offload is included in bits 0 to 31. let features = self.driver_features(0_u32); - let flags = get_tap_offload_flags(features as u64); + let flags = get_tap_offload_flags(u64::from(features)); let mut senders = Vec::new(); let queue_pairs = queue_num / 2; @@ -1561,33 +1621,53 @@ impl VirtioDevice for Net { .with_context(|| "Failed to set tap offload")?; } - let update_evt = Arc::new(EventFd::new(libc::EFD_NONBLOCK)?); - let recv_evt = Arc::new(EventFd::new(libc::EFD_NONBLOCK)?); - let mut handler = NetIoHandler { - rx: RxVirtio::new(rx_queue, rx_queue_evt, recv_evt), + let update_evt = Arc::new(create_new_eventfd()?); + let net_queue = Arc::new(NetIoQueue { + rx: RxVirtio::new(rx_queue, rx_queue_evt), tx: TxVirtio::new(tx_queue, tx_queue_evt), - tap: self.taps.as_ref().map(|t| t[index].clone()), - tap_fd: -1, + ctrl_info: ctrl_info.clone(), mem_space: mem_space.clone(), interrupt_cb: interrupt_cb.clone(), driver_features, + listen_state: Arc::new(Mutex::new(ListenState::new())), + queue_size: self.queue_size_max(), + }); + let tap = Arc::new(RwLock::new(self.taps.as_ref().map(|t| t[index].clone()))); + let net_io = Arc::new(Mutex::new(NetIoHandler { + rx_iothread: self.net_cfg.rx_iothread.as_ref().cloned(), + net_queue, + tap, + device_broken: self.base.broken.clone(), receiver, update_evt: update_evt.clone(), - device_broken: self.base.broken.clone(), - is_listening: true, - ctrl_info: ctrl_info.clone(), - queue_size: self.queue_size_max(), - }; - if let Some(tap) = &handler.tap { - handler.tap_fd = tap.as_raw_fd(); - } - - let notifiers = EventNotifierHelper::internal_notifiers(Arc::new(Mutex::new(handler))); + })); + let cloned_net_io = net_io.clone(); + let locked_net_io = net_io.lock().unwrap(); + let update_evt_notifiers = locked_net_io.update_evt_notifier(cloned_net_io); + let rx_notifiers = locked_net_io.rx_virtio_notifier(); + let tx_notifiers = locked_net_io.tx_virtio_notifier(); + let tap_notifiers = locked_net_io.tap_notifier(); + drop(locked_net_io); register_event_helper( - notifiers, + update_evt_notifiers, self.net_cfg.iothread.as_ref(), &mut self.base.deactivate_evts, )?; + register_event_helper( + rx_notifiers, + self.net_cfg.rx_iothread.as_ref(), + &mut self.rx_deactivate_evts, + )?; + register_event_helper( + tap_notifiers, + self.net_cfg.rx_iothread.as_ref(), + &mut self.rx_deactivate_evts, + )?; + register_event_helper( + tx_notifiers, + self.net_cfg.tx_iothread.as_ref(), + &mut self.tx_deactivate_evts, + )?; self.update_evts.push(update_evt); } self.senders = Some(senders); @@ -1596,9 +1676,15 @@ impl VirtioDevice for Net { Ok(()) } - fn update_config(&mut self, dev_config: Option>) -> Result<()> { - if let Some(conf) = dev_config { - self.net_cfg = conf + // configs[0]: NetDevcfg. configs[1]: NetworkInterfaceConfig. + fn update_config(&mut self, dev_config: Vec>) -> Result<()> { + if dev_config.len() == 2 { + self.netdev_cfg = dev_config[0] + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + self.net_cfg = dev_config[1] .as_any() .downcast_ref::() .unwrap() @@ -1607,7 +1693,7 @@ impl VirtioDevice for Net { // Set tap offload. // The features about offload is included in bits 0 to 31. let features = self.driver_features(0_u32); - let flags = get_tap_offload_flags(features as u64); + let flags = get_tap_offload_flags(u64::from(features)); if let Some(taps) = &self.taps { for (_, tap) in taps.iter().enumerate() { tap.set_offload(flags) @@ -1653,10 +1739,28 @@ impl VirtioDevice for Net { self.net_cfg.iothread.as_ref(), &mut self.base.deactivate_evts, )?; + unregister_event_helper( + self.net_cfg.rx_iothread.as_ref(), + &mut self.rx_deactivate_evts, + )?; + unregister_event_helper( + self.net_cfg.tx_iothread.as_ref(), + &mut self.tx_deactivate_evts, + )?; self.update_evts.clear(); self.ctrl_info = None; Ok(()) } + + fn reset(&mut self) -> Result<()> { + if let Some(ref mut taps) = self.taps { + for tap in taps.iter_mut() { + tap.download_stats.store(0, Ordering::SeqCst); + tap.upload_stats.store(0, Ordering::SeqCst); + } + } + Ok(()) + } } // SAFETY: Send and Sync is not auto-implemented for `Sender` type. @@ -1698,21 +1802,21 @@ impl MigrationHook for Net {} #[cfg(test)] mod tests { - pub use super::*; + use super::*; #[test] fn test_net_init() { // test net new method - let mut net = Net::new(NetworkInterfaceConfig::default()); + let mut net = Net::new(NetworkInterfaceConfig::default(), NetDevcfg::default()); assert_eq!(net.base.device_features, 0); assert_eq!(net.base.driver_features, 0); - assert_eq!(net.taps.is_none(), true); - assert_eq!(net.senders.is_none(), true); - assert_eq!(net.net_cfg.mac.is_none(), true); - assert_eq!(net.net_cfg.tap_fds.is_none(), true); - assert_eq!(net.net_cfg.vhost_type.is_none(), true); - assert_eq!(net.net_cfg.vhost_fds.is_none(), true); + assert!(net.taps.is_none()); + assert!(net.senders.is_none()); + assert!(net.net_cfg.mac.is_none()); + assert!(net.netdev_cfg.tap_fds.is_none()); + assert!(net.netdev_cfg.vhost_type().is_none()); + assert!(net.netdev_cfg.vhost_fds.is_none()); // test net realize method net.realize().unwrap(); @@ -1740,25 +1844,25 @@ mod tests { let mut data: Vec = vec![0; 10]; let offset: u64 = len + 1; - assert_eq!(net.read_config(offset, &mut data).is_ok(), false); + assert!(net.read_config(offset, &mut data).is_err()); let offset: u64 = len; - assert_eq!(net.read_config(offset, &mut data).is_ok(), false); + assert!(net.read_config(offset, &mut data).is_err()); let offset: u64 = 0; - assert_eq!(net.read_config(offset, &mut data).is_ok(), true); + assert!(net.read_config(offset, &mut data).is_ok()); let offset: u64 = len; let mut data: Vec = vec![0; 1]; - assert_eq!(net.write_config(offset, &mut data).is_ok(), false); + assert!(net.write_config(offset, &mut data).is_err()); let offset: u64 = len - 1; let mut data: Vec = vec![0; 1]; - assert_eq!(net.write_config(offset, &mut data).is_ok(), false); + assert!(net.write_config(offset, &mut data).is_err()); let offset: u64 = 0; let mut data: Vec = vec![0; len as usize]; - assert_eq!(net.write_config(offset, &mut data).is_ok(), false); + assert!(net.write_config(offset, &mut data).is_err()); } #[test] @@ -1769,8 +1873,8 @@ mod tests { // Test create tap with net_fds and host_dev_name. let net_fds = vec![32, 33]; let tap_name = "tap0"; - if let Err(err) = create_tap(Some(&net_fds), Some(&tap_name), 1) { - let err_msg = format!("Failed to create tap, index is 0"); + if let Err(err) = create_tap(Some(&net_fds), Some(tap_name), 1) { + let err_msg = "Failed to create tap, index is 0".to_string(); assert_eq!(err.to_string(), err_msg); } else { assert!(false); @@ -1778,7 +1882,7 @@ mod tests { // Test create tap with empty net_fds. if let Err(err) = create_tap(Some(&vec![]), None, 1) { - let err_msg = format!("Failed to get fd from index 0"); + let err_msg = "Failed to get fd from index 0".to_string(); assert_eq!(err.to_string(), err_msg); } else { assert!(false); @@ -1787,7 +1891,7 @@ mod tests { // Test create tap with tap_name which is not exist. if let Err(err) = create_tap(None, Some("the_tap_is_not_exist"), 1) { let err_msg = - format!("Failed to create tap with name the_tap_is_not_exist, index is 0"); + "Failed to create tap with name the_tap_is_not_exist, index is 0".to_string(); assert_eq!(err.to_string(), err_msg); } else { assert!(false); @@ -1803,14 +1907,14 @@ mod tests { 0x00, 0x00, ]; // It has no vla vid, the packet is filtered. - assert_eq!(ctrl_info.filter_packets(&buf), true); + assert!(ctrl_info.filter_packets(&buf)); // It has valid vlan id, the packet is not filtered. let vid: u16 = 1023; buf[ETHERNET_HDR_LENGTH] = u16::to_be_bytes(vid)[0]; buf[ETHERNET_HDR_LENGTH + 1] = u16::to_be_bytes(vid)[1]; ctrl_info.vlan_map.insert(vid >> 5, 1 << (vid & 0x1f)); - assert_eq!(ctrl_info.filter_packets(&buf), false); + assert!(!ctrl_info.filter_packets(&buf)); } #[test] @@ -1818,12 +1922,12 @@ mod tests { let mut net_config = VirtioNetConfig::default(); // Parsing the normal mac address. let mac = "52:54:00:12:34:56"; - let ret = build_device_config_space(&mut net_config, &mac); + let ret = build_device_config_space(&mut net_config, mac); assert_eq!(ret, 1 << VIRTIO_NET_F_MAC); // Parsing the abnormale mac address. let mac = "52:54:00:12:34:"; - let ret = build_device_config_space(&mut net_config, &mac); + let ret = build_device_config_space(&mut net_config, mac); assert_eq!(ret, 0); } @@ -1864,7 +1968,9 @@ mod tests { #[test] fn test_iothread() { - let mut net = Net::new(NetworkInterfaceConfig::default()); + EventLoop::object_init(&None).unwrap(); + + let mut net = Net::new(NetworkInterfaceConfig::default(), NetDevcfg::default()); net.net_cfg.iothread = Some("iothread".to_string()); if let Err(err) = net.realize() { let err_msg = format!( @@ -1875,5 +1981,7 @@ mod tests { } else { assert!(false); } + + EventLoop::loop_clean(); } } diff --git a/virtio/src/device/rng.rs b/virtio/src/device/rng.rs index c44105241db0f450ff7ab1bf0f14968f61b6cdb0..68f17d84ac43a79b596115884e60f5cad9904801 100644 --- a/virtio/src/device/rng.rs +++ b/virtio/src/device/rng.rs @@ -18,7 +18,8 @@ use std::path::Path; use std::rc::Rc; use std::sync::{Arc, Mutex}; -use anyhow::{bail, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; +use clap::Parser; use log::error; use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::EventFd; @@ -28,9 +29,9 @@ use crate::{ ElemIovec, Queue, VirtioBase, VirtioDevice, VirtioInterrupt, VirtioInterruptType, VIRTIO_F_VERSION_1, VIRTIO_TYPE_RNG, }; -use address_space::AddressSpace; +use address_space::{AddressAttr, AddressSpace}; use machine_manager::{ - config::{RngConfig, DEFAULT_VIRTQUEUE_SIZE}, + config::{get_pci_df, valid_id, ConfigError, RngObjConfig, DEFAULT_VIRTQUEUE_SIZE}, event_loop::EventLoop, event_loop::{register_event_helper, unregister_event_helper}, }; @@ -38,6 +39,7 @@ use migration::{DeviceStateDesc, FieldDesc, MigrationHook, MigrationManager, Sta use migration_derive::{ByteCode, Desc}; use util::aio::raw_read; use util::byte_code::ByteCode; +use util::gen_base_func; use util::leak_bucket::LeakBucket; use util::loop_context::{ read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, @@ -46,6 +48,62 @@ use util::loop_context::{ const QUEUE_NUM_RNG: usize = 1; const RNG_SIZE_MAX: u32 = 1 << 20; +const MIN_BYTES_PER_SEC: u64 = 64; +const MAX_BYTES_PER_SEC: u64 = 1_000_000_000; + +/// Config structure for virtio-rng. +#[derive(Parser, Debug, Clone, Default)] +#[command(no_binary_name(true))] +pub struct RngConfig { + #[arg(long, value_parser = ["virtio-rng-device", "virtio-rng-pci"])] + pub classtype: String, + #[arg(long, default_value = "", value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub rng: String, + #[arg(long, alias = "max-bytes")] + pub max_bytes: Option, + #[arg(long)] + pub period: Option, + #[arg(long)] + pub bus: Option, + #[arg(long, value_parser = get_pci_df)] + pub addr: Option<(u8, u8)>, + #[arg(long)] + pub multifunction: Option, +} + +impl RngConfig { + pub fn bytes_per_sec(&self) -> Result> { + if self.max_bytes.is_some() != self.period.is_some() { + bail!("\"max_bytes\" and \"period\" should be configured or not configured Simultaneously."); + } + + if let Some(max) = self.max_bytes { + let peri = self.period.unwrap(); + let mul = max + .checked_mul(1000) + .with_context(|| format!("Illegal max-bytes arguments: {:?}", max))?; + let bytes_per_sec = mul + .checked_div(peri) + .with_context(|| format!("Illegal period arguments: {:?}", peri))?; + + if !(MIN_BYTES_PER_SEC..=MAX_BYTES_PER_SEC).contains(&bytes_per_sec) { + return Err(anyhow!(ConfigError::IllegalValue( + "The bytes per second of rng device".to_string(), + MIN_BYTES_PER_SEC, + true, + MAX_BYTES_PER_SEC, + true, + ))); + } + + return Ok(Some(bytes_per_sec)); + } + Ok(None) + } +} + fn get_req_data_size(in_iov: &[ElemIovec]) -> Result { let mut size = 0_u32; for iov in in_iov { @@ -81,7 +139,8 @@ impl RngHandler { .write( &mut buffer[offset..].as_ref(), iov.addr, - min(size - offset as u32, iov.len) as u64, + u64::from(min(size - offset as u32, iov.len)), + AddressAttr::Ram, ) .with_context(|| "Failed to write request data for virtio rng")?; @@ -108,19 +167,22 @@ impl RngHandler { get_req_data_size(&elem.in_iovec).with_context(|| "Failed to get request size")?; if let Some(leak_bucket) = self.leak_bucket.as_mut() { - if leak_bucket.throttled(EventLoop::get_ctx(None).unwrap(), size as u64) { + if leak_bucket.throttled(EventLoop::get_ctx(None).unwrap(), size) { queue_lock.vring.push_back(); break; } } let mut buffer = vec![0_u8; size as usize]; - let ret = raw_read( - self.random_file.as_raw_fd(), - buffer.as_mut_ptr() as u64, - size as usize, - 0, - ); + // SAFETY: buffer is valid and large enough. + let ret = unsafe { + raw_read( + self.random_file.as_raw_fd(), + buffer.as_mut_ptr() as u64, + size as usize, + 0, + ) + }; if ret < 0 { bail!("Failed to read random file, size: {}", size); } @@ -130,7 +192,7 @@ impl RngHandler { queue_lock .vring - .add_used(&self.mem_space, elem.index, size) + .add_used(elem.index, size) .with_context(|| { format!( "Failed to add used ring, index: {}, size: {}", @@ -216,34 +278,37 @@ pub struct RngState { pub struct Rng { /// Virtio device base property. base: VirtioBase, - /// Configuration of virtio rng device + /// Configuration of virtio rng device. rng_cfg: RngConfig, + /// Configuration of rng-random. + rngobj_cfg: RngObjConfig, /// The file descriptor of random number generator random_file: Option, } impl Rng { - pub fn new(rng_cfg: RngConfig) -> Self { + pub fn new(rng_cfg: RngConfig, rngobj_cfg: RngObjConfig) -> Self { Rng { base: VirtioBase::new(VIRTIO_TYPE_RNG, QUEUE_NUM_RNG, DEFAULT_VIRTQUEUE_SIZE), rng_cfg, + rngobj_cfg, ..Default::default() } } fn check_random_file(&self) -> Result<()> { - let path = Path::new(&self.rng_cfg.random_file); + let path = Path::new(&self.rngobj_cfg.filename); if !path.exists() { bail!( "The path of random file {} is not existed", - self.rng_cfg.random_file + self.rngobj_cfg.filename ); } if !path.metadata().unwrap().file_type().is_char_device() { bail!( "The type of random file {} is not a character special file", - self.rng_cfg.random_file + self.rngobj_cfg.filename ); } @@ -252,18 +317,12 @@ impl Rng { } impl VirtioDevice for Rng { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { self.check_random_file() .with_context(|| "Failed to check random file")?; - let file = File::open(&self.rng_cfg.random_file) + let file = File::open(&self.rngobj_cfg.filename) .with_context(|| "Failed to open file of random number generator")?; self.random_file = Some(file); self.init_config_features()?; @@ -271,7 +330,7 @@ impl VirtioDevice for Rng { } fn init_config_features(&mut self) -> Result<()> { - self.base.device_features = 1 << VIRTIO_F_VERSION_1 as u64; + self.base.device_features = 1 << u64::from(VIRTIO_F_VERSION_1); Ok(()) } @@ -308,7 +367,7 @@ impl VirtioDevice for Rng { .unwrap() .try_clone() .with_context(|| "Failed to clone random file for virtio rng")?, - leak_bucket: match self.rng_cfg.bytes_per_sec { + leak_bucket: match self.rng_cfg.bytes_per_sec()? { Some(bps) => Some(LeakBucket::new(bps)?), None => None, }, @@ -359,55 +418,69 @@ mod tests { use vmm_sys_util::tempfile::TempFile; use super::*; + use crate::tests::address_space_init; use crate::*; - use address_space::{AddressSpace, GuestAddress, HostMemMapping, Region}; - use machine_manager::config::{RngConfig, DEFAULT_VIRTQUEUE_SIZE}; + use address_space::AddressAttr; + use address_space::GuestAddress; + use machine_manager::config::{str_slip_to_clap, VmConfig, DEFAULT_VIRTQUEUE_SIZE}; const VIRTQ_DESC_F_NEXT: u16 = 0x01; const VIRTQ_DESC_F_WRITE: u16 = 0x02; - const SYSTEM_SPACE_SIZE: u64 = (1024 * 1024) as u64; - - // build dummy address space of vm - fn address_space_init() -> Arc { - let root = Region::init_container_region(1 << 36, "sysmem"); - let sys_space = AddressSpace::new(root, "sysmem", None).unwrap(); - let host_mmap = Arc::new( - HostMemMapping::new( - GuestAddress(0), - None, - SYSTEM_SPACE_SIZE, - None, - false, - false, - false, - ) - .unwrap(), - ); - sys_space - .root() - .add_subregion( - Region::init_ram_region(host_mmap.clone(), "sysmem"), - host_mmap.start_address().raw_value(), - ) - .unwrap(); - sys_space + + #[test] + fn test_rng_config_cmdline_parse() { + // Test1: Right rng-random. + let mut vm_config = VmConfig::default(); + assert!(vm_config + .add_object("rng-random,id=objrng0,filename=/path/to/random_file") + .is_ok()); + let rngobj_cfg = vm_config.object.rng_object.remove("objrng0").unwrap(); + assert_eq!(rngobj_cfg.filename, "/path/to/random_file"); + + // Test2: virtio-rng-device + let rng_cmd = "virtio-rng-device,rng=objrng0"; + let rng_config = RngConfig::try_parse_from(str_slip_to_clap(rng_cmd, true, false)).unwrap(); + assert_eq!(rng_config.bytes_per_sec().unwrap(), None); + assert_eq!(rng_config.multifunction, None); + + // Test3: virtio-rng-pci. + let rng_cmd = "virtio-rng-pci,bus=pcie.0,addr=0x1,rng=objrng0,max-bytes=1234,period=1000"; + let rng_config = RngConfig::try_parse_from(str_slip_to_clap(rng_cmd, true, false)).unwrap(); + assert_eq!(rng_config.bytes_per_sec().unwrap(), Some(1234)); + assert_eq!(rng_config.bus.unwrap(), "pcie.0"); + assert_eq!(rng_config.addr.unwrap(), (1, 0)); + + // Test4: Illegal max-bytes/period. + let rng_cmd = "virtio-rng-device,rng=objrng0,max-bytes=63,period=1000"; + let rng_config = RngConfig::try_parse_from(str_slip_to_clap(rng_cmd, true, false)).unwrap(); + assert!(rng_config.bytes_per_sec().is_err()); + + let rng_cmd = "virtio-rng-device,rng=objrng0,max-bytes=1000000001,period=1000"; + let rng_config = RngConfig::try_parse_from(str_slip_to_clap(rng_cmd, true, false)).unwrap(); + assert!(rng_config.bytes_per_sec().is_err()); } #[test] fn test_rng_init() { - let file = TempFile::new().unwrap(); - let random_file = file.as_path().to_str().unwrap().to_string(); + let rngobj_config = RngObjConfig { + classtype: "rng-random".to_string(), + id: "rng0".to_string(), + filename: "".to_string(), + }; let rng_config = RngConfig { - id: "".to_string(), - random_file: random_file.clone(), - bytes_per_sec: Some(64), + classtype: "virtio-rng-pci".to_string(), + rng: "rng0".to_string(), + max_bytes: Some(64), + period: Some(1000), + bus: Some("pcie.0".to_string()), + addr: Some((3, 0)), + ..Default::default() }; - let rng = Rng::new(rng_config); + let rng = Rng::new(rng_config, rngobj_config); assert!(rng.random_file.is_none()); assert_eq!(rng.base.driver_features, 0_u64); assert_eq!(rng.base.device_features, 0_u64); - assert_eq!(rng.rng_cfg.random_file, random_file); - assert_eq!(rng.rng_cfg.bytes_per_sec, Some(64)); + assert_eq!(rng.rng_cfg.bytes_per_sec().unwrap().unwrap(), 64); assert_eq!(rng.queue_num(), QUEUE_NUM_RNG); assert_eq!(rng.queue_size_max(), DEFAULT_VIRTQUEUE_SIZE); @@ -416,18 +489,9 @@ mod tests { #[test] fn test_rng_features() { - let random_file = TempFile::new() - .unwrap() - .as_path() - .to_str() - .unwrap() - .to_string(); - let rng_config = RngConfig { - id: "".to_string(), - random_file, - bytes_per_sec: Some(64), - }; - let mut rng = Rng::new(rng_config); + let rng_config = RngConfig::default(); + let rngobj_cfg = RngObjConfig::default(); + let mut rng = Rng::new(rng_config, rngobj_cfg); // If the device feature is 0, all driver features are not supported. rng.base.device_features = 0; @@ -435,30 +499,30 @@ mod tests { let page = 0_u32; rng.set_driver_features(page, driver_feature); assert_eq!(rng.base.driver_features, 0_u64); - assert_eq!(rng.driver_features(page) as u64, 0_u64); + assert_eq!(u64::from(rng.driver_features(page)), 0_u64); assert_eq!(rng.device_features(0_u32), 0_u32); let driver_feature: u32 = 0xFF; let page = 1_u32; rng.set_driver_features(page, driver_feature); assert_eq!(rng.base.driver_features, 0_u64); - assert_eq!(rng.driver_features(page) as u64, 0_u64); + assert_eq!(u64::from(rng.driver_features(page)), 0_u64); assert_eq!(rng.device_features(1_u32), 0_u32); // If both the device feature bit and the front-end driver feature bit are // supported at the same time, this driver feature bit is supported. rng.base.device_features = - 1_u64 << VIRTIO_F_VERSION_1 | 1_u64 << VIRTIO_F_RING_INDIRECT_DESC as u64; + 1_u64 << VIRTIO_F_VERSION_1 | 1_u64 << u64::from(VIRTIO_F_RING_INDIRECT_DESC); let driver_feature: u32 = 1_u32 << VIRTIO_F_RING_INDIRECT_DESC; let page = 0_u32; rng.set_driver_features(page, driver_feature); assert_eq!( rng.base.driver_features, - (1_u64 << VIRTIO_F_RING_INDIRECT_DESC as u64) + (1_u64 << u64::from(VIRTIO_F_RING_INDIRECT_DESC)) ); assert_eq!( - rng.driver_features(page) as u64, - (1_u64 << VIRTIO_F_RING_INDIRECT_DESC as u64) + u64::from(rng.driver_features(page)), + (1_u64 << u64::from(VIRTIO_F_RING_INDIRECT_DESC)) ); assert_eq!( rng.device_features(page), @@ -484,7 +548,7 @@ mod tests { len: u32::max_value(), }, ElemIovec { - addr: GuestAddress(u32::max_value() as u64), + addr: GuestAddress(u64::from(u32::max_value())), len: 1_u32, }, ]; @@ -498,7 +562,7 @@ mod tests { len, }, ElemIovec { - addr: GuestAddress(u32::max_value() as u64), + addr: GuestAddress(u64::from(u32::max_value())), len, }, ]; @@ -530,14 +594,23 @@ mod tests { let mut queue_config = QueueConfig::new(DEFAULT_VIRTQUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - mem_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress(16 * DEFAULT_VIRTQUEUE_SIZE as u64); - queue_config.addr_cache.avail_ring_host = - mem_space.get_host_address(queue_config.avail_ring).unwrap(); - queue_config.used_ring = GuestAddress(32 * DEFAULT_VIRTQUEUE_SIZE as u64); - queue_config.addr_cache.used_ring_host = - mem_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + mem_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(16 * u64::from(DEFAULT_VIRTQUEUE_SIZE)); + queue_config.addr_cache.avail_ring_host = unsafe { + mem_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; + queue_config.used_ring = GuestAddress(32 * u64::from(DEFAULT_VIRTQUEUE_SIZE)); + queue_config.addr_cache.used_ring_host = unsafe { + mem_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.size = DEFAULT_VIRTQUEUE_SIZE; queue_config.ready = true; @@ -561,15 +634,23 @@ mod tests { }; // write table descriptor for queue mem_space - .write_object(&desc, queue_config.desc_table) + .write_object(&desc, queue_config.desc_table, AddressAttr::Ram) .unwrap(); // write avail_ring idx mem_space - .write_object::(&0, GuestAddress(queue_config.avail_ring.0 + 4 as u64)) + .write_object::( + &0, + GuestAddress(queue_config.avail_ring.0 + 4_u64), + AddressAttr::Ram, + ) .unwrap(); // write avail_ring idx mem_space - .write_object::(&1, GuestAddress(queue_config.avail_ring.0 + 2 as u64)) + .write_object::( + &1, + GuestAddress(queue_config.avail_ring.0 + 2_u64), + AddressAttr::Ram, + ) .unwrap(); let buffer = vec![1_u8; data_len as usize]; @@ -580,13 +661,17 @@ mod tests { .read( &mut read_buffer.as_mut_slice(), GuestAddress(0x40000), - data_len as u64 + u64::from(data_len), + AddressAttr::Ram ) .is_ok()); assert_eq!(read_buffer, buffer); let idx = mem_space - .read_object::(GuestAddress(queue_config.used_ring.0 + 2 as u64)) + .read_object::( + GuestAddress(queue_config.used_ring.0 + 2_u64), + AddressAttr::Ram, + ) .unwrap(); assert_eq!(idx, 1); assert_eq!(cloned_interrupt_evt.read().unwrap(), 1); @@ -613,14 +698,23 @@ mod tests { let mut queue_config = QueueConfig::new(DEFAULT_VIRTQUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - mem_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress(16 * DEFAULT_VIRTQUEUE_SIZE as u64); - queue_config.addr_cache.avail_ring_host = - mem_space.get_host_address(queue_config.avail_ring).unwrap(); - queue_config.used_ring = GuestAddress(32 * DEFAULT_VIRTQUEUE_SIZE as u64); - queue_config.addr_cache.used_ring_host = - mem_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + mem_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(16 * u64::from(DEFAULT_VIRTQUEUE_SIZE)); + queue_config.addr_cache.avail_ring_host = unsafe { + mem_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; + queue_config.used_ring = GuestAddress(32 * u64::from(DEFAULT_VIRTQUEUE_SIZE)); + queue_config.addr_cache.used_ring_host = unsafe { + mem_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.size = DEFAULT_VIRTQUEUE_SIZE; queue_config.ready = true; @@ -644,7 +738,7 @@ mod tests { }; // write table descriptor for queue mem_space - .write_object(&desc, queue_config.desc_table) + .write_object(&desc, queue_config.desc_table, AddressAttr::Ram) .unwrap(); let desc = SplitVringDesc { @@ -658,16 +752,25 @@ mod tests { .write_object( &desc, GuestAddress(queue_config.desc_table.0 + size_of::() as u64), + AddressAttr::Ram, ) .unwrap(); // write avail_ring idx mem_space - .write_object::(&0, GuestAddress(queue_config.avail_ring.0 + 4 as u64)) + .write_object::( + &0, + GuestAddress(queue_config.avail_ring.0 + 4_u64), + AddressAttr::Ram, + ) .unwrap(); // write avail_ring idx mem_space - .write_object::(&1, GuestAddress(queue_config.avail_ring.0 + 2 as u64)) + .write_object::( + &1, + GuestAddress(queue_config.avail_ring.0 + 2_u64), + AddressAttr::Ram, + ) .unwrap(); let mut buffer1 = vec![1_u8; data_len as usize]; @@ -683,7 +786,8 @@ mod tests { .read( &mut read_buffer.as_mut_slice(), GuestAddress(0x40000), - data_len as u64 + u64::from(data_len), + AddressAttr::Ram ) .is_ok()); assert_eq!(read_buffer, buffer1_check); @@ -691,13 +795,17 @@ mod tests { .read( &mut read_buffer.as_mut_slice(), GuestAddress(0x50000), - data_len as u64 + u64::from(data_len), + AddressAttr::Ram ) .is_ok()); assert_eq!(read_buffer, buffer2_check); let idx = mem_space - .read_object::(GuestAddress(queue_config.used_ring.0 + 2 as u64)) + .read_object::( + GuestAddress(queue_config.used_ring.0 + 2_u64), + AddressAttr::Ram, + ) .unwrap(); assert_eq!(idx, 1); assert_eq!(cloned_interrupt_evt.read().unwrap(), 1); diff --git a/virtio/src/device/scsi_cntlr.rs b/virtio/src/device/scsi_cntlr.rs index d3d6cf427fffcf79566ebb199d072c268a9fd3eb..1c03b471d175782306e3f3ebf49694e77e525ed0 100644 --- a/virtio/src/device/scsi_cntlr.rs +++ b/virtio/src/device/scsi_cntlr.rs @@ -17,32 +17,42 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use anyhow::{bail, Context, Result}; +use clap::Parser; use log::{error, info, warn}; use vmm_sys_util::{epoll::EventSet, eventfd::EventFd}; use crate::{ - check_config_space_rw, gpa_hva_iovec_map, iov_discard_front, iov_to_buf, read_config_default, - report_virtio_error, Element, Queue, VirtioBase, VirtioDevice, VirtioError, VirtioInterrupt, - VirtioInterruptType, VIRTIO_F_RING_EVENT_IDX, VIRTIO_F_RING_INDIRECT_DESC, VIRTIO_F_VERSION_1, - VIRTIO_TYPE_SCSI, + check_config_space_rw, gpa_hva_iovec_map, iov_discard_front, iov_read_object, + read_config_default, report_virtio_error, Element, Queue, VirtioBase, VirtioDevice, + VirtioError, VirtioInterrupt, VirtioInterruptType, VIRTIO_F_RING_EVENT_IDX, + VIRTIO_F_RING_INDIRECT_DESC, VIRTIO_F_VERSION_1, VIRTIO_TYPE_SCSI, }; -use address_space::{AddressSpace, GuestAddress}; +use address_space::{AddressAttr, AddressSpace, GuestAddress}; use block_backend::BlockIoErrorCallback; use devices::ScsiBus::{ ScsiBus, ScsiRequest, ScsiRequestOps, ScsiSense, ScsiXferMode, CHECK_CONDITION, EMULATE_SCSI_OPS, SCSI_CMD_BUF_SIZE, SCSI_SENSE_INVALID_OPCODE, }; -use machine_manager::event_loop::{register_event_helper, unregister_event_helper}; -use machine_manager::{ - config::{ScsiCntlrConfig, VIRTIO_SCSI_MAX_LUN, VIRTIO_SCSI_MAX_TARGET}, - event_loop::EventLoop, +use devices::ScsiDisk::ScsiDevice; +use devices::{convert_device_ref, Bus, SCSI_DEVICE}; +use machine_manager::config::{ + get_pci_df, parse_bool, valid_block_device_virtqueue_size, valid_id, MAX_VIRTIO_QUEUE, }; +use machine_manager::event_loop::{register_event_helper, unregister_event_helper, EventLoop}; use util::aio::Iovec; use util::byte_code::ByteCode; +use util::gen_base_func; use util::loop_context::{ read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, }; +/// According to Virtio Spec. +/// Max_channel should be 0. +/// Max_target should be less than or equal to 255. +const VIRTIO_SCSI_MAX_TARGET: u16 = 255; +/// Max_lun should be less than or equal to 16383 (2^14 - 1). +const VIRTIO_SCSI_MAX_LUN: u32 = 16383; + /// Virtio Scsi Controller has 1 ctrl queue, 1 event queue and at least 1 cmd queue. const SCSI_CTRL_QUEUE_NUM: usize = 1; const SCSI_EVENT_QUEUE_NUM: usize = 1; @@ -88,6 +98,27 @@ const VIRTIO_SCSI_S_BAD_TARGET: u8 = 3; /// with a response equal to VIRTIO_SCSI_S_FAILURE. const VIRTIO_SCSI_S_FAILURE: u8 = 9; +#[derive(Parser, Debug, Clone, Default)] +#[command(no_binary_name(true))] +pub struct ScsiCntlrConfig { + #[arg(long, value_parser = ["virtio-scsi-pci"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: String, + #[arg(long, value_parser = get_pci_df)] + pub addr: (u8, u8), + #[arg(long, value_parser = parse_bool)] + pub multifunction: Option, + #[arg(long, alias = "num-queues", value_parser = clap::value_parser!(u32).range(1..=MAX_VIRTIO_QUEUE as i64))] + pub num_queues: Option, + #[arg(long)] + pub iothread: Option, + #[arg(long, alias = "queue-size", default_value = "256", value_parser = valid_block_device_virtqueue_size)] + pub queue_size: u16, +} + #[repr(C, packed)] #[derive(Copy, Clone, Debug, Default)] struct VirtioScsiConfig { @@ -121,7 +152,8 @@ pub struct ScsiCntlr { impl ScsiCntlr { pub fn new(config: ScsiCntlrConfig) -> ScsiCntlr { // Note: config.queues <= MAX_VIRTIO_QUEUE(32). - let queue_num = config.queues as usize + SCSI_CTRL_QUEUE_NUM + SCSI_EVENT_QUEUE_NUM; + let queue_num = + config.num_queues.unwrap() as usize + SCSI_CTRL_QUEUE_NUM + SCSI_EVENT_QUEUE_NUM; let queue_size = config.queue_size; Self { @@ -141,13 +173,7 @@ impl ScsiCntlr { } impl VirtioDevice for ScsiCntlr { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { // If iothread not found, return err. @@ -164,16 +190,15 @@ impl VirtioDevice for ScsiCntlr { } fn init_config_features(&mut self) -> Result<()> { - self.config_space.num_queues = self.config.queues; self.config_space.max_sectors = 0xFFFF_u32; // cmd_per_lun: maximum number of linked commands can be sent to one LUN. 32bit. self.config_space.cmd_per_lun = 128; // seg_max: queue size - 2, 32 bit. - self.config_space.seg_max = self.queue_size_max() as u32 - 2; + self.config_space.seg_max = u32::from(self.queue_size_max()) - 2; self.config_space.max_target = VIRTIO_SCSI_MAX_TARGET; - self.config_space.max_lun = VIRTIO_SCSI_MAX_LUN as u32; + self.config_space.max_lun = VIRTIO_SCSI_MAX_LUN; // num_queues: request queues number. - self.config_space.num_queues = self.config.queues; + self.config_space.num_queues = self.config.num_queues.unwrap(); self.base.device_features |= (1_u64 << VIRTIO_F_VERSION_1) | (1_u64 << VIRTIO_F_RING_EVENT_IDX) @@ -273,11 +298,11 @@ impl VirtioDevice for ScsiCntlr { // Register event notifier for device aio. let bus = self.bus.as_ref().unwrap(); let locked_bus = bus.lock().unwrap(); - for device in locked_bus.devices.values() { - let locked_device = device.lock().unwrap(); + for device in locked_bus.child_devices().values() { + SCSI_DEVICE!(device, locked_dev, scsi_dev); let err_cb = self.gen_error_cb(interrupt_cb.clone()); // SAFETY: the disk_image is assigned after device realized. - let disk_image = locked_device.block_backend.as_ref().unwrap(); + let disk_image = scsi_dev.block_backend.as_ref().unwrap(); let mut locked_backend = disk_image.lock().unwrap(); locked_backend.register_io_event(self.base.broken.clone(), err_cb)?; } @@ -291,10 +316,10 @@ impl VirtioDevice for ScsiCntlr { )?; let bus = self.bus.as_ref().unwrap(); let locked_bus = bus.lock().unwrap(); - for device in locked_bus.devices.values() { - let locked_dev = device.lock().unwrap(); + for device in locked_bus.child_devices().values() { + SCSI_DEVICE!(device, locked_dev, scsi_dev); // SAFETY: the disk_image is assigned after device realized. - let disk_image = locked_dev.block_backend.as_ref().unwrap(); + let disk_image = scsi_dev.block_backend.as_ref().unwrap(); let mut locked_backend = disk_image.lock().unwrap(); locked_backend.unregister_io_event()?; } @@ -456,43 +481,24 @@ impl VirtioScsiReq elem.desc_num ); } + let locked_queue = queue.lock().unwrap(); + let cache = locked_queue.vring.get_cache(); // Get request from virtqueue Element. - let mut req = T::default(); - iov_to_buf(mem_space, &elem.out_iovec, req.as_mut_bytes()).and_then(|size| { - if size < size_of::() { - bail!( - "Invalid length for request: get {}, expected {}", - size, - size_of::(), - ); - } - Ok(()) - })?; - + let req = iov_read_object::(mem_space, &elem.out_iovec, cache)?; // Get response from virtqueue Element. - let mut resp = U::default(); - iov_to_buf(mem_space, &elem.in_iovec, resp.as_mut_bytes()).and_then(|size| { - if size < size_of::() { - bail!( - "Invalid length for response: get {}, expected {}", - size, - size_of::(), - ); - } - Ok(()) - })?; + let resp = iov_read_object::(mem_space, &elem.in_iovec, cache)?; let mut request = VirtioScsiRequest { mem_space: mem_space.clone(), - queue, + queue: queue.clone(), desc_index: elem.index, iovec: Vec::with_capacity(elem.desc_num as usize), data_len: 0, mode: ScsiXferMode::ScsiXferNone, interrupt_cb, driver_features, - // Safety: in_iovec will not be empty since it has been checked after "iov_to_buf". + // Safety: in_iovec will not be empty since it has been checked after "iov_read_object". resp_addr: elem.in_iovec[0].addr, req, resp, @@ -501,12 +507,12 @@ impl VirtioScsiReq // Get possible dataout buffer from virtqueue Element. let mut iovec = elem.out_iovec.clone(); let elemiov = iov_discard_front(&mut iovec, size_of::() as u64).unwrap_or_default(); - let (out_len, out_iovec) = gpa_hva_iovec_map(elemiov, mem_space)?; + let (out_len, out_iovec) = gpa_hva_iovec_map(elemiov, mem_space, cache)?; // Get possible dataout buffer from virtqueue Element. let mut iovec = elem.in_iovec.clone(); let elemiov = iov_discard_front(&mut iovec, size_of::() as u64).unwrap_or_default(); - let (in_len, in_iovec) = gpa_hva_iovec_map(elemiov, mem_space)?; + let (in_len, in_iovec) = gpa_hva_iovec_map(elemiov, mem_space, cache)?; if out_len > 0 && in_len > 0 { warn!("Wrong scsi request! Don't support both datain and dataout buffer"); @@ -529,7 +535,7 @@ impl VirtioScsiReq fn complete(&self) -> Result<()> { self.mem_space - .write_object(&self.resp, self.resp_addr) + .write_object(&self.resp, self.resp_addr, AddressAttr::Ram) .with_context(|| "Failed to write the scsi response")?; let mut queue_lock = self.queue.lock().unwrap(); @@ -538,11 +544,7 @@ impl VirtioScsiReq // DESC_CHAIN_MAX_TOTAL_LEN(1 << 32). So, it will not overflow here. queue_lock .vring - .add_used( - &self.mem_space, - self.desc_index, - self.data_len + (size_of::() as u32), - ) + .add_used(self.desc_index, self.data_len + (size_of::() as u32)) .with_context(|| { format!( "Failed to add used ring(scsi completion), index {}, len {}", @@ -550,10 +552,7 @@ impl VirtioScsiReq ) })?; - if queue_lock - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue_lock.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue_lock), false) .with_context(|| { VirtioError::InterruptTrigger( @@ -611,11 +610,11 @@ impl ScsiCtrlQueueHandler { let ctrl_desc = elem .out_iovec - .get(0) + .first() .with_context(|| "Error request in ctrl queue. Empty dataout buf!")?; let ctrl_type = self .mem_space - .read_object::(ctrl_desc.addr) + .read_object::(ctrl_desc.addr, AddressAttr::Ram) .with_context(|| "Failed to get control queue descriptor")?; match ctrl_type { @@ -747,7 +746,7 @@ impl ScsiRequestOps for CmdQueueRequest { // | Byte 0 | Byte 1 | Byte 2 | Byte 3 | Byte 4 | Byte 5 | Byte 6 | Byte 7 | // | 1 | target | lun | 0 | fn virtio_scsi_get_lun_id(lun: [u8; 8]) -> u16 { - (((lun[2] as u16) << 8) | (lun[3] as u16)) & 0x3FFF + ((u16::from(lun[2]) << 8) | u16::from(lun[3])) & 0x3FFF } struct ScsiCmdQueueHandler { @@ -927,7 +926,7 @@ impl ScsiCmdQueueHandler { } let sreq = scsi_req.unwrap(); - if sreq.cmd.xfer > sreq.datalen && sreq.cmd.mode != ScsiXferMode::ScsiXferNone { + if sreq.cmd.xfer > u64::from(sreq.datalen) && sreq.cmd.mode != ScsiXferMode::ScsiXferNone { // Wrong virtio scsi request which doesn't provide enough datain/dataout buffer. qrequest.resp.response = VIRTIO_SCSI_S_OVERRUN; qrequest.complete()?; @@ -965,3 +964,55 @@ pub fn scsi_cntlr_create_scsi_bus( locked_scsi_cntlr.bus = Some(Arc::new(Mutex::new(bus))); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use machine_manager::config::str_slip_to_clap; + + #[test] + fn test_scsi_cntlr_config_cmdline_parser() { + // Test1: Right. + let cmdline1 = "virtio-scsi-pci,id=scsi0,bus=pcie.0,addr=0x3,multifunction=on,iothread=iothread1,num-queues=3,queue-size=128"; + let device_cfg = + ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cmdline1, true, false)).unwrap(); + assert_eq!(device_cfg.id, "scsi0"); + assert_eq!(device_cfg.bus, "pcie.0"); + assert_eq!(device_cfg.addr, (3, 0)); + assert_eq!(device_cfg.multifunction, Some(true)); + assert_eq!(device_cfg.iothread.unwrap(), "iothread1"); + assert_eq!(device_cfg.num_queues.unwrap(), 3); + assert_eq!(device_cfg.queue_size, 128); + + // Test2: Default value. + let cmdline2 = "virtio-scsi-pci,id=scsi0,bus=pcie.0,addr=0x3.0x1"; + let device_cfg = + ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cmdline2, true, false)).unwrap(); + assert_eq!(device_cfg.addr, (3, 1)); + assert_eq!(device_cfg.multifunction, None); + assert_eq!(device_cfg.num_queues, None); + assert_eq!(device_cfg.queue_size, 256); + + // Test3: Illegal value. + let cmdline3 = "virtio-scsi-pci,id=scsi0,bus=pcie.0,addr=0x3.0x1,num-queues=33"; + let result = ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cmdline3, true, false)); + assert!(result.is_err()); + let cmdline3 = "virtio-scsi-pci,id=scsi0,bus=pcie.0,addr=0x3.0x1,queue-size=1025"; + let result = ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cmdline3, true, false)); + assert!(result.is_err()); + let cmdline3 = "virtio-scsi-pci,id=scsi0,bus=pcie.0,addr=0x3.0x1,queue-size=65"; + let result = ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cmdline3, true, false)); + assert!(result.is_err()); + + // Test4: Missing necessary parameters. + let cmdline4 = "virtio-scsi-pci,id=scsi0"; + let result = ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cmdline4, true, false)); + assert!(result.is_err()); + let cmdline4 = "virtio-scsi-pci,bus=pcie.0,addr=0x3.0x1"; + let result = ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cmdline4, true, false)); + assert!(result.is_err()); + let cmdline4 = "virtio-scsi-pci,id=scsi0,addr=0x3.0x1"; + let result = ScsiCntlrConfig::try_parse_from(str_slip_to_clap(cmdline4, true, false)); + assert!(result.is_err()); + } +} diff --git a/virtio/src/device/serial.rs b/virtio/src/device/serial.rs index 5bf1c101f1c753dedf6d94f5dec7d8503c0f835a..3141672bb910b298082c69b87ac324bcdc0ea7f4 100644 --- a/virtio/src/device/serial.rs +++ b/virtio/src/device/serial.rs @@ -20,18 +20,19 @@ use std::{cmp, usize}; use anyhow::{anyhow, bail, Context, Result}; use byteorder::{ByteOrder, LittleEndian}; use log::{error, info, warn}; +use machine_manager::config::ChardevConfig; use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::EventFd; use crate::{ - gpa_hva_iovec_map, iov_discard_front, iov_to_buf, read_config_default, report_virtio_error, + gpa_hva_iovec_map, iov_read_object, iov_to_buf, read_config_default, report_virtio_error, Element, Queue, VirtioBase, VirtioDevice, VirtioError, VirtioInterrupt, VirtioInterruptType, VIRTIO_CONSOLE_F_MULTIPORT, VIRTIO_CONSOLE_F_SIZE, VIRTIO_F_VERSION_1, VIRTIO_TYPE_CONSOLE, }; -use address_space::AddressSpace; +use address_space::{AddressAttr, AddressSpace}; use chardev_backend::chardev::{Chardev, ChardevNotifyDevice, ChardevStatus, InputReceiver}; use machine_manager::{ - config::{ChardevType, VirtioSerialInfo, VirtioSerialPort, DEFAULT_VIRTQUEUE_SIZE}, + config::{ChardevType, VirtioSerialInfo, VirtioSerialPortCfg, DEFAULT_VIRTQUEUE_SIZE}, event_loop::EventLoop, event_loop::{register_event_helper, unregister_event_helper}, }; @@ -39,6 +40,7 @@ use migration::{DeviceStateDesc, FieldDesc, MigrationHook, MigrationManager, Sta use migration_derive::{ByteCode, Desc}; use util::aio::iov_from_buf_direct; use util::byte_code::ByteCode; +use util::gen_base_func; use util::loop_context::{ read_fd, EventNotifier, EventNotifierHelper, NotifierCallback, NotifierOperation, }; @@ -191,7 +193,7 @@ impl Serial { } pub fn get_max_nr(ports: &Arc>>>>) -> u32 { - let mut max = 0; + let mut max: u32 = 0; for port in ports.lock().unwrap().iter() { let nr = port.lock().unwrap().nr; if nr > max { @@ -214,13 +216,7 @@ pub fn find_port_by_nr( } impl VirtioDevice for Serial { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { self.init_config_features()?; @@ -354,17 +350,21 @@ pub struct SerialPort { } impl SerialPort { - pub fn new(port_cfg: VirtioSerialPort) -> Self { + pub fn new(port_cfg: VirtioSerialPortCfg, chardev_cfg: ChardevConfig) -> Self { // Console is default host connected. And pty chardev has opened by default in realize() // function. - let host_connected = port_cfg.is_console || port_cfg.chardev.backend == ChardevType::Pty; + let is_console = matches!(port_cfg.classtype.as_str(), "virtconsole"); + let mut host_connected = is_console; + if let ChardevType::Pty { .. } = chardev_cfg.classtype { + host_connected = true; + } SerialPort { name: Some(port_cfg.id), paused: false, - chardev: Arc::new(Mutex::new(Chardev::new(port_cfg.chardev))), - nr: port_cfg.nr, - is_console: port_cfg.is_console, + chardev: Arc::new(Mutex::new(Chardev::new(chardev_cfg))), + nr: port_cfg.nr.unwrap(), + is_console, guest_connected: false, host_connected, ctrl_handler: None, @@ -455,6 +455,14 @@ impl SerialPortHandler { let mut queue_lock = self.output_queue.lock().unwrap(); loop { + if let Some(port) = self.port.as_ref() { + let locked_port = port.lock().unwrap(); + let locked_cdev = locked_port.chardev.lock().unwrap(); + if locked_cdev.outbuf_is_full() { + break; + } + } + let elem = queue_lock .vring .pop_avail(&self.mem_space, self.driver_features)?; @@ -464,40 +472,36 @@ impl SerialPortHandler { // Discard requests when there is no port using this queue. Popping elements without // processing means discarding the request. - if self.port.is_some() { - let mut iovec = elem.out_iovec; - let mut iovec_size = Element::iovec_size(&iovec); - while iovec_size > 0 { - let mut buffer = [0_u8; BUF_SIZE]; - let size = iov_to_buf(&self.mem_space, &iovec, &mut buffer)? as u64; - - self.write_chardev_msg(&buffer, size as usize); - - iovec = iov_discard_front(&mut iovec, size) - .unwrap_or_default() - .to_vec(); - // Safety: iovec follows the iov_discard_front operation and - // iovec_size always equals Element::iovec_size(&iovec). - iovec_size -= size; - trace::virtio_serial_output_data(iovec_size, size); + if let Some(port) = self.port.as_ref() { + let iovec = elem.out_iovec; + let iovec_size = Element::iovec_size(&iovec); + let mut buf = vec![0u8; iovec_size as usize]; + let cache = queue_lock.vring.get_cache(); + let size = iov_to_buf(&self.mem_space, cache, &iovec, &mut buf[..])? as u64; + + let locked_port = port.lock().unwrap(); + if locked_port.host_connected { + if let Err(e) = locked_port + .chardev + .lock() + .unwrap() + .fill_outbuf(buf, Some(self.output_queue_evt.clone())) + { + error!("Failed to append elem buffer to chardev with error {:?}", e); + } } + trace::virtio_serial_output_data(iovec_size, size); } - queue_lock - .vring - .add_used(&self.mem_space, elem.index, 0) - .with_context(|| { - format!( - "Failed to add used ring for virtio serial port output, index: {} len: {}", - elem.index, 0, - ) - })?; + queue_lock.vring.add_used(elem.index, 0).with_context(|| { + format!( + "Failed to add used ring for virtio serial port output, index: {} len: {}", + elem.index, 0, + ) + })?; } - if queue_lock - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue_lock.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue_lock), false) .with_context(|| { VirtioError::InterruptTrigger( @@ -511,30 +515,6 @@ impl SerialPortHandler { Ok(()) } - fn write_chardev_msg(&self, buffer: &[u8], write_len: usize) { - let port_locked = self.port.as_ref().unwrap().lock().unwrap(); - // Discard output buffer if this port's chardev is not connected. - if !port_locked.host_connected { - return; - } - - if let Some(output) = &mut port_locked.chardev.lock().unwrap().output { - let mut locked_output = output.lock().unwrap(); - // To do: - // If the buffer is not fully written to chardev, the incomplete part will be discarded. - // This may occur when chardev is abnormal. Consider optimizing this logic in the - // future. - if let Err(e) = locked_output.write_all(&buffer[..write_len]) { - error!("Port {} failed to write msg to chardev: {:?}", self.nr, e); - } - if let Err(e) = locked_output.flush() { - error!("Port {} failed to flush msg to chardev: {:?}", self.nr, e); - } - } else { - error!("Port {} failed to get output fd", self.nr); - }; - } - fn get_input_avail_bytes(&mut self, max_size: usize) -> usize { let port = self.port.as_ref(); if port.is_none() || !port.unwrap().lock().unwrap().guest_connected { @@ -569,10 +549,9 @@ impl SerialPortHandler { } let mut queue_lock = self.input_queue.lock().unwrap(); - let _ = - queue_lock - .vring - .suppress_queue_notify(&self.mem_space, self.driver_features, !enable); + let _ = queue_lock + .vring + .suppress_queue_notify(self.driver_features, !enable); } fn input_handle_internal(&mut self, buffer: &[u8]) -> Result<()> { @@ -603,8 +582,14 @@ impl SerialPortHandler { let write_end = written_count + len; let mut source_slice = &buffer[written_count..write_end]; + // GPAChecked: the elem_iov has been checked in pop_avail(). self.mem_space - .write(&mut source_slice, elem_iov.addr, len as u64) + .write( + &mut source_slice, + elem_iov.addr, + len as u64, + AddressAttr::Ram, + ) .with_context(|| { format!( "Failed to write slice for virtio serial port input: addr {:X} len {}", @@ -622,7 +607,7 @@ impl SerialPortHandler { queue_lock .vring - .add_used(&self.mem_space, elem.index, once_count as u32) + .add_used(elem.index, once_count as u32) .with_context(|| { format!( "Failed to add used ring for virtio serial port input: index {} len {}", @@ -630,10 +615,7 @@ impl SerialPortHandler { ) })?; - if queue_lock - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue_lock.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue_lock), false) .with_context(|| { VirtioError::InterruptTrigger( @@ -754,17 +736,11 @@ impl SerialControlHandler { break; } - let mut req = VirtioConsoleControl::default(); - iov_to_buf(&self.mem_space, &elem.out_iovec, req.as_mut_bytes()).and_then(|size| { - if size < size_of::() { - bail!( - "Invalid length for request: get {}, expected {}", - size, - size_of::(), - ); - } - Ok(()) - })?; + let mut req = iov_read_object::( + &self.mem_space, + &elem.out_iovec, + queue_lock.vring.get_cache(), + )?; req.id = LittleEndian::read_u32(req.id.as_bytes()); req.event = LittleEndian::read_u16(req.event.as_bytes()); req.value = LittleEndian::read_u16(req.value.as_bytes()); @@ -775,21 +751,15 @@ impl SerialControlHandler { ); self.handle_control_message(&mut req); - queue_lock - .vring - .add_used(&self.mem_space, elem.index, 0) - .with_context(|| { - format!( - "Failed to add used ring for control port, index: {} len: {}.", - elem.index, 0 - ) - })?; + queue_lock.vring.add_used(elem.index, 0).with_context(|| { + format!( + "Failed to add used ring for control port, index: {} len: {}.", + elem.index, 0 + ) + })?; } - if queue_lock - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue_lock.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue_lock), false) .with_context(|| { VirtioError::InterruptTrigger( @@ -904,7 +874,8 @@ impl SerialControlHandler { return Ok(()); } - let (in_size, ctrl_vec) = gpa_hva_iovec_map(&elem.in_iovec, &self.mem_space)?; + let cache = queue_lock.vring.get_cache(); + let (in_size, ctrl_vec) = gpa_hva_iovec_map(&elem.in_iovec, &self.mem_space, cache)?; let len = size_of::() + extra.len(); if in_size < len as u64 { bail!( @@ -921,7 +892,8 @@ impl SerialControlHandler { msg_data.extend(extra); } - iov_from_buf_direct(&ctrl_vec, &msg_data).and_then(|size| { + // SAFETY: ctrl_vec is generated by address_space. + unsafe { iov_from_buf_direct(&ctrl_vec, &msg_data) }.and_then(|size| { if size != len { bail!( "Expected send msg length is {}, actual send length {}.", @@ -934,7 +906,7 @@ impl SerialControlHandler { queue_lock .vring - .add_used(&self.mem_space, elem.index, len as u32) + .add_used(elem.index, len as u32) .with_context(|| { format!( "Failed to add used ring(serial input control queue), index {}, len {}", @@ -942,10 +914,7 @@ impl SerialControlHandler { ) })?; - if queue_lock - .vring - .should_notify(&self.mem_space, self.driver_features) - { + if queue_lock.vring.should_notify(self.driver_features) { (self.interrupt_cb)(&VirtioInterruptType::Vring, Some(&queue_lock), false) .with_context(|| { VirtioError::InterruptTrigger( @@ -1018,20 +987,17 @@ impl ChardevNotifyDevice for SerialPort { #[cfg(test)] mod tests { - pub use super::*; - - use machine_manager::config::PciBdf; + use super::*; #[test] fn test_set_driver_features() { let mut serial = Serial::new(VirtioSerialInfo { + classtype: "virtio-serial-pci".to_string(), id: "serial".to_string(), - pci_bdf: Some(PciBdf { - bus: "pcie.0".to_string(), - addr: (0, 0), - }), - multifunction: false, + multifunction: Some(false), max_ports: 31, + bus: Some("pcie.0".to_string()), + addr: Some((0, 0)), }); // If the device feature is 0, all driver features are not supported. @@ -1040,13 +1006,13 @@ mod tests { let page = 0_u32; serial.set_driver_features(page, driver_feature); assert_eq!(serial.base.driver_features, 0_u64); - assert_eq!(serial.driver_features(page) as u64, 0_u64); + assert_eq!(u64::from(serial.driver_features(page)), 0_u64); let driver_feature: u32 = 0xFF; let page = 1_u32; serial.set_driver_features(page, driver_feature); assert_eq!(serial.base.driver_features, 0_u64); - assert_eq!(serial.driver_features(page) as u64, 0_u64); + assert_eq!(u64::from(serial.driver_features(page)), 0_u64); // If both the device feature bit and the front-end driver feature bit are // supported at the same time, this driver feature bit is supported. @@ -1059,7 +1025,7 @@ mod tests { (1_u64 << VIRTIO_CONSOLE_F_SIZE) ); assert_eq!( - serial.driver_features(page) as u64, + u64::from(serial.driver_features(page)), (1_u64 << VIRTIO_CONSOLE_F_SIZE) ); serial.base.driver_features = 0; @@ -1094,25 +1060,24 @@ mod tests { fn test_read_config() { let max_ports: u8 = 31; let serial = Serial::new(VirtioSerialInfo { + classtype: "virtio-serial-pci".to_string(), id: "serial".to_string(), - pci_bdf: Some(PciBdf { - bus: "pcie.0".to_string(), - addr: (0, 0), - }), - multifunction: false, - max_ports: max_ports as u32, + multifunction: Some(false), + max_ports: u32::from(max_ports), + bus: Some("pcie.0".to_string()), + addr: Some((0, 0)), }); // The offset of configuration that needs to be read exceeds the maximum. let offset = size_of::() as u64; let mut read_data: Vec = vec![0; 8]; - assert_eq!(serial.read_config(offset, &mut read_data).is_ok(), false); + assert!(serial.read_config(offset, &mut read_data).is_err()); // Check the configuration that needs to be read. let offset = 0_u64; let mut read_data: Vec = vec![0; 12]; let expect_data: Vec = vec![0, 0, 0, 0, max_ports, 0, 0, 0, 0, 0, 0, 0]; - assert_eq!(serial.read_config(offset, &mut read_data).is_ok(), true); + assert!(serial.read_config(offset, &mut read_data).is_ok()); assert_eq!(read_data, expect_data); } } diff --git a/virtio/src/lib.rs b/virtio/src/lib.rs index 9ef9dde8b6496d03be6baa65db4cd9dc0a1c2ce6..2b48c9fb5f291aae5c4b1626ae743581413a036e 100644 --- a/virtio/src/lib.rs +++ b/virtio/src/lib.rs @@ -33,11 +33,14 @@ mod queue; mod transport; pub use device::balloon::*; -pub use device::block::{Block, BlockState, VirtioBlkConfig}; +pub use device::block::{Block, BlockState, VirtioBlkConfig, VirtioBlkDevConfig}; #[cfg(feature = "virtio_gpu")] pub use device::gpu::*; +pub use device::input::*; pub use device::net::*; -pub use device::rng::{Rng, RngState}; +#[cfg(feature = "virtio_rng")] +pub use device::rng::{Rng, RngConfig, RngState}; +#[cfg(feature = "virtio_scsi")] pub use device::scsi_cntlr as ScsiCntlr; pub use device::serial::{find_port_by_nr, get_max_nr, Serial, SerialPort, VirtioSerialState}; pub use error::VirtioError; @@ -49,6 +52,7 @@ pub use vhost::user as VhostUser; use std::cmp; use std::io::Write; +use std::mem::size_of; use std::os::unix::prelude::RawFd; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, Ordering}; use std::sync::{Arc, Mutex}; @@ -57,10 +61,13 @@ use anyhow::{anyhow, bail, Context, Result}; use log::{error, warn}; use vmm_sys_util::eventfd::EventFd; -use address_space::AddressSpace; +use address_space::{AddressSpace, RegionCache}; +use devices::pci::register_pcidevops_type; +use devices::sysbus::register_sysbusdevops_type; use machine_manager::config::ConfigCheck; use migration_derive::ByteCode; use util::aio::{mem_to_buf, Iovec}; +use util::byte_code::ByteCode; use util::num_ops::{read_u32, write_u32}; use util::AsAny; @@ -77,6 +84,7 @@ pub const VIRTIO_TYPE_RNG: u32 = 4; pub const VIRTIO_TYPE_BALLOON: u32 = 5; pub const VIRTIO_TYPE_SCSI: u32 = 8; pub const VIRTIO_TYPE_GPU: u32 = 16; +pub const VIRTIO_TYPE_INPUT: u32 = 18; pub const VIRTIO_TYPE_VSOCK: u32 = 19; pub const VIRTIO_TYPE_FS: u32 = 26; @@ -543,9 +551,9 @@ pub trait VirtioDevice: Send + AsAny { } let features = if page == 0 { - (self.driver_features(1) as u64) << 32 | (v as u64) + u64::from(self.driver_features(1)) << 32 | u64::from(v) } else { - (v as u64) << 32 | (self.driver_features(0) as u64) + u64::from(v) << 32 | u64::from(self.driver_features(0)) }; self.virtio_base_mut().driver_features = features; } @@ -729,8 +737,9 @@ pub trait VirtioDevice: Send + AsAny { /// /// # Arguments /// - /// * `_file_path` - The related backend file path. - fn update_config(&mut self, _dev_config: Option>) -> Result<()> { + /// * `_configs` - The related configs for device. + /// eg: DriveConfig and VirtioBlkDevConfig for virtio blk device. + fn update_config(&mut self, _configs: Vec>) -> Result<()> { bail!("Unsupported to update configuration") } @@ -787,17 +796,42 @@ pub fn report_virtio_error( broken.store(true, Ordering::SeqCst); } +/// Read object typed `T` from iovec. +pub fn iov_read_object( + mem_space: &Arc, + iovec: &[ElemIovec], + cache: &Option, +) -> Result { + let mut obj = T::default(); + let count = iov_to_buf(mem_space, cache, iovec, obj.as_mut_bytes())?; + let size = size_of::(); + if count < size { + bail!("Read length error: expected {}, read {}.", size, count); + } + Ok(obj) +} + /// Read iovec to buf and return the read number of bytes. -pub fn iov_to_buf(mem_space: &AddressSpace, iovec: &[ElemIovec], buf: &mut [u8]) -> Result { +pub fn iov_to_buf( + mem_space: &AddressSpace, + cache: &Option, + iovec: &[ElemIovec], + buf: &mut [u8], +) -> Result { let mut start: usize = 0; let mut end: usize = 0; + // Note: iovec is part of elem.in_iovec/out_iovec which has been checked + // in pop_avail(). The sum of iov_len is not greater than u32::MAX. for iov in iovec { let mut addr_map = Vec::new(); - mem_space.get_address_map(iov.addr, iov.len as u64, &mut addr_map)?; + mem_space.get_address_map(cache, iov.addr, u64::from(iov.len), &mut addr_map)?; for addr in addr_map.into_iter() { end = cmp::min(start + addr.iov_len as usize, buf.len()); - mem_to_buf(&mut buf[start..end], addr.iov_base)?; + // SAFETY: addr_map is generated by address_space and len is not less than buf's. + unsafe { + mem_to_buf(&mut buf[start..end], addr.iov_base)?; + } if end >= buf.len() { return Ok(end); } @@ -810,12 +844,12 @@ pub fn iov_to_buf(mem_space: &AddressSpace, iovec: &[ElemIovec], buf: &mut [u8]) /// Discard "size" bytes of the front of iovec. pub fn iov_discard_front(iovec: &mut [ElemIovec], mut size: u64) -> Option<&mut [ElemIovec]> { for (index, iov) in iovec.iter_mut().enumerate() { - if iov.len as u64 > size { + if u64::from(iov.len) > size { iov.addr.0 += size; iov.len -= size as u32; return Some(&mut iovec[index..]); } - size -= iov.len as u64; + size -= u64::from(iov.len); } None } @@ -824,11 +858,11 @@ pub fn iov_discard_front(iovec: &mut [ElemIovec], mut size: u64) -> Option<&mut pub fn iov_discard_back(iovec: &mut [ElemIovec], mut size: u64) -> Option<&mut [ElemIovec]> { let len = iovec.len(); for (index, iov) in iovec.iter_mut().rev().enumerate() { - if iov.len as u64 > size { + if u64::from(iov.len) > size { iov.len -= size as u32; return Some(&mut iovec[..(len - index)]); } - size -= iov.len as u64; + size -= u64::from(iov.len); } None } @@ -838,14 +872,85 @@ pub fn iov_discard_back(iovec: &mut [ElemIovec], mut size: u64) -> Option<&mut [ fn gpa_hva_iovec_map( gpa_elemiovec: &[ElemIovec], mem_space: &AddressSpace, + cache: &Option, ) -> Result<(u64, Vec)> { - let mut iov_size = 0; + let mut iov_size: u64 = 0; let mut hva_iovec = Vec::with_capacity(gpa_elemiovec.len()); + // Note: gpa_elemiovec is part of elem.in_iovec/out_iovec which has been checked + // in pop_avail(). The sum of iov_len is not greater than u32::MAX. for elem in gpa_elemiovec.iter() { - mem_space.get_address_map(elem.addr, elem.len as u64, &mut hva_iovec)?; - iov_size += elem.len as u64; + mem_space.get_address_map(cache, elem.addr, u64::from(elem.len), &mut hva_iovec)?; + iov_size += u64::from(elem.len); } Ok((iov_size, hva_iovec)) } + +pub fn virtio_register_sysbusdevops_type() -> Result<()> { + register_sysbusdevops_type::() +} + +pub fn virtio_register_pcidevops_type() -> Result<()> { + register_pcidevops_type::() +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use address_space::{AddressSpace, GuestAddress, HostMemMapping, Region}; + use devices::sysbus::{SysBus, IRQ_BASE, IRQ_MAX}; + + pub const MEMORY_SIZE: u64 = 1024 * 1024; + + pub fn sysbus_init() -> Arc> { + let sys_mem = AddressSpace::new( + Region::init_container_region(u64::max_value(), "sys_mem"), + "sys_mem", + None, + ) + .unwrap(); + #[cfg(target_arch = "x86_64")] + let sys_io = AddressSpace::new( + Region::init_container_region(1 << 16, "sys_io"), + "sys_io", + None, + ) + .unwrap(); + let free_irqs: (i32, i32) = (IRQ_BASE, IRQ_MAX); + let mmio_region: (u64, u64) = (0x0A00_0000, 0x1000_0000); + Arc::new(Mutex::new(SysBus::new( + #[cfg(target_arch = "x86_64")] + &sys_io, + &sys_mem, + free_irqs, + mmio_region, + ))) + } + + pub fn address_space_init() -> Arc { + let root = Region::init_container_region(1 << 36, "root"); + let sys_space = AddressSpace::new(root, "sys_space", None).unwrap(); + let host_mmap = Arc::new( + HostMemMapping::new( + GuestAddress(0), + None, + MEMORY_SIZE, + None, + false, + false, + false, + ) + .unwrap(), + ); + sys_space + .root() + .add_subregion( + Region::init_ram_region(host_mmap.clone(), "region_1"), + host_mmap.start_address().raw_value(), + ) + .unwrap(); + sys_space + } +} diff --git a/virtio/src/queue/mod.rs b/virtio/src/queue/mod.rs index 7f581a0441d02a5b562ef45b34819d145be81cc5..44f164731ced5c54480a5031a5bb0fe543099ed1 100644 --- a/virtio/src/queue/mod.rs +++ b/virtio/src/queue/mod.rs @@ -21,6 +21,7 @@ use vmm_sys_util::eventfd::EventFd; use address_space::{AddressSpace, GuestAddress, RegionCache}; use machine_manager::config::DEFAULT_VIRTQUEUE_SIZE; +use util::loop_context::create_new_eventfd; /// Split Virtqueue. pub const QUEUE_TYPE_SPLIT_VRING: u16 = 1; @@ -90,7 +91,9 @@ impl Element { pub fn iovec_size(iovec: &[ElemIovec]) -> u64 { let mut size: u64 = 0; for elem in iovec.iter() { - size += elem.len as u64; + // Note: iovec is part of elem.in_iovec/out_iovec which has been checked + // in pop_avail(). The sum of iov_len is not greater than u32::MAX. + size += u64::from(elem.len); } size } @@ -123,32 +126,24 @@ pub trait VringOps { /// /// # Arguments /// - /// * `sys_mem` - Address space to which the vring belongs. /// * `index` - Index of descriptor in the virqueue descriptor table. /// * `len` - Total length of the descriptor chain which was used (written to). - fn add_used(&mut self, sys_mem: &Arc, index: u16, len: u32) -> Result<()>; + fn add_used(&mut self, index: u16, len: u32) -> Result<()>; /// Return true if guest needed to be notified. /// /// # Arguments /// - /// * `sys_mem` - Address space to which the vring belongs. /// * `features` - Bit mask of features negotiated by the backend and the frontend. - fn should_notify(&mut self, sys_mem: &Arc, features: u64) -> bool; + fn should_notify(&mut self, features: u64) -> bool; /// Give guest a hint to suppress virtqueue notification. /// /// # Arguments /// - /// * `sys_mem` - Address space to which the vring belongs. /// * `features` - Bit mask of features negotiated by the backend and the frontend. /// * `suppress` - Suppress virtqueue notification or not. - fn suppress_queue_notify( - &mut self, - sys_mem: &Arc, - features: u64, - suppress: bool, - ) -> Result<()>; + fn suppress_queue_notify(&mut self, features: u64, suppress: bool) -> Result<()>; /// Get the actual size of the vring. fn actual_size(&self) -> u16; @@ -157,13 +152,13 @@ pub trait VringOps { fn get_queue_config(&self) -> QueueConfig; /// The number of descriptor chains in the available ring. - fn avail_ring_len(&mut self, sys_mem: &Arc) -> Result; + fn avail_ring_len(&mut self) -> Result; /// Get the avail index of the vring. - fn get_avail_idx(&self, sys_mem: &Arc) -> Result; + fn get_avail_idx(&self) -> Result; /// Get the used index of the vring. - fn get_used_idx(&self, sys_mem: &Arc) -> Result; + fn get_used_idx(&self) -> Result; /// Get the region cache information of the SplitVring. fn get_cache(&self) -> &Option; @@ -226,7 +221,7 @@ impl NotifyEventFds { pub fn new(queue_num: usize) -> Self { let mut events = Vec::new(); for _i in 0..queue_num { - events.push(Arc::new(EventFd::new(libc::EFD_NONBLOCK).unwrap())); + events.push(Arc::new(create_new_eventfd().unwrap())); } NotifyEventFds { events } diff --git a/virtio/src/queue/split.rs b/virtio/src/queue/split.rs index 31b2fe306c88ea54db614d40e971c343e54e256f..30a906ffc827a672762620acbeb8fa3068c52327 100644 --- a/virtio/src/queue/split.rs +++ b/virtio/src/queue/split.rs @@ -11,6 +11,7 @@ // See the Mulan PSL v2 for more details. use std::cmp::min; +use std::io::Write; use std::mem::size_of; use std::num::Wrapping; use std::ops::{Deref, DerefMut}; @@ -27,7 +28,8 @@ use super::{ use crate::{ report_virtio_error, virtio_has_feature, VirtioError, VirtioInterrupt, VIRTIO_F_RING_EVENT_IDX, }; -use address_space::{AddressSpace, GuestAddress, RegionCache, RegionType}; +use address_space::{AddressAttr, AddressSpace, GuestAddress, RegionCache, RegionType}; +use migration::{migration::Migratable, MigrationManager}; use util::byte_code::ByteCode; /// When host consumes a buffer, don't interrupt the guest. @@ -52,6 +54,51 @@ const VRING_IDX_POSITION: u64 = size_of::() as u64; /// The length of virtio descriptor. const DESCRIPTOR_LEN: u64 = size_of::() as u64; +/// Read some data from memory to form an object via host address. +/// +/// # Arguments +/// +/// * `hoat_addr` - The start host address where the data will be read from. +/// +/// # Safety +/// +/// Make true that host_addr and std::mem::size_of::() are in the range of ram. +/// +/// # Note +/// To use this method, it is necessary to implement `ByteCode` trait for your object. +unsafe fn read_object_direct(host_addr: u64) -> Result { + trace::virtio_read_object_direct(host_addr, std::mem::size_of::()); + let mut obj = T::default(); + let mut dst = obj.as_mut_bytes(); + let src = std::slice::from_raw_parts_mut(host_addr as *mut u8, std::mem::size_of::()); + dst.write_all(src) + .with_context(|| "Failed to read object via host address")?; + + Ok(obj) +} + +/// Write an object to memory via host address. +/// +/// # Arguments +/// +/// * `data` - The object that will be written to the memory. +/// * `host_addr` - The start host address where the object will be written to. +/// +/// # Safety +/// +/// Make true that host_addr and std::mem::size_of::() are in the range of ram. +/// +/// # Note +/// To use this method, it is necessary to implement `ByteCode` trait for your object. +unsafe fn write_object_direct(data: &T, host_addr: u64) -> Result<()> { + trace::virtio_write_object_direct(host_addr, std::mem::size_of::()); + // Mark vmm dirty page manually if live migration is active. + MigrationManager::mark_dirty_log(host_addr, data.as_bytes().len() as u64); + let mut dst = std::slice::from_raw_parts_mut(host_addr as *mut u8, std::mem::size_of::()); + dst.write_all(data.as_bytes()) + .with_context(|| "Failed to write object via host address") +} + #[derive(Default, Clone, Copy)] pub struct VirtioAddrCache { /// Host virtual address of the descriptor table. @@ -116,7 +163,7 @@ impl QueueConfig { } fn get_desc_size(&self) -> u64 { - min(self.size, self.max_size) as u64 * DESCRIPTOR_LEN + u64::from(min(self.size, self.max_size)) * DESCRIPTOR_LEN } fn get_used_size(&self, features: u64) -> u64 { @@ -126,7 +173,7 @@ impl QueueConfig { 0_u64 }; - size + VRING_FLAGS_AND_IDX_LEN + (min(self.size, self.max_size) as u64) * USEDELEM_LEN + size + VRING_FLAGS_AND_IDX_LEN + u64::from(min(self.size, self.max_size)) * USEDELEM_LEN } fn get_avail_size(&self, features: u64) -> u64 { @@ -137,7 +184,7 @@ impl QueueConfig { }; size + VRING_FLAGS_AND_IDX_LEN - + (min(self.size, self.max_size) as u64) * (size_of::() as u64) + + u64::from(min(self.size, self.max_size)) * (size_of::() as u64) } pub fn reset(&mut self) { @@ -151,41 +198,44 @@ impl QueueConfig { features: u64, broken: &Arc, ) { - self.addr_cache.desc_table_host = - if let Some((addr, size)) = mem_space.addr_cache_init(self.desc_table) { - if size < self.get_desc_size() { - report_virtio_error(interrupt_cb.clone(), features, broken); - 0_u64 - } else { - addr - } - } else { + self.addr_cache.desc_table_host = if let Some((addr, size)) = + mem_space.addr_cache_init(self.desc_table, AddressAttr::Ram) + { + if size < self.get_desc_size() { + report_virtio_error(interrupt_cb.clone(), features, broken); 0_u64 - }; - - self.addr_cache.avail_ring_host = - if let Some((addr, size)) = mem_space.addr_cache_init(self.avail_ring) { - if size < self.get_avail_size(features) { - report_virtio_error(interrupt_cb.clone(), features, broken); - 0_u64 - } else { - addr - } } else { - 0_u64 - }; + addr + } + } else { + 0_u64 + }; - self.addr_cache.used_ring_host = - if let Some((addr, size)) = mem_space.addr_cache_init(self.used_ring) { - if size < self.get_used_size(features) { - report_virtio_error(interrupt_cb.clone(), features, broken); - 0_u64 - } else { - addr - } + self.addr_cache.avail_ring_host = if let Some((addr, size)) = + mem_space.addr_cache_init(self.avail_ring, AddressAttr::Ram) + { + if size < self.get_avail_size(features) { + report_virtio_error(interrupt_cb.clone(), features, broken); + 0_u64 } else { + addr + } + } else { + 0_u64 + }; + + self.addr_cache.used_ring_host = if let Some((addr, size)) = + mem_space.addr_cache_init(self.used_ring, AddressAttr::Ram) + { + if size < self.get_used_size(features) { + report_virtio_error(interrupt_cb.clone(), features, broken); 0_u64 - }; + } else { + addr + } + } else { + 0_u64 + }; } } @@ -265,9 +315,11 @@ impl SplitVringDesc { u64::from(index) * DESCRIPTOR_LEN, ) })?; - let desc = sys_mem - .read_object_direct::(desc_addr) - .with_context(|| VirtioError::ReadObjectErr("a descriptor", desc_addr))?; + // SAFETY: dest_addr has been checked in SplitVringDesc::is_valid() and is guaranteed to be within the ram range. + let desc = unsafe { + read_object_direct::(desc_addr) + .with_context(|| VirtioError::ReadObjectErr("a descriptor", desc_addr)) + }?; if desc.is_valid(sys_mem, queue_size, cache) { Ok(desc) @@ -290,7 +342,7 @@ impl SplitVringDesc { let mut miss_cached = true; if let Some(reg_cache) = cache { let base = self.addr.0; - let offset = self.len as u64; + let offset = u64::from(self.len); let end = match base.checked_add(offset) { Some(addr) => addr, None => { @@ -298,11 +350,12 @@ impl SplitVringDesc { return false; } }; + // GPAChecked: the vring desc [addr, addr+len] must locate in guest ram. if base > reg_cache.start && end < reg_cache.end { miss_cached = false; } } else { - let gotten_cache = sys_mem.get_region_cache(self.addr); + let gotten_cache = sys_mem.get_region_cache(self.addr, AddressAttr::Ram); if let Some(obtained_cache) = gotten_cache { if obtained_cache.reg_type == RegionType::Ram { *cache = gotten_cache; @@ -311,6 +364,7 @@ impl SplitVringDesc { } if miss_cached { + // GPAChecked: the vring desc addr must locate in guest ram. if let Err(ref e) = checked_offset_mem(sys_mem, self.addr, u64::from(self.len)) { error!("The memory of descriptor is invalid, {:?} ", e); return false; @@ -361,7 +415,7 @@ impl SplitVringDesc { fn is_valid_indirect_desc(&self) -> bool { if self.len == 0 || u64::from(self.len) % DESCRIPTOR_LEN != 0 - || u64::from(self.len) / DESCRIPTOR_LEN > u16::MAX as u64 + || u64::from(self.len) / DESCRIPTOR_LEN > u64::from(u16::MAX) { error!("The indirect descriptor is invalid, len: {}", self.len); return false; @@ -435,7 +489,11 @@ impl SplitVringDesc { elem.out_iovec.push(iovec); } elem.desc_num += 1; - desc_total_len += iovec.len as u64; + // Note: iovec.addr + iovec.len is located in RAM, and iovec.len is not greater than the + // VM RAM size. The number of iovec is not greater than 'queue_size * 2 - 1' which with + // a indirect table. Currently, the max value of queue_size is 1024. So, desc_total_len + // must not overflow. + desc_total_len += u64::from(iovec.len); if desc.has_next() { desc = Self::next_desc(sys_mem, desc_table_host, queue_size, desc.next, cache)?; @@ -495,73 +553,77 @@ impl SplitVring { } /// Get the flags and idx of the available ring from guest memory. - fn get_avail_flags_idx(&self, sys_mem: &Arc) -> Result { - sys_mem - .read_object_direct::(self.addr_cache.avail_ring_host) - .with_context(|| { - VirtioError::ReadObjectErr("avail flags idx", self.avail_ring.raw_value()) - }) + fn get_avail_flags_idx(&self) -> Result { + // SAFETY: avail_ring_host is checked when addr_cache inited. + unsafe { + read_object_direct::(self.addr_cache.avail_ring_host).with_context( + || VirtioError::ReadObjectErr("avail flags idx", self.avail_ring.raw_value()), + ) + } } /// Get the idx of the available ring from guest memory. - fn get_avail_idx(&self, sys_mem: &Arc) -> Result { - let flags_idx = self.get_avail_flags_idx(sys_mem)?; + fn get_avail_idx(&self) -> Result { + let flags_idx = self.get_avail_flags_idx()?; Ok(flags_idx.idx) } /// Get the flags of the available ring from guest memory. - fn get_avail_flags(&self, sys_mem: &Arc) -> Result { - let flags_idx = self.get_avail_flags_idx(sys_mem)?; + fn get_avail_flags(&self) -> Result { + let flags_idx = self.get_avail_flags_idx()?; Ok(flags_idx.flags) } /// Get the flags and idx of the used ring from guest memory. - fn get_used_flags_idx(&self, sys_mem: &Arc) -> Result { + fn get_used_flags_idx(&self) -> Result { // Make sure the idx read from sys_mem is new. fence(Ordering::SeqCst); - sys_mem - .read_object_direct::(self.addr_cache.used_ring_host) - .with_context(|| { - VirtioError::ReadObjectErr("used flags idx", self.used_ring.raw_value()) - }) + // SAFETY: used_ring_host has been checked in set_addr_cache() and is guaranteed to be within the ram range. + unsafe { + read_object_direct::(self.addr_cache.used_ring_host).with_context( + || VirtioError::ReadObjectErr("used flags idx", self.used_ring.raw_value()), + ) + } } /// Get the index of the used ring from guest memory. - fn get_used_idx(&self, sys_mem: &Arc) -> Result { - let flag_idx = self.get_used_flags_idx(sys_mem)?; + fn get_used_idx(&self) -> Result { + let flag_idx = self.get_used_flags_idx()?; Ok(flag_idx.idx) } /// Set the used flags to suppress virtqueue notification or not - fn set_used_flags(&self, sys_mem: &Arc, suppress: bool) -> Result<()> { - let mut flags_idx = self.get_used_flags_idx(sys_mem)?; + fn set_used_flags(&self, suppress: bool) -> Result<()> { + let mut flags_idx = self.get_used_flags_idx()?; if suppress { flags_idx.flags |= VRING_USED_F_NO_NOTIFY; } else { flags_idx.flags &= !VRING_USED_F_NO_NOTIFY; } - sys_mem - .write_object_direct::(&flags_idx, self.addr_cache.used_ring_host) - .with_context(|| { - format!( - "Failed to set used flags, used_ring: 0x{:X}", - self.used_ring.raw_value() - ) - })?; + // SAFETY: used_ring_host has been checked when addr_cache inited. + unsafe { + write_object_direct::(&flags_idx, self.addr_cache.used_ring_host) + .with_context(|| { + format!( + "Failed to set used flags, used_ring: 0x{:X}", + self.used_ring.raw_value() + ) + }) + }?; // Make sure the data has been set. fence(Ordering::SeqCst); Ok(()) } /// Set the avail idx to the field of the event index for the available ring. - fn set_avail_event(&self, sys_mem: &Arc, event_idx: u16) -> Result<()> { + fn set_avail_event(&self, event_idx: u16) -> Result<()> { trace::virtqueue_set_avail_event(self as *const _ as u64, event_idx); let avail_event_offset = VRING_FLAGS_AND_IDX_LEN + USEDELEM_LEN * u64::from(self.actual_size()); - - sys_mem - .write_object_direct( + // SAFETY: used_ring_host has been checked in set_addr_cache(). + unsafe { + write_object_direct( &event_idx, self.addr_cache.used_ring_host + avail_event_offset, ) @@ -571,14 +633,15 @@ impl SplitVring { self.used_ring.raw_value(), avail_event_offset, ) - })?; + }) + }?; // Make sure the data has been set. fence(Ordering::SeqCst); Ok(()) } /// Get the event index of the used ring from guest memory. - fn get_used_event(&self, sys_mem: &Arc) -> Result { + fn get_used_event(&self) -> Result { let used_event_offset = VRING_FLAGS_AND_IDX_LEN + AVAILELEM_LEN * u64::from(self.actual_size()); // Make sure the event idx read from sys_mem is new. @@ -586,16 +649,18 @@ impl SplitVring { // The GPA of avail_ring_host with avail table length has been checked in // is_invalid_memory which must not be overflowed. let used_event_addr = self.addr_cache.avail_ring_host + used_event_offset; - let used_event = sys_mem - .read_object_direct::(used_event_addr) - .with_context(|| VirtioError::ReadObjectErr("used event id", used_event_addr))?; + // SAFETY: used_event_addr is protected by virtio calculations and is guaranteed to be within the ram range. + let used_event = unsafe { + read_object_direct::(used_event_addr) + .with_context(|| VirtioError::ReadObjectErr("used event id", used_event_addr)) + }?; Ok(used_event) } /// Return true if VRING_AVAIL_F_NO_INTERRUPT is set. - fn is_avail_ring_no_interrupt(&self, sys_mem: &Arc) -> bool { - match self.get_avail_flags(sys_mem) { + fn is_avail_ring_no_interrupt(&self) -> bool { + match self.get_avail_flags() { Ok(avail_flags) => (avail_flags & VRING_AVAIL_F_NO_INTERRUPT) != 0, Err(ref e) => { warn!( @@ -608,9 +673,9 @@ impl SplitVring { } /// Return true if it's required to trigger interrupt for the used vring. - fn used_ring_need_event(&mut self, sys_mem: &Arc) -> bool { + fn used_ring_need_event(&mut self) -> bool { let old = self.last_signal_used; - let new = match self.get_used_idx(sys_mem) { + let new = match self.get_used_idx() { Ok(used_idx) => Wrapping(used_idx), Err(ref e) => { error!("Failed to get the status for notifying used vring: {:?}", e); @@ -618,7 +683,7 @@ impl SplitVring { } }; - let used_event_idx = match self.get_used_event(sys_mem) { + let used_event_idx = match self.get_used_event() { Ok(idx) => Wrapping(idx), Err(ref e) => { error!("Failed to get the status for notifying used vring: {:?}", e); @@ -642,6 +707,7 @@ impl SplitVring { } fn is_invalid_memory(&self, sys_mem: &Arc, actual_size: u64) -> bool { + // GPAChecked: the desc ring table must locate in guest ram. let desc_table_end = match checked_offset_mem(sys_mem, self.desc_table, DESCRIPTOR_LEN * actual_size) { Ok(addr) => addr, @@ -656,6 +722,7 @@ impl SplitVring { } }; + // GPAChecked: the avail ring table must locate in guest ram. let desc_avail_end = match checked_offset_mem( sys_mem, self.avail_ring, @@ -673,6 +740,7 @@ impl SplitVring { } }; + // GPAChecked: the used ring table must locate in guest ram. let desc_used_end = match checked_offset_mem( sys_mem, self.used_ring, @@ -745,11 +813,12 @@ impl SplitVring { // The GPA of avail_ring_host with avail table length has been checked in // is_invalid_memory which must not be overflowed. let desc_index_addr = self.addr_cache.avail_ring_host + index_offset; - let desc_index = sys_mem - .read_object_direct::(desc_index_addr) - .with_context(|| { + // SAFETY: dest_index_addr is protected by virtio calculations and is guaranteed to be within the ram range. + let desc_index = unsafe { + read_object_direct::(desc_index_addr).with_context(|| { VirtioError::ReadObjectErr("the index of descriptor", desc_index_addr) - })?; + }) + }?; let desc = SplitVringDesc::new( sys_mem, @@ -761,7 +830,7 @@ impl SplitVring { // Suppress queue notification related to current processing desc chain. if virtio_has_feature(features, VIRTIO_F_RING_EVENT_IDX) { - self.set_avail_event(sys_mem, (next_avail + Wrapping(1)).0) + self.set_avail_event((next_avail + Wrapping(1)).0) .with_context(|| "Failed to set avail event for popping avail ring")?; } @@ -819,7 +888,7 @@ impl VringOps for SplitVring { fn pop_avail(&mut self, sys_mem: &Arc, features: u64) -> Result { let mut element = Element::new(0); - if !self.is_enabled() || self.avail_ring_len(sys_mem)? == 0 { + if !self.is_enabled() || self.avail_ring_len()? == 0 { return Ok(element); } @@ -842,7 +911,7 @@ impl VringOps for SplitVring { self.next_avail -= Wrapping(1); } - fn add_used(&mut self, sys_mem: &Arc, index: u16, len: u32) -> Result<()> { + fn add_used(&mut self, index: u16, len: u32) -> Result<()> { if index >= self.size { return Err(anyhow!(VirtioError::QueueIndex(index, self.size))); } @@ -855,19 +924,23 @@ impl VringOps for SplitVring { id: u32::from(index), len, }; - sys_mem - .write_object_direct::(&used_elem, used_elem_addr) - .with_context(|| "Failed to write object for used element")?; + // SAFETY: used_elem_addr is guaranteed to be within ram range. + unsafe { + write_object_direct::(&used_elem, used_elem_addr) + .with_context(|| "Failed to write object for used element") + }?; // Make sure used element is filled before updating used idx. fence(Ordering::Release); self.next_used += Wrapping(1); - sys_mem - .write_object_direct( + // SAFETY: used_ring_host has been checked when addr_cache inited. + unsafe { + write_object_direct( &(self.next_used.0), self.addr_cache.used_ring_host + VRING_IDX_POSITION, ) - .with_context(|| "Failed to write next used idx")?; + .with_context(|| "Failed to write next used idx") + }?; // Make sure used index is exposed before notifying guest. fence(Ordering::SeqCst); @@ -878,24 +951,23 @@ impl VringOps for SplitVring { Ok(()) } - fn should_notify(&mut self, sys_mem: &Arc, features: u64) -> bool { + fn should_notify(&mut self, features: u64) -> bool { if virtio_has_feature(features, VIRTIO_F_RING_EVENT_IDX) { - self.used_ring_need_event(sys_mem) + self.used_ring_need_event() } else { - !self.is_avail_ring_no_interrupt(sys_mem) + !self.is_avail_ring_no_interrupt() } } - fn suppress_queue_notify( - &mut self, - sys_mem: &Arc, - features: u64, - suppress: bool, - ) -> Result<()> { + fn suppress_queue_notify(&mut self, features: u64, suppress: bool) -> Result<()> { + if !self.is_enabled() { + bail!("queue is not ready"); + } + if virtio_has_feature(features, VIRTIO_F_RING_EVENT_IDX) { - self.set_avail_event(sys_mem, self.get_avail_idx(sys_mem)?)?; + self.set_avail_event(self.get_avail_idx()?)?; } else { - self.set_used_flags(sys_mem, suppress)?; + self.set_used_flags(suppress)?; } Ok(()) } @@ -911,18 +983,18 @@ impl VringOps for SplitVring { } /// The number of descriptor chains in the available ring. - fn avail_ring_len(&mut self, sys_mem: &Arc) -> Result { - let avail_idx = self.get_avail_idx(sys_mem).map(Wrapping)?; + fn avail_ring_len(&mut self) -> Result { + let avail_idx = self.get_avail_idx().map(Wrapping)?; Ok((avail_idx - self.next_avail).0) } - fn get_avail_idx(&self, sys_mem: &Arc) -> Result { - SplitVring::get_avail_idx(self, sys_mem) + fn get_avail_idx(&self) -> Result { + SplitVring::get_avail_idx(self) } - fn get_used_idx(&self, sys_mem: &Arc) -> Result { - SplitVring::get_used_idx(self, sys_mem) + fn get_used_idx(&self) -> Result { + SplitVring::get_used_idx(self) } fn get_cache(&self) -> &Option { @@ -942,7 +1014,7 @@ impl VringOps for SplitVring { let mut avail_bytes = 0_usize; let mut avail_idx = self.next_avail; - let end_idx = self.get_avail_idx(sys_mem).map(Wrapping)?; + let end_idx = self.get_avail_idx().map(Wrapping)?; while (end_idx - avail_idx).0 > 0 { let desc_info = self.get_desc_info(sys_mem, avail_idx, 0)?; @@ -975,33 +1047,9 @@ impl VringOps for SplitVring { #[cfg(test)] mod tests { use super::*; + use crate::tests::address_space_init; use crate::{Queue, QUEUE_TYPE_PACKED_VRING, QUEUE_TYPE_SPLIT_VRING}; - use address_space::{AddressSpace, GuestAddress, HostMemMapping, Region}; - - fn address_space_init() -> Arc { - let root = Region::init_container_region(1 << 36, "sysmem"); - let sys_space = AddressSpace::new(root, "sysmem", None).unwrap(); - let host_mmap = Arc::new( - HostMemMapping::new( - GuestAddress(0), - None, - SYSTEM_SPACE_SIZE, - None, - false, - false, - false, - ) - .unwrap(), - ); - sys_space - .root() - .add_subregion( - Region::init_ram_region(host_mmap.clone(), "sysmem"), - host_mmap.start_address().raw_value(), - ) - .unwrap(); - sys_space - } + use address_space::{AddressAttr, AddressSpace, GuestAddress}; trait VringOpsTest { fn set_desc( @@ -1050,7 +1098,7 @@ mod tests { return Err(anyhow!(VirtioError::QueueIndex(index, self.size))); } - let desc_addr_offset = DESCRIPTOR_LEN * index as u64; + let desc_addr_offset = DESCRIPTOR_LEN * u64::from(index); let desc = SplitVringDesc { addr, len, @@ -1060,22 +1108,29 @@ mod tests { sys_mem.write_object::( &desc, GuestAddress(self.desc_table.0 + desc_addr_offset), + AddressAttr::Ram, )?; Ok(()) } fn set_avail_ring_idx(&self, sys_mem: &Arc, idx: u16) -> Result<()> { - let avail_idx_offset = 2 as u64; - sys_mem - .write_object::(&idx, GuestAddress(self.avail_ring.0 + avail_idx_offset))?; + let avail_idx_offset = 2_u64; + sys_mem.write_object::( + &idx, + GuestAddress(self.avail_ring.0 + avail_idx_offset), + AddressAttr::Ram, + )?; Ok(()) } fn set_avail_ring_flags(&self, sys_mem: &Arc, flags: u16) -> Result<()> { - let avail_idx_offset = 0 as u64; - sys_mem - .write_object::(&flags, GuestAddress(self.avail_ring.0 + avail_idx_offset))?; + let avail_idx_offset = 0_u64; + sys_mem.write_object::( + &flags, + GuestAddress(self.avail_ring.0 + avail_idx_offset), + AddressAttr::Ram, + )?; Ok(()) } @@ -1085,47 +1140,61 @@ mod tests { avail_pos: u16, desc_index: u16, ) -> Result<()> { - let avail_idx_offset = VRING_FLAGS_AND_IDX_LEN + AVAILELEM_LEN * (avail_pos as u64); + let avail_idx_offset = VRING_FLAGS_AND_IDX_LEN + AVAILELEM_LEN * u64::from(avail_pos); sys_mem.write_object::( &desc_index, GuestAddress(self.avail_ring.0 + avail_idx_offset), + AddressAttr::Ram, )?; Ok(()) } fn get_avail_event(&self, sys_mem: &Arc) -> Result { let avail_event_idx_offset = - VRING_FLAGS_AND_IDX_LEN + USEDELEM_LEN * (self.actual_size() as u64); - let event_idx = sys_mem - .read_object::(GuestAddress(self.used_ring.0 + avail_event_idx_offset))?; + VRING_FLAGS_AND_IDX_LEN + USEDELEM_LEN * u64::from(self.actual_size()); + let event_idx = sys_mem.read_object::( + GuestAddress(self.used_ring.0 + avail_event_idx_offset), + AddressAttr::Ram, + )?; Ok(event_idx) } fn get_used_elem(&self, sys_mem: &Arc, index: u16) -> Result { - let used_elem_offset = VRING_FLAGS_AND_IDX_LEN + USEDELEM_LEN * (index as u64); - let used_elem = sys_mem - .read_object::(GuestAddress(self.used_ring.0 + used_elem_offset))?; + let used_elem_offset = VRING_FLAGS_AND_IDX_LEN + USEDELEM_LEN * u64::from(index); + let used_elem = sys_mem.read_object::( + GuestAddress(self.used_ring.0 + used_elem_offset), + AddressAttr::Ram, + )?; Ok(used_elem) } fn get_used_ring_idx(&self, sys_mem: &Arc) -> Result { let used_idx_offset = VRING_IDX_POSITION; - let idx = - sys_mem.read_object::(GuestAddress(self.used_ring.0 + used_idx_offset))?; + let idx = sys_mem.read_object::( + GuestAddress(self.used_ring.0 + used_idx_offset), + AddressAttr::Ram, + )?; Ok(idx) } fn set_used_ring_idx(&self, sys_mem: &Arc, idx: u16) -> Result<()> { let used_idx_offset = VRING_IDX_POSITION; - sys_mem.write_object::(&idx, GuestAddress(self.used_ring.0 + used_idx_offset))?; + sys_mem.write_object::( + &idx, + GuestAddress(self.used_ring.0 + used_idx_offset), + AddressAttr::Ram, + )?; Ok(()) } fn set_used_event_idx(&self, sys_mem: &Arc, idx: u16) -> Result<()> { let event_idx_offset = - VRING_FLAGS_AND_IDX_LEN + AVAILELEM_LEN * (self.actual_size() as u64); - sys_mem - .write_object::(&idx, GuestAddress(self.avail_ring.0 + event_idx_offset))?; + VRING_FLAGS_AND_IDX_LEN + AVAILELEM_LEN * u64::from(self.actual_size()); + sys_mem.write_object::( + &idx, + GuestAddress(self.avail_ring.0 + event_idx_offset), + AddressAttr::Ram, + )?; Ok(()) } } @@ -1144,12 +1213,12 @@ mod tests { flags, next, }; - sys_mem.write_object::(&desc, desc_addr)?; + sys_mem.write_object::(&desc, desc_addr, AddressAttr::Ram)?; Ok(()) } const SYSTEM_SPACE_SIZE: u64 = (1024 * 1024) as u64; - const QUEUE_SIZE: u16 = 256 as u16; + const QUEUE_SIZE: u16 = 256_u16; fn align(size: u64, alignment: u64) -> u64 { let align_adjust = if size % alignment != 0 { @@ -1157,7 +1226,7 @@ mod tests { } else { 0 }; - (size + align_adjust) as u64 + size + align_adjust } #[test] @@ -1174,38 +1243,38 @@ mod tests { // it is valid queue_config.desc_table = GuestAddress(0); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); queue_config.ready = true; queue_config.size = QUEUE_SIZE; let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the status is not ready queue_config.ready = false; let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); queue_config.ready = true; // it is invalid when the size of virtual ring is more than the max size queue_config.size = QUEUE_SIZE + 1; let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // it is invalid when the size of virtual ring is zero queue_config.size = 0; let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // it is invalid when the size of virtual ring isn't power of 2 queue_config.size = 15; let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); } #[test] @@ -1214,58 +1283,58 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); queue_config.ready = true; queue_config.size = QUEUE_SIZE; let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of descriptor table is out of bound queue_config.desc_table = - GuestAddress(SYSTEM_SPACE_SIZE - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + 1 as u64); + GuestAddress(SYSTEM_SPACE_SIZE - u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + 1_u64); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue queue_config.desc_table = GuestAddress(0); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of avail ring is out of bound queue_config.avail_ring = GuestAddress( SYSTEM_SPACE_SIZE - - (VRING_AVAIL_LEN_EXCEPT_AVAILELEM + AVAILELEM_LEN * (QUEUE_SIZE as u64)) - + 1 as u64, + - (VRING_AVAIL_LEN_EXCEPT_AVAILELEM + AVAILELEM_LEN * u64::from(QUEUE_SIZE)) + + 1_u64, ); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of used ring is out of bound queue_config.used_ring = GuestAddress( SYSTEM_SPACE_SIZE - - (VRING_USED_LEN_EXCEPT_USEDELEM + USEDELEM_LEN * (QUEUE_SIZE as u64)) - + 1 as u64, + - (VRING_USED_LEN_EXCEPT_USEDELEM + USEDELEM_LEN * u64::from(QUEUE_SIZE)) + + 1_u64, ); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); } #[test] @@ -1274,69 +1343,69 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); queue_config.ready = true; queue_config.size = QUEUE_SIZE; let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of descriptor table is equal to the address of avail ring queue_config.avail_ring = GuestAddress(0); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of descriptor table is overlapped to the address of avail // ring. - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN - 1); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN - 1); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of avail ring is equal to the address of used ring - queue_config.used_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.used_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of avail ring is overlapped to the address of used ring queue_config.used_ring = GuestAddress( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64) + + AVAILELEM_LEN * u64::from(QUEUE_SIZE) - 1, ); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); } #[test] @@ -1345,54 +1414,54 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); queue_config.ready = true; queue_config.size = QUEUE_SIZE; let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of descriptor table is not aligned to 16 - queue_config.desc_table = GuestAddress(15 as u64); + queue_config.desc_table = GuestAddress(15_u64); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue queue_config.desc_table = GuestAddress(0); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of avail ring is not aligned to 2 - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN + 1); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + 1); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); // it is invalid when the address of used ring is not aligned to 4 queue_config.used_ring = GuestAddress( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64) + + AVAILELEM_LEN * u64::from(QUEUE_SIZE) + 3, ); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), false); + assert!(!queue.is_valid(&sys_space)); // recover the address for valid queue queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); let queue = Queue::new(queue_config, QUEUE_TYPE_SPLIT_VRING).unwrap(); - assert_eq!(queue.is_valid(&sys_space), true); + assert!(queue.is_valid(&sys_space)); } #[test] @@ -1401,23 +1470,32 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - sys_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); - queue_config.addr_cache.avail_ring_host = - sys_space.get_host_address(queue_config.avail_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + sys_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); + queue_config.addr_cache.avail_ring_host = unsafe { + sys_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); - queue_config.addr_cache.used_ring_host = - sys_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.used_ring_host = unsafe { + sys_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.ready = true; queue_config.size = QUEUE_SIZE; let mut vring = SplitVring::new(queue_config); - assert_eq!(vring.is_valid(&sys_space), true); + assert!(vring.is_valid(&sys_space)); // it is ok when the descriptor chain is normal // set the information of index 0 for descriptor @@ -1454,7 +1532,7 @@ mod tests { // set 1 to the idx of avail ring vring.set_avail_ring_idx(&sys_space, 1).unwrap(); - let features = 1 << VIRTIO_F_RING_EVENT_IDX as u64; + let features = 1 << u64::from(VIRTIO_F_RING_EVENT_IDX); let elem = match vring.pop_avail(&sys_space, features) { Ok(ret) => ret, Err(_) => Element { @@ -1467,11 +1545,11 @@ mod tests { assert_eq!(elem.index, 0); assert_eq!(elem.desc_num, 3); assert_eq!(elem.out_iovec.len(), 1); - let elem_iov = elem.out_iovec.get(0).unwrap(); + let elem_iov = elem.out_iovec.first().unwrap(); assert_eq!(elem_iov.addr, GuestAddress(0x111)); assert_eq!(elem_iov.len, 16); assert_eq!(elem.in_iovec.len(), 2); - let elem_iov = elem.in_iovec.get(0).unwrap(); + let elem_iov = elem.in_iovec.first().unwrap(); assert_eq!(elem_iov.addr, GuestAddress(0x222)); assert_eq!(elem_iov.len, 32); let elem_iov = elem.in_iovec.get(1).unwrap(); @@ -1481,7 +1559,7 @@ mod tests { // the event idx of avail ring is equal to get_avail_event let event_idx = vring.get_avail_event(&sys_space).unwrap(); assert_eq!(event_idx, 1); - let avail_idx = vring.get_avail_idx(&sys_space).unwrap(); + let avail_idx = vring.get_avail_idx().unwrap(); assert_eq!(avail_idx, 1); } @@ -1491,23 +1569,32 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - sys_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); - queue_config.addr_cache.avail_ring_host = - sys_space.get_host_address(queue_config.avail_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + sys_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); + queue_config.addr_cache.avail_ring_host = unsafe { + sys_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); - queue_config.addr_cache.used_ring_host = - sys_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.used_ring_host = unsafe { + sys_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.ready = true; queue_config.size = QUEUE_SIZE; let mut vring = SplitVring::new(queue_config); - assert_eq!(vring.is_valid(&sys_space), true); + assert!(vring.is_valid(&sys_space)); // it is ok when the descriptor chain is indirect // set the information for indirect descriptor @@ -1560,7 +1647,7 @@ mod tests { // set 1 to the idx of avail ring vring.set_avail_ring_idx(&sys_space, 1).unwrap(); - let features = 1 << VIRTIO_F_RING_EVENT_IDX as u64; + let features = 1 << u64::from(VIRTIO_F_RING_EVENT_IDX); let elem = match vring.pop_avail(&sys_space, features) { Ok(ret) => ret, Err(_) => Element { @@ -1573,14 +1660,14 @@ mod tests { assert_eq!(elem.index, 0); assert_eq!(elem.desc_num, 3); assert_eq!(elem.out_iovec.len(), 2); - let elem_iov = elem.out_iovec.get(0).unwrap(); + let elem_iov = elem.out_iovec.first().unwrap(); assert_eq!(elem_iov.addr, GuestAddress(0x444)); assert_eq!(elem_iov.len, 100); let elem_iov = elem.out_iovec.get(1).unwrap(); assert_eq!(elem_iov.addr, GuestAddress(0x555)); assert_eq!(elem_iov.len, 200); assert_eq!(elem.in_iovec.len(), 1); - let elem_iov = elem.in_iovec.get(0).unwrap(); + let elem_iov = elem.in_iovec.first().unwrap(); assert_eq!(elem_iov.addr, GuestAddress(0x666)); assert_eq!(elem_iov.len, 300); } @@ -1591,28 +1678,37 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - sys_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); - queue_config.addr_cache.avail_ring_host = - sys_space.get_host_address(queue_config.avail_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + sys_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); + queue_config.addr_cache.avail_ring_host = unsafe { + sys_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); - queue_config.addr_cache.used_ring_host = - sys_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.used_ring_host = unsafe { + sys_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.ready = true; queue_config.size = QUEUE_SIZE; let mut vring = SplitVring::new(queue_config); - assert_eq!(vring.is_valid(&sys_space), true); + assert!(vring.is_valid(&sys_space)); // it is error when the idx of avail ring which is equal to next_avail // set 0 to the idx of avail ring which is equal to next_avail vring.set_avail_ring_idx(&sys_space, 0).unwrap(); - let features = 1 << VIRTIO_F_RING_EVENT_IDX as u64; + let features = 1 << u64::from(VIRTIO_F_RING_EVENT_IDX); if let Ok(elem) = vring.pop_avail(&sys_space, features) { if elem.desc_num != 0 { assert!(false); @@ -1652,7 +1748,7 @@ mod tests { 0, ) .unwrap(); - if let Ok(_) = vring.pop_avail(&sys_space, features) { + if vring.pop_avail(&sys_space, features).is_ok() { assert!(false); } @@ -1786,23 +1882,32 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - sys_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); - queue_config.addr_cache.avail_ring_host = - sys_space.get_host_address(queue_config.avail_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + sys_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); + queue_config.addr_cache.avail_ring_host = unsafe { + sys_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); - queue_config.addr_cache.used_ring_host = - sys_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.used_ring_host = unsafe { + sys_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.ready = true; queue_config.size = QUEUE_SIZE; let mut vring = SplitVring::new(queue_config); - assert_eq!(vring.is_valid(&sys_space), true); + assert!(vring.is_valid(&sys_space)); // Set the information of index 0 for normal descriptor. vring @@ -1855,7 +1960,7 @@ mod tests { // Set 1 to the idx of avail ring. vring.set_avail_ring_idx(&sys_space, 1).unwrap(); - let features = 1 << VIRTIO_F_RING_EVENT_IDX as u64; + let features = 1 << u64::from(VIRTIO_F_RING_EVENT_IDX); if let Err(err) = vring.pop_avail(&sys_space, features) { assert_eq!(err.to_string(), "Failed to get vring element"); } else { @@ -1921,7 +2026,7 @@ mod tests { // Two elem for reading. assert_eq!(elem.out_iovec.len(), 2); - let elem_iov = elem.out_iovec.get(0).unwrap(); + let elem_iov = elem.out_iovec.first().unwrap(); assert_eq!(elem_iov.addr, GuestAddress(0x111)); assert_eq!(elem_iov.len, 16); let elem_iov = elem.out_iovec.get(1).unwrap(); @@ -1930,7 +2035,7 @@ mod tests { // Two elem for writing. assert_eq!(elem.in_iovec.len(), 2); - let elem_iov = elem.in_iovec.get(0).unwrap(); + let elem_iov = elem.in_iovec.first().unwrap(); assert_eq!(elem_iov.addr, GuestAddress(0x444)); assert_eq!(elem_iov.len, 100); let elem_iov = elem.in_iovec.get(1).unwrap(); @@ -1940,7 +2045,7 @@ mod tests { // The event idx of avail ring is equal to get_avail_event. let event_idx = vring.get_avail_event(&sys_space).unwrap(); assert_eq!(event_idx, 1); - let avail_idx = vring.get_avail_idx(&sys_space).unwrap(); + let avail_idx = vring.get_avail_idx().unwrap(); assert_eq!(avail_idx, 1); } @@ -1950,38 +2055,44 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - sys_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); - queue_config.addr_cache.avail_ring_host = - sys_space.get_host_address(queue_config.avail_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + sys_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); + queue_config.addr_cache.avail_ring_host = unsafe { + sys_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); - queue_config.addr_cache.used_ring_host = - sys_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.used_ring_host = unsafe { + sys_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.ready = true; queue_config.size = QUEUE_SIZE; let mut vring = SplitVring::new(queue_config); - assert_eq!(vring.is_valid(&sys_space), true); + assert!(vring.is_valid(&sys_space)); // it is false when the index is more than the size of queue - if let Err(err) = vring.add_used(&sys_space, QUEUE_SIZE, 100) { + if let Err(err) = vring.add_used(QUEUE_SIZE, 100) { if let Some(e) = err.downcast_ref::() { - match e { - VirtioError::QueueIndex(offset, size) => { - assert_eq!(*offset, 256); - assert_eq!(*size, 256); - } - _ => (), + if let VirtioError::QueueIndex(offset, size) = e { + assert_eq!(*offset, 256); + assert_eq!(*size, 256); } } } - assert!(vring.add_used(&sys_space, 10, 100).is_ok()); + assert!(vring.add_used(10, 100).is_ok()); let elem = vring.get_used_elem(&sys_space, 0).unwrap(); assert_eq!(elem.id, 10); assert_eq!(elem.len, 100); @@ -1994,57 +2105,66 @@ mod tests { let mut queue_config = QueueConfig::new(QUEUE_SIZE); queue_config.desc_table = GuestAddress(0); - queue_config.addr_cache.desc_table_host = - sys_space.get_host_address(queue_config.desc_table).unwrap(); - queue_config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * DESCRIPTOR_LEN); - queue_config.addr_cache.avail_ring_host = - sys_space.get_host_address(queue_config.avail_ring).unwrap(); + queue_config.addr_cache.desc_table_host = unsafe { + sys_space + .get_host_address(queue_config.desc_table, AddressAttr::Ram) + .unwrap() + }; + queue_config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN); + queue_config.addr_cache.avail_ring_host = unsafe { + sys_space + .get_host_address(queue_config.avail_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * DESCRIPTOR_LEN + u64::from(QUEUE_SIZE) * DESCRIPTOR_LEN + VRING_AVAIL_LEN_EXCEPT_AVAILELEM - + AVAILELEM_LEN * (QUEUE_SIZE as u64), + + AVAILELEM_LEN * u64::from(QUEUE_SIZE), 4096, )); - queue_config.addr_cache.used_ring_host = - sys_space.get_host_address(queue_config.used_ring).unwrap(); + queue_config.addr_cache.used_ring_host = unsafe { + sys_space + .get_host_address(queue_config.used_ring, AddressAttr::Ram) + .unwrap() + }; queue_config.ready = true; queue_config.size = QUEUE_SIZE; let mut vring = SplitVring::new(queue_config); - assert_eq!(vring.is_valid(&sys_space), true); + assert!(vring.is_valid(&sys_space)); // it's true when the feature of event idx and no interrupt for the avail ring is closed - let features = 0 as u64; + let features = 0_u64; assert!(vring.set_avail_ring_flags(&sys_space, 0).is_ok()); - assert_eq!(vring.should_notify(&sys_space, features), true); + assert!(vring.should_notify(features)); // it's false when the feature of event idx is closed and the feature of no interrupt for // the avail ring is open - let features = 0 as u64; + let features = 0_u64; assert!(vring .set_avail_ring_flags(&sys_space, VRING_AVAIL_F_NO_INTERRUPT) .is_ok()); - assert_eq!(vring.should_notify(&sys_space, features), false); + assert!(!vring.should_notify(features)); // it's true when the feature of event idx is open and // (new - event_idx - Wrapping(1) < new -old) - let features = 1 << VIRTIO_F_RING_EVENT_IDX as u64; + let features = 1 << u64::from(VIRTIO_F_RING_EVENT_IDX); vring.last_signal_used = Wrapping(5); // old assert!(vring.set_used_ring_idx(&sys_space, 10).is_ok()); // new assert!(vring.set_used_event_idx(&sys_space, 6).is_ok()); // event_idx - assert_eq!(vring.should_notify(&sys_space, features), true); + assert!(vring.should_notify(features)); // it's false when the feature of event idx is open and // (new - event_idx - Wrapping(1) > new - old) vring.last_signal_used = Wrapping(5); // old assert!(vring.set_used_ring_idx(&sys_space, 10).is_ok()); // new assert!(vring.set_used_event_idx(&sys_space, 1).is_ok()); // event_idx - assert_eq!(vring.should_notify(&sys_space, features), false); + assert!(!vring.should_notify(features)); // it's false when the feature of event idx is open and // (new - event_idx - Wrapping(1) = new -old) vring.last_signal_used = Wrapping(5); // old assert!(vring.set_used_ring_idx(&sys_space, 10).is_ok()); // new assert!(vring.set_used_event_idx(&sys_space, 4).is_ok()); // event_idx - assert_eq!(vring.should_notify(&sys_space, features), false); + assert!(!vring.should_notify(features)); } } diff --git a/virtio/src/transport/virtio_mmio.rs b/virtio/src/transport/virtio_mmio.rs index 55381c5ea2c9f7eedd62057ab156379366a29c32..eb081e1338cbc5fa879577de3afeda1480f4ddf2 100644 --- a/virtio/src/transport/virtio_mmio.rs +++ b/virtio/src/transport/virtio_mmio.rs @@ -26,13 +26,13 @@ use crate::{ QUEUE_TYPE_PACKED_VRING, VIRTIO_F_RING_PACKED, VIRTIO_MMIO_INT_CONFIG, VIRTIO_MMIO_INT_VRING, }; use address_space::{AddressRange, AddressSpace, GuestAddress, RegionIoEventFd}; -use devices::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType, SysRes}; -use devices::{Device, DeviceBase}; -#[cfg(target_arch = "x86_64")] -use machine_manager::config::{BootSource, Param}; +use devices::sysbus::{SysBus, SysBusDevBase, SysBusDevOps, SysBusDevType}; +use devices::{convert_bus_mut, Device, DeviceBase, MUT_SYS_BUS}; use migration::{DeviceStateDesc, FieldDesc, MigrationHook, MigrationManager, StateTransfer}; use migration_derive::{ByteCode, Desc}; use util::byte_code::ByteCode; +use util::gen_base_func; +use util::loop_context::create_new_eventfd; /// Registers of virtio-mmio device refer to Virtio Spec. /// Magic value - Read Only. @@ -105,7 +105,7 @@ impl HostNotifyInfo { fn new(queue_num: usize) -> Self { let mut events = Vec::new(); for _i in 0..queue_num { - events.push(Arc::new(EventFd::new(libc::EFD_NONBLOCK).unwrap())); + events.push(Arc::new(create_new_eventfd().unwrap())); } HostNotifyInfo { events } @@ -134,55 +134,34 @@ pub struct VirtioMmioDevice { } impl VirtioMmioDevice { - pub fn new(mem_space: &Arc, device: Arc>) -> Self { - let device_clone = device.clone(); - let queue_num = device_clone.lock().unwrap().queue_num(); - - VirtioMmioDevice { + pub fn new( + mem_space: &Arc, + name: String, + device: Arc>, + sysbus: &Arc>, + region_base: u64, + region_size: u64, + ) -> Result { + if region_base >= sysbus.lock().unwrap().mmio_region.1 { + bail!("Mmio region space exhausted."); + } + let queue_num = device.lock().unwrap().queue_num(); + let mut mmio_device = VirtioMmioDevice { base: SysBusDevBase { + base: DeviceBase::new(name, false, None), dev_type: SysBusDevType::VirtioMmio, - interrupt_evt: Some(Arc::new(EventFd::new(libc::EFD_NONBLOCK).unwrap())), + interrupt_evt: Some(Arc::new(create_new_eventfd()?)), ..Default::default() }, device, host_notify_info: HostNotifyInfo::new(queue_num), mem_space: mem_space.clone(), interrupt_cb: None, - } - } - - pub fn realize( - mut self, - sysbus: &mut SysBus, - region_base: u64, - region_size: u64, - #[cfg(target_arch = "x86_64")] bs: &Arc>, - ) -> Result>> { - if region_base >= sysbus.mmio_region.1 { - bail!("Mmio region space exhausted."); - } - self.set_sys_resource(sysbus, region_base, region_size)?; - self.assign_interrupt_cb(); - self.device - .lock() - .unwrap() - .realize() - .with_context(|| "Failed to realize virtio.")?; + }; + mmio_device.set_sys_resource(sysbus, region_base, region_size, "VirtioMmio")?; + mmio_device.set_parent_bus(sysbus.clone()); - let dev = Arc::new(Mutex::new(self)); - sysbus.attach_device(&dev, region_base, region_size, "VirtioMmio")?; - - #[cfg(target_arch = "x86_64")] - bs.lock().unwrap().kernel_cmdline.push(Param { - param_type: "virtio_mmio.device".to_string(), - value: format!( - "{}@0x{:08x}:{}", - region_size, - region_base, - dev.lock().unwrap().base.res.irq - ), - }); - Ok(dev) + Ok(mmio_device) } /// Activate the virtio device, this function is called by vcpu thread when frontend @@ -292,10 +271,10 @@ impl VirtioMmioDevice { .map(|config| u32::from(config.max_size))?, QUEUE_READY_REG => locked_device .queue_config() - .map(|config| config.ready as u32)?, + .map(|config| u32::from(config.ready))?, INTERRUPT_STATUS_REG => locked_device.interrupt_status(), STATUS_REG => locked_device.device_status(), - CONFIG_GENERATION_REG => locked_device.config_generation() as u32, + CONFIG_GENERATION_REG => u32::from(locked_device.config_generation()), // SHM_SEL is unimplemented. According to the Virtio v1.2 spec: Reading from a non-existent // region(i.e. where the ID written to SHMSel is unused) results in a length of -1. SHM_LEN_LOW | SHM_LEN_HIGH => u32::MAX, @@ -382,23 +361,27 @@ impl VirtioMmioDevice { } impl Device for VirtioMmioDevice { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); + + fn realize(mut self) -> Result>> { + self.assign_interrupt_cb(); + self.device + .lock() + .unwrap() + .realize() + .with_context(|| "Failed to realize virtio.")?; - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + MUT_SYS_BUS!(parent_bus, locked_bus, sysbus); + let dev = Arc::new(Mutex::new(self)); + sysbus.attach_device(&dev)?; + + Ok(dev) } } impl SysBusDevOps for VirtioMmioDevice { - fn sysbusdev_base(&self) -> &SysBusDevBase { - &self.base - } - - fn sysbusdev_base_mut(&mut self) -> &mut SysBusDevBase { - &mut self.base - } + gen_base_func!(sysbusdev_base, sysbusdev_base_mut, SysBusDevBase, base); /// Read data by virtio driver from VM. fn read(&mut self, data: &mut [u8], _base: GuestAddress, offset: u64) -> bool { @@ -417,7 +400,7 @@ impl SysBusDevOps for VirtioMmioDevice { }; LittleEndian::write_u32(data, value); } - 0x100..=0xfff => { + 0x100..=0x1ff => { if let Err(ref e) = self .device .lock() @@ -477,7 +460,7 @@ impl SysBusDevOps for VirtioMmioDevice { self.device.lock().unwrap().set_device_activated(true); } } - 0x100..=0xfff => { + 0x100..=0x1ff => { let mut locked_device = self.device.lock().unwrap(); if locked_device.check_device_status(CONFIG_STATUS_DRIVER, CONFIG_STATUS_FAILED) { if let Err(ref e) = locked_device.write_config(offset - 0x100, data) { @@ -521,10 +504,6 @@ impl SysBusDevOps for VirtioMmioDevice { } ret } - - fn get_sys_resource_mut(&mut self) -> Option<&mut SysRes> { - Some(&mut self.base.res) - } } impl acpi::AmlBuilder for VirtioMmioDevice { @@ -591,51 +570,26 @@ impl MigrationHook for VirtioMmioDevice { #[cfg(test)] mod tests { use super::*; + use crate::tests::{address_space_init, sysbus_init}; use crate::{ check_config_space_rw, read_config_default, VirtioBase, QUEUE_TYPE_SPLIT_VRING, VIRTIO_TYPE_BLOCK, }; - use address_space::{AddressSpace, GuestAddress, HostMemMapping, Region}; - - fn address_space_init() -> Arc { - let root = Region::init_container_region(1 << 36, "sysmem"); - let sys_space = AddressSpace::new(root, "sysmem", None).unwrap(); - let host_mmap = Arc::new( - HostMemMapping::new( - GuestAddress(0), - None, - SYSTEM_SPACE_SIZE, - None, - false, - false, - false, - ) - .unwrap(), - ); - sys_space - .root() - .add_subregion( - Region::init_ram_region(host_mmap.clone(), "sysmem"), - host_mmap.start_address().raw_value(), - ) - .unwrap(); - sys_space - } + use address_space::{AddressSpace, GuestAddress}; - const SYSTEM_SPACE_SIZE: u64 = (1024 * 1024) as u64; const CONFIG_SPACE_SIZE: usize = 16; const QUEUE_NUM: usize = 2; const QUEUE_SIZE: u16 = 256; - pub struct VirtioDeviceTest { + struct VirtioDeviceTest { base: VirtioBase, - pub config_space: Vec, - pub b_active: bool, - pub b_realized: bool, + config_space: Vec, + b_active: bool, + b_realized: bool, } impl VirtioDeviceTest { - pub fn new() -> Self { + fn new() -> Self { let mut config_space = Vec::new(); for i in 0..CONFIG_SPACE_SIZE { config_space.push(i as u8); @@ -651,13 +605,7 @@ mod tests { } impl VirtioDevice for VirtioDeviceTest { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { self.b_realized = true; @@ -677,7 +625,7 @@ mod tests { check_config_space_rw(&self.config_space, offset, data)?; let data_len = data.len(); self.config_space[(offset as usize)..(offset as usize + data_len)] - .copy_from_slice(&data[..]); + .copy_from_slice(data); Ok(()) } @@ -692,14 +640,29 @@ mod tests { } } - #[test] - fn test_virtio_mmio_device_new() { + fn virtio_mmio_test_init() -> (Arc>, VirtioMmioDevice) { let virtio_device = Arc::new(Mutex::new(VirtioDeviceTest::new())); + let sys_space = address_space_init(); - let virtio_mmio_device = VirtioMmioDevice::new(&sys_space, virtio_device.clone()); + let sysbus = sysbus_init(); + let virtio_mmio_device = VirtioMmioDevice::new( + &sys_space, + "test_virtio_mmio_device".to_string(), + virtio_device.clone(), + &sysbus, + 0x0A00_0000, + 0x0000_0200, + ) + .unwrap(); + + (virtio_device, virtio_mmio_device) + } + #[test] + fn test_virtio_mmio_device_new() { + let (virtio_device, virtio_mmio_device) = virtio_mmio_test_init(); let locked_device = virtio_device.lock().unwrap(); - assert_eq!(locked_device.device_activated(), false); + assert!(!locked_device.device_activated()); assert_eq!( virtio_mmio_device.host_notify_info.events.len(), locked_device.queue_num() @@ -714,41 +677,27 @@ mod tests { #[test] fn test_virtio_mmio_device_read_01() { - let virtio_device = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_space = address_space_init(); - let mut virtio_mmio_device = VirtioMmioDevice::new(&sys_space, virtio_device.clone()); + let (virtio_device, mut virtio_mmio_device) = virtio_mmio_test_init(); let addr = GuestAddress(0); // read the register of magic value let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, MAGIC_VALUE_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, MAGIC_VALUE_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), MMIO_MAGIC_VALUE); // read the register of version let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, VERSION_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, VERSION_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), MMIO_VERSION); // read the register of device id let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, DEVICE_ID_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, DEVICE_ID_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), VIRTIO_TYPE_BLOCK); // read the register of vendor id let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, VENDOR_ID_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, VENDOR_ID_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), VENDOR_ID); // read the register of the features @@ -756,45 +705,31 @@ mod tests { let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; virtio_device.lock().unwrap().set_hfeatures_sel(0); virtio_device.lock().unwrap().base.device_features = 0x0000_00f8_0000_00fe; - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, DEVICE_FEATURES_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, DEVICE_FEATURES_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 0x0000_00fe); // get high 32bit of the features for device which supports VirtIO Version 1 let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; virtio_device.lock().unwrap().set_hfeatures_sel(1); - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, DEVICE_FEATURES_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, DEVICE_FEATURES_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 0x0000_00f9); } #[test] fn test_virtio_mmio_device_read_02() { - let virtio_device = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_space = address_space_init(); - let mut virtio_mmio_device = VirtioMmioDevice::new(&sys_space, virtio_device.clone()); + let (virtio_device, mut virtio_mmio_device) = virtio_mmio_test_init(); let addr = GuestAddress(0); // read the register representing max size of the queue // for queue_select as 0 let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; virtio_device.lock().unwrap().set_queue_select(0); - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, QUEUE_NUM_MAX_REG), - true - ); - assert_eq!(LittleEndian::read_u32(&buf[..]), QUEUE_SIZE as u32); + assert!(virtio_mmio_device.read(&mut buf[..], addr, QUEUE_NUM_MAX_REG)); + assert_eq!(LittleEndian::read_u32(&buf[..]), u32::from(QUEUE_SIZE)); // for queue_select as 1 let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; virtio_device.lock().unwrap().set_queue_select(1); - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, QUEUE_NUM_MAX_REG), - true - ); - assert_eq!(LittleEndian::read_u32(&buf[..]), QUEUE_SIZE as u32); + assert!(virtio_mmio_device.read(&mut buf[..], addr, QUEUE_NUM_MAX_REG)); + assert_eq!(LittleEndian::read_u32(&buf[..]), u32::from(QUEUE_SIZE)); // read the register representing the status of queue // for queue_select as 0 @@ -805,15 +740,9 @@ mod tests { .unwrap() .set_device_status(CONFIG_STATUS_FEATURES_OK); LittleEndian::write_u32(&mut buf[..], 1); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_READY_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_READY_REG)); let mut data: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut data[..], addr, QUEUE_READY_REG), - true - ); + assert!(virtio_mmio_device.read(&mut data[..], addr, QUEUE_READY_REG)); assert_eq!(LittleEndian::read_u32(&data[..]), 1); // for queue_select as 1 let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; @@ -822,73 +751,50 @@ mod tests { .lock() .unwrap() .set_device_status(CONFIG_STATUS_FEATURES_OK); - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, QUEUE_READY_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, QUEUE_READY_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 0); // read the register representing the status of interrupt let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, INTERRUPT_STATUS_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, INTERRUPT_STATUS_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 0); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; virtio_device .lock() .unwrap() .set_interrupt_status(0b10_1111); - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, INTERRUPT_STATUS_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, INTERRUPT_STATUS_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 0b10_1111); // read the register representing the status of device let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; virtio_device.lock().unwrap().set_device_status(0); - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, STATUS_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, STATUS_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 0); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; virtio_device.lock().unwrap().set_device_status(5); - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, STATUS_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, STATUS_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 5); } #[test] fn test_virtio_mmio_device_read_03() { - let virtio_device = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_space = address_space_init(); - let mut virtio_mmio_device = VirtioMmioDevice::new(&sys_space, virtio_device.clone()); + let (virtio_device, mut virtio_mmio_device) = virtio_mmio_test_init(); let addr = GuestAddress(0); // read the configuration atomic value let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, CONFIG_GENERATION_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, CONFIG_GENERATION_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 0); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; virtio_device.lock().unwrap().set_config_generation(10); - assert_eq!( - virtio_mmio_device.read(&mut buf[..], addr, CONFIG_GENERATION_REG), - true - ); + assert!(virtio_mmio_device.read(&mut buf[..], addr, CONFIG_GENERATION_REG)); assert_eq!(LittleEndian::read_u32(&buf[..]), 10); // read the unknown register let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!(virtio_mmio_device.read(&mut buf[..], addr, 0xf1), false); - assert_eq!(virtio_mmio_device.read(&mut buf[..], addr, 0xfff + 1), true); + assert!(!virtio_mmio_device.read(&mut buf[..], addr, 0xf1)); + assert!(virtio_mmio_device.read(&mut buf[..], addr, 0x1ff + 1)); assert_eq!(buf, [0xff, 0xff, 0xff, 0xff]); // read the configuration space of virtio device @@ -902,29 +808,24 @@ mod tests { .copy_from_slice(&result); let mut data: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - assert_eq!(virtio_mmio_device.read(&mut data[..], addr, 0x100), true); + assert!(virtio_mmio_device.read(&mut data[..], addr, 0x100)); assert_eq!(data, result); let mut data: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0]; let result: Vec = vec![9, 10, 11, 12, 13, 14, 15, 16]; - assert_eq!(virtio_mmio_device.read(&mut data[..], addr, 0x108), true); + assert!(virtio_mmio_device.read(&mut data[..], addr, 0x108)); assert_eq!(data, result); } #[test] fn test_virtio_mmio_device_write_01() { - let virtio_device = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_space = address_space_init(); - let mut virtio_mmio_device = VirtioMmioDevice::new(&sys_space, virtio_device.clone()); + let (virtio_device, mut virtio_mmio_device) = virtio_mmio_test_init(); let addr = GuestAddress(0); // write the selector for device features let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 2); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, DEVICE_FEATURES_SEL_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, DEVICE_FEATURES_SEL_REG)); assert_eq!(virtio_device.lock().unwrap().hfeatures_sel(), 2); // write the device features @@ -934,25 +835,16 @@ mod tests { .lock() .unwrap() .set_device_status(CONFIG_STATUS_FEATURES_OK); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG), - false - ); + assert!(!virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG)); virtio_device .lock() .unwrap() .set_device_status(CONFIG_STATUS_FAILED); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG), - false - ); + assert!(!virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG)); virtio_device.lock().unwrap().set_device_status( CONFIG_STATUS_FEATURES_OK | CONFIG_STATUS_FAILED | CONFIG_STATUS_DRIVER, ); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG), - false - ); + assert!(!virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG)); // it is ok to write the low 32bit of device features virtio_device .lock() @@ -962,10 +854,7 @@ mod tests { virtio_device.lock().unwrap().set_gfeatures_sel(0); LittleEndian::write_u32(&mut buf[..], 0x0000_00fe); virtio_device.lock().unwrap().base.device_features = 0x0000_00fe; - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG)); assert_eq!( virtio_device.lock().unwrap().base.driver_features as u32, 0x0000_00fe @@ -975,35 +864,26 @@ mod tests { virtio_device.lock().unwrap().set_gfeatures_sel(1); LittleEndian::write_u32(&mut buf[..], 0x0000_00ff); virtio_device.lock().unwrap().base.device_features = 0x0000_00ff_0000_0000; - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_REG)); assert_eq!( virtio_device.lock().unwrap().queue_type(), QUEUE_TYPE_PACKED_VRING ); assert_eq!( - virtio_device.lock().unwrap().base.driver_features >> 32 as u32, + virtio_device.lock().unwrap().base.driver_features >> 32_u32, 0x0000_00ff ); // write the selector of driver features let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 0x00ff_0000); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_SEL_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, DRIVER_FEATURES_SEL_REG)); assert_eq!(virtio_device.lock().unwrap().gfeatures_sel(), 0x00ff_0000); // write the selector of queue let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 0x0000_ff00); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_SEL_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_SEL_REG)); assert_eq!(virtio_device.lock().unwrap().queue_select(), 0x0000_ff00); // write the size of queue @@ -1014,10 +894,7 @@ mod tests { .unwrap() .set_device_status(CONFIG_STATUS_FEATURES_OK); LittleEndian::write_u32(&mut buf[..], 128); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_NUM_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_NUM_REG)); if let Ok(config) = virtio_device.lock().unwrap().queue_config() { assert_eq!(config.size, 128); } else { @@ -1027,9 +904,7 @@ mod tests { #[test] fn test_virtio_mmio_device_write_02() { - let virtio_device = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_space = address_space_init(); - let mut virtio_mmio_device = VirtioMmioDevice::new(&sys_space, virtio_device.clone()); + let (virtio_device, mut virtio_mmio_device) = virtio_mmio_test_init(); let addr = GuestAddress(0); // write the ready status of queue @@ -1040,15 +915,9 @@ mod tests { .unwrap() .set_device_status(CONFIG_STATUS_FEATURES_OK); LittleEndian::write_u32(&mut buf[..], 1); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_READY_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_READY_REG)); let mut data: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut data[..], addr, QUEUE_READY_REG), - true - ); + assert!(virtio_mmio_device.read(&mut data[..], addr, QUEUE_READY_REG)); assert_eq!(LittleEndian::read_u32(&data[..]), 1); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; @@ -1058,15 +927,9 @@ mod tests { .unwrap() .set_device_status(CONFIG_STATUS_FEATURES_OK); LittleEndian::write_u32(&mut buf[..], 2); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_READY_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_READY_REG)); let mut data: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut data[..], addr, QUEUE_READY_REG), - true - ); + assert!(virtio_mmio_device.read(&mut data[..], addr, QUEUE_READY_REG)); assert_eq!(LittleEndian::read_u32(&data[..]), 0); // write the interrupt status @@ -1080,23 +943,15 @@ mod tests { .unwrap() .set_interrupt_status(0b10_1111); LittleEndian::write_u32(&mut buf[..], 0b111); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, INTERRUPT_ACK_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, INTERRUPT_ACK_REG)); let mut data: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut data[..], addr, INTERRUPT_STATUS_REG), - true - ); + assert!(virtio_mmio_device.read(&mut data[..], addr, INTERRUPT_STATUS_REG)); assert_eq!(LittleEndian::read_u32(&data[..]), 0b10_1000); } #[test] fn test_virtio_mmio_device_write_03() { - let virtio_device = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_space = address_space_init(); - let mut virtio_mmio_device = VirtioMmioDevice::new(&sys_space, virtio_device.clone()); + let (virtio_device, mut virtio_mmio_device) = virtio_mmio_test_init(); let addr = GuestAddress(0); // write the low 32bit of queue's descriptor table address @@ -1107,10 +962,7 @@ mod tests { .set_device_status(CONFIG_STATUS_FEATURES_OK); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 0xffff_fefe); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_DESC_LOW_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_DESC_LOW_REG)); if let Ok(config) = virtio_mmio_device.device.lock().unwrap().queue_config() { assert_eq!(config.desc_table.0 as u32, 0xffff_fefe) } else { @@ -1125,10 +977,7 @@ mod tests { .set_device_status(CONFIG_STATUS_FEATURES_OK); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 0xfcfc_ffff); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_DESC_HIGH_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_DESC_HIGH_REG)); if let Ok(config) = virtio_device.lock().unwrap().queue_config() { assert_eq!((config.desc_table.0 >> 32) as u32, 0xfcfc_ffff) } else { @@ -1143,10 +992,7 @@ mod tests { .set_device_status(CONFIG_STATUS_FEATURES_OK); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 0xfcfc_fafa); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_AVAIL_LOW_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_AVAIL_LOW_REG)); if let Ok(config) = virtio_device.lock().unwrap().queue_config() { assert_eq!(config.avail_ring.0 as u32, 0xfcfc_fafa) } else { @@ -1161,10 +1007,7 @@ mod tests { .set_device_status(CONFIG_STATUS_FEATURES_OK); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 0xecec_fafa); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_AVAIL_HIGH_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_AVAIL_HIGH_REG)); if let Ok(config) = virtio_device.lock().unwrap().queue_config() { assert_eq!((config.avail_ring.0 >> 32) as u32, 0xecec_fafa) } else { @@ -1179,10 +1022,7 @@ mod tests { .set_device_status(CONFIG_STATUS_FEATURES_OK); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 0xacac_fafa); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_USED_LOW_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_USED_LOW_REG)); if let Ok(config) = virtio_device.lock().unwrap().queue_config() { assert_eq!(config.used_ring.0 as u32, 0xacac_fafa) } else { @@ -1197,10 +1037,7 @@ mod tests { .set_device_status(CONFIG_STATUS_FEATURES_OK); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], 0xcccc_fafa); - assert_eq!( - virtio_mmio_device.write(&buf[..], addr, QUEUE_USED_HIGH_REG), - true - ); + assert!(virtio_mmio_device.write(&buf[..], addr, QUEUE_USED_HIGH_REG)); if let Ok(config) = virtio_device.lock().unwrap().queue_config() { assert_eq!((config.used_ring.0 >> 32) as u32, 0xcccc_fafa) } else { @@ -1214,14 +1051,12 @@ mod tests { } else { 0 }; - (size + align_adjust) as u64 + size + align_adjust } #[test] fn test_virtio_mmio_device_write_04() { - let virtio_device = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_space = address_space_init(); - let mut virtio_mmio_device = VirtioMmioDevice::new(&sys_space, virtio_device.clone()); + let (virtio_device, mut virtio_mmio_device) = virtio_mmio_test_init(); let addr = GuestAddress(0); virtio_mmio_device.assign_interrupt_cb(); @@ -1230,9 +1065,9 @@ mod tests { locked_device.set_device_status(CONFIG_STATUS_FEATURES_OK); if let Ok(config) = locked_device.queue_config_mut(true) { config.desc_table = GuestAddress(0); - config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * 16); + config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * 16); config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * 16 + 8 + 2 * (QUEUE_SIZE as u64), + u64::from(QUEUE_SIZE) * 16 + 8 + 2 * u64::from(QUEUE_SIZE), 4096, )); config.size = QUEUE_SIZE; @@ -1241,9 +1076,9 @@ mod tests { locked_device.set_queue_select(1); if let Ok(config) = locked_device.queue_config_mut(true) { config.desc_table = GuestAddress(0); - config.avail_ring = GuestAddress((QUEUE_SIZE as u64) * 16); + config.avail_ring = GuestAddress(u64::from(QUEUE_SIZE) * 16); config.used_ring = GuestAddress(align( - (QUEUE_SIZE as u64) * 16 + 8 + 2 * (QUEUE_SIZE as u64), + u64::from(QUEUE_SIZE) * 16 + 8 + 2 * u64::from(QUEUE_SIZE), 4096, )); config.size = QUEUE_SIZE / 2; @@ -1254,13 +1089,10 @@ mod tests { // write the device status let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; LittleEndian::write_u32(&mut buf[..], CONFIG_STATUS_ACKNOWLEDGE); - assert_eq!(virtio_mmio_device.write(&buf[..], addr, STATUS_REG), true); - assert_eq!(virtio_device.lock().unwrap().device_activated(), false); + assert!(virtio_mmio_device.write(&buf[..], addr, STATUS_REG)); + assert!(!virtio_device.lock().unwrap().device_activated()); let mut data: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut data[..], addr, STATUS_REG), - true - ); + assert!(virtio_mmio_device.read(&mut data[..], addr, STATUS_REG)); assert_eq!(LittleEndian::read_u32(&data[..]), CONFIG_STATUS_ACKNOWLEDGE); let mut buf: Vec = vec![0xff, 0xff, 0xff, 0xff]; @@ -1271,15 +1103,12 @@ mod tests { | CONFIG_STATUS_DRIVER_OK | CONFIG_STATUS_FEATURES_OK, ); - assert_eq!(virtio_device.lock().unwrap().b_active, false); - assert_eq!(virtio_mmio_device.write(&buf[..], addr, STATUS_REG), true); - assert_eq!(virtio_device.lock().unwrap().device_activated(), true); - assert_eq!(virtio_device.lock().unwrap().b_active, true); + assert!(!virtio_device.lock().unwrap().b_active); + assert!(virtio_mmio_device.write(&buf[..], addr, STATUS_REG)); + assert!(virtio_device.lock().unwrap().device_activated()); + assert!(virtio_device.lock().unwrap().b_active); let mut data: Vec = vec![0xff, 0xff, 0xff, 0xff]; - assert_eq!( - virtio_mmio_device.read(&mut data[..], addr, STATUS_REG), - true - ); + assert!(virtio_mmio_device.read(&mut data[..], addr, STATUS_REG)); assert_eq!( LittleEndian::read_u32(&data[..]), CONFIG_STATUS_ACKNOWLEDGE diff --git a/virtio/src/transport/virtio_pci.rs b/virtio/src/transport/virtio_pci.rs index 041356f9a562412279df094bddd17bce9f69301a..405927b85b9111e468f8e3cadd057e8046daa813 100644 --- a/virtio/src/transport/virtio_pci.rs +++ b/virtio/src/transport/virtio_pci.rs @@ -12,7 +12,7 @@ use std::cmp::{max, min}; use std::mem::size_of; -use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; use std::sync::{Arc, Mutex, Weak}; use anyhow::{anyhow, bail, Context, Result}; @@ -31,11 +31,13 @@ use crate::{ CONFIG_STATUS_FEATURES_OK, CONFIG_STATUS_NEEDS_RESET, INVALID_VECTOR_NUM, QUEUE_TYPE_PACKED_VRING, QUEUE_TYPE_SPLIT_VRING, VIRTIO_F_RING_PACKED, VIRTIO_F_VERSION_1, VIRTIO_MMIO_INT_CONFIG, VIRTIO_MMIO_INT_VRING, VIRTIO_TYPE_BLOCK, VIRTIO_TYPE_CONSOLE, - VIRTIO_TYPE_FS, VIRTIO_TYPE_GPU, VIRTIO_TYPE_NET, VIRTIO_TYPE_SCSI, + VIRTIO_TYPE_FS, VIRTIO_TYPE_GPU, VIRTIO_TYPE_NET, VIRTIO_TYPE_SCSI, VIRTIO_TYPE_VSOCK, }; #[cfg(feature = "virtio_gpu")] use address_space::HostMemMapping; -use address_space::{AddressRange, AddressSpace, GuestAddress, Region, RegionIoEventFd, RegionOps}; +use address_space::{ + AddressAttr, AddressRange, AddressSpace, GuestAddress, Region, RegionIoEventFd, RegionOps, +}; use devices::pci::config::{ RegionType, BAR_SPACE_UNMAPPED, DEVICE_ID, MINIMUM_BAR_SIZE_FOR_MMIO, PCIE_CONFIG_SPACE_SIZE, PCI_SUBDEVICE_ID_QEMU, PCI_VENDOR_ID_REDHAT_QUMRANET, REG_SIZE, REVISION_ID, STATUS, @@ -46,15 +48,14 @@ use devices::pci::{ config::PciConfig, init_intx, init_msix, init_multifunction, le_write_u16, le_write_u32, PciBus, PciDevBase, PciDevOps, PciError, }; -use devices::{Device, DeviceBase}; +use devices::{convert_bus_ref, Bus, Device, DeviceBase, PCI_BUS}; #[cfg(feature = "virtio_gpu")] use machine_manager::config::VIRTIO_GPU_ENABLE_BAR0_SIZE; use migration::{DeviceStateDesc, FieldDesc, MigrationHook, MigrationManager, StateTransfer}; use migration_derive::{ByteCode, Desc}; use util::byte_code::ByteCode; -use util::num_ops::ranges_overlap; -use util::num_ops::{read_data_u32, write_data_u32}; -use util::offset_of; +use util::num_ops::{ranges_overlap, read_data_u32, write_data_u32}; +use util::{gen_base_func, offset_of}; const VIRTIO_QUEUE_MAX: u32 = 1024; @@ -171,7 +172,7 @@ fn get_virtio_class_id(device_type: u32, _device_quirk: Option VIRTIO_PCI_CLASS_ID_BLOCK, VIRTIO_TYPE_FS => VIRTIO_PCI_CLASS_ID_STORAGE_OTHER, VIRTIO_TYPE_NET => VIRTIO_PCI_CLASS_ID_NET, - VIRTIO_TYPE_CONSOLE => VIRTIO_PCI_CLASS_ID_COMMUNICATION_OTHER, + VIRTIO_TYPE_CONSOLE | VIRTIO_TYPE_VSOCK => VIRTIO_PCI_CLASS_ID_COMMUNICATION_OTHER, #[cfg(target_arch = "x86_64")] VIRTIO_TYPE_GPU => VIRTIO_PCI_CLASS_ID_DISPLAY_VGA, #[cfg(target_arch = "aarch64")] @@ -326,17 +327,17 @@ impl VirtioPciDevice { devfn: u8, sys_mem: Arc, device: Arc>, - parent_bus: Weak>, + parent_bus: Weak>, multi_func: bool, need_irqfd: bool, ) -> Self { let queue_num = device.lock().unwrap().queue_num(); VirtioPciDevice { base: PciDevBase { - base: DeviceBase::new(name, true), - config: PciConfig::new(PCIE_CONFIG_SPACE_SIZE, VIRTIO_PCI_BAR_MAX), + base: DeviceBase::new(name, true, Some(parent_bus)), + config: PciConfig::new(devfn, PCIE_CONFIG_SPACE_SIZE, VIRTIO_PCI_BAR_MAX), devfn, - parent_bus, + bme: Arc::new(AtomicBool::new(false)), }, device, dev_id: Arc::new(AtomicU16::new(0)), @@ -390,7 +391,9 @@ impl VirtioPciDevice { let mut locked_msix = cloned_msix.lock().unwrap(); if locked_msix.enabled { - locked_msix.notify(vector, dev_id.load(Ordering::Acquire)); + if vector != INVALID_VECTOR_NUM { + locked_msix.notify(vector, dev_id.load(Ordering::Acquire)); + } } else { cloned_intx.lock().unwrap().notify(1); } @@ -468,11 +471,9 @@ impl VirtioPciDevice { } locked_dev.virtio_base_mut().queues = queues; - let parent = self.base.parent_bus.upgrade().unwrap(); - parent - .lock() - .unwrap() - .update_dev_id(self.base.devfn, &self.dev_id); + let bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); + pci_bus.update_dev_id(self.base.devfn, &self.dev_id); if self.need_irqfd { let mut queue_num = locked_dev.queue_num(); // No need to create call event for control queue. @@ -576,11 +577,11 @@ impl VirtioPciDevice { 0 } } - COMMON_MSIX_REG => locked_device.config_vector() as u32, + COMMON_MSIX_REG => u32::from(locked_device.config_vector()), COMMON_NUMQ_REG => locked_device.virtio_base().queues_config.len() as u32, COMMON_STATUS_REG => locked_device.device_status(), - COMMON_CFGGENERATION_REG => locked_device.config_generation() as u32, - COMMON_Q_SELECT_REG => locked_device.queue_select() as u32, + COMMON_CFGGENERATION_REG => u32::from(locked_device.config_generation()), + COMMON_Q_SELECT_REG => u32::from(locked_device.queue_select()), COMMON_Q_SIZE_REG => locked_device .queue_config() .map(|config| u32::from(config.size))?, @@ -590,7 +591,7 @@ impl VirtioPciDevice { COMMON_Q_ENABLE_REG => locked_device .queue_config() .map(|config| u32::from(config.ready))?, - COMMON_Q_NOFF_REG => locked_device.queue_select() as u32, + COMMON_Q_NOFF_REG => u32::from(locked_device.queue_select()), COMMON_Q_DESCLO_REG => locked_device .queue_config() .map(|config| config.desc_table.0 as u32)?, @@ -647,7 +648,7 @@ impl VirtioPciDevice { locked_device.set_driver_features(gfeatures_sel, value); if gfeatures_sel == 1 { - let features = (locked_device.driver_features(1) as u64) << 32; + let features = u64::from(locked_device.driver_features(1)) << 32; if virtio_has_feature(features, VIRTIO_F_RING_PACKED) { locked_device.set_queue_type(QUEUE_TYPE_PACKED_VRING); } else { @@ -665,7 +666,7 @@ impl VirtioPciDevice { } COMMON_STATUS_REG => { if value & CONFIG_STATUS_FEATURES_OK != 0 && value & CONFIG_STATUS_DRIVER_OK == 0 { - let features = (locked_device.driver_features(1) as u64) << 32; + let features = u64::from(locked_device.driver_features(1)) << 32; if !virtio_has_feature(features, VIRTIO_F_VERSION_1) { error!( "Device {} is modern only, but the driver not support VIRTIO_F_VERSION_1", self.base.base.id @@ -774,7 +775,7 @@ impl VirtioPciDevice { }; let common_write = move |data: &[u8], _addr: GuestAddress, offset: u64| -> bool { - let mut value = 0; + let mut value: u32 = 0; if !read_data_u32(data, &mut value) { return false; } @@ -926,8 +927,8 @@ impl VirtioPciDevice { warn!("The offset {} of VirtioPciCfgAccessCap is not aligned", off); return; } - if (off as u64) - .checked_add(len as u64) + if u64::from(off) + .checked_add(u64::from(len)) .filter(|&end| end <= self.base.config.bars[bar as usize].size) .is_none() { @@ -937,12 +938,20 @@ impl VirtioPciDevice { let result = if is_write { let mut data = self.base.config.config[pci_cfg_data_offset..].as_ref(); - self.sys_mem - .write(&mut data, GuestAddress(bar_base + off as u64), len as u64) + self.sys_mem.write( + &mut data, + GuestAddress(bar_base + u64::from(off)), + u64::from(len), + AddressAttr::MMIO, + ) } else { let mut data = self.base.config.config[pci_cfg_data_offset..].as_mut(); - self.sys_mem - .read(&mut data, GuestAddress(bar_base + off as u64), len as u64) + self.sys_mem.read( + &mut data, + GuestAddress(bar_base + u64::from(off)), + u64::from(len), + AddressAttr::MMIO, + ) }; if let Err(e) = result { error!( @@ -957,7 +966,7 @@ impl VirtioPciDevice { // its own request completion. i.e, If the vq is not enough, vcpu A will // receive completion of request that submitted by vcpu B, then A needs // to IPI B. - min(queues_max as u16 - queues_fixed, nr_cpus as u16) + min(queues_max as u16 - queues_fixed, u16::from(nr_cpus)) } fn queues_register_irqfd(&self, call_fds: &[Arc]) -> bool { @@ -1003,26 +1012,24 @@ impl VirtioPciDevice { } impl Device for VirtioPciDevice { - fn device_base(&self) -> &DeviceBase { - &self.base.base - } - - fn device_base_mut(&mut self) -> &mut DeviceBase { - &mut self.base.base - } -} + gen_base_func!(device_base, device_base_mut, DeviceBase, base.base); -impl PciDevOps for VirtioPciDevice { - fn pci_base(&self) -> &PciDevBase { - &self.base - } + fn reset(&mut self, _reset_child_device: bool) -> Result<()> { + info!("func: reset, id: {:?}", &self.base.base.id); + self.deactivate_device(); + self.device + .lock() + .unwrap() + .reset() + .with_context(|| "Failed to reset virtio device")?; + self.base.config.reset()?; - fn pci_base_mut(&mut self) -> &mut PciDevBase { - &mut self.base + Ok(()) } - fn realize(mut self) -> Result<()> { + fn realize(mut self) -> Result>> { info!("func: realize, id: {:?}", &self.base.base.id); + let parent_bus = self.parent_bus().unwrap(); self.init_write_mask(false)?; self.init_write_clear_mask(false)?; @@ -1062,7 +1069,7 @@ impl PciDevOps for VirtioPciDevice { self.multi_func, &mut self.base.config.config, self.base.devfn, - self.base.parent_bus.clone(), + parent_bus.clone(), )?; #[cfg(target_arch = "aarch64")] self.base.config.set_interrupt_pin(); @@ -1134,7 +1141,7 @@ impl PciDevOps for VirtioPciDevice { init_intx( self.name(), &mut self.base.config, - self.base.parent_bus.clone(), + parent_bus.clone(), self.base.devfn, )?; @@ -1152,11 +1159,11 @@ impl PciDevOps for VirtioPciDevice { .with_context(|| "Failed to realize virtio device")?; let name = self.name(); - let devfn = self.base.devfn; + let devfn = u64::from(self.base.devfn); let dev = Arc::new(Mutex::new(self)); - let mut mem_region_size = ((VIRTIO_PCI_CAP_NOTIFY_OFFSET + VIRTIO_PCI_CAP_NOTIFY_LENGTH) - as u64) - .next_power_of_two(); + let mut mem_region_size = + u64::from(VIRTIO_PCI_CAP_NOTIFY_OFFSET + VIRTIO_PCI_CAP_NOTIFY_LENGTH) + .next_power_of_two(); mem_region_size = max(mem_region_size, MINIMUM_BAR_SIZE_FOR_MMIO as u64); let modern_mem_region = Region::init_container_region(mem_region_size, "VirtioPciModernMem"); @@ -1171,22 +1178,16 @@ impl PciDevOps for VirtioPciDevice { )?; // Register device to pci bus. - let pci_bus = dev.lock().unwrap().base.parent_bus.upgrade().unwrap(); - let mut locked_pci_bus = pci_bus.lock().unwrap(); - let pci_device = locked_pci_bus.devices.get(&devfn); - if pci_device.is_none() { - locked_pci_bus.devices.insert(devfn, dev.clone()); - } else { - bail!( - "Devfn {:?} has been used by {:?}", - &devfn, - pci_device.unwrap().lock().unwrap().name() - ); - } + let bus = parent_bus.upgrade().unwrap(); + bus.lock().unwrap().attach_child(devfn, dev.clone())?; - MigrationManager::register_transport_instance(VirtioPciState::descriptor(), dev, &name); + MigrationManager::register_transport_instance( + VirtioPciState::descriptor(), + dev.clone(), + &name, + ); - Ok(()) + Ok(dev) } fn unrealize(&mut self) -> Result<()> { @@ -1197,7 +1198,7 @@ impl PciDevOps for VirtioPciDevice { .unrealize() .with_context(|| "Failed to unrealize the virtio device")?; - let bus = self.base.parent_bus.upgrade().unwrap(); + let bus = self.parent_bus().unwrap().upgrade().unwrap(); self.base.config.unregister_bars(&bus)?; MigrationManager::unregister_device_instance(MsixState::descriptor(), &self.name()); @@ -1205,6 +1206,10 @@ impl PciDevOps for VirtioPciDevice { Ok(()) } +} + +impl PciDevOps for VirtioPciDevice { + gen_base_func!(pci_base, pci_base_mut, PciDevBase, base); fn read_config(&mut self, offset: usize, data: &mut [u8]) { trace::virtio_tpt_read_config(&self.base.base.id, offset as u64, data.len()); @@ -1224,34 +1229,21 @@ impl PciDevOps for VirtioPciDevice { } trace::virtio_tpt_write_config(&self.base.base.id, offset as u64, data); - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let locked_parent_bus = parent_bus.lock().unwrap(); + let bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(bus, locked_bus, pci_bus); self.base.config.write( offset, data, self.dev_id.clone().load(Ordering::Acquire), #[cfg(target_arch = "x86_64")] - Some(&locked_parent_bus.io_region), - Some(&locked_parent_bus.mem_region), + Some(&pci_bus.io_region), + Some(&pci_bus.mem_region), ); self.do_cfg_access(offset, end, true); } - fn reset(&mut self, _reset_child_device: bool) -> Result<()> { - info!("func: reset, id: {:?}", &self.base.base.id); - self.deactivate_device(); - self.device - .lock() - .unwrap() - .reset() - .with_context(|| "Failed to reset virtio device")?; - self.base.config.reset()?; - - Ok(()) - } - fn get_dev_path(&self) -> Option { - let parent_bus = self.base.parent_bus.upgrade().unwrap(); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); match self.device.lock().unwrap().device_type() { VIRTIO_TYPE_BLOCK => { // The virtio blk device is identified as a single-channel SCSI device, @@ -1341,12 +1333,12 @@ impl MigrationHook for VirtioPciDevice { } // Reregister ioevents for notifies. - let parent_bus = self.base.parent_bus.upgrade().unwrap(); - let locked_parent_bus = parent_bus.lock().unwrap(); + let parent_bus = self.parent_bus().unwrap().upgrade().unwrap(); + PCI_BUS!(parent_bus, locked_bus, pci_bus); if let Err(e) = self.base.config.update_bar_mapping( #[cfg(target_arch = "x86_64")] - Some(&locked_parent_bus.io_region), - Some(&locked_parent_bus.mem_region), + Some(&pci_bus.io_region), + Some(&pci_bus.mem_region), ) { bail!("Failed to update bar, error is {:?}", e); } @@ -1383,8 +1375,9 @@ mod tests { use vmm_sys_util::eventfd::EventFd; use super::*; + use crate::tests::address_space_init; use crate::VirtioBase; - use address_space::{AddressSpace, GuestAddress, HostMemMapping}; + use address_space::{AddressSpace, GuestAddress}; use devices::pci::{ config::{HEADER_TYPE, HEADER_TYPE_MULTIFUNC}, le_read_u16, @@ -1394,13 +1387,13 @@ mod tests { const VIRTIO_DEVICE_QUEUE_NUM: usize = 2; const VIRTIO_DEVICE_QUEUE_SIZE: u16 = 256; - pub struct VirtioDeviceTest { + struct VirtioDeviceTest { base: VirtioBase, - pub is_activated: bool, + is_activated: bool, } impl VirtioDeviceTest { - pub fn new() -> Self { + fn new() -> Self { let mut base = VirtioBase::new( VIRTIO_DEVICE_TEST_TYPE, VIRTIO_DEVICE_QUEUE_NUM, @@ -1415,13 +1408,7 @@ mod tests { } impl VirtioDevice for VirtioDeviceTest { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { self.init_config_features()?; @@ -1466,31 +1453,40 @@ mod tests { }; } - #[test] - fn test_common_config_dev_feature() { + fn virtio_pci_test_init( + multi_func: bool, + ) -> ( + Arc>, + Arc>, + VirtioPciDevice, + ) { let virtio_dev = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sysmem"), - "sysmem", - None, - ) - .unwrap(); + let sys_mem = address_space_init(); let parent_bus = Arc::new(Mutex::new(PciBus::new( String::from("test bus"), #[cfg(target_arch = "x86_64")] Region::init_container_region(1 << 16, "parent_bus"), sys_mem.root().clone(), ))); - let mut virtio_pci = VirtioPciDevice::new( + let virtio_pci = VirtioPciDevice::new( String::from("test device"), 0, sys_mem, virtio_dev.clone(), - Arc::downgrade(&parent_bus), - false, + Arc::downgrade(&(parent_bus.clone() as Arc>)), + multi_func, false, ); + // Note: if parent_bus is used in the code execution during the testing process, a variable needs to + // be used to maintain the count and avoid rust from automatically releasing this `Arc`. + (virtio_dev, parent_bus, virtio_pci) + } + + #[test] + fn test_common_config_dev_feature() { + let (virtio_dev, _, mut virtio_pci) = virtio_pci_test_init(false); + // Read virtio device features virtio_dev.lock().unwrap().set_hfeatures_sel(0_u32); com_cfg_read_test!(virtio_pci, COMMON_DF_REG, 0xFFFF_FFF0_u32); @@ -1526,28 +1522,7 @@ mod tests { #[test] fn test_common_config_queue() { - let virtio_dev = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sysmem"), - "sysmem", - None, - ) - .unwrap(); - let parent_bus = Arc::new(Mutex::new(PciBus::new( - String::from("test bus"), - #[cfg(target_arch = "x86_64")] - Region::init_container_region(1 << 16, "parent_bus"), - sys_mem.root().clone(), - ))); - let virtio_pci = VirtioPciDevice::new( - String::from("test device"), - 0, - sys_mem, - virtio_dev.clone(), - Arc::downgrade(&parent_bus), - false, - false, - ); + let (virtio_dev, _, virtio_pci) = virtio_pci_test_init(false); // Read Queue's Descriptor Table address virtio_dev @@ -1556,50 +1531,28 @@ mod tests { .set_queue_select(VIRTIO_DEVICE_QUEUE_NUM as u16 - 1); let queue_select = virtio_dev.lock().unwrap().queue_select(); virtio_dev.lock().unwrap().virtio_base_mut().queues_config[queue_select as usize] - .desc_table = GuestAddress(0xAABBCCDD_FFEEDDAA); + .desc_table = GuestAddress(0xAABB_CCDD_FFEE_DDAA); com_cfg_read_test!(virtio_pci, COMMON_Q_DESCLO_REG, 0xFFEEDDAA_u32); com_cfg_read_test!(virtio_pci, COMMON_Q_DESCHI_REG, 0xAABBCCDD_u32); // Read Queue's Available Ring address virtio_dev.lock().unwrap().set_queue_select(0); virtio_dev.lock().unwrap().virtio_base_mut().queues_config[0].avail_ring = - GuestAddress(0x11223344_55667788); + GuestAddress(0x1122_3344_5566_7788); com_cfg_read_test!(virtio_pci, COMMON_Q_AVAILLO_REG, 0x55667788_u32); com_cfg_read_test!(virtio_pci, COMMON_Q_AVAILHI_REG, 0x11223344_u32); // Read Queue's Used Ring address virtio_dev.lock().unwrap().set_queue_select(0); virtio_dev.lock().unwrap().virtio_base_mut().queues_config[0].used_ring = - GuestAddress(0x55667788_99AABBCC); + GuestAddress(0x5566_7788_99AA_BBCC); com_cfg_read_test!(virtio_pci, COMMON_Q_USEDLO_REG, 0x99AABBCC_u32); com_cfg_read_test!(virtio_pci, COMMON_Q_USEDHI_REG, 0x55667788_u32); } #[test] fn test_common_config_queue_error() { - let virtio_dev = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sysmem"), - "sysmem", - None, - ) - .unwrap(); - let parent_bus = Arc::new(Mutex::new(PciBus::new( - String::from("test bus"), - #[cfg(target_arch = "x86_64")] - Region::init_container_region(1 << 16, "parent_bus"), - sys_mem.root().clone(), - ))); - let cloned_virtio_dev = virtio_dev.clone(); - let mut virtio_pci = VirtioPciDevice::new( - String::from("test device"), - 0, - sys_mem, - cloned_virtio_dev, - Arc::downgrade(&parent_bus), - false, - false, - ); + let (virtio_dev, _, mut virtio_pci) = virtio_pci_test_init(false); assert!(init_msix( &mut virtio_pci.base, @@ -1634,7 +1587,7 @@ mod tests { .unwrap() .virtio_base() .queues_config - .get(0) + .first() .unwrap() .ready ); @@ -1652,29 +1605,8 @@ mod tests { #[test] fn test_virtio_pci_config_access() { - let virtio_dev: Arc> = - Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sysmem"), - "sysmem", - None, - ) - .unwrap(); - let parent_bus = Arc::new(Mutex::new(PciBus::new( - String::from("test bus"), - #[cfg(target_arch = "x86_64")] - Region::init_container_region(1 << 16, "parent_bus"), - sys_mem.root().clone(), - ))); - let mut virtio_pci = VirtioPciDevice::new( - String::from("test device"), - 0, - sys_mem, - virtio_dev, - Arc::downgrade(&parent_bus), - false, - false, - ); + let (_, _parent_bus, mut virtio_pci) = virtio_pci_test_init(false); + virtio_pci.init_write_mask(false).unwrap(); virtio_pci.init_write_clear_mask(false).unwrap(); @@ -1693,69 +1625,14 @@ mod tests { #[test] fn test_virtio_pci_realize() { - let virtio_dev: Arc> = - Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sysmem"), - "sysmem", - None, - ) - .unwrap(); - let parent_bus = Arc::new(Mutex::new(PciBus::new( - String::from("test bus"), - #[cfg(target_arch = "x86_64")] - Region::init_container_region(1 << 16, "parent_bus"), - sys_mem.root().clone(), - ))); - let virtio_pci = VirtioPciDevice::new( - String::from("test device"), - 0, - sys_mem, - virtio_dev, - Arc::downgrade(&parent_bus), - false, - false, - ); + let (_, _parent_bus, virtio_pci) = virtio_pci_test_init(false); assert!(virtio_pci.realize().is_ok()); } #[test] fn test_device_activate() { - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sysmem"), - "sysmem", - None, - ) - .unwrap(); - let mem_size: u64 = 1024 * 1024; - let host_mmap = Arc::new( - HostMemMapping::new(GuestAddress(0), None, mem_size, None, false, false, false) - .unwrap(), - ); - sys_mem - .root() - .add_subregion( - Region::init_ram_region(host_mmap.clone(), "sysmem"), - host_mmap.start_address().raw_value(), - ) - .unwrap(); + let (virtio_dev, _parent_bus, mut virtio_pci) = virtio_pci_test_init(false); - let virtio_dev = Arc::new(Mutex::new(VirtioDeviceTest::new())); - let parent_bus = Arc::new(Mutex::new(PciBus::new( - String::from("test bus"), - #[cfg(target_arch = "x86_64")] - Region::init_container_region(1 << 16, "parent_bus"), - sys_mem.root().clone(), - ))); - let mut virtio_pci = VirtioPciDevice::new( - String::from("test device"), - 0, - sys_mem, - virtio_dev.clone(), - Arc::downgrade(&parent_bus), - false, - false, - ); #[cfg(target_arch = "aarch64")] virtio_pci.base.config.set_interrupt_pin(); @@ -1769,10 +1646,11 @@ mod tests { ) .unwrap(); + let parent_bus = virtio_pci.parent_bus().unwrap(); init_intx( virtio_pci.name(), &mut virtio_pci.base.config, - virtio_pci.base.parent_bus.clone(), + parent_bus.clone(), virtio_pci.base.devfn, ) .unwrap(); @@ -1788,7 +1666,7 @@ mod tests { .iter_mut() { queue_cfg.desc_table = GuestAddress(0); - queue_cfg.avail_ring = GuestAddress((VIRTIO_DEVICE_QUEUE_SIZE as u64) * 16); + queue_cfg.avail_ring = GuestAddress(u64::from(VIRTIO_DEVICE_QUEUE_SIZE) * 16); queue_cfg.used_ring = GuestAddress(2 * 4096); queue_cfg.ready = true; queue_cfg.size = VIRTIO_DEVICE_QUEUE_SIZE; @@ -1800,7 +1678,7 @@ mod tests { let status = (CONFIG_STATUS_ACKNOWLEDGE | CONFIG_STATUS_DRIVER | CONFIG_STATUS_FEATURES_OK) .as_bytes(); (common_cfg_ops.write)(status, GuestAddress(0), COMMON_STATUS_REG); - assert_eq!(virtio_dev.lock().unwrap().device_activated(), false); + assert!(!virtio_dev.lock().unwrap().device_activated()); // Device status is not ok, failed to activate virtio device let status = (CONFIG_STATUS_ACKNOWLEDGE | CONFIG_STATUS_DRIVER @@ -1808,7 +1686,7 @@ mod tests { | CONFIG_STATUS_FEATURES_OK) .as_bytes(); (common_cfg_ops.write)(status, GuestAddress(0), COMMON_STATUS_REG); - assert_eq!(virtio_dev.lock().unwrap().device_activated(), false); + assert!(!virtio_dev.lock().unwrap().device_activated()); // Status is ok, virtio device is activated. let status = (CONFIG_STATUS_ACKNOWLEDGE | CONFIG_STATUS_DRIVER @@ -1816,48 +1694,27 @@ mod tests { | CONFIG_STATUS_FEATURES_OK) .as_bytes(); (common_cfg_ops.write)(status, GuestAddress(0), COMMON_STATUS_REG); - assert_eq!(virtio_dev.lock().unwrap().device_activated(), true); + assert!(virtio_dev.lock().unwrap().device_activated()); // If device status(not zero) is set to zero, reset the device (common_cfg_ops.write)(0_u32.as_bytes(), GuestAddress(0), COMMON_STATUS_REG); - assert_eq!(virtio_dev.lock().unwrap().device_activated(), false); + assert!(!virtio_dev.lock().unwrap().device_activated()); } #[test] fn test_multifunction() { - let virtio_dev: Arc> = - Arc::new(Mutex::new(VirtioDeviceTest::new())); - let sys_mem = AddressSpace::new( - Region::init_container_region(u64::max_value(), "sysmem"), - "sysmem", - None, - ) - .unwrap(); - let parent_bus = Arc::new(Mutex::new(PciBus::new( - String::from("test bus"), - #[cfg(target_arch = "x86_64")] - Region::init_container_region(1 << 16, "parent_bus"), - sys_mem.root().clone(), - ))); - let mut virtio_pci = VirtioPciDevice::new( - String::from("test device"), - 24, - sys_mem, - virtio_dev, - Arc::downgrade(&parent_bus), - true, - false, - ); + let (_, _parent_bus, mut virtio_pci) = virtio_pci_test_init(true); + let parent_bus = virtio_pci.parent_bus().unwrap(); assert!(init_multifunction( virtio_pci.multi_func, &mut virtio_pci.base.config.config, virtio_pci.base.devfn, - virtio_pci.base.parent_bus.clone() + parent_bus, ) .is_ok()); let header_type = le_read_u16(&virtio_pci.base.config.config, HEADER_TYPE as usize).unwrap(); - assert_eq!(header_type, HEADER_TYPE_MULTIFUNC as u16); + assert_eq!(header_type, u16::from(HEADER_TYPE_MULTIFUNC)); } } diff --git a/virtio/src/vhost/kernel/mod.rs b/virtio/src/vhost/kernel/mod.rs index 02d55cadbc1e10bbcc10ad208da18e47bf1b708f..e04b1589abbe2ff6e4c6c680e313de846921c8bb 100644 --- a/virtio/src/vhost/kernel/mod.rs +++ b/virtio/src/vhost/kernel/mod.rs @@ -10,11 +10,15 @@ // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. // See the Mulan PSL v2 for more details. +#[cfg(feature = "vhost_net")] mod net; +#[cfg(feature = "vhost_vsock")] mod vsock; +#[cfg(feature = "vhost_net")] pub use net::Net; -pub use vsock::{Vsock, VsockState}; +#[cfg(feature = "vhost_vsock")] +pub use vsock::{Vsock, VsockConfig, VsockState}; use std::fs::{File, OpenOptions}; use std::os::unix::fs::OpenOptionsExt; @@ -30,7 +34,8 @@ use super::super::QueueConfig; use super::VhostOps; use crate::VirtioError; use address_space::{ - AddressSpace, FlatRange, GuestAddress, Listener, ListenerReqType, RegionIoEventFd, RegionType, + AddressAttr, AddressSpace, FlatRange, GuestAddress, Listener, ListenerReqType, RegionIoEventFd, + RegionType, }; use util::byte_code::ByteCode; @@ -154,7 +159,9 @@ impl VhostMemInfo { fn add_mem_range(&self, fr: &FlatRange) { let guest_phys_addr = fr.addr_range.base.raw_value(); let memory_size = fr.addr_range.size; - let userspace_addr = fr.owner.get_host_address().unwrap() + fr.offset_in_region; + let userspace_addr = + // SAFETY: memory_size is range's size, so we make sure [hva, hva+size] is in ram range. + unsafe { fr.owner.get_host_address(AddressAttr::Ram).unwrap() } + fr.offset_in_region; self.regions.lock().unwrap().push(VhostMemoryRegion { guest_phys_addr, @@ -169,7 +176,9 @@ impl VhostMemInfo { let target = VhostMemoryRegion { guest_phys_addr: fr.addr_range.base.raw_value(), memory_size: fr.addr_range.size, - userspace_addr: fr.owner.get_host_address().unwrap() + fr.offset_in_region, + // SAFETY: memory_size is range's size, so we make sure [hva, hva+size] is in ram range. + userspace_addr: unsafe { fr.owner.get_host_address(AddressAttr::Ram).unwrap() } + + fr.offset_in_region, flags_padding: 0_u64, }; for (index, mr) in mem_regions.iter().enumerate() { diff --git a/virtio/src/vhost/kernel/net.rs b/virtio/src/vhost/kernel/net.rs index 55a49724204a2dce3aad596d1f0bb7dd2f6bc80d..1131bac37dcf6d6ef06a499a8fee3f6b254b8ccd 100644 --- a/virtio/src/vhost/kernel/net.rs +++ b/virtio/src/vhost/kernel/net.rs @@ -31,10 +31,11 @@ use crate::{ VIRTIO_NET_F_HOST_TSO4, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MQ, VIRTIO_TYPE_NET, }; use address_space::AddressSpace; -use machine_manager::config::NetworkInterfaceConfig; +use machine_manager::config::{NetDevcfg, NetworkInterfaceConfig}; use machine_manager::event_loop::{register_event_helper, unregister_event_helper}; use util::byte_code::ByteCode; -use util::loop_context::EventNotifierHelper; +use util::gen_base_func; +use util::loop_context::{create_new_eventfd, EventNotifierHelper}; use util::tap::Tap; /// Number of virtqueues. @@ -79,6 +80,8 @@ pub struct Net { base: VirtioBase, /// Configuration of the network device. net_cfg: NetworkInterfaceConfig, + /// Configuration of the backend netdev. + netdev_cfg: NetDevcfg, /// Virtio net configurations. config_space: Arc>, /// Tap device opened. @@ -94,17 +97,22 @@ pub struct Net { } impl Net { - pub fn new(cfg: &NetworkInterfaceConfig, mem_space: &Arc) -> Self { - let queue_num = if cfg.mq { - (cfg.queues + 1) as usize + pub fn new( + net_cfg: &NetworkInterfaceConfig, + netdev_cfg: NetDevcfg, + mem_space: &Arc, + ) -> Self { + let queue_num = if net_cfg.mq { + (netdev_cfg.queues + 1) as usize } else { QUEUE_NUM_NET }; - let queue_size = cfg.queue_size; + let queue_size = net_cfg.queue_size; Net { base: VirtioBase::new(VIRTIO_TYPE_NET, queue_num, queue_size), - net_cfg: cfg.clone(), + net_cfg: net_cfg.clone(), + netdev_cfg, config_space: Default::default(), taps: None, backends: None, @@ -116,19 +124,13 @@ impl Net { } impl VirtioDevice for Net { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { - let queue_pairs = self.net_cfg.queues / 2; + let queue_pairs = self.netdev_cfg.queues / 2; let mut backends = Vec::with_capacity(queue_pairs as usize); for index in 0..queue_pairs { - let fd = if let Some(fds) = self.net_cfg.vhost_fds.as_mut() { + let fd = if let Some(fds) = self.netdev_cfg.vhost_fds.as_mut() { fds.get(index as usize).copied() } else { None @@ -142,12 +144,12 @@ impl VirtioDevice for Net { backends.push(backend); } - let host_dev_name = match self.net_cfg.host_dev_name.as_str() { + let host_dev_name = match self.netdev_cfg.ifname.as_str() { "" => None, - _ => Some(self.net_cfg.host_dev_name.as_str()), + _ => Some(self.netdev_cfg.ifname.as_str()), }; - self.taps = create_tap(self.net_cfg.tap_fds.as_ref(), host_dev_name, queue_pairs) + self.taps = create_tap(self.netdev_cfg.tap_fds.as_ref(), host_dev_name, queue_pairs) .with_context(|| "Failed to create tap for vhost net")?; self.backends = Some(backends); @@ -174,7 +176,7 @@ impl VirtioDevice for Net { let mut locked_config = self.config_space.lock().unwrap(); - let queue_pairs = self.net_cfg.queues / 2; + let queue_pairs = self.netdev_cfg.queues / 2; if self.net_cfg.mq && (VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MIN..=VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MAX) .contains(&queue_pairs) @@ -320,8 +322,7 @@ impl VirtioDevice for Net { let event = if self.call_events.is_empty() { let host_notify = VhostNotify { notify_evt: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) - .with_context(|| VirtioError::EventFdCreate)?, + create_new_eventfd().with_context(|| VirtioError::EventFdCreate)?, ), queue: queue_mutex.clone(), }; @@ -385,7 +386,7 @@ impl VirtioDevice for Net { } fn reset(&mut self) -> Result<()> { - let queue_pairs = self.net_cfg.queues / 2; + let queue_pairs = self.netdev_cfg.queues / 2; for index in 0..queue_pairs as usize { let backend = match &self.backends { None => return Err(anyhow!("Failed to get backend for vhost net")), @@ -409,78 +410,48 @@ mod tests { use std::fs::File; use super::*; - use address_space::*; + use crate::tests::address_space_init; use machine_manager::config::DEFAULT_VIRTQUEUE_SIZE; - const SYSTEM_SPACE_SIZE: u64 = (1024 * 1024) as u64; - - fn vhost_address_space_init() -> Arc { - let root = Region::init_container_region(1 << 36, "sysmem"); - let sys_space = AddressSpace::new(root, "sysmem", None).unwrap(); - let host_mmap = Arc::new( - HostMemMapping::new( - GuestAddress(0), - None, - SYSTEM_SPACE_SIZE, - None, - false, - false, - false, - ) - .unwrap(), - ); - sys_space - .root() - .add_subregion( - Region::init_ram_region(host_mmap.clone(), "sysmem"), - host_mmap.start_address().raw_value(), - ) - .unwrap(); - sys_space - } - #[test] fn test_vhost_net_realize() { - let net1 = NetworkInterfaceConfig { - id: "eth1".to_string(), - host_dev_name: "tap1".to_string(), - mac: Some("1F:2C:3E:4A:5B:6D".to_string()), - vhost_type: Some("vhost-kernel".to_string()), + let netdev_cfg1 = NetDevcfg { + netdev_type: "tap".to_string(), + id: "net1".to_string(), tap_fds: Some(vec![4]), + vhost_kernel: true, vhost_fds: Some(vec![5]), - iothread: None, + ifname: "tap1".to_string(), queues: 2, + ..Default::default() + }; + let vhost_net_conf = NetworkInterfaceConfig { + id: "eth1".to_string(), + mac: Some("1F:2C:3E:4A:5B:6D".to_string()), + iothread: None, mq: false, - socket_path: None, queue_size: DEFAULT_VIRTQUEUE_SIZE, + ..Default::default() }; - let conf = vec![net1]; - let confs = Some(conf); - let vhost_net_confs = confs.unwrap(); - let vhost_net_conf = vhost_net_confs[0].clone(); - let vhost_net_space = vhost_address_space_init(); - let mut vhost_net = Net::new(&vhost_net_conf, &vhost_net_space); + let vhost_net_space = address_space_init(); + let mut vhost_net = Net::new(&vhost_net_conf, netdev_cfg1, &vhost_net_space); // the tap_fd and vhost_fd attribute of vhost-net can't be assigned. - assert_eq!(vhost_net.realize().is_ok(), false); + assert!(vhost_net.realize().is_err()); - let net1 = NetworkInterfaceConfig { - id: "eth0".to_string(), - host_dev_name: "".to_string(), - mac: Some("1A:2B:3C:4D:5E:6F".to_string()), - vhost_type: Some("vhost-kernel".to_string()), - tap_fds: None, - vhost_fds: None, - iothread: None, + let netdev_cfg2 = NetDevcfg { + netdev_type: "tap".to_string(), + id: "net2".to_string(), + vhost_kernel: true, queues: 2, - mq: false, - socket_path: None, + ..Default::default() + }; + let net_cfg2 = NetworkInterfaceConfig { + id: "eth2".to_string(), + mac: Some("1A:2B:3C:4D:5E:6F".to_string()), queue_size: DEFAULT_VIRTQUEUE_SIZE, + ..Default::default() }; - let conf = vec![net1]; - let confs = Some(conf); - let vhost_net_confs = confs.unwrap(); - let vhost_net_conf = vhost_net_confs[0].clone(); - let mut vhost_net = Net::new(&vhost_net_conf, &vhost_net_space); + let mut vhost_net = Net::new(&net_cfg2, netdev_cfg2, &vhost_net_space); // if fail to open vhost-net device, no need to continue. if let Err(_e) = File::open("/dev/vhost-net") { @@ -488,14 +459,14 @@ mod tests { } // without assigned value of tap_fd and vhost_fd, // vhost-net device can be realized successfully. - assert_eq!(vhost_net.realize().is_ok(), true); + assert!(vhost_net.realize().is_ok()); // test for get/set_driver_features vhost_net.base.device_features = 0; let page: u32 = 0x0; let value: u32 = 0xff; vhost_net.set_driver_features(page, value); - assert_eq!(vhost_net.driver_features(page) as u64, 0_u64); + assert_eq!(u64::from(vhost_net.driver_features(page)), 0_u64); let new_page = vhost_net.device_features(page); assert_eq!(new_page, page); @@ -503,7 +474,7 @@ mod tests { let page: u32 = 0x0; let value: u32 = 0xff; vhost_net.set_driver_features(page, value); - assert_eq!(vhost_net.driver_features(page) as u64, 0xff_u64); + assert_eq!(u64::from(vhost_net.driver_features(page)), 0xff_u64); let new_page = vhost_net.device_features(page); assert_ne!(new_page, page); @@ -511,22 +482,22 @@ mod tests { let len = vhost_net.config_space.lock().unwrap().as_bytes().len() as u64; let offset: u64 = 0; let data: Vec = vec![1; len as usize]; - assert_eq!(vhost_net.write_config(offset, &data).is_ok(), true); + assert!(vhost_net.write_config(offset, &data).is_ok()); let mut read_data: Vec = vec![0; len as usize]; - assert_eq!(vhost_net.read_config(offset, &mut read_data).is_ok(), true); + assert!(vhost_net.read_config(offset, &mut read_data).is_ok()); assert_ne!(read_data, data); let offset: u64 = 1; let data: Vec = vec![1; len as usize]; - assert_eq!(vhost_net.write_config(offset, &data).is_ok(), true); + assert!(vhost_net.write_config(offset, &data).is_ok()); let offset: u64 = len + 1; let mut read_data: Vec = vec![0; len as usize]; - assert_eq!(vhost_net.read_config(offset, &mut read_data).is_ok(), false); + assert!(vhost_net.read_config(offset, &mut read_data).is_err()); let offset: u64 = len - 1; let mut read_data: Vec = vec![0; len as usize]; - assert_eq!(vhost_net.read_config(offset, &mut read_data).is_ok(), false); + assert!(vhost_net.read_config(offset, &mut read_data).is_err()); } } diff --git a/virtio/src/vhost/kernel/vsock.rs b/virtio/src/vhost/kernel/vsock.rs index 8d782c6236c10d278ebfaab33db76f117a7738ec..b67d21aeb2f7a3c3cb8ccad11b909d6e968a6693 100644 --- a/virtio/src/vhost/kernel/vsock.rs +++ b/virtio/src/vhost/kernel/vsock.rs @@ -16,22 +16,24 @@ use std::sync::{Arc, Mutex}; use anyhow::{anyhow, bail, Context, Result}; use byteorder::{ByteOrder, LittleEndian}; +use clap::{ArgAction, Parser}; use vmm_sys_util::eventfd::EventFd; use vmm_sys_util::ioctl::ioctl_with_ref; use super::super::{VhostIoHandler, VhostNotify, VhostOps}; use super::{VhostBackend, VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING}; use crate::{ - check_config_space_rw, Queue, VirtioBase, VirtioDevice, VirtioError, VirtioInterrupt, - VirtioInterruptType, VIRTIO_F_ACCESS_PLATFORM, VIRTIO_TYPE_VSOCK, + Queue, VirtioBase, VirtioDevice, VirtioError, VirtioInterrupt, VirtioInterruptType, + VIRTIO_F_ACCESS_PLATFORM, VIRTIO_TYPE_VSOCK, }; -use address_space::AddressSpace; -use machine_manager::config::{VsockConfig, DEFAULT_VIRTQUEUE_SIZE}; +use address_space::{AddressAttr, AddressSpace}; +use machine_manager::config::{get_pci_df, parse_bool, valid_id, DEFAULT_VIRTQUEUE_SIZE}; use machine_manager::event_loop::{register_event_helper, unregister_event_helper}; use migration::{DeviceStateDesc, FieldDesc, MigrationHook, MigrationManager, StateTransfer}; use migration_derive::{ByteCode, Desc}; use util::byte_code::ByteCode; -use util::loop_context::EventNotifierHelper; +use util::gen_base_func; +use util::loop_context::{create_new_eventfd, EventNotifierHelper}; /// Number of virtqueues. const QUEUE_NUM_VSOCK: usize = 3; @@ -40,6 +42,29 @@ const VHOST_PATH: &str = "/dev/vhost-vsock"; /// Event transport reset const VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: u32 = 0; +const MAX_GUEST_CID: u64 = 4_294_967_295; +const MIN_GUEST_CID: u64 = 3; + +/// Config structure for virtio-vsock. +#[derive(Parser, Debug, Clone, Default)] +#[command(no_binary_name(true))] +pub struct VsockConfig { + #[arg(long, value_parser = ["vhost-vsock-pci", "vhost-vsock-device"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: Option, + #[arg(long, value_parser = get_pci_df)] + pub addr: Option<(u8, u8)>, + #[arg(long, value_parser = parse_bool, action = ArgAction::Append)] + pub multifunction: Option, + #[arg(long, alias = "guest-cid", value_parser = clap::value_parser!(u64).range(MIN_GUEST_CID..=MAX_GUEST_CID))] + pub guest_cid: u64, + #[arg(long, alias = "vhostfd")] + pub vhost_fd: Option, +} + trait VhostVsockBackend { /// Each guest should have an unique CID which is used to route data to the guest. fn set_guest_cid(&self, cid: u64) -> Result<()>; @@ -143,12 +168,12 @@ impl Vsock { .write_object( &VIRTIO_VSOCK_EVENT_TRANSPORT_RESET, element.in_iovec[0].addr, + AddressAttr::Ram, ) .with_context(|| "Failed to write buf for virtio vsock event")?; event_queue_locked .vring .add_used( - &self.mem_space, element.index, VIRTIO_VSOCK_EVENT_TRANSPORT_RESET.as_bytes().len() as u32, ) @@ -169,13 +194,7 @@ impl Vsock { } impl VirtioDevice for Vsock { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { let vhost_fd: Option = self.vsock_cfg.vhost_fd; @@ -215,10 +234,7 @@ impl VirtioDevice for Vsock { Ok(()) } - fn write_config(&mut self, offset: u64, data: &[u8]) -> Result<()> { - check_config_space_rw(&self.config_space, offset, data)?; - let data_len = data.len(); - self.config_space[(offset as usize)..(offset as usize + data_len)].copy_from_slice(data); + fn write_config(&mut self, _offset: u64, _data: &[u8]) -> Result<()> { Ok(()) } @@ -287,8 +303,7 @@ impl VirtioDevice for Vsock { let event = if self.call_events.is_empty() { let host_notify = VhostNotify { notify_evt: Arc::new( - EventFd::new(libc::EFD_NONBLOCK) - .with_context(|| VirtioError::EventFdCreate)?, + create_new_eventfd().with_context(|| VirtioError::EventFdCreate)?, ), queue: queue_mutex.clone(), }; @@ -390,25 +405,37 @@ impl MigrationHook for Vsock { #[cfg(test)] mod tests { - pub use super::super::*; - pub use super::*; - pub use address_space::*; - - fn vsock_address_space_init() -> Arc { - let root = Region::init_container_region(u64::max_value(), "sysmem"); - let sys_mem = AddressSpace::new(root, "sysmem", None).unwrap(); - sys_mem - } + use super::*; + use crate::tests::address_space_init; + use machine_manager::config::str_slip_to_clap; fn vsock_create_instance() -> Vsock { let vsock_conf = VsockConfig { id: "test_vsock_1".to_string(), guest_cid: 3, vhost_fd: None, + ..Default::default() }; - let sys_mem = vsock_address_space_init(); - let vsock = Vsock::new(&vsock_conf, &sys_mem); - vsock + let sys_mem = address_space_init(); + + Vsock::new(&vsock_conf, &sys_mem) + } + + #[test] + fn test_vsock_config_cmdline_parser() { + let vsock_cmd = "vhost-vsock-device,id=test_vsock,guest-cid=3"; + let vsock_config = + VsockConfig::try_parse_from(str_slip_to_clap(vsock_cmd, true, false)).unwrap(); + assert_eq!(vsock_config.id, "test_vsock"); + assert_eq!(vsock_config.guest_cid, 3); + assert_eq!(vsock_config.vhost_fd, None); + + let vsock_cmd = "vhost-vsock-device,id=test_vsock,guest-cid=3,vhostfd=4"; + let vsock_config = + VsockConfig::try_parse_from(str_slip_to_clap(vsock_cmd, true, false)).unwrap(); + assert_eq!(vsock_config.id, "test_vsock"); + assert_eq!(vsock_config.guest_cid, 3); + assert_eq!(vsock_config.vhost_fd, Some(4)); } #[test] @@ -437,32 +464,32 @@ mod tests { vsock.base.device_features = 0x0123_4567_89ab_cdef; // check for unsupported feature vsock.set_driver_features(0, 0x7000_0000); - assert_eq!(vsock.driver_features(0) as u64, 0_u64); + assert_eq!(u64::from(vsock.driver_features(0)), 0_u64); assert_eq!(vsock.base.device_features, 0x0123_4567_89ab_cdef); // check for supported feature vsock.set_driver_features(0, 0x8000_0000); - assert_eq!(vsock.driver_features(0) as u64, 0x8000_0000_u64); + assert_eq!(u64::from(vsock.driver_features(0)), 0x8000_0000_u64); assert_eq!(vsock.base.device_features, 0x0123_4567_89ab_cdef); // test vsock read_config let mut buf: [u8; 8] = [0; 8]; - assert_eq!(vsock.read_config(0, &mut buf).is_ok(), true); + assert!(vsock.read_config(0, &mut buf).is_ok()); let value = LittleEndian::read_u64(&buf); assert_eq!(value, vsock.vsock_cfg.guest_cid); let mut buf: [u8; 4] = [0; 4]; - assert_eq!(vsock.read_config(0, &mut buf).is_ok(), true); + assert!(vsock.read_config(0, &mut buf).is_ok()); let value = LittleEndian::read_u32(&buf); assert_eq!(value, vsock.vsock_cfg.guest_cid as u32); let mut buf: [u8; 4] = [0; 4]; - assert_eq!(vsock.read_config(4, &mut buf).is_ok(), true); + assert!(vsock.read_config(4, &mut buf).is_ok()); let value = LittleEndian::read_u32(&buf); assert_eq!(value, (vsock.vsock_cfg.guest_cid >> 32) as u32); let mut buf: [u8; 4] = [0; 4]; - assert_eq!(vsock.read_config(5, &mut buf).is_err(), true); - assert_eq!(vsock.read_config(3, &mut buf).is_err(), true); + assert!(vsock.read_config(5, &mut buf).is_err()); + assert!(vsock.read_config(3, &mut buf).is_err()); } #[test] @@ -481,12 +508,9 @@ mod tests { // test vsock set_guest_cid let backend = vsock.backend.unwrap(); - assert_eq!(backend.set_guest_cid(3).is_ok(), true); - assert_eq!( - backend.set_guest_cid(u32::max_value() as u64).is_ok(), - false - ); - assert_eq!(backend.set_guest_cid(2).is_ok(), false); - assert_eq!(backend.set_guest_cid(0).is_ok(), false); + assert!(backend.set_guest_cid(3).is_ok()); + assert!(backend.set_guest_cid(u64::from(u32::max_value())).is_err()); + assert!(backend.set_guest_cid(2).is_err()); + assert!(backend.set_guest_cid(0).is_err()); } } diff --git a/virtio/src/vhost/user/block.rs b/virtio/src/vhost/user/block.rs index 25d030b778c854f8ddca2e5df78647f255419d94..4e0757f9fc0e309a0b032d6f71d660dc18bee843 100644 --- a/virtio/src/vhost/user/block.rs +++ b/virtio/src/vhost/user/block.rs @@ -13,6 +13,7 @@ use std::sync::{Arc, Mutex}; use anyhow::{anyhow, bail, Context, Result}; +use clap::Parser; use vmm_sys_util::eventfd::EventFd; use super::client::VhostUserClient; @@ -30,14 +31,42 @@ use crate::{ VIRTIO_BLK_F_TOPOLOGY, VIRTIO_BLK_F_WRITE_ZEROES, VIRTIO_F_VERSION_1, VIRTIO_TYPE_BLOCK, }; use address_space::AddressSpace; -use machine_manager::{config::BlkDevConfig, event_loop::unregister_event_helper}; +use machine_manager::config::{ + get_chardev_socket_path, get_pci_df, valid_block_device_virtqueue_size, valid_id, + ChardevConfig, MAX_VIRTIO_QUEUE, +}; +use machine_manager::event_loop::unregister_event_helper; use util::byte_code::ByteCode; +use util::gen_base_func; + +#[derive(Parser, Debug, Clone, Default)] +#[command(no_binary_name(true))] +pub struct VhostUserBlkDevConfig { + #[arg(long, value_parser = ["vhost-user-blk-device", "vhost-user-blk-pci"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub bus: Option, + #[arg(long, value_parser = get_pci_df)] + pub addr: Option<(u8, u8)>, + #[arg(long, alias = "num-queues", value_parser = clap::value_parser!(u16).range(1..=MAX_VIRTIO_QUEUE as i64))] + pub num_queues: Option, + #[arg(long)] + pub chardev: String, + #[arg(long, alias = "queue-size", default_value = "256", value_parser = valid_block_device_virtqueue_size)] + pub queue_size: u16, + #[arg(long)] + pub bootindex: Option, +} pub struct Block { /// Virtio device base property. base: VirtioBase, /// Configuration of the block device. - blk_cfg: BlkDevConfig, + blk_cfg: VhostUserBlkDevConfig, + /// Configuration of the vhost user blk's socket chardev. + chardev_cfg: ChardevConfig, /// Config space of the block device. config_space: VirtioBlkConfig, /// System address space. @@ -51,13 +80,18 @@ pub struct Block { } impl Block { - pub fn new(cfg: &BlkDevConfig, mem_space: &Arc) -> Self { - let queue_num = cfg.queues as usize; + pub fn new( + cfg: &VhostUserBlkDevConfig, + chardev_cfg: ChardevConfig, + mem_space: &Arc, + ) -> Self { + let queue_num = cfg.num_queues.unwrap_or(1) as usize; let queue_size = cfg.queue_size; Block { base: VirtioBase::new(VIRTIO_TYPE_BLOCK, queue_num, queue_size), blk_cfg: cfg.clone(), + chardev_cfg, config_space: Default::default(), mem_space: mem_space.clone(), client: None, @@ -68,12 +102,7 @@ impl Block { /// Connect with spdk and register update event. fn init_client(&mut self) -> Result<()> { - let socket_path = self - .blk_cfg - .socket_path - .as_ref() - .map(|path| path.to_string()) - .with_context(|| "vhost-user: socket path is not found")?; + let socket_path = get_chardev_socket_path(self.chardev_cfg.clone())?; let client = VhostUserClient::new( &self.mem_space, &socket_path, @@ -91,13 +120,7 @@ impl Block { } impl VirtioDevice for Block { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { self.init_client()?; @@ -123,7 +146,7 @@ impl VirtioDevice for Block { .set_protocol_features(self.protocol_features) .with_context(|| "Failed to set protocol features for vhost-user blk")?; - if virtio_has_feature(protocol_features, VHOST_USER_PROTOCOL_F_CONFIG as u32) { + if virtio_has_feature(protocol_features, u32::from(VHOST_USER_PROTOCOL_F_CONFIG)) { let config = locked_client .get_virtio_blk_config() .with_context(|| "Failed to get config for vhost-user blk")?; @@ -135,7 +158,7 @@ impl VirtioDevice for Block { ); } - if virtio_has_feature(protocol_features, VHOST_USER_PROTOCOL_F_MQ as u32) { + if virtio_has_feature(protocol_features, u32::from(VHOST_USER_PROTOCOL_F_MQ)) { let max_queue_num = locked_client .get_max_queue_num() .with_context(|| "Failed to get queue num for vhost-user blk")?; @@ -146,10 +169,10 @@ impl VirtioDevice for Block { ); } - if self.blk_cfg.queues > 1 { - self.config_space.num_queues = self.blk_cfg.queues; + if self.blk_cfg.num_queues.unwrap_or(1) > 1 { + self.config_space.num_queues = self.blk_cfg.num_queues.unwrap_or(1); } - } else if self.blk_cfg.queues > 1 { + } else if self.blk_cfg.num_queues.unwrap_or(1) > 1 { bail!( "spdk doesn't support multi queue, spdk protocol features: {:#b}", protocol_features @@ -169,7 +192,7 @@ impl VirtioDevice for Block { | 1_u64 << VIRTIO_BLK_F_WRITE_ZEROES | 1_u64 << VIRTIO_BLK_F_SEG_MAX | 1_u64 << VIRTIO_BLK_F_RO; - if self.blk_cfg.queues > 1 { + if self.blk_cfg.num_queues.unwrap_or(1) > 1 { self.base.device_features |= 1_u64 << VIRTIO_BLK_F_MQ; } self.base.device_features &= features; @@ -217,13 +240,7 @@ impl VirtioDevice for Block { if !self.enable_irqfd { let queue_num = self.base.queues.len(); - listen_guest_notifier( - &mut self.base, - &mut client, - self.blk_cfg.iothread.as_ref(), - queue_num, - interrupt_cb, - )?; + listen_guest_notifier(&mut self.base, &mut client, None, queue_num, interrupt_cb)?; } client.activate_vhost_user()?; @@ -235,10 +252,7 @@ impl VirtioDevice for Block { if let Some(client) = &self.client { client.lock().unwrap().reset_vhost_user(false); } - unregister_event_helper( - self.blk_cfg.iothread.as_ref(), - &mut self.base.deactivate_evts, - )?; + unregister_event_helper(None, &mut self.base.deactivate_evts)?; Ok(()) } diff --git a/virtio/src/vhost/user/client.rs b/virtio/src/vhost/user/client.rs index b32bf72cc1bfeeaafef468fa6f3baed0c9799b09..d4c86a780dd49b610968044d467e7009ec6df744 100644 --- a/virtio/src/vhost/user/client.rs +++ b/virtio/src/vhost/user/client.rs @@ -20,7 +20,7 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use anyhow::{bail, Context, Result}; -use log::{error, info, warn}; +use log::{debug, error, info, warn}; use vmm_sys_util::{epoll::EventSet, eventfd::EventFd}; use super::super::VhostOps; @@ -33,7 +33,8 @@ use crate::device::block::VirtioBlkConfig; use crate::VhostUser::message::VhostUserConfig; use crate::{virtio_has_feature, Queue, QueueConfig}; use address_space::{ - AddressSpace, FileBackend, FlatRange, GuestAddress, Listener, ListenerReqType, RegionIoEventFd, + AddressAttr, AddressSpace, FileBackend, FlatRange, GuestAddress, Listener, ListenerReqType, + RegionIoEventFd, }; use machine_manager::event_loop::{register_event_helper, unregister_event_helper, EventLoop}; use util::loop_context::{ @@ -234,14 +235,15 @@ impl VhostUserMemInfo { let guest_phys_addr = fr.addr_range.base.raw_value(); let memory_size = fr.addr_range.size; - let host_address = match fr.owner.get_host_address() { + // SAFETY: memory_size is range's size, so we make sure [hva, hva+size] is in ram range. + let host_address = match unsafe { fr.owner.get_host_address(AddressAttr::Ram) } { Some(addr) => addr, None => bail!("Failed to get host address to add mem range for vhost user device"), }; let file_back = match fr.owner.get_file_backend() { Some(file_back_) => file_back_, _ => { - info!("It is not share memory for vhost user device"); + debug!("It is not share memory for vhost user device"); return Ok(()); } }; @@ -281,13 +283,14 @@ impl VhostUserMemInfo { let file_back = match fr.owner.get_file_backend() { None => { - info!("fr {:?} backend is not file, ignored", fr); + debug!("fr {:?} backend is not file, ignored", fr); return Ok(()); } Some(fb) => fb, }; let mut mem_regions = self.regions.lock().unwrap(); - let host_address = match fr.owner.get_host_address() { + // SAFETY: memory_size is range's size, so we make sure [hva, hva+size] is in ram range. + let host_address = match unsafe { fr.owner.get_host_address(AddressAttr::Ram) } { Some(addr) => addr, None => bail!("Failed to get host address to del mem range for vhost user device"), }; @@ -402,7 +405,6 @@ pub struct VhostUserClient { client: Arc>, mem_info: VhostUserMemInfo, delete_evts: Vec, - mem_space: Arc, queues: Vec>>, queue_evts: Vec>, call_events: Vec>, @@ -438,7 +440,6 @@ impl VhostUserClient { client, mem_info, delete_evts: Vec::new(), - mem_space: mem_space.clone(), queues: Vec::new(), queue_evts: Vec::new(), call_events: Vec::new(), @@ -482,7 +483,7 @@ impl VhostUserClient { .with_context(|| "Failed to get protocol features for vhost-user blk")?; if virtio_has_feature( protocol_feature, - VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD as u32, + u32::from(VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD), ) { if self.inflight.is_none() { // Expect 1 fd. @@ -568,7 +569,7 @@ impl VhostUserClient { })?; // When spdk/ovs has been killed, stratovirt can not get the last avail // index in spdk/ovs, it can only use used index as last avail index. - let last_avail_idx = queue.vring.get_used_idx(&self.mem_space)?; + let last_avail_idx = queue.vring.get_used_idx()?; self.set_vring_base(queue_index, last_avail_idx) .with_context(|| { format!( @@ -903,7 +904,7 @@ impl VhostOps for VhostUserClient { size_of::() as u32, ); let payload_opt: Option<&[u8]> = None; - let vring_state = VhostUserVringState::new(queue_idx as u32, num as u32); + let vring_state = VhostUserVringState::new(queue_idx as u32, u32::from(num)); client .sock .send_msg(Some(&hdr), Some(&vring_state), payload_opt, &[]) @@ -982,7 +983,7 @@ impl VhostOps for VhostUserClient { size_of::() as u32, ); let payload_opt: Option<&[u8]> = None; - let vring_state = VhostUserVringState::new(queue_idx as u32, last_avail_idx as u32); + let vring_state = VhostUserVringState::new(queue_idx as u32, u32::from(last_avail_idx)); client .sock .send_msg(Some(&hdr), Some(&vring_state), payload_opt, &[]) @@ -1058,7 +1059,7 @@ impl VhostOps for VhostUserClient { size_of::() as u32, ); let payload_opt: Option<&[u8]> = None; - let vring_state = VhostUserVringState::new(queue_idx as u32, status as u32); + let vring_state = VhostUserVringState::new(queue_idx as u32, u32::from(status)); client .sock .send_msg(Some(&hdr), Some(&vring_state), payload_opt, &[]) diff --git a/virtio/src/vhost/user/fs.rs b/virtio/src/vhost/user/fs.rs index a1c5ed8d31bf77e4cc26b7e17310f0b60dbf8f56..f08ddd399fea540c2fd78cc140905ebfb0898218 100644 --- a/virtio/src/vhost/user/fs.rs +++ b/virtio/src/vhost/user/fs.rs @@ -19,7 +19,8 @@ const VIRTIO_FS_QUEUE_SIZE: u16 = 128; use std::sync::{Arc, Mutex}; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; +use clap::Parser; use vmm_sys_util::eventfd::EventFd; use super::super::super::{VirtioDevice, VIRTIO_TYPE_FS}; @@ -27,9 +28,45 @@ use super::super::VhostOps; use super::{listen_guest_notifier, VhostBackendType, VhostUserClient}; use crate::{read_config_default, VirtioBase, VirtioInterrupt}; use address_space::AddressSpace; -use machine_manager::config::{FsConfig, MAX_TAG_LENGTH}; +use machine_manager::config::{ + get_pci_df, parse_bool, valid_id, ChardevConfig, ConfigError, SocketType, +}; use machine_manager::event_loop::unregister_event_helper; use util::byte_code::ByteCode; +use util::gen_base_func; + +const MAX_TAG_LENGTH: usize = 36; + +/// Config struct for `fs`. +/// Contains fs device's attr. +#[derive(Parser, Debug, Clone)] +#[command(no_binary_name(true))] +pub struct FsConfig { + #[arg(long, value_parser = ["vhost-user-fs-pci", "vhost-user-fs-device"])] + pub classtype: String, + #[arg(long, value_parser = valid_id)] + pub id: String, + #[arg(long)] + pub chardev: String, + #[arg(long, value_parser = valid_tag)] + pub tag: String, + #[arg(long)] + pub bus: Option, + #[arg(long, value_parser = get_pci_df)] + pub addr: Option<(u8, u8)>, + #[arg(long, value_parser = parse_bool)] + pub multifunction: Option, +} + +fn valid_tag(tag: &str) -> Result { + if tag.len() >= MAX_TAG_LENGTH { + return Err(anyhow!(ConfigError::StringLengthTooLong( + "fs device tag".to_string(), + MAX_TAG_LENGTH - 1, + ))); + } + Ok(tag.to_string()) +} #[derive(Copy, Clone)] #[repr(C, packed)] @@ -52,6 +89,7 @@ impl ByteCode for VirtioFsConfig {} pub struct Fs { base: VirtioBase, fs_cfg: FsConfig, + chardev_cfg: ChardevConfig, config_space: VirtioFsConfig, client: Option>>, mem_space: Arc, @@ -64,14 +102,16 @@ impl Fs { /// # Arguments /// /// `fs_cfg` - The config of this Fs device. + /// `chardev_cfg` - The config of this Fs device's chardev. /// `mem_space` - The address space of this Fs device. - pub fn new(fs_cfg: FsConfig, mem_space: Arc) -> Self { + pub fn new(fs_cfg: FsConfig, chardev_cfg: ChardevConfig, mem_space: Arc) -> Self { let queue_num = VIRIOT_FS_HIGH_PRIO_QUEUE_NUM + VIRTIO_FS_REQ_QUEUES_NUM; let queue_size = VIRTIO_FS_QUEUE_SIZE; Fs { base: VirtioBase::new(VIRTIO_TYPE_FS, queue_num, queue_size), fs_cfg, + chardev_cfg, config_space: VirtioFsConfig::default(), client: None, mem_space, @@ -81,19 +121,19 @@ impl Fs { } impl VirtioDevice for Fs { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { let queues_num = VIRIOT_FS_HIGH_PRIO_QUEUE_NUM + VIRTIO_FS_REQ_QUEUES_NUM; + + let socket_path = match self.chardev_cfg.classtype.socket_type()? { + SocketType::Unix { path } => path, + _ => bail!("Vhost-user-fs Chardev backend should be unix-socket type."), + }; + let client = VhostUserClient::new( &self.mem_space, - &self.fs_cfg.sock, + &socket_path, queues_num as u64, VhostBackendType::TypeFs, ) @@ -194,3 +234,24 @@ impl VirtioDevice for Fs { self.realize() } } + +#[cfg(test)] +mod tests { + use super::*; + use machine_manager::config::str_slip_to_clap; + + #[test] + fn test_vhostuserfs_cmdline_parser() { + // Test1: Right. + let fs_cmd = "vhost-user-fs-device,id=fs0,chardev=chardev0,tag=tag0"; + let fs_config = FsConfig::try_parse_from(str_slip_to_clap(fs_cmd, true, false)).unwrap(); + assert_eq!(fs_config.id, "fs0"); + assert_eq!(fs_config.chardev, "chardev0"); + assert_eq!(fs_config.tag, "tag0"); + + // Test2: Illegal value. + let fs_cmd = "vhost-user-fs-device,id=fs0,chardev=chardev0,tag=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + let result = FsConfig::try_parse_from(str_slip_to_clap(fs_cmd, true, false)); + assert!(result.is_err()); + } +} diff --git a/virtio/src/vhost/user/mod.rs b/virtio/src/vhost/user/mod.rs index 2f0dc96f07253b31e9ed8a2b044a181ef5c583aa..68a0b166e63eecd668481ef86cfbc3400b615f00 100644 --- a/virtio/src/vhost/user/mod.rs +++ b/virtio/src/vhost/user/mod.rs @@ -12,16 +12,20 @@ pub mod fs; +#[cfg(feature = "vhostuser_block")] mod block; mod client; mod message; +#[cfg(feature = "vhostuser_net")] mod net; mod sock; -pub use self::block::Block; +#[cfg(feature = "vhostuser_block")] +pub use self::block::{Block, VhostUserBlkDevConfig}; pub use self::client::*; pub use self::fs::*; pub use self::message::*; +#[cfg(feature = "vhostuser_net")] pub use self::net::Net; pub use self::sock::*; diff --git a/virtio/src/vhost/user/net.rs b/virtio/src/vhost/user/net.rs index e7f0964bf56992bc33b14b559f77b8a52ae6a98d..fc98344638fee0dc706ed85270db05dd98f7d88e 100644 --- a/virtio/src/vhost/user/net.rs +++ b/virtio/src/vhost/user/net.rs @@ -28,9 +28,10 @@ use crate::{ VIRTIO_NET_F_MRG_RXBUF, VIRTIO_TYPE_NET, }; use address_space::AddressSpace; -use machine_manager::config::NetworkInterfaceConfig; +use machine_manager::config::{NetDevcfg, NetworkInterfaceConfig}; use machine_manager::event_loop::{register_event_helper, unregister_event_helper}; use util::byte_code::ByteCode; +use util::gen_base_func; use util::loop_context::EventNotifierHelper; /// Number of virtqueues. @@ -42,6 +43,10 @@ pub struct Net { base: VirtioBase, /// Configuration of the vhost user network device. net_cfg: NetworkInterfaceConfig, + /// Configuration of the backend netdev. + netdev_cfg: NetDevcfg, + /// path of the socket chardev. + sock_path: String, /// Virtio net configurations. config_space: Arc>, /// System address space. @@ -53,18 +58,25 @@ pub struct Net { } impl Net { - pub fn new(cfg: &NetworkInterfaceConfig, mem_space: &Arc) -> Self { - let queue_num = if cfg.mq { + pub fn new( + net_cfg: &NetworkInterfaceConfig, + netdev_cfg: NetDevcfg, + sock_path: String, + mem_space: &Arc, + ) -> Self { + let queue_num = if net_cfg.mq { // If support multi-queue, it should add 1 control queue. - (cfg.queues + 1) as usize + (netdev_cfg.queues + 1) as usize } else { QUEUE_NUM_NET }; - let queue_size = cfg.queue_size; + let queue_size = net_cfg.queue_size; Net { base: VirtioBase::new(VIRTIO_TYPE_NET, queue_num, queue_size), - net_cfg: cfg.clone(), + net_cfg: net_cfg.clone(), + netdev_cfg, + sock_path, config_space: Default::default(), mem_space: mem_space.clone(), client: None, @@ -106,23 +118,12 @@ impl Net { } impl VirtioDevice for Net { - fn virtio_base(&self) -> &VirtioBase { - &self.base - } - - fn virtio_base_mut(&mut self) -> &mut VirtioBase { - &mut self.base - } + gen_base_func!(virtio_base, virtio_base_mut, VirtioBase, base); fn realize(&mut self) -> Result<()> { - let socket_path = self - .net_cfg - .socket_path - .as_ref() - .with_context(|| "vhost-user: socket path is not found")?; let client = VhostUserClient::new( &self.mem_space, - socket_path, + &self.sock_path, self.queue_num() as u64, VhostBackendType::TypeNet, ) @@ -158,7 +159,7 @@ impl VirtioDevice for Net { let mut locked_config = self.config_space.lock().unwrap(); - let queue_pairs = self.net_cfg.queues / 2; + let queue_pairs = self.netdev_cfg.queues / 2; if self.net_cfg.mq && (VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MIN..=VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MAX) .contains(&queue_pairs)