/** * The ThreadPool class. * Keeps a set of threads constantly waiting to execute incoming jobs. * source: http://roar11.com/2016/01/a-platform-independent-thread-pool-using-c14/ */ #pragma once #ifndef THREADPOOL_HPP #define THREADPOOL_HPP #include "ThreadSafeQueue.hpp" #include #include #include #include #include #include #include #include #include #include namespace TP { class ThreadPool { private: class IThreadTask { public: IThreadTask(void) = default; virtual ~IThreadTask(void) = default; IThreadTask(const IThreadTask& rhs) = delete; IThreadTask& operator=(const IThreadTask& rhs) = delete; IThreadTask(IThreadTask&& other) = default; IThreadTask& operator=(IThreadTask&& other) = default; /** * Run the task. */ virtual void execute() = 0; }; template class ThreadTask: public IThreadTask { public: ThreadTask(Func&& func) :m_func{std::move(func)} { } ~ThreadTask(void) override = default; ThreadTask(const ThreadTask& rhs) = delete; ThreadTask& operator=(const ThreadTask& rhs) = delete; ThreadTask(ThreadTask&& other) = default; ThreadTask& operator=(ThreadTask&& other) = default; /** * Run the task. */ void execute() override { m_func(); } private: Func m_func; }; public: /** * A wrapper around a std::future that adds the behavior of futures returned from std::async. * Specifically, this object will block and wait for execution to finish before going out of scope. */ template class TaskFuture { public: TaskFuture(std::future&& future) :m_future{std::move(future)} { } TaskFuture(const TaskFuture& rhs) = delete; TaskFuture& operator=(const TaskFuture& rhs) = delete; TaskFuture(TaskFuture&& other) = default; TaskFuture& operator=(TaskFuture&& other) = default; ~TaskFuture(void) { if(m_future.valid()) { m_future.get(); } } auto get(void) { return m_future.get(); } private: std::future m_future; }; public: /** * Constructor. */ ThreadPool(void) :ThreadPool{std::max(std::thread::hardware_concurrency()/2, 2u) - 1u} { /* * Always create at least one thread. If hardware_concurrency() returns 0, * subtracting one would turn it to UINT_MAX, so get the maximum of * hardware_concurrency() and 2 before subtracting 1. */ } /** * Constructor. */ explicit ThreadPool(const std::uint32_t numThreads) :m_done{false}, m_workQueue{}, m_threads{} { try { for(std::uint32_t i = 0u; i < numThreads; ++i) { m_threads.emplace_back(&ThreadPool::worker, this); } } catch(...) { destroy(); throw; } } /** * Non-copyable. */ ThreadPool(const ThreadPool& rhs) = delete; /** * Non-assignable. */ ThreadPool& operator=(const ThreadPool& rhs) = delete; /** * Destructor. */ ~ThreadPool(void) { destroy(); } auto queueSize() const { return m_workQueue.size(); } /** * Submit a job to be run by the thread pool. */ template auto submit(Func&& func, Args&&... args) { auto boundTask = std::bind(std::forward(func), std::forward(args)...); using ResultType = std::result_of_t; using PackagedTask = std::packaged_task; using TaskType = ThreadTask; PackagedTask task{std::move(boundTask)}; TaskFuture result{task.get_future()}; m_workQueue.push(std::make_unique(std::move(task))); return result; } private: /** * Constantly running function each thread uses to acquire work items from the queue. */ void worker(void) { while(!m_done) { std::unique_ptr pTask{nullptr}; if(m_workQueue.waitPop(pTask)) { pTask->execute(); } } } /** * Invalidates the queue and joins all running threads. */ void destroy(void) { m_done = true; m_workQueue.invalidate(); for(auto& thread : m_threads) { if(thread.joinable()) { thread.join(); } } } private: std::atomic_bool m_done; ThreadSafeQueue> m_workQueue; std::vector m_threads; }; namespace DefaultThreadPool { /** * Get the default thread pool for the application. * This pool is created with std::thread::hardware_concurrency() - 1 threads. */ inline ThreadPool& getThreadPool(void) { static ThreadPool defaultPool; return defaultPool; } inline auto queueSize() { return getThreadPool().queueSize(); } /** * Submit a job to the default thread pool. */ template inline auto submitJob(Func&& func, Args&&... args) { return getThreadPool().submit( std::forward(func), std::forward(args)...); } } } #endif