IOS: Keep copy of m_devices in USBHost

This gets rid of the ugly direct access to USBScanner::m_devices that
was introduced by the previous commit.

This also fixes a potential thread safety issue.
USB_HIDv4::TriggerDeviceChangeReply loops through m_devices and calls
GetDeviceEntry for each device. If USB_HIDv4::TriggerDeviceChangeReply
is called after a new device is added to m_devices but before hooks are
dispatched, GetDeviceEntry crashes, because the hook that's supposed to
update m_device_ids hasn't run yet. With this commit, this issue can no
longer happen, because USBHost::m_devices_mutex doesn't get unlocked in
between updating m_devices and dispatching the hooks.
This commit is contained in:
JosJuice 2025-04-06 14:43:45 +02:00
parent 427e9c5ad2
commit 50a8ae9d90
5 changed files with 38 additions and 19 deletions

View file

@ -4,7 +4,9 @@
#include "Core/IOS/USB/Host.h" #include "Core/IOS/USB/Host.h"
#include <functional> #include <functional>
#include <map>
#include <memory> #include <memory>
#include <mutex>
#include <optional> #include <optional>
#include <string> #include <string>
@ -61,7 +63,11 @@ void USBHost::DoState(PointerWrap& p)
std::shared_ptr<USB::Device> USBHost::GetDeviceById(const u64 device_id) const std::shared_ptr<USB::Device> USBHost::GetDeviceById(const u64 device_id) const
{ {
return m_usb_scanner.GetDeviceById(device_id); std::lock_guard lk(m_devices_mutex);
const auto it = m_devices.find(device_id);
if (it == m_devices.end())
return nullptr;
return it->second;
} }
void USBHost::OnDeviceChange(ChangeEvent event, std::shared_ptr<USB::Device> changed_device) void USBHost::OnDeviceChange(ChangeEvent event, std::shared_ptr<USB::Device> changed_device)
@ -85,13 +91,22 @@ void USBHost::Update()
void USBHost::DispatchHooks(const DeviceChangeHooks& hooks) void USBHost::DispatchHooks(const DeviceChangeHooks& hooks)
{ {
for (const auto& hook : hooks) std::lock_guard lk(m_devices_mutex);
for (const auto& [device, event] : hooks)
{ {
INFO_LOG_FMT(IOS_USB, "{} - {} device: {:04x}:{:04x}", GetDeviceName(), INFO_LOG_FMT(IOS_USB, "{} - {} device: {:04x}:{:04x}", GetDeviceName(),
hook.second == ChangeEvent::Inserted ? "New" : "Removed", hook.first->GetVid(), event == ChangeEvent::Inserted ? "New" : "Removed", device->GetVid(),
hook.first->GetPid()); device->GetPid());
OnDeviceChange(hook.second, hook.first);
if (event == ChangeEvent::Inserted)
m_devices.emplace(device->GetId(), device);
else if (event == ChangeEvent::Removed)
m_devices.erase(device->GetId());
OnDeviceChange(event, device);
} }
if (!hooks.empty()) if (!hooks.empty())
OnDeviceChangeEnd(); OnDeviceChangeEnd();
} }

View file

@ -4,7 +4,9 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <map>
#include <memory> #include <memory>
#include <mutex>
#include <optional> #include <optional>
#include <string> #include <string>
@ -45,6 +47,9 @@ protected:
std::optional<IPCReply> HandleTransfer(std::shared_ptr<USB::Device> device, u32 request, std::optional<IPCReply> HandleTransfer(std::shared_ptr<USB::Device> device, u32 request,
std::function<s32()> submit) const; std::function<s32()> submit) const;
std::map<u64, std::shared_ptr<USB::Device>> m_devices;
mutable std::recursive_mutex m_devices_mutex;
USBScanner m_usb_scanner{this}; USBScanner m_usb_scanner{this};
private: private:

View file

@ -73,7 +73,7 @@ std::optional<IPCReply> OH0::IOCtlV(const IOCtlVRequest& request)
void OH0::DoState(PointerWrap& p) void OH0::DoState(PointerWrap& p)
{ {
if (p.IsReadMode() && !m_usb_scanner.m_devices.empty()) if (p.IsReadMode() && !m_devices.empty())
{ {
Core::DisplayMessage("It is suggested that you unplug and replug all connected USB devices.", Core::DisplayMessage("It is suggested that you unplug and replug all connected USB devices.",
5000); 5000);
@ -114,8 +114,8 @@ IPCReply OH0::GetDeviceList(const IOCtlVRequest& request) const
const u8 interface_class = memory.Read_U8(request.in_vectors[1].address); const u8 interface_class = memory.Read_U8(request.in_vectors[1].address);
u8 entries_count = 0; u8 entries_count = 0;
std::lock_guard lk(m_usb_scanner.m_devices_mutex); std::lock_guard lk(m_devices_mutex);
for (const auto& device : m_usb_scanner.m_devices) for (const auto& device : m_devices)
{ {
if (entries_count >= max_entries_count) if (entries_count >= max_entries_count)
break; break;
@ -231,14 +231,13 @@ std::optional<IPCReply> OH0::RegisterClassChangeHook(const IOCtlVRequest& reques
bool OH0::HasDeviceWithVidPid(const u16 vid, const u16 pid) const bool OH0::HasDeviceWithVidPid(const u16 vid, const u16 pid) const
{ {
return std::ranges::any_of(m_usb_scanner.m_devices, [=](const auto& device) { return std::ranges::any_of(m_devices, [=](const auto& device) {
return device.second->GetVid() == vid && device.second->GetPid() == pid; return device.second->GetVid() == vid && device.second->GetPid() == pid;
}); });
} }
void OH0::OnDeviceChange(const ChangeEvent event, std::shared_ptr<USB::Device> device) void OH0::OnDeviceChange(const ChangeEvent event, std::shared_ptr<USB::Device> device)
{ {
std::lock_guard lk(m_usb_scanner.m_devices_mutex);
if (event == ChangeEvent::Inserted) if (event == ChangeEvent::Inserted)
TriggerHook(m_insertion_hooks, {device->GetVid(), device->GetPid()}, IPC_SUCCESS); TriggerHook(m_insertion_hooks, {device->GetVid(), device->GetPid()}, IPC_SUCCESS);
else if (event == ChangeEvent::Removed) else if (event == ChangeEvent::Removed)
@ -259,10 +258,10 @@ void OH0::TriggerHook(std::map<T, u32>& hooks, T value, const ReturnCode return_
std::pair<ReturnCode, u64> OH0::DeviceOpen(const u16 vid, const u16 pid) std::pair<ReturnCode, u64> OH0::DeviceOpen(const u16 vid, const u16 pid)
{ {
std::lock_guard lk(m_usb_scanner.m_devices_mutex); std::lock_guard lk(m_devices_mutex);
bool has_device_with_vid_pid = false; bool has_device_with_vid_pid = false;
for (const auto& device : m_usb_scanner.m_devices) for (const auto& device : m_devices)
{ {
if (device.second->GetVid() != vid || device.second->GetPid() != pid) if (device.second->GetVid() != vid || device.second->GetPid() != pid)
continue; continue;

View file

@ -32,8 +32,6 @@ public:
void WaitForFirstScan(); void WaitForFirstScan();
bool UpdateDevices(bool always_add_hooks = false); bool UpdateDevices(bool always_add_hooks = false);
std::shared_ptr<USB::Device> GetDeviceById(u64 device_id) const;
enum class ChangeEvent enum class ChangeEvent
{ {
Inserted, Inserted,
@ -41,9 +39,6 @@ public:
}; };
using DeviceChangeHooks = std::map<std::shared_ptr<USB::Device>, ChangeEvent>; using DeviceChangeHooks = std::map<std::shared_ptr<USB::Device>, ChangeEvent>;
std::map<u64, std::shared_ptr<USB::Device>> m_devices;
mutable std::mutex m_devices_mutex;
private: private:
bool AddDevice(std::unique_ptr<USB::Device> device); bool AddDevice(std::unique_ptr<USB::Device> device);
bool AddNewDevices(std::set<u64>& new_devices, DeviceChangeHooks& hooks, bool always_add_hooks); bool AddNewDevices(std::set<u64>& new_devices, DeviceChangeHooks& hooks, bool always_add_hooks);
@ -53,6 +48,11 @@ private:
void CheckAndAddDevice(std::unique_ptr<USB::Device> device, std::set<u64>& new_devices, void CheckAndAddDevice(std::unique_ptr<USB::Device> device, std::set<u64>& new_devices,
DeviceChangeHooks& hooks, bool always_add_hooks); DeviceChangeHooks& hooks, bool always_add_hooks);
std::shared_ptr<USB::Device> GetDeviceById(u64 device_id) const;
std::map<u64, std::shared_ptr<USB::Device>> m_devices;
mutable std::mutex m_devices_mutex;
USBHost* m_host = nullptr; USBHost* m_host = nullptr;
Common::Flag m_thread_running; Common::Flag m_thread_running;
std::thread m_thread; std::thread m_thread;

View file

@ -207,10 +207,10 @@ void USB_HIDv4::TriggerDeviceChangeReply()
auto& memory = system.GetMemory(); auto& memory = system.GetMemory();
{ {
std::lock_guard lk(m_usb_scanner.m_devices_mutex); std::lock_guard lk(m_devices_mutex);
const u32 dest = m_devicechange_hook_request->buffer_out; const u32 dest = m_devicechange_hook_request->buffer_out;
u32 offset = 0; u32 offset = 0;
for (const auto& device : m_usb_scanner.m_devices) for (const auto& device : m_devices)
{ {
const std::vector<u8> device_section = GetDeviceEntry(*device.second.get()); const std::vector<u8> device_section = GetDeviceEntry(*device.second.get());
if (offset + device_section.size() > m_devicechange_hook_request->buffer_out_size - 1) if (offset + device_section.size() > m_devicechange_hook_request->buffer_out_size - 1)