A Brief Introduction to MMDetection Framework: Detailed Registration Mechanism

Posted by dinsdale on Mon, 08 Nov 2021 18:59:23 +0100

  Last Blog The configuration file in MMDetection is introduced. It mentions that after configuring the model, dataset, training strategy, etc. in the configuration file, the Config class can manage the parameter information in the configuration file in the form of a dictionary, and then the MMDetection framework can automatically parse it to help us build the entire algorithm process. MMDetection uses a registration mechanism to build from configuration parameters to algorithm modules. This blog will start from the source code and introduce the registration mechanism in MMCV in detail.

  1. Official Document - MMCV
  2. Official Knowledge - MMCV Core Component Analysis (5): Registry

1. Registrar

_Registration mechanism is a very important concept in MMCV. In MMCV, if you want to add your own algorithm module or process, you need to use registration mechanism to achieve it.

1.1 Registry Class

_Introduce the Registry class before introducing the registration mechanism.

_MMCV uses registers to manage different modules with similar functionality, such as ResNet, FPN, RoIHead, SGD, Adam are model structures, and SGD, Adam are optimizers. Inside the registrar is actually maintaining a global query table, where key is a string and value is a class.

_Simply put, the Registrar can see a string-to-class mapping. With the registrar, the user can query the corresponding class through a string and instantiate it. With this understanding, it is easy to understand when I look at the source code of the Registry class. First, look at the constructor, which initializes the name of the registrar, instantiates the function, and initializes a dictionary-type query table_ module_dict:

from mmcv.utils import Registry

class Registry:
	# Constructor
    def __init__(self, name, build_func=None, parent=None, scope=None):
        """
        name (str): The name of the Registrar
        build_func(func): Building function handles for instances from registrar
        parent (Registry): Parent Registrar
        scope (str): Registrar Domain Name
        """
        self._name = name
        # Using module_dict manages string-to-class mapping
        self._module_dict = dict()
        self._children = dict()
        # If scope is not specified, the package name where the location is defined by default is used, such as mmdet, mmseg
        self._scope = self.infer_scope() if scope is None else scope

        # build_func is initialized with the following priority:
        # 1. build_func: prefers the specified function
        # 2. parent.build_func: Next use build_of parent class Func
        # 3. build_from_cfg: Instantiate objects from config dict by default
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
            
        # Set parent-child affiliation
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None

_For example, now we want to use registers to manage our model, first initialize a Registry instance MODELS, then call register_of the Registry class The module () method completes the registration of ResNet and VGG classes, and you can see that the final MODELS print results contain information of these two classes (items in the print information correspond to self._module_dict), indicating that the registration was successful. For code simplicity, it is recommended that python's function decorator @implement register_ Call to module (). You can then instantiate our model through the build() function.

# Instantiate a registrar to manage models
MODELS = Registry('myModels')

# Mode 1: Use the function decorator to register during class creation (recommended)
@MODELS.register_module()
class ResNet(object):
    def __init__(self, depth):
        self.depth = depth
        print('Initialize ResNet{}'.format(depth))

# Mode 2: After creating the class, call register_explicitly Module for registration (not recommended)   
class FPN(object):
    def __init__(self, in_channel):
        self.in_channel= in_channel
        print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)

print(MODELS)
""" Print results as:
Registry(name=myModels, items={'ResNet': <class '__main__.ResNet'>, 'FPN': <class '__main__.FPN'>})
"""

# Configuration parameters, general cfg obtained from configuration file
backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)
# Instantiate the model (passing configuration parameters to the model's constructor) to get the instantiated object
my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" Print results as:
Initialize ResNet101
Initialize FPN256
<__main__.ResNet object at 0x000001E68E99E198> <__main__.FPN object at 0x000001E695044B38>
"""

1.2 register_module and build functions

_After instantiating a Registry object, the registration and instantiation of classes are done through register_ The module and build functions are finished, so let's look at the source code for these two functions.

_register_module() actually calls self. _internally Register_ The module() function is also simple in that it saves the current registered module name and module type to _as key-value pair key->value module_dict query table.

def _register_module(self, module_class, module_name=None, force=False):
	"""
	module_class (class): Module type to register
	module_name (str): Module name to register
	force (bool): Is registration mandatory
	"""
    if not inspect.isclass(module_class):
        raise TypeError('module must be a class, '
                        f'but got {type(module_class)}')
	
	# Use default name if module name is not specified
    if module_name is None:
        module_name = module_class.__name__
    # module_name is in the form of a list to support building pytorch modules in nn.Sequentail
    if isinstance(module_name, str):
        module_name = [module_name]
    for name in module_name:
    	# Registering modules with the same name is not allowed if force=False
    	# If force=True, overwrite the previous one with the last registration
        if not force and name in self._module_dict:
            raise KeyError(f'{name} is already registered '
                           f'in {self.name}')
        # Add the currently registered module to the query table
        self._module_dict[name] = module_class

The_build function refers to build_ The func() function (see Registry's constructor) can be manually specified by the user when a module is registered, but since modules are generally registered as function decorators, build_func() actually calls build_from_cfg() function. build_from_cfg() finds the corresponding module type obj_based on the type value in the configuration parameter Cls, then use cfg and default_ The parameters in args instantiate the corresponding module and return the instantiated object to the superior build() function call.

def build_from_cfg(cfg, registry, default_args=None):
    """
    cfg (dict): Configuration parameter information
    registry (Registry): Registrar
    """
    # cfg type check, must be a dictionary type
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    # type field is required in cfg
    if 'type' not in cfg:
        if default_args is None or 'type' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "type", '
                f'but got {cfg}\n{default_args}')
    # Registry type check, must be of type Registry
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')
    # default_args are passed in as a dictionary
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    args = cfg.copy()
	
	# Adding external incoming parameters other than cfg to args
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
	# Get module name
    obj_type = args.pop('type')
    if isinstance(obj_type, str):
    	# Get module type from module name
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
    	# The type value is the module itself
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
    try:
        return obj_cls(**args)
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{obj_cls.__name__}: {e}')

2. Registration mechanism in MMDetection

_According to the previous section, MMCV uses a registrar to manage different modules with similar functionality. A query table is maintained inside a registrar. Modules registered with the Registrar are saved in this query table as key-value pairs. The registrar also provides an instantiation method that returns the corresponding instantiated object based on the module name.

_MMDetection has built many common registers and implemented corresponding interface functions, such as DETECTORS corresponding to build_detector(), DATASETS corresponding to build_dataset(), whatever xxx_build(), which ultimately calls the Registry.build() function. Most of the time we just need to use an existing registrar.

# Registry in MMDetection
MODELS = Registry('models', parent=MMCV_MODELS)		# Inherited from MMCV
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')

Topics: Python Object Detection