Skip to content

File: Broken/Core/BrokenWorker.py

Broken.Core.BrokenWorker

WorkerType

WorkerType: TypeAlias = Union[Thread, Process]

Any stdlib parallelizable primitive

MANAGER

MANAGER = Manager()

Global multiprocessing manager

BrokenWorker

A semi-complete Thread and Process manager for easy parallelization primitives, smart task queueing, caching results and more.

References: - Independently reinvented https://en.wikipedia.org/wiki/Thread_pool

Source code in Broken/Core/BrokenWorker.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
@define(eq=False)
class BrokenWorker:
    """
    A semi-complete Thread and Process manager for easy parallelization primitives, smart task
    queueing, caching results and more.

    References:
    - Independently reinvented https://en.wikipedia.org/wiki/Thread_pool
    """

    # # Static utilities

    @staticmethod
    def _spawn(
        target: Callable,
        *args: Any,
        daemon: bool=True,
        _type: WorkerType=Thread,
        **kwargs,
    ) -> WorkerType:
        worker = _type(
            target=target,
            daemon=daemon,
            kwargs=kwargs,
            args=args,
        )
        worker.start()
        return worker

    @classmethod
    @functools.wraps(_spawn)
    def thread(cls, *args, **kwargs) -> Thread:
        return cls._spawn(*args, **kwargs, _type=Thread)

    @classmethod
    @functools.wraps(_spawn)
    def process(cls, *args, **kwargs) -> Process:
        return cls._spawn(*args, **kwargs, _type=Process)

    # # Easy lock

    @staticmethod
    @functools.cache
    def easy_lock(method: Callable) -> Callable:
        """Get a wrapper with a common threading.Lock for a method, multi-call safe"""

        shared_lock = Lock()

        @functools.wraps(method)
        def wrapped(*args, **kwargs) -> Any:
            with shared_lock:
                return method(*args, **kwargs)

        return wrapped

    # # Initialization

    type: WorkerType = Thread
    """The primitive to use for parallelization"""

    size: int = field(default=1, converter=int)
    """How many workers to keep alive"""

    workers: set[WorkerType] = Factory(set)
    """The currently alive workers"""

    queue: Union[ThreadQueue, ProcessQueue] = None
    """The list of tasks to be processed"""

    @property
    def queue_type(self) -> type[Queue]:
        if (self.type is Thread):
            return ThreadQueue
        return ProcessQueue

    @property
    def diskcache_enabled(self) -> bool:
        return (self.cache_size and self.cache_path)

    @property
    def cache_dict_type(self) -> type[dict]:
        return (dict if (self.type is Thread) else MANAGER.dict)

    def __attrs_post_init__(self):

        # Initialize DiskCache or dict cache
        if (self.diskcache_enabled):
            self.cache_data = DiskCache(
                directory=Path(self.cache_path),
                size_limit=int(self.cache_size)*MB,
            )
        else:
            self.cache_data = self.cache_dict_type()

        # Initialize remaining items
        self.queue = self.queue_type()
        BrokenWorker.thread(self.keep_alive_thread)

    # # Worker management

    @property
    def alive(self) -> Iterable[WorkerType]:
        """Iterates over the alive workers"""
        for worker in self.workers:
            if worker.is_alive():
                yield worker

    @property
    def still_alive(self) -> int:
        """Believe me, I am still alive"""
        return sum(1 for _ in self.alive)

    def sanitize(self) -> None:
        """Removes dead workers on the set"""
        self.workers = set(self.alive)

    def join_workers(self, timeout: Optional[float]=None) -> None:
        """Waits for all workers to finish"""
        for worker in copy.copy(self.workers):
            worker.join(timeout)

    # # Caching

    cache_data: Union[dict, DiskCache] = None
    """The cached results database"""

    cache_path: Path = None
    """(DiskCache) Path to the cache directory, disabled if None"""

    cache_size: int = 500
    """(DiskCache) Maximum size of the cache in megabytes"""

    def clear_cache(self) -> None:
        self.cache_data.clear()

    # # Serde middleware for Process

    def __serialize__(self, object: Any) -> Any:
        if (self.type is Process):
            return dill.dumps(object, recurse=True)
        return object

    def __deserialize__(self, object: Any) -> Any:
        if (self.type is Process):
            return dill.loads(object)
        return object

    # # Tasks

    def join_tasks(self) -> None:
        """Waits for all tasks to finish"""
        self.queue.join()

    def put(self, task: Hashable) -> Hashable:
        """Submit a new task directly to the queue"""
        return (self.queue.put(self.__serialize__(task)) or task)

    @abstractmethod
    def get(self, task: Hashable) -> Optional[Any]:
        """Get the result of a task, keeping it on cache (non-blocking)"""
        result = self.cache_data.get(hash(task), None)

        # Remove errors from cache to allow re-queueing
        if isinstance(result, Exception):
            return self.pop(task)

        return result

    def get_blocking(self, task: Hashable) -> Any:
        """Get the result of a task, keeping it on cache (waits to finish)"""
        while (result := self.get(hash(task))) is None:
            time.sleep(0.1)
        return result

    def pop(self, task: Hashable) -> Any:
        """Get the result of a task, removing it from cache"""
        return self.cache_data.pop(hash(task))

    def call(self, method: Callable, *args, **kwargs) -> Hashable:
        """Submit a new task to call a method with args and kwargs"""
        return self.put(functools.partial(method, *args, **kwargs))

    def get_smart(self, task: Hashable) -> Any:
        """Queues the task if not on cache, returns the result (blocking)"""
        if (result := self.get(task)) is None:
            return self.get_blocking(self.put(task))
        return result

    def map(self, *tasks: Hashable) -> List:
        """Puts all tasks in the queue and returns the results in order"""
        tasks = flatten(tasks)

        # Queues tasks not present in cache
        for task, result in zip(tasks, map(self.get, tasks)):
            if (result is None):
                self.put(task)

        # Returns the results in order
        return list(map(self.get_blocking, tasks))

    def map_call(self, method: Callable, inputs: Iterable, **kwargs) -> List:
        """Maps a method to a list of inputs, returns the results in order"""
        return self.map((
            functools.partial(method, item, **kwargs)
            for item in inputs
        ))

    # # Context

    def __enter__(self) -> Self:
        return self

    def __exit__(self, *args) -> None:
        self.close()

    def close(self) -> None:
        self.join_tasks()
        self.size = 0

        # Poison pill until it all ends
        while self.still_alive:
            while (self.queue.qsize() > 0):
                if (not self.still_alive):
                    break
                time.sleep(0.001)
            self.queue.put(None)

        # Avoid queue leftovers next use
        self.queue = self.queue_type()
        self.join_workers()

    # # Automation

    @easyloop
    def keep_alive_thread(self) -> None:
        """Ensures 'size' workers are running the supervisor"""
        while (self.still_alive < self.size):
            self.workers.add(self._spawn(
                target=self.__supervisor__,
                _type=self.type
            ))
        time.sleep(0.5)

    def __supervisor__(self) -> None:
        """Automatically handle getting tasks and storing results"""
        task: Any = None

        # Tracks new current 'task's, stops on None
        def get_tasks() -> Generator:
            nonlocal task

            while True:
                try:
                    if (task := self.queue.get(block=True)) is not None:
                        yield (task := self.__deserialize__(task))
                        continue
                    break
                finally:
                    self.queue.task_done()

        # Optional results are 'yielded', fail on non-generator main
        if not inspect.isgeneratorfunction(self.main):
            raise TypeError((
                f"{type(self).__name__}.main() function must be a generator, "
                "either 'yield result' or 'yield None' on the code."
            ))

        try:
            # Wraps 'main' outputs and store results
            for result in self.main(get_tasks()):
                self.store(task, result)
        except GeneratorExit:
            pass
        except Exception as error:
            self.store(task, error)
            raise error

    def store(self, task: Hashable, result: Optional[Any]) -> None:
        if (result is not None):
            self.cache_data[hash(task)] = result

    # # Specific implementations

    @abstractmethod
    def main(self, tasks: Iterable) -> Generator:
        """A worker gets tasks and yields optional results to be cached"""
        log.success(f"Worker {self.type.__name__} started")

        for task in tasks:
            yield task()

thread

thread(*args, **kwargs) -> Thread
Source code in Broken/Core/BrokenWorker.py
61
62
63
64
@classmethod
@functools.wraps(_spawn)
def thread(cls, *args, **kwargs) -> Thread:
    return cls._spawn(*args, **kwargs, _type=Thread)

process

process(*args, **kwargs) -> Process
Source code in Broken/Core/BrokenWorker.py
66
67
68
69
@classmethod
@functools.wraps(_spawn)
def process(cls, *args, **kwargs) -> Process:
    return cls._spawn(*args, **kwargs, _type=Process)

easy_lock

easy_lock(method: Callable) -> Callable

Get a wrapper with a common threading.Lock for a method, multi-call safe

Source code in Broken/Core/BrokenWorker.py
73
74
75
76
77
78
79
80
81
82
83
84
85
@staticmethod
@functools.cache
def easy_lock(method: Callable) -> Callable:
    """Get a wrapper with a common threading.Lock for a method, multi-call safe"""

    shared_lock = Lock()

    @functools.wraps(method)
    def wrapped(*args, **kwargs) -> Any:
        with shared_lock:
            return method(*args, **kwargs)

    return wrapped

type

type: WorkerType = Thread

The primitive to use for parallelization

size

size: int = field(default=1, converter=int)

How many workers to keep alive

workers

workers: set[WorkerType] = Factory(set)

The currently alive workers

queue

queue: Union[ThreadQueue, ProcessQueue] = None

The list of tasks to be processed

queue_type

queue_type: type[Queue]

diskcache_enabled

diskcache_enabled: bool

cache_dict_type

cache_dict_type: type[dict]

__attrs_post_init__

__attrs_post_init__()
Source code in Broken/Core/BrokenWorker.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __attrs_post_init__(self):

    # Initialize DiskCache or dict cache
    if (self.diskcache_enabled):
        self.cache_data = DiskCache(
            directory=Path(self.cache_path),
            size_limit=int(self.cache_size)*MB,
        )
    else:
        self.cache_data = self.cache_dict_type()

    # Initialize remaining items
    self.queue = self.queue_type()
    BrokenWorker.thread(self.keep_alive_thread)

alive

alive: Iterable[WorkerType]

Iterates over the alive workers

still_alive

still_alive: int

Believe me, I am still alive

sanitize

sanitize() -> None

Removes dead workers on the set

Source code in Broken/Core/BrokenWorker.py
144
145
146
def sanitize(self) -> None:
    """Removes dead workers on the set"""
    self.workers = set(self.alive)

join_workers

join_workers(timeout: Optional[float] = None) -> None

Waits for all workers to finish

Source code in Broken/Core/BrokenWorker.py
148
149
150
151
def join_workers(self, timeout: Optional[float]=None) -> None:
    """Waits for all workers to finish"""
    for worker in copy.copy(self.workers):
        worker.join(timeout)

cache_data

cache_data: Union[dict, DiskCache] = None

The cached results database

cache_path

cache_path: Path = None

(DiskCache) Path to the cache directory, disabled if None

cache_size

cache_size: int = 500

(DiskCache) Maximum size of the cache in megabytes

clear_cache

clear_cache() -> None
Source code in Broken/Core/BrokenWorker.py
164
165
def clear_cache(self) -> None:
    self.cache_data.clear()

__serialize__

__serialize__(object: Any) -> Any
Source code in Broken/Core/BrokenWorker.py
169
170
171
172
def __serialize__(self, object: Any) -> Any:
    if (self.type is Process):
        return dill.dumps(object, recurse=True)
    return object

__deserialize__

__deserialize__(object: Any) -> Any
Source code in Broken/Core/BrokenWorker.py
174
175
176
177
def __deserialize__(self, object: Any) -> Any:
    if (self.type is Process):
        return dill.loads(object)
    return object

join_tasks

join_tasks() -> None

Waits for all tasks to finish

Source code in Broken/Core/BrokenWorker.py
181
182
183
def join_tasks(self) -> None:
    """Waits for all tasks to finish"""
    self.queue.join()

put

put(task: Hashable) -> Hashable

Submit a new task directly to the queue

Source code in Broken/Core/BrokenWorker.py
185
186
187
def put(self, task: Hashable) -> Hashable:
    """Submit a new task directly to the queue"""
    return (self.queue.put(self.__serialize__(task)) or task)

get

get(task: Hashable) -> Optional[Any]

Get the result of a task, keeping it on cache (non-blocking)

Source code in Broken/Core/BrokenWorker.py
189
190
191
192
193
194
195
196
197
198
@abstractmethod
def get(self, task: Hashable) -> Optional[Any]:
    """Get the result of a task, keeping it on cache (non-blocking)"""
    result = self.cache_data.get(hash(task), None)

    # Remove errors from cache to allow re-queueing
    if isinstance(result, Exception):
        return self.pop(task)

    return result

get_blocking

get_blocking(task: Hashable) -> Any

Get the result of a task, keeping it on cache (waits to finish)

Source code in Broken/Core/BrokenWorker.py
200
201
202
203
204
def get_blocking(self, task: Hashable) -> Any:
    """Get the result of a task, keeping it on cache (waits to finish)"""
    while (result := self.get(hash(task))) is None:
        time.sleep(0.1)
    return result

pop

pop(task: Hashable) -> Any

Get the result of a task, removing it from cache

Source code in Broken/Core/BrokenWorker.py
206
207
208
def pop(self, task: Hashable) -> Any:
    """Get the result of a task, removing it from cache"""
    return self.cache_data.pop(hash(task))

call

call(method: Callable, *args, **kwargs) -> Hashable

Submit a new task to call a method with args and kwargs

Source code in Broken/Core/BrokenWorker.py
210
211
212
def call(self, method: Callable, *args, **kwargs) -> Hashable:
    """Submit a new task to call a method with args and kwargs"""
    return self.put(functools.partial(method, *args, **kwargs))

get_smart

get_smart(task: Hashable) -> Any

Queues the task if not on cache, returns the result (blocking)

Source code in Broken/Core/BrokenWorker.py
214
215
216
217
218
def get_smart(self, task: Hashable) -> Any:
    """Queues the task if not on cache, returns the result (blocking)"""
    if (result := self.get(task)) is None:
        return self.get_blocking(self.put(task))
    return result

map

map(*tasks: Hashable) -> List

Puts all tasks in the queue and returns the results in order

Source code in Broken/Core/BrokenWorker.py
220
221
222
223
224
225
226
227
228
229
230
def map(self, *tasks: Hashable) -> List:
    """Puts all tasks in the queue and returns the results in order"""
    tasks = flatten(tasks)

    # Queues tasks not present in cache
    for task, result in zip(tasks, map(self.get, tasks)):
        if (result is None):
            self.put(task)

    # Returns the results in order
    return list(map(self.get_blocking, tasks))

map_call

map_call(
    method: Callable, inputs: Iterable, **kwargs
) -> List

Maps a method to a list of inputs, returns the results in order

Source code in Broken/Core/BrokenWorker.py
232
233
234
235
236
237
def map_call(self, method: Callable, inputs: Iterable, **kwargs) -> List:
    """Maps a method to a list of inputs, returns the results in order"""
    return self.map((
        functools.partial(method, item, **kwargs)
        for item in inputs
    ))

__enter__

__enter__() -> Self
Source code in Broken/Core/BrokenWorker.py
241
242
def __enter__(self) -> Self:
    return self

__exit__

__exit__(*args) -> None
Source code in Broken/Core/BrokenWorker.py
244
245
def __exit__(self, *args) -> None:
    self.close()

close

close() -> None
Source code in Broken/Core/BrokenWorker.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def close(self) -> None:
    self.join_tasks()
    self.size = 0

    # Poison pill until it all ends
    while self.still_alive:
        while (self.queue.qsize() > 0):
            if (not self.still_alive):
                break
            time.sleep(0.001)
        self.queue.put(None)

    # Avoid queue leftovers next use
    self.queue = self.queue_type()
    self.join_workers()

keep_alive_thread

keep_alive_thread() -> None

Ensures 'size' workers are running the supervisor

Source code in Broken/Core/BrokenWorker.py
265
266
267
268
269
270
271
272
273
@easyloop
def keep_alive_thread(self) -> None:
    """Ensures 'size' workers are running the supervisor"""
    while (self.still_alive < self.size):
        self.workers.add(self._spawn(
            target=self.__supervisor__,
            _type=self.type
        ))
    time.sleep(0.5)

__supervisor__

__supervisor__() -> None

Automatically handle getting tasks and storing results

Source code in Broken/Core/BrokenWorker.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def __supervisor__(self) -> None:
    """Automatically handle getting tasks and storing results"""
    task: Any = None

    # Tracks new current 'task's, stops on None
    def get_tasks() -> Generator:
        nonlocal task

        while True:
            try:
                if (task := self.queue.get(block=True)) is not None:
                    yield (task := self.__deserialize__(task))
                    continue
                break
            finally:
                self.queue.task_done()

    # Optional results are 'yielded', fail on non-generator main
    if not inspect.isgeneratorfunction(self.main):
        raise TypeError((
            f"{type(self).__name__}.main() function must be a generator, "
            "either 'yield result' or 'yield None' on the code."
        ))

    try:
        # Wraps 'main' outputs and store results
        for result in self.main(get_tasks()):
            self.store(task, result)
    except GeneratorExit:
        pass
    except Exception as error:
        self.store(task, error)
        raise error

store

store(task: Hashable, result: Optional[Any]) -> None
Source code in Broken/Core/BrokenWorker.py
309
310
311
def store(self, task: Hashable, result: Optional[Any]) -> None:
    if (result is not None):
        self.cache_data[hash(task)] = result

main

main(tasks: Iterable) -> Generator

A worker gets tasks and yields optional results to be cached

Source code in Broken/Core/BrokenWorker.py
315
316
317
318
319
320
321
@abstractmethod
def main(self, tasks: Iterable) -> Generator:
    """A worker gets tasks and yields optional results to be cached"""
    log.success(f"Worker {self.type.__name__} started")

    for task in tasks:
        yield task()