from contextlib import contextmanager
import faulthandler
import io
import os
import sys
import time
import logging
import multiprocessing
from multiprocessing import Process, Queue, Event
from queue import Empty
from typing import Collection
from twisted.scripts.trial import Options, _getSuite
from twisted.trial.runner import TrialRunner
from django.conf import settings
from django.test.runner import DiscoverRunner
from django.db import connections, DEFAULT_DB_ALIAS
from django.core.exceptions import ImproperlyConfigured
from .base import setup_test_db
# "auto" - one worker per Django application
# "cpu" - one worker per process core
[docs]
WORKER_MAX = getattr(settings, "TEST_RUNNER_WORKER_MAX", 3)
[docs]
WORKER_COUNT = getattr(settings, "TEST_RUNNER_WORKER_COUNT", "auto")
[docs]
NOT_THREAD_SAFE = getattr(settings, "TEST_RUNNER_NOT_THREAD_SAFE", None)
[docs]
PARENT_TIMEOUT = getattr(settings, "TEST_RUNNER_PARENT_TIMEOUT", 10)
[docs]
WORKER_TIMEOUT = getattr(settings, "TEST_RUNNER_WORKER_TIMEOUT", 10)
# amqplib spits out a lot of log messages which just add a lot of noise.
[docs]
logger = logging.getLogger(__name__)
[docs]
null_file = open("/dev/null", "w")
[docs]
class GeoNodeBaseSuiteDiscoverRunner(DiscoverRunner):
def __init__(
self,
pattern=None,
top_level=None,
verbosity=1,
interactive=True,
failfast=True,
keepdb=False,
reverse=False,
debug_mode=False,
debug_sql=False,
parallel=0,
tags=None,
exclude_tags=None,
test_name_patterns=None,
pdb=False,
buffer=False,
enable_faulthandler=True,
timing=False,
**kwargs,
):
[docs]
self.top_level = top_level
[docs]
self.verbosity = verbosity
[docs]
self.interactive = interactive
[docs]
self.failfast = failfast
[docs]
self.debug_mode = debug_mode
[docs]
self.debug_sql = debug_sql
[docs]
self.parallel = parallel
if not faulthandler.is_enabled() and enable_faulthandler:
try:
faulthandler.enable(file=sys.stderr.fileno())
except (AttributeError, io.UnsupportedOperation):
faulthandler.enable(file=sys.__stderr__.fileno())
if self.pdb and self.parallel > 1:
raise ValueError("You cannot use --pdb with parallel tests; pass --parallel=1 to use it.")
if self.buffer and self.parallel > 1:
raise ValueError("You cannot use -b/--buffer with parallel tests; pass " "--parallel=1 to use it.")
[docs]
self.test_name_patterns = None
[docs]
self.time_keeper = TimeKeeper() if timing else NullTimeKeeper()
if test_name_patterns:
# unittest does not export the _convert_select_pattern function
# that converts command-line arguments to patterns.
self.test_name_patterns = {pattern if "*" in pattern else f"*{pattern}*" for pattern in test_name_patterns}
[docs]
class BufferWritesDevice:
def __init__(self):
[docs]
def write(self, string):
self._data.append(string)
[docs]
def read(self):
return "".join(self._data)
[docs]
def flush(self, *args, **kwargs):
pass
[docs]
def isatty(self):
return False
# Redirect stdout to /dev/null because we don't want to see all the repeated
# "database creation" logging statements from all the workers.
# All the test output is printed to stderr to this is not problematic.
sys.stdout = null_file
[docs]
class ParallelTestSuiteRunner:
def __init__(
self,
pattern=None,
top_level=None,
verbosity=1,
interactive=True,
failfast=True,
keepdb=False,
reverse=False,
debug_mode=False,
debug_sql=False,
parallel=0,
tags=None,
exclude_tags=None,
**kwargs,
):
[docs]
self.top_level = top_level
[docs]
self.verbosity = verbosity
[docs]
self.interactive = interactive
[docs]
self.failfast = failfast
[docs]
self.debug_mode = debug_mode
[docs]
self.debug_sql = debug_sql
[docs]
self.parallel = parallel
[docs]
self._keyboard_interrupt_intercepted = False
[docs]
self._worker_max = kwargs.get("worker_max", WORKER_MAX)
[docs]
self._worker_count = kwargs.get("worker_count", WORKER_COUNT)
[docs]
self._not_thread_safe = kwargs.get("not_thread_safe", NOT_THREAD_SAFE) or []
[docs]
self._parent_timeout = kwargs.get("parent_timeout", PARENT_TIMEOUT)
[docs]
self._worker_timeout = kwargs.get("worker_timeout", WORKER_TIMEOUT)
[docs]
self._database_names = self._get_database_names()
[docs]
def _get_database_names(self):
database_names = {}
for alias in connections:
connection = connections[alias]
database_name = connection.settings_dict["NAME"]
database_names[alias] = database_name
return database_names
[docs]
def run_tests(self, test_labels, **kwargs):
return self._run_tests(tests=test_labels)
[docs]
def _run_tests(self, tests, **kwargs):
# tests = dict where the key is a test group name and the value are
# the tests to run
tests_queue = Queue()
results_queue = Queue()
stop_event = Event()
pending_tests = {}
pending_not_thread_safe_tests = {}
completed_tests = {}
failures = 0
errors = 0
start_time = time.time()
# First tun tests which are not thread safe in the main process
for group in self._not_thread_safe:
if group not in tests.keys():
continue
group_tests = tests[group]
del tests[group]
logger.debug(f"Running tests in a main process: {group_tests}")
pending_not_thread_safe_tests[group] = group_tests
result = self._tests_func(tests=group_tests, worker_index=None)
results_queue.put((group, result), block=False)
for group, tests in tests.items():
tests_queue.put((group, tests), block=False)
pending_tests[group] = tests
worker_count = self._worker_count
if worker_count == "auto":
worker_count = len(pending_tests)
elif worker_count == "cpu":
worker_count = multiprocessing.cpu_count()
if worker_count > len(pending_tests):
# No need to spawn more workers then there are tests.
worker_count = len(pending_tests)
worker_max = self._worker_max
if worker_max == "auto":
worker_max = len(pending_tests)
elif worker_max == "cpu":
worker_max = multiprocessing.cpu_count()
if worker_count > worker_max:
# No need to spawn more workers then there are tests.
worker_count = worker_max
worker_args = (tests_queue, results_queue, stop_event)
logger.debug(f"Number of workers {worker_count} ")
workers = self._create_worker_pool(
pool_size=worker_count, target_func=self._run_tests_worker, worker_args=worker_args
)
for index, worker in enumerate(workers):
logger.debug(f"Staring worker {index}")
worker.start()
if workers:
while pending_tests:
try:
try:
group, result = results_queue.get(timeout=self._parent_timeout, block=True)
except Exception:
raise Empty
try:
if group not in pending_not_thread_safe_tests:
pending_tests.pop(group)
else:
pending_not_thread_safe_tests.pop(group)
except KeyError:
logger.debug(f"Got a result for unknown group: {group}")
else:
completed_tests[group] = result
self._print_result(result)
if result.failures or result.errors:
failures += len(result.failures)
errors += len(result.errors)
if self.failfast:
# failfast is enabled, kill all the active workers
# and stop
for worker in workers:
if worker.is_alive():
worker.terminate()
break
except Empty:
worker_left = False
for worker in workers:
if worker.is_alive():
worker_left = True
break
if not worker_left:
break
# We are done, signalize all the workers to stop
stop_event.set()
end_time = time.time()
self._exit(start_time, end_time, failures, errors)
[docs]
def _run_tests_worker(self, index, tests_queue, results_queue, stop_event):
def pop_item():
group, tests = tests_queue.get(timeout=self._worker_timeout)
return group, tests
try:
try:
for group, tests in iter(pop_item, None):
if stop_event.is_set():
# We should stop
break
try:
logger.debug(f"Worker {index} is running tests {tests}")
result = self._tests_func(tests=tests, worker_index=index)
results_queue.put((group, result))
logger.debug(f"Worker {index} has finished running tests {tests}")
except (KeyboardInterrupt, SystemExit):
if isinstance(self, TwistedParallelTestSuiteRunner):
# Twisted raises KeyboardInterrupt when the tests
# have completed
pass
else:
raise
except Exception as e:
logger.debug(f"Running tests failed, reason: {e}")
result = TestResult().from_exception(e)
results_queue.put((group, result))
except Empty:
logger.debug(f"Worker {index} timed out while waiting for tests to run")
finally:
tests_queue.close()
results_queue.close()
logger.debug(f"Worker {index} is stopping")
[docs]
def _pre_tests_func(self):
# This method gets called before _tests_func is called
pass
[docs]
def _post_tests_func(self):
# This method gets called after _tests_func has completed and _print_result
# function is called
pass
[docs]
def _tests_func(self, worker_index):
raise Exception("_tests_func not implements")
[docs]
def _print_result(self, result):
print(result.output, file=sys.stderr)
[docs]
def _exit(self, start_time, end_time, failure_count, error_count):
time_difference = end_time - start_time
print(f"Total run time: {time_difference} seconds", file=sys.stderr)
try:
sys.exit(failure_count + error_count)
except Exception:
pass
[docs]
def _group_by_app(self, test_labels):
"""
Groups tests by an app. This helps to partition tests so they can be run
in separate worker processes.
@TODO: Better partitioning of tests based on the previous runs - measure
test suite run time and partition tests so we can spawn as much workers as
it makes sense to get the maximum performance benefits.
"""
tests = {}
for test_label in test_labels:
if not test_label.find("."):
app = test_label
else:
app = test_label.split(".")[0] + test_label.split(".")[1]
if not tests.get(app):
tests[app] = [test_label]
else:
tests[app].append(test_label)
return tests
[docs]
def _group_by_file(self, test_names):
tests = {}
for test_name in test_names:
tests[test_name] = test_name
return tests
[docs]
def _create_worker_pool(self, pool_size, target_func, worker_args):
workers = []
for index in range(0, pool_size):
args = (index,) + worker_args
worker = Process(target=target_func, args=args)
workers.append(worker)
return workers
[docs]
def log(self, string):
if self.verbosity >= 3:
print(string)
[docs]
class DjangoParallelTestSuiteRunner(ParallelTestSuiteRunner, DiscoverRunner):
def __init__(
self,
pattern=None,
top_level=None,
verbosity=1,
interactive=True,
failfast=True,
keepdb=False,
reverse=False,
debug_mode=False,
debug_sql=False,
parallel=0,
tags=None,
exclude_tags=None,
**kwargs,
):
[docs]
self.top_level = top_level
[docs]
self.verbosity = verbosity
[docs]
self.interactive = interactive
[docs]
self.failfast = failfast
[docs]
self.debug_mode = debug_mode
[docs]
self.debug_sql = debug_sql
[docs]
self.parallel = parallel
[docs]
self._keyboard_interrupt_intercepted = False
[docs]
self._worker_max = kwargs.get("worker_max", WORKER_MAX)
[docs]
self._worker_count = kwargs.get("worker_count", WORKER_COUNT)
[docs]
self._not_thread_safe = kwargs.get("not_thread_safe", NOT_THREAD_SAFE) or []
[docs]
self._parent_timeout = kwargs.get("parent_timeout", PARENT_TIMEOUT)
[docs]
self._worker_timeout = kwargs.get("worker_timeout", WORKER_TIMEOUT)
[docs]
self._database_names = self._get_database_names()
[docs]
def run_tests(self, test_labels, extra_tests=None, **kwargs):
app_tests = self._group_by_app(test_labels)
return self._run_tests(tests=app_tests)
[docs]
def run_suite(self, suite, **kwargs):
return DjangoParallelTestRunner(verbosity=self.verbosity, failfast=self.failfast).run_suite(suite)
[docs]
def setup_databases(self, **kwargs):
return self.setup_databases_18(**kwargs)
[docs]
def dependency_ordered(self, test_databases, dependencies):
"""
Reorder test_databases into an order that honors the dependencies
described in TEST[DEPENDENCIES].
"""
ordered_test_databases = []
resolved_databases = set()
# Maps db signature to dependencies of all its aliases
dependencies_map = {}
# Check that no database depends on its own alias
for sig, (_, aliases) in test_databases:
all_deps = set()
for alias in aliases:
all_deps.update(dependencies.get(alias, []))
if not all_deps.isdisjoint(aliases):
raise ImproperlyConfigured(
f"Circular dependency: databases {aliases!r} depend on each other, but are aliases."
)
dependencies_map[sig] = all_deps
while test_databases:
changed = False
deferred = []
# Try to find a DB that has all its dependencies met
for signature, (db_name, aliases) in test_databases:
if dependencies_map[signature].issubset(resolved_databases):
resolved_databases.update(aliases)
ordered_test_databases.append((signature, (db_name, aliases)))
changed = True
else:
deferred.append((signature, (db_name, aliases)))
if not changed:
raise ImproperlyConfigured("Circular dependency in TEST[DEPENDENCIES]")
test_databases = deferred
return ordered_test_databases
[docs]
def setup_databases_18(self, **kwargs):
mirrored_aliases = {}
test_databases = {}
dependencies = {}
worker_index = kwargs.get("worker_index", None)
for alias in connections:
connection = connections[alias]
database_name = f"test_{worker_index}_{connection.settings_dict['NAME']}"
connection.settings_dict["TEST_NAME"] = database_name
item = test_databases.setdefault(
connection.creation.test_db_signature(), (connection.settings_dict["NAME"], [])
)
item[1].append(alias)
if alias != DEFAULT_DB_ALIAS:
dependencies[alias] = connection.settings_dict.get("TEST_DEPENDENCIES", [DEFAULT_DB_ALIAS])
old_names = []
mirrors = []
for signature, (db_name, aliases) in self.dependency_ordered(list(test_databases.items()), dependencies):
connection = connections[aliases[0]]
old_names.append((connection, db_name, True))
test_db_name = connection.creation.create_test_db(verbosity=0, autoclobber=not self.interactive)
for alias in aliases[1:]:
connection = connections[alias]
if db_name:
old_names.append((connection, db_name, False))
connection.settings_dict["NAME"] = test_db_name
else:
old_names.append((connection, db_name, True))
connection.creation.create_test_db(verbosity=0, autoclobber=not self.interactive)
for alias, mirror_alias in mirrored_aliases.items():
mirrors.append((alias, connections[alias].settings_dict["NAME"]))
connections[alias].settings_dict["NAME"] = connections[mirror_alias].settings_dict["NAME"]
return old_names, mirrors
[docs]
def _tests_func(self, tests, worker_index):
self.setup_test_environment()
suite = self.build_suite(tests, [])
old_config = self.setup_databases(worker_index=worker_index)
result = self.run_suite(suite)
self.teardown_databases(old_config)
self.teardown_test_environment()
result = TestResult().from_django_result(result)
return result
[docs]
class DjangoParallelTestRunner(DiscoverRunner):
def __init__(self, verbosity=2, failfast=True, **kwargs):
[docs]
stream = BufferWritesDevice()
super().__init__(stream=stream, verbosity=verbosity, failfast=failfast)
[docs]
class TwistedParallelTestSuiteRunner(ParallelTestSuiteRunner):
def __init__(self, config, verbosity=1, interactive=False, failfast=True, **kwargs):
super().__init__(verbosity, interactive, failfast, **kwargs)
[docs]
def run_tests(self, test_labels, extra_tests=None, **kwargs):
app_tests = self._group_by_app(test_labels)
return self._run_tests(tests=app_tests)
[docs]
def run_suite(self):
# config = self.config
tests = self.config.opts["tests"]
tests = self._group_by_file(tests)
self._run_tests(tests=tests)
[docs]
def _tests_func(self, tests, worker_index):
if not isinstance(tests, (list, set)):
tests = [tests]
args = ["-e"]
args.extend(tests)
config = Options()
config.parseOptions(args)
stream = BufferWritesDevice()
runner = self._make_runner(config=config, stream=stream)
suite = _getSuite(config)
result = setup_test_db(worker_index, None, runner.run, suite)
result = TestResult().from_trial_result(result)
return result
[docs]
def _make_runner(self, config, stream):
# Based on twisted.scripts.trial._makeRunner
mode = None
if config["debug"]:
mode = TrialRunner.DEBUG
if config["dry-run"]:
mode = TrialRunner.DRY_RUN
return TrialRunner(
config["reporter"],
mode=mode,
stream=stream,
profile=config["profile"],
logfile=config["logfile"],
tracebackFormat=config["tbformat"],
realTimeErrors=config["rterrors"],
uncleanWarnings=config["unclean-warnings"],
workingDirectory=config["temp-directory"],
forceGarbageCollection=config["force-gc"],
)
[docs]
class TestResult:
[docs]
def from_django_result(self, result_obj):
self.dots = result_obj.dots
self.errors = result_obj.errors
self.failures = self._format_failures(result_obj.failures)
try:
self.output = result_obj.stream.read()
except Exception:
pass
return self
[docs]
def from_trial_result(self, result_obj):
self.errors = self._format_failures(result_obj.errors)
self.failures = self._format_failures(result_obj.failures)
try:
self.output = result_obj.stream.read()
except Exception:
pass
return self
[docs]
def from_exception(self, exception):
self.exception = str(exception)
return self
[docs]
class NullTimeKeeper:
@contextmanager
[docs]
def timed(self, name):
yield
[docs]
def print_results(self):
pass
[docs]
class TimeKeeper:
def __init__(self):
[docs]
self.records = Collection.defaultdict(list)
@contextmanager
[docs]
def timed(self, name):
self.records[name]
start_time = time.perf_counter()
try:
yield
finally:
end_time = time.perf_counter() - start_time
self.records[name].append(end_time)
[docs]
def print_results(self):
for name, end_times in self.records.items():
for record_time in end_times:
record = f"{name} took {record_time:.3f}s"
sys.stderr.write(record + os.linesep)