mirror of https://github.com/zeldaret/tp.git
226 lines
7.2 KiB
Python
226 lines
7.2 KiB
Python
import os
|
|
import tempfile
|
|
import bz2
|
|
import rich
|
|
import pickle
|
|
import time
|
|
import sys
|
|
import gc
|
|
|
|
from rich.progress import Progress
|
|
from typing import Any, Tuple, List, Dict
|
|
from multiprocessing import Manager, Queue, Process, Pool
|
|
from multiprocessing import managers as m
|
|
from queue import Empty
|
|
|
|
from .globals import *
|
|
from .context import Context, MainContext
|
|
from .symbol_table import GlobalSymbolTable
|
|
|
|
class TimeCode:
|
|
def __init__(self, context: Context, text: str):
|
|
self.context = context
|
|
self.text = text
|
|
self.start = 0
|
|
self.end = 0
|
|
|
|
def __enter__(self):
|
|
self.start = time.perf_counter()
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.end = time.perf_counter()
|
|
self.context.debug(f"'{self.text}': {self.end-self.start} seconds")
|
|
|
|
|
|
def _process_entrypoint(input: Queue, output: Queue, shared_file: str):
|
|
"""
|
|
Function running on processes create by 'execute_tasks'.
|
|
Get tasks from the task queue and run them and talk back to the main process using the output queue.
|
|
"""
|
|
|
|
# load shared data if any
|
|
shared = {}
|
|
if shared_file:
|
|
try:
|
|
context = Context(index=-1, output=output)
|
|
with TimeCode(context, "load_shared") as tc:
|
|
gc.disable()
|
|
with open(shared_file, 'rb') as file:
|
|
shared = pickle.load(file)
|
|
gc.enable()
|
|
except:
|
|
exc_type, exc_value, tb = sys.exc_info()
|
|
tb = rich.traceback.Traceback.from_exception(
|
|
exc_type,
|
|
exc_value,
|
|
tb.tb_next if tb else tb,
|
|
)
|
|
context.exception(tb)
|
|
|
|
while True:
|
|
try:
|
|
# get available task
|
|
i, task = input.get(block=False)
|
|
context = Context(index=i, output=output)
|
|
try:
|
|
# execute task
|
|
result = task[0](context, *task[1], **shared)
|
|
context.complete(result)
|
|
except SystemExit:
|
|
context.exit()
|
|
sys.exit(1)
|
|
except:
|
|
# exception inside task, capture exception information and send it back to the main process
|
|
exc_type, exc_value, tb = sys.exc_info()
|
|
tb = rich.traceback.Traceback.from_exception(
|
|
exc_type,
|
|
exc_value,
|
|
tb.tb_next if tb else tb,
|
|
)
|
|
context.exception(tb)
|
|
except KeyboardInterrupt:
|
|
break
|
|
except BrokenPipeError:
|
|
break
|
|
except Empty:
|
|
# to more tasks, exit
|
|
break
|
|
|
|
|
|
def execute_tasks(process_count: int,
|
|
input_tasks: List[Tuple[Any, Any]],
|
|
shared: Dict[str, Any] = {},
|
|
callback: Any = None) -> List[Any]:
|
|
"""
|
|
Creates 'process_count' processes that will together execute the provided tasks.
|
|
"""
|
|
|
|
manager = Manager()
|
|
results = [None] * len(input_tasks)
|
|
|
|
if len(input_tasks) == 1:
|
|
process_count = 0
|
|
|
|
if process_count == 0:
|
|
output = manager.Queue()
|
|
for i, task in enumerate(input_tasks):
|
|
context = MainContext(i, output)
|
|
results[i] = task[0](context, *task[1], **shared)
|
|
callback("complete", i)
|
|
while not output.empty():
|
|
command = output.get(block=True)
|
|
callback(command[0], *command[1])
|
|
|
|
return results
|
|
|
|
|
|
|
|
input = manager.Queue()
|
|
output = manager.Queue()
|
|
timeout = 5 * 60 # if one single task takes more then 5 minutes, something is wrong
|
|
|
|
# instead of copying state for each task, shared state is written to a file which is loaded once per process.
|
|
shared_file = None
|
|
temp_file = None
|
|
if len(shared) > 0:
|
|
context = MainContext(0, None)
|
|
with TimeCode(context, "create_shared") as tc:
|
|
temp_file = tempfile.NamedTemporaryFile(
|
|
"wb", suffix='.dump', prefix="mp_shared", delete=True)
|
|
shared_file = temp_file.name
|
|
debug(f"shared file: '{temp_file.name}'")
|
|
pickle_data = pickle.dumps(shared)
|
|
temp_file.write(pickle_data)
|
|
temp_file.flush()
|
|
|
|
# add tasks to the task queue
|
|
for i, task in enumerate(input_tasks):
|
|
try:
|
|
input.put((i, task))
|
|
except:
|
|
get_console().print_exception()
|
|
error(i)
|
|
error(task)
|
|
fatal_exit()
|
|
|
|
# create the processes
|
|
processors = [
|
|
Process(target=_process_entrypoint, args=(input, output, shared_file))
|
|
for i in range(process_count)
|
|
]
|
|
|
|
# start the processes
|
|
for process in processors:
|
|
process.start()
|
|
|
|
# receive messages
|
|
waiting = len(input_tasks)
|
|
while waiting > 0:
|
|
try:
|
|
command = output.get(block=True, timeout=timeout)
|
|
processing = True
|
|
if callback:
|
|
processing = callback(command[0], *command[1])
|
|
if processing:
|
|
if command[0] == 'debug':
|
|
debug(*command[1])
|
|
elif command[0] == 'warning':
|
|
warning(*command[1])
|
|
elif command[0] == 'error':
|
|
error(*command[1])
|
|
elif command[0] == 'info':
|
|
info(*command[1])
|
|
elif command[0] == 'complete':
|
|
results[command[1][0]] = command[1][1]
|
|
waiting -= 1
|
|
elif command[0] == 'exception':
|
|
waiting -= 1
|
|
print(command[1][1])
|
|
elif command[0] == 'exit':
|
|
sys.exit(1)
|
|
else:
|
|
warning(f"unknown command: {command}")
|
|
except Empty:
|
|
error(f"task took to long to complete (+{timeout} seconds)")
|
|
fatal_exit()
|
|
|
|
# wait for all processes to finish
|
|
for process in processors:
|
|
process.join()
|
|
|
|
# TODO: Maybe we don't need to clear the queue
|
|
while not output.empty():
|
|
command = output.get(block=False)
|
|
warning(f"skipped command: {command}")
|
|
|
|
if temp_file:
|
|
temp_file.close()
|
|
|
|
return results
|
|
|
|
|
|
def apply(process_count: int, func: Any, data: List[Any], shared: Dict[str, Any] = {}, callback=None) -> List[Any]:
|
|
""" Helper method for running 'execute_tasks' where all tasks uses the function. """
|
|
return execute_tasks(process_count, [(func, x) for x in data], shared=shared, callback=callback)
|
|
|
|
|
|
def progress(process_count: int, func: Any, data: List[Any], shared: Dict[str, Any] = {}) -> List[Any]:
|
|
"""
|
|
Helper method for running 'execute_tasks' where all tasks uses the function.
|
|
Displays a progress bar with tasks completed.
|
|
"""
|
|
|
|
with Progress(console=get_console(), transient=True, refresh_per_second=1) as progress:
|
|
task = progress.add_task(f"processing...", total=len(data))
|
|
|
|
def callback(command, *args):
|
|
if command == 'complete' or command == 'exception':
|
|
progress.update(task, advance=1)
|
|
return True
|
|
|
|
return execute_tasks(
|
|
process_count,
|
|
[(func, x) for x in data],
|
|
shared=shared,
|
|
callback=callback)
|