"""
Custom test case class to replace unittest's TestCase class.
"""

__version__ = "$Revision: 11265 $"

from unittest import TestCase as _TestCase

__all__ = ["TestCase"]


class TestCase(_TestCase):

    """Custom test case class to replace unittest's TestCase class.
    
    This class is responsible to run methods marked with the @before and
    @after decorators.
    
    """

    def _install_custom_setup_and_teardown(self):
        self._old_setup = self.setUp
        self._old_teardown = self.tearDown
        self.setUp = self._my_setup
        self.tearDown = self._my_teardown

    def _restore_original_setup_and_teardown(self):
        self.setUp = self._old_setup
        self.tearDown = self._old_teardown

    def _run_tests(self, cb):
        self._install_custom_setup_and_teardown()
        try:
            cb(self)
        finally:
            self._restore_original_setup_and_teardown()

    def run(self, result=None):
        self._run_tests(lambda o: _TestCase.run(o, result))

    def debug(self):
        self._run_tests(_TestCase.debug)

    def debug_with_teardown(self):
        """Run the test without collecting errors in a TestResult.

        As opposed to debug, this ensures that all teardown methods are run.

        """
        def run(o):
            self.setUp()
            try:
                getattr(self, self._testMethodName)()
            finally:
                self.tearDown()
        self._run_tests(run)

    def _my_setup(self):
        self._old_setup()
        for method in reversed(self._get_decorated_methods(
                        "_unittest_before")):
            method(self)

    def _my_teardown(self):
        self._old_teardown()
        for method in self._get_decorated_methods("_unittest_after"):
            method(self)

    def _get_decorated_methods(self, attrname):
        """Return all methods of this class that have a certain attribute.

        Methods are returned in subclass-first order.
        
        """
        def get_methods_in_class(cls):
            members = vars(cls).items()
            return [value for (name, value) in members
                if callable(value) and hasattr(value, attrname)]
        classes = type(self).mro()
        return sum([get_methods_in_class(cls) for cls in classes], [])
