diff --git a/drivers/virtio/src/blk.rs b/drivers/virtio/src/blk.rs index 469821808efbe644ef4470d16cf0ee6edae3deea..b619951fb0e88d215d97c43f3e43147c04c6e985 100644 --- a/drivers/virtio/src/blk.rs +++ b/drivers/virtio/src/blk.rs @@ -5,13 +5,18 @@ //! VirtIO block driver adapter. use block::BlockDriverOps; use driver_base::{DeviceKind, DriverOps, DriverResult}; -use virtio_drivers::{Hal, device::blk::VirtIOBlk as InnerDev, transport::Transport}; +use virtio_drivers::{ + Hal, + device::blk::{SECTOR_SIZE, VirtIOBlk as InnerDev}, + transport::Transport, +}; use crate::as_driver_error; /// The VirtIO block device driver. pub struct VirtIoBlkDev { - inner: InnerDev, + device: InnerDev, + sector_size: usize, } unsafe impl Send for VirtIoBlkDev {} @@ -21,10 +26,28 @@ impl VirtIoBlkDev { /// Creates a new driver instance and initializes the device, or returns /// an error if any step fails. pub fn try_new(transport: T) -> DriverResult { + let device = Self::init_device(transport)?; Ok(Self { - inner: InnerDev::new(transport).map_err(as_driver_error)?, + device, + sector_size: SECTOR_SIZE, }) } + + fn init_device(transport: T) -> DriverResult> { + InnerDev::new(transport).map_err(as_driver_error) + } + + fn read_sector(&mut self, sector: u64, out_buf: &mut [u8]) -> DriverResult { + self.device + .read_blocks(sector as usize, out_buf) + .map_err(as_driver_error) + } + + fn write_sector(&mut self, sector: u64, in_buf: &[u8]) -> DriverResult { + self.device + .write_blocks(sector as usize, in_buf) + .map_err(as_driver_error) + } } impl DriverOps for VirtIoBlkDev { @@ -40,24 +63,20 @@ impl DriverOps for VirtIoBlkDev { impl BlockDriverOps for VirtIoBlkDev { #[inline] fn num_blocks(&self) -> u64 { - self.inner.capacity() + self.device.capacity() } #[inline] fn block_size(&self) -> usize { - virtio_drivers::device::blk::SECTOR_SIZE + self.sector_size } fn read_block(&mut self, block_id: u64, buf: &mut [u8]) -> DriverResult { - self.inner - .read_blocks(block_id as _, buf) - .map_err(as_driver_error) + self.read_sector(block_id, buf) } fn write_block(&mut self, block_id: u64, buf: &[u8]) -> DriverResult { - self.inner - .write_blocks(block_id as _, buf) - .map_err(as_driver_error) + self.write_sector(block_id, buf) } fn flush(&mut self) -> DriverResult { diff --git a/drivers/virtio/src/gpu.rs b/drivers/virtio/src/gpu.rs index a3cf6609591cd8addbdd620014062a8c30409b5d..fffeb78a3cb2da9b50de55a767c56373da044a20 100644 --- a/drivers/virtio/src/gpu.rs +++ b/drivers/virtio/src/gpu.rs @@ -7,38 +7,31 @@ use display::{DisplayDriverOps, DisplayInfo, FrameBuffer}; use driver_base::{DeviceKind, DriverOps, DriverResult}; use virtio_drivers::{Hal, device::gpu::VirtIOGpu as InnerDev, transport::Transport}; -use crate::as_driver_error; - /// The VirtIO GPU device driver. pub struct VirtIoGpuDev { - inner: InnerDev, info: DisplayInfo, + inner: InnerDev, } unsafe impl Send for VirtIoGpuDev {} unsafe impl Sync for VirtIoGpuDev {} impl VirtIoGpuDev { - /// Creates a new driver instance and initializes the device, or returns - /// an error if any step fails. pub fn try_new(transport: T) -> DriverResult { - let mut virtio = InnerDev::new(transport).unwrap(); - - // get framebuffer - let fbuffer = virtio.setup_framebuffer().unwrap(); - let fb_base_vaddr = fbuffer.as_mut_ptr() as usize; - let fb_size = fbuffer.len(); - let (width, height) = virtio.resolution().unwrap(); - let info = DisplayInfo { - width, - height, - fb_base_vaddr, - fb_size, - }; + let mut device = InnerDev::new(transport).map_err(crate::as_driver_error)?; + let framebuffer = device.setup_framebuffer().map_err(crate::as_driver_error)?; + let fb_base_vaddr = framebuffer.as_mut_ptr() as usize; + let fb_size = framebuffer.len(); + let (width, height) = device.resolution().map_err(crate::as_driver_error)?; Ok(Self { - inner: virtio, - info, + info: DisplayInfo { + width, + height, + fb_base_vaddr, + fb_size, + }, + inner: device, }) } } @@ -69,6 +62,6 @@ impl DisplayDriverOps for VirtIoGpuDev { } fn flush(&mut self) -> DriverResult { - self.inner.flush().map_err(as_driver_error) + self.inner.flush().map_err(crate::as_driver_error) } } diff --git a/drivers/virtio/src/input.rs b/drivers/virtio/src/input.rs index 92203285d32a3606cd5b95e294885631285aeaf7..09bba2ef5e0675b3cddd01a1a7e54a125f84e379 100644 --- a/drivers/virtio/src/input.rs +++ b/drivers/virtio/src/input.rs @@ -3,7 +3,7 @@ // See LICENSES for license details. //! VirtIO input driver adapter. -use alloc::{borrow::ToOwned, string::String}; +use alloc::string::String; use driver_base::{DeviceKind, DriverError, DriverOps, DriverResult}; use input::{Event, EventType, InputDeviceId, InputDriverOps}; @@ -26,30 +26,34 @@ unsafe impl Send for VirtIoInputDev {} unsafe impl Sync for VirtIoInputDev {} impl VirtIoInputDev { - /// Creates a new driver instance and initializes the device, or returns - /// an error if any step fails. pub fn try_new(transport: T) -> DriverResult { - let mut virtio = InnerDev::new(transport).map_err(as_driver_error)?; - let name = virtio.name().unwrap_or_else(|_| "".to_owned()); - let device_id = virtio.ids().map_err(as_driver_error)?; - let device_id = InputDeviceId { - bus_type: device_id.bustype, - vendor: device_id.vendor, - product: device_id.product, - version: device_id.version, - }; + let mut device = InnerDev::new(transport).map_err(as_driver_error)?; + let name = device.name().unwrap_or_else(|_| String::from("")); + let ids = device.ids().map_err(as_driver_error)?; Ok(Self { - inner: virtio, - device_id, + inner: device, + device_id: InputDeviceId { + bus_type: ids.bustype, + vendor: ids.vendor, + product: ids.product, + version: ids.version, + }, name, }) } + + fn load_event_bits(&mut self, event_type: EventType, out: &mut [u8]) -> DriverResult { + let written = + self.inner + .query_config_select(InputConfigSelect::EvBits, event_type as u8, out); + Ok(written != 0) + } } impl DriverOps for VirtIoInputDev { fn name(&self) -> &str { - &self.name + self.name.as_str() } fn device_kind(&self) -> DeviceKind { @@ -63,32 +67,26 @@ impl InputDriverOps for VirtIoInputDev { } fn physical_location(&self) -> &str { - // TODO: unique physical location "virtio0/input0" } fn unique_id(&self) -> &str { - // TODO: unique ID "virtio" } fn get_event_bits(&mut self, ty: EventType, out: &mut [u8]) -> DriverResult { - let read = self - .inner - .query_config_select(InputConfigSelect::EvBits, ty as u8, out); - Ok(read != 0) + self.load_event_bits(ty, out) } fn read_event(&mut self) -> DriverResult { - self.inner.ack_interrupt(); - self.inner - .pop_pending_event() - .map(|e| Event { - event_type: e.event_type, - code: e.code, - value: e.value, - }) - .ok_or(DriverError::WouldBlock) + let Some(event) = self.inner.pop_pending_event() else { + return Err(DriverError::WouldBlock); + }; + Ok(Event { + event_type: event.event_type, + code: event.code, + value: event.value, + }) } } diff --git a/drivers/virtio/src/socket.rs b/drivers/virtio/src/socket.rs index 7f2a15b27600f28333d7127f5a7c1ed440f4d855..47b08e6c7a7f945eb7f668371574f359d68bfc33 100644 --- a/drivers/virtio/src/socket.rs +++ b/drivers/virtio/src/socket.rs @@ -18,6 +18,23 @@ use crate::as_driver_error; /// Default buffer size for VirtIO socket device (32KB). const DEFAULT_BUFFER_SIZE: usize = 32 * 1024; +struct ConnectionArgs { + peer_addr: VsockAddr, + host_port: u32, +} + +impl ConnectionArgs { + fn from_conn_id(conn_id: VsockConnId) -> Self { + Self { + peer_addr: VsockAddr { + cid: conn_id.peer_addr.cid as _, + port: conn_id.peer_addr.port as _, + }, + host_port: conn_id.local_port, + } + } +} + /// The VirtIO socket device driver. pub struct VirtIoSocketDev { inner: InnerDev, @@ -30,11 +47,44 @@ impl VirtIoSocketDev { /// Creates a new driver instance and initializes the device, or returns /// an error if any step fails. pub fn try_new(transport: T) -> DriverResult { - let virtio_socket = VirtIOSocket::::new(transport).map_err(as_driver_error)?; + let virtio_socket = Self::open_socket(transport)?; Ok(Self { inner: InnerDev::new_with_capacity(virtio_socket, DEFAULT_BUFFER_SIZE as u32), }) } + + fn open_socket(transport: T) -> DriverResult> { + VirtIOSocket::::new(transport).map_err(as_driver_error) + } + + fn translate_event(event: VsockEvent) -> VsockDriverEventType { + let connection = VsockConnId { + peer_addr: vsock::VsockAddr { + cid: event.source.cid as _, + port: event.source.port as _, + }, + local_port: event.destination.port, + }; + + match event.event_type { + VsockEventType::ConnectionRequest => { + VsockDriverEventType::ConnectionRequest(connection) + } + VsockEventType::Connected => VsockDriverEventType::Connected(connection), + VsockEventType::Received { length } => { + VsockDriverEventType::Received(connection, length) + } + VsockEventType::Disconnected { .. } => VsockDriverEventType::Disconnected(connection), + VsockEventType::CreditUpdate => VsockDriverEventType::CreditUpdate(connection), + _ => VsockDriverEventType::Unknown, + } + } + + fn refresh_credit(&mut self, connection: &ConnectionArgs) { + let _ = self + .inner + .update_credit(connection.peer_addr, connection.host_port); + } } impl DriverOps for VirtIoSocketDev { @@ -47,16 +97,6 @@ impl DriverOps for VirtIoSocketDev { } } -fn extract_addr_and_port(cid: VsockConnId) -> (VsockAddr, u32) { - ( - VsockAddr { - cid: cid.peer_addr.cid as _, - port: cid.peer_addr.port as _, - }, - cid.local_port, - ) -} - impl VsockDriverOps for VirtIoSocketDev { fn guest_cid(&self) -> u64 { self.inner.guest_cid() @@ -67,88 +107,58 @@ impl VsockDriverOps for VirtIoSocketDev { } fn connect(&mut self, cid: VsockConnId) -> DriverResult<()> { - let (peer_addr, src_port) = extract_addr_and_port(cid); + let connection = ConnectionArgs::from_conn_id(cid); self.inner - .connect(peer_addr, src_port) + .connect(connection.peer_addr, connection.host_port) .map_err(as_driver_error) } fn send(&mut self, cid: VsockConnId, buf: &[u8]) -> DriverResult { - let (peer_addr, src_port) = extract_addr_and_port(cid); - match self.inner.send(peer_addr, src_port, buf) { + let connection = ConnectionArgs::from_conn_id(cid); + match self + .inner + .send(connection.peer_addr, connection.host_port, buf) + { Ok(()) => Ok(buf.len()), Err(e) => Err(as_driver_error(e)), } } fn recv(&mut self, cid: VsockConnId, buf: &mut [u8]) -> DriverResult { - let (peer_addr, src_port) = extract_addr_and_port(cid); + let connection = ConnectionArgs::from_conn_id(cid); let res = self .inner - .recv(peer_addr, src_port, buf) + .recv(connection.peer_addr, connection.host_port, buf) .map_err(as_driver_error); - let _ = self.inner.update_credit(peer_addr, src_port); + self.refresh_credit(&connection); res } fn recv_avail(&mut self, cid: VsockConnId) -> DriverResult { - let (peer_addr, src_port) = extract_addr_and_port(cid); + let connection = ConnectionArgs::from_conn_id(cid); self.inner - .recv_buffer_available_bytes(peer_addr, src_port) + .recv_buffer_available_bytes(connection.peer_addr, connection.host_port) .map_err(as_driver_error) } fn disconnect(&mut self, cid: VsockConnId) -> DriverResult<()> { - let (peer_addr, src_port) = extract_addr_and_port(cid); + let connection = ConnectionArgs::from_conn_id(cid); self.inner - .shutdown(peer_addr, src_port) + .shutdown(connection.peer_addr, connection.host_port) .map_err(as_driver_error) } fn abort(&mut self, cid: VsockConnId) -> DriverResult<()> { - let (peer_addr, src_port) = extract_addr_and_port(cid); + let connection = ConnectionArgs::from_conn_id(cid); self.inner - .force_close(peer_addr, src_port) + .force_close(connection.peer_addr, connection.host_port) .map_err(as_driver_error) } fn poll_event(&mut self) -> DriverResult> { - match self.inner.poll() { - Ok(None) => { - // no event - Ok(None) - } - Ok(Some(event)) => { - // translate event - let result = translate_virtio_event(event, &mut self.inner)?; - Ok(Some(result)) - } - Err(e) => { - // error - Err(as_driver_error(e)) - } - } - } -} - -fn translate_virtio_event( - event: VsockEvent, - _inner: &mut InnerDev, -) -> DriverResult { - let cid = VsockConnId { - peer_addr: vsock::VsockAddr { - cid: event.source.cid as _, - port: event.source.port as _, - }, - local_port: event.destination.port, - }; - - match event.event_type { - VsockEventType::ConnectionRequest => Ok(VsockDriverEventType::ConnectionRequest(cid)), - VsockEventType::Connected => Ok(VsockDriverEventType::Connected(cid)), - VsockEventType::Received { length } => Ok(VsockDriverEventType::Received(cid, length)), - VsockEventType::Disconnected { reason: _ } => Ok(VsockDriverEventType::Disconnected(cid)), - VsockEventType::CreditUpdate => Ok(VsockDriverEventType::CreditUpdate(cid)), - _ => Ok(VsockDriverEventType::Unknown), + let Some(event) = self.inner.poll().map_err(as_driver_error)? else { + return Ok(None); + }; + Ok(Some(Self::translate_event(event))) } }