Source code for geonode.tests.suite.runner

from contextlib import contextmanager
import faulthandler
import io
import os
import sys
import time
import logging
import multiprocessing

from multiprocessing import Process, Queue, Event
from queue import Empty
from typing import Collection

from twisted.scripts.trial import Options, _getSuite
from twisted.trial.runner import TrialRunner

from django.conf import settings

from django.test.runner import DiscoverRunner
from django.db import connections, DEFAULT_DB_ALIAS
from django.core.exceptions import ImproperlyConfigured

from .base import setup_test_db

# "auto" - one worker per Django application
# "cpu" - one worker per process core
[docs] WORKER_MAX = getattr(settings, "TEST_RUNNER_WORKER_MAX", 3)
[docs] WORKER_COUNT = getattr(settings, "TEST_RUNNER_WORKER_COUNT", "auto")
[docs] NOT_THREAD_SAFE = getattr(settings, "TEST_RUNNER_NOT_THREAD_SAFE", None)
[docs] PARENT_TIMEOUT = getattr(settings, "TEST_RUNNER_PARENT_TIMEOUT", 10)
[docs] WORKER_TIMEOUT = getattr(settings, "TEST_RUNNER_WORKER_TIMEOUT", 10)
# amqplib spits out a lot of log messages which just add a lot of noise.
[docs] logger = logging.getLogger(__name__)
[docs] null_file = open("/dev/null", "w")
[docs] class GeoNodeBaseSuiteDiscoverRunner(DiscoverRunner): def __init__( self, pattern=None, top_level=None, verbosity=1, interactive=True, failfast=True, keepdb=False, reverse=False, debug_mode=False, debug_sql=False, parallel=0, tags=None, exclude_tags=None, test_name_patterns=None, pdb=False, buffer=False, enable_faulthandler=True, timing=False, **kwargs, ):
[docs] self.pattern = pattern
[docs] self.top_level = top_level
[docs] self.verbosity = verbosity
[docs] self.interactive = interactive
[docs] self.failfast = failfast
[docs] self.keepdb = keepdb
[docs] self.reverse = reverse
[docs] self.debug_mode = debug_mode
[docs] self.debug_sql = debug_sql
[docs] self.parallel = parallel
[docs] self.tags = set(tags or [])
[docs] self.exclude_tags = set(exclude_tags or [])
if not faulthandler.is_enabled() and enable_faulthandler: try: faulthandler.enable(file=sys.stderr.fileno()) except (AttributeError, io.UnsupportedOperation): faulthandler.enable(file=sys.__stderr__.fileno())
[docs] self.pdb = pdb
if self.pdb and self.parallel > 1: raise ValueError("You cannot use --pdb with parallel tests; pass --parallel=1 to use it.")
[docs] self.buffer = buffer
if self.buffer and self.parallel > 1: raise ValueError("You cannot use -b/--buffer with parallel tests; pass " "--parallel=1 to use it.")
[docs] self.test_name_patterns = None
[docs] self.time_keeper = TimeKeeper() if timing else NullTimeKeeper()
if test_name_patterns: # unittest does not export the _convert_select_pattern function # that converts command-line arguments to patterns. self.test_name_patterns = {pattern if "*" in pattern else f"*{pattern}*" for pattern in test_name_patterns}
[docs] class BufferWritesDevice: def __init__(self):
[docs] self._data = []
[docs] def write(self, string): self._data.append(string)
[docs] def read(self): return "".join(self._data)
[docs] def flush(self, *args, **kwargs): pass
[docs] def isatty(self): return False
# Redirect stdout to /dev/null because we don't want to see all the repeated # "database creation" logging statements from all the workers. # All the test output is printed to stderr to this is not problematic. sys.stdout = null_file
[docs] class ParallelTestSuiteRunner: def __init__( self, pattern=None, top_level=None, verbosity=1, interactive=True, failfast=True, keepdb=False, reverse=False, debug_mode=False, debug_sql=False, parallel=0, tags=None, exclude_tags=None, **kwargs, ):
[docs] self.pattern = pattern
[docs] self.top_level = top_level
[docs] self.verbosity = verbosity
[docs] self.interactive = interactive
[docs] self.failfast = failfast
[docs] self.keepdb = keepdb
[docs] self.reverse = reverse
[docs] self.debug_mode = debug_mode
[docs] self.debug_sql = debug_sql
[docs] self.parallel = parallel
[docs] self.tags = set(tags or [])
[docs] self.exclude_tags = set(exclude_tags or [])
[docs] self._keyboard_interrupt_intercepted = False
[docs] self._worker_max = kwargs.get("worker_max", WORKER_MAX)
[docs] self._worker_count = kwargs.get("worker_count", WORKER_COUNT)
[docs] self._not_thread_safe = kwargs.get("not_thread_safe", NOT_THREAD_SAFE) or []
[docs] self._parent_timeout = kwargs.get("parent_timeout", PARENT_TIMEOUT)
[docs] self._worker_timeout = kwargs.get("worker_timeout", WORKER_TIMEOUT)
[docs] self._database_names = self._get_database_names()
[docs] def _get_database_names(self): database_names = {} for alias in connections: connection = connections[alias] database_name = connection.settings_dict["NAME"] database_names[alias] = database_name return database_names
[docs] def run_tests(self, test_labels, **kwargs): return self._run_tests(tests=test_labels)
[docs] def _run_tests(self, tests, **kwargs): # tests = dict where the key is a test group name and the value are # the tests to run tests_queue = Queue() results_queue = Queue() stop_event = Event() pending_tests = {} pending_not_thread_safe_tests = {} completed_tests = {} failures = 0 errors = 0 start_time = time.time() # First tun tests which are not thread safe in the main process for group in self._not_thread_safe: if group not in tests.keys(): continue group_tests = tests[group] del tests[group] logger.debug(f"Running tests in a main process: {group_tests}") pending_not_thread_safe_tests[group] = group_tests result = self._tests_func(tests=group_tests, worker_index=None) results_queue.put((group, result), block=False) for group, tests in tests.items(): tests_queue.put((group, tests), block=False) pending_tests[group] = tests worker_count = self._worker_count if worker_count == "auto": worker_count = len(pending_tests) elif worker_count == "cpu": worker_count = multiprocessing.cpu_count() if worker_count > len(pending_tests): # No need to spawn more workers then there are tests. worker_count = len(pending_tests) worker_max = self._worker_max if worker_max == "auto": worker_max = len(pending_tests) elif worker_max == "cpu": worker_max = multiprocessing.cpu_count() if worker_count > worker_max: # No need to spawn more workers then there are tests. worker_count = worker_max worker_args = (tests_queue, results_queue, stop_event) logger.debug(f"Number of workers {worker_count} ") workers = self._create_worker_pool( pool_size=worker_count, target_func=self._run_tests_worker, worker_args=worker_args ) for index, worker in enumerate(workers): logger.debug(f"Staring worker {index}") worker.start() if workers: while pending_tests: try: try: group, result = results_queue.get(timeout=self._parent_timeout, block=True) except Exception: raise Empty try: if group not in pending_not_thread_safe_tests: pending_tests.pop(group) else: pending_not_thread_safe_tests.pop(group) except KeyError: logger.debug(f"Got a result for unknown group: {group}") else: completed_tests[group] = result self._print_result(result) if result.failures or result.errors: failures += len(result.failures) errors += len(result.errors) if self.failfast: # failfast is enabled, kill all the active workers # and stop for worker in workers: if worker.is_alive(): worker.terminate() break except Empty: worker_left = False for worker in workers: if worker.is_alive(): worker_left = True break if not worker_left: break # We are done, signalize all the workers to stop stop_event.set() end_time = time.time() self._exit(start_time, end_time, failures, errors)
[docs] def _run_tests_worker(self, index, tests_queue, results_queue, stop_event): def pop_item(): group, tests = tests_queue.get(timeout=self._worker_timeout) return group, tests try: try: for group, tests in iter(pop_item, None): if stop_event.is_set(): # We should stop break try: logger.debug(f"Worker {index} is running tests {tests}") result = self._tests_func(tests=tests, worker_index=index) results_queue.put((group, result)) logger.debug(f"Worker {index} has finished running tests {tests}") except (KeyboardInterrupt, SystemExit): if isinstance(self, TwistedParallelTestSuiteRunner): # Twisted raises KeyboardInterrupt when the tests # have completed pass else: raise except Exception as e: logger.debug(f"Running tests failed, reason: {e}") result = TestResult().from_exception(e) results_queue.put((group, result)) except Empty: logger.debug(f"Worker {index} timed out while waiting for tests to run") finally: tests_queue.close() results_queue.close() logger.debug(f"Worker {index} is stopping")
[docs] def _pre_tests_func(self): # This method gets called before _tests_func is called pass
[docs] def _post_tests_func(self): # This method gets called after _tests_func has completed and _print_result # function is called pass
[docs] def _tests_func(self, worker_index): raise Exception("_tests_func not implements")
[docs] def _print_result(self, result): print(result.output, file=sys.stderr)
[docs] def _exit(self, start_time, end_time, failure_count, error_count): time_difference = end_time - start_time print(f"Total run time: {time_difference} seconds", file=sys.stderr) try: sys.exit(failure_count + error_count) except Exception: pass
[docs] def _group_by_app(self, test_labels): """ Groups tests by an app. This helps to partition tests so they can be run in separate worker processes. @TODO: Better partitioning of tests based on the previous runs - measure test suite run time and partition tests so we can spawn as much workers as it makes sense to get the maximum performance benefits. """ tests = {} for test_label in test_labels: if not test_label.find("."): app = test_label else: app = test_label.split(".")[0] + test_label.split(".")[1] if not tests.get(app): tests[app] = [test_label] else: tests[app].append(test_label) return tests
[docs] def _group_by_file(self, test_names): tests = {} for test_name in test_names: tests[test_name] = test_name return tests
[docs] def _create_worker_pool(self, pool_size, target_func, worker_args): workers = [] for index in range(0, pool_size): args = (index,) + worker_args worker = Process(target=target_func, args=args) workers.append(worker) return workers
[docs] def log(self, string): if self.verbosity >= 3: print(string)
[docs] class DjangoParallelTestSuiteRunner(ParallelTestSuiteRunner, DiscoverRunner): def __init__( self, pattern=None, top_level=None, verbosity=1, interactive=True, failfast=True, keepdb=False, reverse=False, debug_mode=False, debug_sql=False, parallel=0, tags=None, exclude_tags=None, **kwargs, ):
[docs] self.pattern = pattern
[docs] self.top_level = top_level
[docs] self.verbosity = verbosity
[docs] self.interactive = interactive
[docs] self.failfast = failfast
[docs] self.keepdb = keepdb
[docs] self.reverse = reverse
[docs] self.debug_mode = debug_mode
[docs] self.debug_sql = debug_sql
[docs] self.parallel = parallel
[docs] self.tags = set(tags or [])
[docs] self.exclude_tags = set(exclude_tags or [])
[docs] self._keyboard_interrupt_intercepted = False
[docs] self._worker_max = kwargs.get("worker_max", WORKER_MAX)
[docs] self._worker_count = kwargs.get("worker_count", WORKER_COUNT)
[docs] self._not_thread_safe = kwargs.get("not_thread_safe", NOT_THREAD_SAFE) or []
[docs] self._parent_timeout = kwargs.get("parent_timeout", PARENT_TIMEOUT)
[docs] self._worker_timeout = kwargs.get("worker_timeout", WORKER_TIMEOUT)
[docs] self._database_names = self._get_database_names()
[docs] def run_tests(self, test_labels, extra_tests=None, **kwargs): app_tests = self._group_by_app(test_labels) return self._run_tests(tests=app_tests)
[docs] def run_suite(self, suite, **kwargs): return DjangoParallelTestRunner(verbosity=self.verbosity, failfast=self.failfast).run_suite(suite)
[docs] def setup_databases(self, **kwargs): return self.setup_databases_18(**kwargs)
[docs] def dependency_ordered(self, test_databases, dependencies): """ Reorder test_databases into an order that honors the dependencies described in TEST[DEPENDENCIES]. """ ordered_test_databases = [] resolved_databases = set() # Maps db signature to dependencies of all its aliases dependencies_map = {} # Check that no database depends on its own alias for sig, (_, aliases) in test_databases: all_deps = set() for alias in aliases: all_deps.update(dependencies.get(alias, [])) if not all_deps.isdisjoint(aliases): raise ImproperlyConfigured( f"Circular dependency: databases {aliases!r} depend on each other, but are aliases." ) dependencies_map[sig] = all_deps while test_databases: changed = False deferred = [] # Try to find a DB that has all its dependencies met for signature, (db_name, aliases) in test_databases: if dependencies_map[signature].issubset(resolved_databases): resolved_databases.update(aliases) ordered_test_databases.append((signature, (db_name, aliases))) changed = True else: deferred.append((signature, (db_name, aliases))) if not changed: raise ImproperlyConfigured("Circular dependency in TEST[DEPENDENCIES]") test_databases = deferred return ordered_test_databases
[docs] def setup_databases_18(self, **kwargs): mirrored_aliases = {} test_databases = {} dependencies = {} worker_index = kwargs.get("worker_index", None) for alias in connections: connection = connections[alias] database_name = f"test_{worker_index}_{connection.settings_dict['NAME']}" connection.settings_dict["TEST_NAME"] = database_name item = test_databases.setdefault( connection.creation.test_db_signature(), (connection.settings_dict["NAME"], []) ) item[1].append(alias) if alias != DEFAULT_DB_ALIAS: dependencies[alias] = connection.settings_dict.get("TEST_DEPENDENCIES", [DEFAULT_DB_ALIAS]) old_names = [] mirrors = [] for signature, (db_name, aliases) in self.dependency_ordered(list(test_databases.items()), dependencies): connection = connections[aliases[0]] old_names.append((connection, db_name, True)) test_db_name = connection.creation.create_test_db(verbosity=0, autoclobber=not self.interactive) for alias in aliases[1:]: connection = connections[alias] if db_name: old_names.append((connection, db_name, False)) connection.settings_dict["NAME"] = test_db_name else: old_names.append((connection, db_name, True)) connection.creation.create_test_db(verbosity=0, autoclobber=not self.interactive) for alias, mirror_alias in mirrored_aliases.items(): mirrors.append((alias, connections[alias].settings_dict["NAME"])) connections[alias].settings_dict["NAME"] = connections[mirror_alias].settings_dict["NAME"] return old_names, mirrors
[docs] def _tests_func(self, tests, worker_index): self.setup_test_environment() suite = self.build_suite(tests, []) old_config = self.setup_databases(worker_index=worker_index) result = self.run_suite(suite) self.teardown_databases(old_config) self.teardown_test_environment() result = TestResult().from_django_result(result) return result
[docs] class DjangoParallelTestRunner(DiscoverRunner): def __init__(self, verbosity=2, failfast=True, **kwargs):
[docs] stream = BufferWritesDevice()
super().__init__(stream=stream, verbosity=verbosity, failfast=failfast)
[docs] class TwistedParallelTestSuiteRunner(ParallelTestSuiteRunner): def __init__(self, config, verbosity=1, interactive=False, failfast=True, **kwargs):
[docs] self.config = config
super().__init__(verbosity, interactive, failfast, **kwargs)
[docs] def run_tests(self, test_labels, extra_tests=None, **kwargs): app_tests = self._group_by_app(test_labels) return self._run_tests(tests=app_tests)
[docs] def run_suite(self): # config = self.config tests = self.config.opts["tests"] tests = self._group_by_file(tests) self._run_tests(tests=tests)
[docs] def _tests_func(self, tests, worker_index): if not isinstance(tests, (list, set)): tests = [tests] args = ["-e"] args.extend(tests) config = Options() config.parseOptions(args) stream = BufferWritesDevice() runner = self._make_runner(config=config, stream=stream) suite = _getSuite(config) result = setup_test_db(worker_index, None, runner.run, suite) result = TestResult().from_trial_result(result) return result
[docs] def _make_runner(self, config, stream): # Based on twisted.scripts.trial._makeRunner mode = None if config["debug"]: mode = TrialRunner.DEBUG if config["dry-run"]: mode = TrialRunner.DRY_RUN return TrialRunner( config["reporter"], mode=mode, stream=stream, profile=config["profile"], logfile=config["logfile"], tracebackFormat=config["tbformat"], realTimeErrors=config["rterrors"], uncleanWarnings=config["unclean-warnings"], workingDirectory=config["temp-directory"], forceGarbageCollection=config["force-gc"], )
[docs] class TestResult:
[docs] dots = False
[docs] errors = None
[docs] failures = None
[docs] exception = None
[docs] output = None
[docs] def from_django_result(self, result_obj): self.dots = result_obj.dots self.errors = result_obj.errors self.failures = self._format_failures(result_obj.failures) try: self.output = result_obj.stream.read() except Exception: pass return self
[docs] def from_trial_result(self, result_obj): self.errors = self._format_failures(result_obj.errors) self.failures = self._format_failures(result_obj.failures) try: self.output = result_obj.stream.read() except Exception: pass return self
[docs] def from_exception(self, exception): self.exception = str(exception) return self
[docs] def _format_failures(self, failures): # errors and failures attributes by default contain values which are not # pickable (class instance) if not failures: return failures formatted = [] for failure in failures: klass, message = failure formatted.append((str(klass), message)) return formatted
[docs] class NullTimeKeeper: @contextmanager
[docs] def timed(self, name): yield
[docs] def print_results(self): pass
[docs] class TimeKeeper: def __init__(self):
[docs] self.records = Collection.defaultdict(list)
@contextmanager
[docs] def timed(self, name): self.records[name] start_time = time.perf_counter() try: yield finally: end_time = time.perf_counter() - start_time self.records[name].append(end_time)
[docs] def print_results(self): for name, end_times in self.records.items(): for record_time in end_times: record = f"{name} took {record_time:.3f}s" sys.stderr.write(record + os.linesep)