[source code analysis] PyTorch distributed -- DistributedDataParallel -- initialize & store

Posted by howtoholdem on Fri, 19 Nov 2021 01:46:29 +0100

[source code analysis] PyTorch distributed (6) - DistributedDataParallel - initialize & store

0x00 summary

This article is the sixth in the PyTorch distributed series. It introduces the initialization method and Store that the distributed dataparallel depends on.

Other articles in this series are as follows:

This series of articles is as follows:

[ Source code analysis] PyTorch distributed (1) -- history and overview

[ Source code analysis] how PyTorch uses GPU

Source code analysis] PyTorch distributed (2) -- dataparallel (Part 1)

[ Source code analysis] PyTorch distributed (3) -- dataparallel (Part 2)

[ Source code analysis] PyTorch distributed (4) -- basic concept of distributed application

Source code analysis] PyTorch distributed (5) -- overview of distributeddataparallel & how to use

0x01 review

1.1 basic concepts

For distributed communication, PyTorch provides several concepts: process group, backend, initialization and Store.

  • Process group: DDP is a real distributed training. Multiple machines can be used to form a parallel operation task. In order to enable communication between DDP worker s, PyTorch sets the concept of process group.
  • Back end: the concept of back end is a logical concept. In essence, the back-end is an IPC communication mechanism.
  • Initialization: Although there are the concepts of back-end and process group, how can worker s find each other before establishing process group? This requires an initialization method to tell you how to contact processes on other machines.
  • Store: it can be considered as a distributed key value store, which can be used to share information between processes in the group and initialize distributed packages (by explicitly creating a store as an alternative to init_method).

1.2 initialization process group

You need to use torch.distributed.init before calling any other DDP methods_ process_ Group(). This method initializes the default distributed process group and distributed package. This method will block until all processes join. The function definition is as follows:

init_process_group ( backend , 
                       init_method = None , 
                       timeout = default_pg_timeout , 
                       world_size =- 1 , 
                       rank =- 1 , 
                       store = None , 
                       group_name = '' , 
                       pg_options = None )

There are two main ways to initialize a process group:

  1. Explicitly specify store, rank, and world_size.
  2. Specify init_method (a URL string) that indicates where / how peers are found.

If neither is specified, init_method is assumed to be "env: / /". So you can see that store and init_ Methods are mutually exclusive.

init_process_group parameters are as follows:

  • Back end – the back end to use. Valid values include mpi, gloo, and nccl. This field should be given as a lowercase string (e.g. "gloo") and can also be accessed through the Backend property (e.g. Backend.GLOO). If multiple processes are used on each machine in the nccl Backend, each process must have exclusive access to each GPU it uses, because sharing GPUs between processes may cause deadlock.
  • init_method – specifies how to initialize the URL of the process group. If init is not specified_ If method or store is specified, it defaults to "env: / /". Mutually exclusive with store.
  • world_size – number of processes participating in the job. If store is specified, then world_size is required.
  • Rank – the rank of the current process (it should be a number between 0 and world_size-1). Rank is required if store is specified.
  • Store – a key / value store accessible to all worker s for exchanging connection / address information. And init_method is mutually exclusive.
  • timeout – the operation performed on the process group timed out. The default value is equal to 30 minutes. This applies to the gloo backend. For nccl, this is only in the environment variable NCCL_BLOCKING_WAIT or nccl_ ASYNC_ ERROR_ Applicable when handling is set to 1.
  • group_name – group name.
  • pg_ Options (process group options, optional) – process group options that specify which additional options need to be passed in during the construction of a specific process group.

0x02 initialization

2.1 initialization method

At present, the DDP module supports three initialization modes:

  • Environment variable initialization
  • Shared file-system initialization: init_method**=**'file:///mnt/nfs/sharedfile'
  • TCP initialization : init_method**=**'tcp://10.1.1.20:23456'

environment variable

This method will read the configuration from the environment variable and is a way to allow full customization of information. By setting the following four environment variables on all machines, all processes can normally connect to the master (that is, the rank 0 process) to obtain the information of other processes and finally shake hands with them.

  • MASTER_PORT: port on the machine of the rank 0 process.
  • MASTER_ADDR: the IP address on the machine of the rank 0 process.
  • WORLD_SIZE: the total number of processes, so the master knows how many worker s to wait for.
  • rank: the rank of each process, so the process will know whether it is a master.

Shared file system

The shared file system requires that all processes have access to the shared file system and will coordinate them through the shared file. This means that each process will open a file, write its information, and wait for each process to do so. After that, all the required information will be available to all processes. To avoid competitive conditions, the file system must pass fcntl Support locking.

dist.init_process_group(
    init_method='file:///mnt/nfs/sharedfile',
    rank=args.rank,
    world_size=4)

TCP

TCP initialization is realized by providing the IP and port of the rank 0 process. Here, all worker s can connect to the rank 0 process and exchange information about how to connect with each other.

dist.init_process_group(
    init_method='tcp://10.1.1.20:23456',
    rank=args.rank,
    world_size=4)

2.2 init_method VS store

We are curious, why is there init_method and store?

By looking at init_process_group code, we can find the following rules.

  • When MPI, init_method is useless.

  • In non MPI backend, if there is no store parameter, init is used_ Method to build a store.

Therefore, it finally falls on the store, which is the entity of its function.

        if store is None:
            rendezvous_iterator = rendezvous(
                init_method, rank, world_size, timeout=timeout
            )
            store, rank, world_size = next(rendezvous_iterator)
            store.set_timeout(timeout)

init_ process_ The group code is as follows:

def init_process_group(backend,
                       init_method=None,
                       timeout=default_pg_timeout,
                       world_size=-1,
                       rank=-1,
                       store=None,
                       group_name='',
                       pg_options=None):

    global _pg_group_ranks
    global _backend
    global _default_pg_init_method

    if store is not None:
        assert world_size > 0, 'world_size must be positive if using store'
        assert rank >= 0, 'rank must be non-negative if using store'
    elif init_method is None:
        init_method = "env://"

    backend = Backend(backend)

    if backend == Backend.MPI:
          default_pg = _new_process_group_helper(
            -1,
            -1,
            [],
            Backend.MPI,
            None,
            group_name=group_name,
            timeout=timeout)
        _update_default_pg(default_pg)
    else:
        # backward compatible API
        if store is None:
            # If there is no store, use init_method to build a store.
            rendezvous_iterator = rendezvous(
                init_method, rank, world_size, timeout=timeout
            )
            store, rank, world_size = next(rendezvous_iterator)
            store.set_timeout(timeout)

        default_pg = _new_process_group_helper(
            world_size,
            rank,
            [],
            backend,
            store,
            pg_options=pg_options,
            group_name=group_name,
            timeout=timeout)
        _update_default_pg(default_pg)

    _pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())}  # type: ignore[attr-defined, index]
    _backend = _pg_map[GroupMember.WORLD][0]  # type: ignore[index]
    _default_pg_init_method = init_method

    # ellipsis

2.3 rendezvous

rendezvous is mentioned in the above code. Let's take a look at this concept.

Before we can run the set algorithm, the participating processes need to find each other and exchange information before they can communicate. We call this process rendezvous. The result of the rendezvous procedure is a triple containing a shared key / value store, the rank of the process, and the total number of participating processes. If none of the built-in rendezvous methods are suitable for your execution environment, you can choose to register your own rendezvous handler. When you call the rendezvous function, choose a unique name and use the URL scheme to identify it.

The rendezvous method is to select different handler s according to parameters.

def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):

    # Append node-specific arguments.
    result = urlparse(url)
    if rank != -1 or world_size != -1:
        query_dict: Dict[str, Union[int, str]] = dict(
            # mypy doesn't allow dict() to accept List of values (#257)
            pair.split("=") for pair in filter(None, result.query.split("&"))  # type: ignore[arg-type, misc]
        )
        if rank != -1:
            query_dict["rank"] = rank
        if world_size != -1:
            query_dict["world_size"] = world_size

        result = result._replace(
            query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()]))
        )
        url = urlunparse(result)

    return _rendezvous_handlers[result.scheme](url, **kwargs)

The handler is as follows. You will find that the handler actually corresponds to three initialization methods:

register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)

2.4 summary

From the current analysis results, we have reached the following conclusions:

  • init_ The method finally falls on the store, which is the effective entity.
  • Participating processes need to find each other and exchange information before they can communicate. This process is called rendezvous.

0x03 Store

We give a formal concept. Store is the distributed key value store provided by the distributed package. All workers will access this store to share information and initialize the distributed package. Users can explicitly create storage as init_method. At present, there are three key value stores: TCPStore, FileStore, and HashStore.

We continue to look at the concept of handler in the previous section.

3.1 _rendezvous_handlers

A global variable is defined in PyTorch_ rendezvous_handlers, a method used to save how to return to the store, can be regarded as a factory method.

_rendezvous_handlers = {}

The specific registration method is:

register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)

The registration code is as follows: insert a handler into the global variable.

def register_rendezvous_handler(scheme, handler):
    """Registers a new rendezvous handler.
    Args:
        scheme (str): URL scheme to identify your rendezvous handler.
        handler (function): Handler that is invoked when the
            `rendezvous()` function is called with a URL that uses
            the corresponding scheme. It must be a generator function
            that yields the triplet.
    """
    global _rendezvous_handlers
    if scheme in _rendezvous_handlers:
        raise RuntimeError(
            "Rendezvous handler for {}:// already registered".format(scheme)
        )
    _rendezvous_handlers[scheme] = handler

3.2 handlers

If you look closely at the code of handlers, you will find that it returns different stores, such as_ tcp_rendezvous_handler specifically uses various information to create a TCPStore, and then returns.

The following codes delete non critical codes.

3.2.1 _file_rendezvous_handler

The FileStore is returned here.

def _file_rendezvous_handler(url: str, **kwargs):

    result = urlparse(url)
    path = result.path
    query: Dict[str, str]
    # mypy doesn't allow dict() to accept List of values (#257)
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))  # type: ignore[misc, arg-type]

    rank = int(query["rank"])
    world_size = int(query["world_size"])
    store = FileStore(path, world_size)
    yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using file:// method")

3.2.2 _tcp_rendezvous_handler

TCPStore is returned here.

def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
    result = urlparse(url)
    query: Dict[str, Union[int, str]]
    # mypy doesn't allow dict() to accept List of values (#257)
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))  # type: ignore[misc, arg-type]

    rank = int(query["rank"])
    world_size = int(query["world_size"])
    start_daemon = rank == 0
    assert result.hostname is not None
    store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
    yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using tcp:// method")

3.2.3 _env_rendezvous_handler

TCPStore is also returned, but it will extract the required information from the environment variables.

def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):

    result = urlparse(url)
    query: Dict[str, Union[int, str]]
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) 
    rank: Optional[Union[str, int]]
    world_size: Optional[Union[str, int]]
    master_port: Optional[Union[str, int]]

    if "rank" in query:
        rank = int(query["rank"])
    else:
        rank = int(_get_env_or_raise("RANK"))

    if "world_size" in query:
        world_size = int(query["world_size"])
    else:
        world_size = int(_get_env_or_raise("WORLD_SIZE"))

    master_addr = _get_env_or_raise("MASTER_ADDR")
    master_port = int(_get_env_or_raise("MASTER_PORT"))

    use_torchelastic_store = os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None)

    if use_torchelastic_store == str(True):
        worker_process_prefix = "/worker"
        # When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed
        # to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread
        # on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False
        tcp_store = TCPStore(master_addr, master_port, world_size, False, timeout)
        yield (PrefixStore(worker_process_prefix, tcp_store), rank, world_size)
    else:
        # Start the TCP store daemon on the rank 0
        start_daemon = rank == 0
        store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
        yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using env:// method")

3.3 use

3.3.1 using handler

How to use handler? In init_process_group includes:

rendezvous_iterator = rendezvous(
    init_method, rank, world_size, timeout=timeout
)
store, rank, world_size = next(rendezvous_iterator)

Rendezvous is based on init_method to select a_ rendezvous_handler, and then_ rendezvous_handler returned store.

def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
    # Append node-specific arguments.
    result = urlparse(url)
    if rank != -1 or world_size != -1:
        query_dict: Dict[str, Union[int, str]] = dict(
            # mypy doesn't allow dict() to accept List of values (#257)
            pair.split("=") for pair in filter(None, result.query.split("&"))  # type: ignore[arg-type, misc]
        )
        if rank != -1:
            query_dict["rank"] = rank
        if world_size != -1:
            query_dict["world_size"] = world_size

        result = result._replace(
            query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()]))
        )
        url = urlunparse(result)

    return _rendezvous_handlers[result.scheme](url, **kwargs)

3.3.2 use Store

Let's continue to see how to use store. In init_ process_ In the group code, store is used to initialize the process group.

default_pg = _new_process_group_helper(
    world_size,
    rank,
    [],
    backend,
    store,
    pg_options=pg_options,
    group_name=group_name,
    timeout=timeout)
_update_default_pg(default_pg)
3.3.2.1 _new_process_group_helper

To keep looking_ new_process_group_helper, let's first look at a few global variables. The ProcessGroup information of the following variables is stored globally, such as_ pg_map[pg] = (Backend.NCCL, store).

# Cached process groups
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
# For MPI pg, it is a map from ProcessGroup to (Backend, None)
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
# Process group's names, map from ProcessGroup to str
_pg_names: Dict[ProcessGroup, str] = {}
# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}

_ new_ process_ group_ After the store parameter is obtained in the helper, a prefix is generated accordingly_ Store, and then according to this pre_store to generate ProcessGroupGloo_ new_ process_ group_ The helper code is as follows:

def _new_process_group_helper(world_size,
                              rank,
                              group_ranks,
                              backend,
                              store,
                              pg_options=None,
                              group_name=None,
                              timeout=default_pg_timeout):
    """
    Create a new distributed process group.

    This function must be called by ALL processes in the global group, even if
    the calling process is not part of the newly created group. In that case,
    this function returns GroupMember.NON_GROUP_MEMBER.

    This function is called with ``group_ranks == []`` for the default group.
    """
    global _pg_map
    global _group_count
    global _pg_names

    if not group_name:
        group_name = str(_group_count)
        _group_count += 1

    # The list of group ranks is empty if we're creating the default group.
    is_default_group = (len(group_ranks) == 0)

    backend = Backend(backend)
    pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL]
    if backend == Backend.MPI: # store not used
        pg = ProcessGroupMPI.create(group_ranks)
        if not pg:
            return GroupMember.NON_GROUP_MEMBER
        _pg_map[pg] = (Backend.MPI, None)
        _pg_names[pg] = group_name
    else:
      	# store will be used here
      
        # If this is a subgroup (which means group_ranks is specified),
        # we check if the current process is a member of the new group.
        if not is_default_group:
            global_rank = _get_default_group().rank()
            if global_rank not in group_ranks:
                return GroupMember.NON_GROUP_MEMBER

        # Use the group name as prefix in the default store, such that
        # a single store can be reused by multiple groups.
        
        prefix_store = PrefixStore(group_name, store) # Built PrefixStore

        if backend == Backend.GLOO:
            pg = ProcessGroupGloo(
                prefix_store, # Building process groups using PrefixStore
                rank,
                world_size,
                timeout=timeout)
            _pg_map[pg] = (Backend.GLOO, store)
            _pg_names[pg] = group_name
        elif backend == Backend.NCCL:
            if pg_options is not None:
                assert isinstance(pg_options, ProcessGroupNCCL.Options), \
                    "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
            else:
                # default pg_options for NCCL
                pg_options = ProcessGroupNCCL.Options()
                pg_options.is_high_priority_stream = False
                pg_options._timeout = timeout

            pg = ProcessGroupNCCL(
                prefix_store, # Building process groups using PrefixStore
                rank,
                world_size,
                pg_options)
            _pg_map[pg] = (Backend.NCCL, store)
            _pg_names[pg] = group_name
        else:
            pg = getattr(Backend, backend.upper())(
                prefix_store,
                rank,
                world_size,
                timeout)
            _pg_map[pg] = (backend, store)
            _pg_names[pg] = group_name

    return pg
3.3.2.2 ProcessGroupGloo

There are specific uses in ProcessGroupGloo, such as generating a GlooStore on the PrefixStore, using the PrefixStore to establish a network, and so on.

ProcessGroupGloo::ProcessGroupGloo(
    const c10::intrusive_ptr<Store>& store,
    int rank,
    int size,
    c10::intrusive_ptr<Options> options)
    : ProcessGroup(rank, size),
      store_(new GlooStore(store)), // A GlooStore is generated on top of the PrefixStore
      options_(options),
      stop_(false),
      collectiveCounter_(0) {
  auto& devices = options->devices;

  contexts_.reserve(options->devices.size());
  for (size_t i = 0; i < options->devices.size(); i++) {
    auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
    // Another PrefixStore is generated
    auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
    context->setTimeout(options->timeout);
    // Network with PrefixStore
    context->connectFullMesh(store, options->devices[i]);
    contexts_.push_back(std::move(context));
  }

  // Every worker thread stores the AsyncWork object it's currently
  // working on in the workInProgress_ vector. It must have size equal
  // to the number of workers such that they can simply index into it
  // using the worker index they are started with.
  workInProgress_.resize(options->threads);

  threads_.resize(options->threads);
  for (size_t i = 0; i < threads_.size(); i++) {
    threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i);
  }
}

In the following code, there is also a reference to store_ Use of, such as wait, access.

void ProcessGroupGloo::setSequenceNumberForGroup() {
  if (rank_ == 0) {
    // Create and broadcast sequence number
    auto seq = 1 + rand();
    sequenceNum_ = c10d::SequenceNum(seq);
    std::vector<char> values = c10d::toVec<char>(seq, kBytes);
    store_->set(kSeqNumStoreKey, values); // Save value
  } else {
    // Read rank 0's sequence number from store.
    sequenceNum_ = c10d::SequenceNum();
    store_->wait({kSeqNumStoreKey}, options_->timeout); // wait for
    std::vector<char> values = store_->get(kSeqNumStoreKey); // Take value
    uint64_t num = c10d::fromVec<char>(values);
    sequenceNum_->set(num);
  }
}  

3.4 summary

According to the current analysis results, our extended conclusions are as follows:

  • init_ The method finally falls on the store, which is the effective entity.
  • Participating processes need to find each other and exchange information before they can communicate. This process is called rendezvous.
  • rendezvous actually returns some kind of store for subsequent communication.
  • In the process group, store will be used to build communication, waiting, access, etc.

Next, we choose TCPStore for trust analysis.

0x04 TCPStore

TCPStore is a distributed key value storage implementation based on TCP. The server stores / saves data, and the storage client can connect to the server through TCP to store and perform peer-to-peer operations such as set() inserting key value pairs and get() retrieving key values. There should be an initialized TCPStore storage server in the system, because the storage client will wait for the storage service to establish a connection.

The parameters of TCPStore are as follows:

  • host_ Name (STR) – host name or IP address. The storage server is running on it.
  • Port (int) – the storage server listens for incoming requests on this port.
  • world_ Size (int, optional) – total number of users.
    • world_size = number of clients + 1, 1 represents server.
    • The default value is - 1 (a negative value indicates an unfixed number of users).
  • is_ Master (bool, optional) – true when initializing the storage server and false when initializing the storage client. The default value is false.
  • Timeout (timedelta, optional) – the timeout used by the store during initialization and the get() and wait() methods. The default is timedelta(seconds=300).
  • wait_ for_ worker (bool, optional) – whether to wait for all workers to connect to the storage server. This is only in the world_ Applicable when size is a fixed value. The default value is true.

Examples are as follows:

import torch.distributed as dist
from datetime import timedelta
# Run on process 1 (server)
server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
# Run on process 2 (client)
client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
# Use any of the store methods from either the client or server after initialization
server_store.set("first_key", "first_value")
client_store.get("first_key")

perhaps

    >>> import torch.distributed as dist
    >>> from datetime import timedelta
    >>> # Using TCPStore as an example, other store types can also be used
    >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
    >>> # This will throw an exception after 10 seconds
    >>> store.wait(["bad_key"], timedelta(seconds=10))

From an example, it is a simple relationship between server, client or master and worker. Let's analyze it carefully next.

4.1 TCPStore in python

In the Python world, host and port are simply set.

class TCPStore(Store):
    def __init__(self, host_name, port, world_size=-1, is_master=False, timeout=None, *args, **kwargs): # real signature unknown; NOTE: unreliably restored from __doc__ 
        pass

    host = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default
    """Gets the hostname on which the store listens for requests."""

    port = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default
    """Gets the port number on which the store listens for requests."""

We need to go deep into the C + + world.

4.2 TCPStore in CPP

4.2.1 API interface

Firstly, TCPStore in C + + can be regarded as an API interface, which is defined as follows:

class TCPStore : public Store {
 public:
  explicit TCPStore(
      const std::string& masterAddr,
      PortType masterPort,
      c10::optional<int> numWorkers = c10::nullopt_t(-1),
      bool isServer = false,
      const std::chrono::milliseconds& timeout = kDefaultTimeout,
      bool waitWorkers = true);

  virtual ~TCPStore();

  void set(const std::string& key, const std::vector<uint8_t>& value) override;
  std::vector<uint8_t> compareSet(
      const std::string& key,
      const std::vector<uint8_t>& expectedValue,
      const std::vector<uint8_t>& desiredValue) override;
  std::vector<uint8_t> get(const std::string& key) override;
  int64_t add(const std::string& key, int64_t value) override;
  bool deleteKey(const std::string& key) override;

  // NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe
  // watchKey() is a blocking operation. It will register the socket on
  // TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will
  // return once it has verified the callback is registered on both background
  // threads. Only one thread can call watchKey() at a time.
  void watchKey(const std::string& key, WatchKeyCallback callback) override;
  bool check(const std::vector<std::string>& keys) override;
  int64_t getNumKeys() override;
  void wait(const std::vector<std::string>& keys) override;
  void wait(
      const std::vector<std::string>& keys,
      const std::chrono::milliseconds& timeout) override;
  // Waits for all workers to join.
  void waitForWorkers();
  // Returns the hostname used by the TCPStore.
  const std::string& getHost() const noexcept;
  // Returns the port used by the TCPStore.
  PortType getPort() const noexcept;

 private:
  int64_t addHelper_(const std::string& key, int64_t value);
  std::vector<uint8_t> getHelper_(const std::string& key);
  void waitHelper_(
      const std::vector<std::string>& keys,
      const std::chrono::milliseconds& timeout);

  std::mutex watchKeyMutex_;
  bool isServer_;
  int storeSocket_ = -1; // 
  int listenSocket_ = -1; // 
  int masterListenSocket_ = -1; // master is listening here

  std::string tcpStoreAddr_;
  PortType tcpStorePort_;

  c10::optional<int> numWorkers_;
  const std::string initKey_;
  const std::string regularPrefix_;

  std::unique_ptr<TCPStoreMasterDaemon> tcpStoreMasterDaemon_ = nullptr;
  std::unique_ptr<TCPStoreWorkerDaemon> tcpStoreWorkerDaemon_ = nullptr;
};

4.2.2 socket usage

The most important part of the membership variables is three socket, or they are the essence of store.

  int storeSocket_ = -1; // 
  int listenSocket_ = -1; // 
  int masterListenSocket_ = -1; // master is listening here
4.2.2.1 division of business

The specific explanation is as follows (the analysis will continue in combination with the code later):

  • masterListenSocket_ listen is above the masterPort.
    • tcpStoreMasterDaemon_ It is a master, a server that provides services for the whole TCPStore.
    • tcpStoreMasterDaemon_ Use tcputil::addPollfd(fds, storeListenSocket_, POLLIN) to listen for masterListenSocket_.
    • Key value is STD:: unordered_ map<std::string, std::vector<uint8_ t>> tcpStore.
  • storeSocket_ In tcpStoreWorkerDaemon_ Above, it is connected to the masterlistensocket_: On the masterport.
    • storeSocket_ The function of is to encapsulate the operations facing the master port. Users only need to set, get and other operations without knowing the master port.
    • The function of set(key, data) is through storeSocket_ Send a request to the master to set key: value.
    • tcpStoreMasterDaemon_ When the socket changes, it starts to respond accordingly.
    • tcpStoreMasterDaemon_ Internally, add key: value to STD:: unordered_ map<std::string, std::vector<uint8_ t>> tcpStore_ above.
  • listenSocket_ In tcpStoreWorkerDaemon_ Also connected to the masterlistensocket_: On the masterport. There is a decoupling below. As described in the notes, it will register the socket on tcpstoremasterdaemon and the callback on tcpstoreworkerddaemon.
    • listenSocket_ Encapsulates the processing of watchKey. Store Client uses watchKey (const STD:: String & key, watchkeycallback) to request registration, that is:
      • Worker requested registration. Using tcpStoreWorkerDaemon_ -> Setcallback (regKey, callback) is used to create tcpStoreWorkerDaemon_ STD:: unordered_ map<std::string, WatchKeyCallback> keyToCallbacks_ Add a callback above.
      • The Worker sends a request. Through listenSocket_ Send a message (key, WATCH_KEY) to the master and tell the master to call this callback if the value of the key changes.
    • Master performs registration. Master receives Watch_ Register after the key message, call watchHandler and use watchedSockets_[key].push_back(socket) and tell yourself that if the key changes, send a message to the socket.
    • The Master notifies the Worker. In TCPStoreMasterDaemon::setHandler, if the new value is set, calling sendKeyUpdatesToClients will traverse watchedSockets_. [key], if there is a socket, send a message change notification to the socket.
    • The Worker executes a callback. So if the key changes, it will be in tcpStoreWorkerDaemon_ This callback is called.
4.2.2.2 Set example

Let's first look at the example of Set, which is that the Worker sets value on the Master through socket.

                                                                          +
+----------------------------------------------------------------------+  |  +----------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                              Worker |
|                                                                      |  |  |                                              |
|                                                                      |  |  |                                              |
|                                                                      |  |  |                                              |
|   +------------------------------------------------------------+     |  |  |                                              |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |                                              |
|   |                                                            |     |  |  |                                              |
|   |    TCPStore.masterListenSocket_                            |     |  |  |      +---------------------------------+     |
|   |                                                            |     |  |  |      | set(key, value)                 |     |
|   |                                                            |     |  |  |      |                                 |     |
|   |    tcpStore_[key] = value  <------------------------------------------------+ |    storeSocket_                 |     |
|   |                                                            |     |  |  |      |                                 |     |
|   |                                                            |     |  |  |      +---------------------------------+     |
|   |                                                            |     |  |  |                                              |
|   +------------------------------------------------------------+     |  |  |                                              |
|                                                                      |  |  |                                              |
+----------------------------------------------------------------------+  |  +----------------------------------------------+
                                                                          +

Mobile phones are as follows:

4.2.2.3 combination of set and watchKey

The schematic diagram of the combination of Set and watchKey is as follows (the worker requests registration and specifically executes callback; the master executes registration and notifies the worker to execute callback):

  1. Worker requested registration. The Store Client uses the watchkey (const STD:: String & key, watchkeycallback callback) to use tcpStoreWorkerDaemon_ -> Setcallback (regKey, callback) is used to create tcpStoreWorkerDaemon_ STD:: unordered_ map<std::string, WatchKeyCallback> keyToCallbacks_ Add a callback above.
  2. The worker sends a request. Worker through listenSocket_ Send a message (key, WATCH_KEY) to the master and tell the master to call this callback if the value of the key changes.
  3. Master performs registration. Master receives Watch_ After the KEY message, call watchHandler and use watchedSockets_. [key].push_ Back (socket) and tell yourself that if the key changes, send a message to the socket.
  4. Next, let's assume that the Store Client (assuming the same worker setting here, it may actually be different workers) sets a value.
  5. The master notifies the Worker. Master in TCPStoreMasterDaemon::setHandler, if a new value is set, calling sendKeyUpdatesToClients will traverse watchedSockets_. [key], if there is a socket, send a message change notification to the socket.
  6. The Worker executes a callback. If the key changes, it will be in tcpStoreWorkerDaemon_ This callback is called.
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                                                        Worker |
|                                                                      |  |  |                                                                        |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |      +---------------------------------+                               |
|   |                                                            |     |  |  |      |                                 |                               |
|   |                                                  2         |     |  |  |      | watchKey(key, callback) +----------------------+                |
|   |           TCPStore.masterListenSocket_   <----------------------------------+ |                                 |              |                |
|   |                       +                                    |     |  |  |      |    listenSocket_                |              |                |
|   |                       | 3                                  |     |  |  |      |                                 |            1 |                |
|   |                       v                                    |     |  |  |      |                                 |              |                |
|   |           watchedSockets_[key] = socket                    |     |  |  |      +---------------------------------+              |                |
|   |                                                            |     |  |  |                                                       |                |
|   |  +-------------------------------------------------+       |     |  |  |                                                       |                |
|   |  |                                                 |       |     |  |  |                                                       |                |
|   |  |    setHandler                                   |       |     |  |  |   +----------------------------------------------------------------+   |
|   |  |                                                 |       |     |  |  |   | TCPStoreWorkerDaemon                              |            |   |
|   |  |                                                 |       |     |  |  |   |                                                   v            |   |
|   |  |       tcpStore_[key] = newData                  |       |     |  |  |   |   unordered_map<string, WatchKeyCallback> keyToCallbacks_      |   |
|   |  |                   +                             |       |     |  |  |   |                                                                |   |
|   |  |                   |                             |       |     |  |  |   |   TCPStore.listenSocket_                                       |   |
|   |  |                   |                             |       |     |  |  |   |                                                                |   |
|   |  |                   v                             |       |     |  |  |   |  +----------------------------------------------------------+  |   |
|   |  |       sendKeyUpdatesToClients                   |       |     |  |  |   |  | run                                                      |  |   |
|   |  |                   +                             |       |  5  |  |  |   |  |                                                          |  |   |
|   |  |                   |                             |  +---------------------->+                                        6                 |  |   |
|   |  |                   |                             |  |    |     |  |  |   |  |       callbackHandler +-----> keyToCallbacks_(callback)  |  |   |
|   |  |                   v                             |  |    |     |  |  |   |  |                                                          |  |   |
|   |  |                                                 |  |    |     |  |  |   |  +----------------------------------------------------------+  |   |
|   |  |    for (int socket : watchedSockets_[key]){     |  |    |     |  |  |   +----------------------------------------------------------------+   |
|   |  |       tcputil::sendString(socket, key, true) +-----+    |     |  |  |                                                                        |
|   |  |    }                                            |       |     |  |  |                                                                        |
|   |  |                                                 |       |     |  |  |       +------------------------+                                       |
|   |  |                                                 |       |  4  |  |  |       | set(key, newData)      |                                       |
|   |  |                                                 | <-----------------------+ |                        |                                       |
|   |  +-------------------------------------------------+       |     |  |  |       |                        |                                       |
|   |                                                            |     |  |  |       +------------------------+                                       |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|                                                                      |  |  |                                                                        |
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+

Mobile phones are as follows:

4.2.3 function

TCPStore provides several functions.

void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
  std::string regKey = regularPrefix_ + key;
  tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
  tcputil::sendString(storeSocket_, regKey, true);
  tcputil::sendVector<uint8_t>(storeSocket_, data);
}

std::vector<uint8_t> TCPStore::get(const std::string& key) {
  std::string regKey = regularPrefix_ + key;
  return getHelper_(regKey);
}

int64_t TCPStore::add(const std::string& key, int64_t value) {
  std::string regKey = regularPrefix_ + key;
  return addHelper_(regKey, value);
}

int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
  tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
  tcputil::sendString(storeSocket_, key, true);
  tcputil::sendValue<int64_t>(storeSocket_, value);
  return tcputil::recvValue<int64_t>(storeSocket_);
}

These functions call the following basic functions to send and receive.

// this is only for convenience when sending rvalues
template <typename T>
void sendValue(int socket, const T& value, bool moreData = false) {
  sendBytes<T>(socket, &value, 1, moreData);
}

template <typename T>
T recvValue(int socket) {
  T value;
  recvBytes<T>(socket, &value, 1);
  return value;
}

4.2.4 build function

We can see from the build function:

  • For the storage server role, it mainly starts tcpStoreMasterDaemon, Note that after the daemon is started, the server enters the state of waiting for the worker and will not start the tcpstoreworkerddaemon in the following code.
  • For storage clients, tcpStoreWorkerDaemon_9;is started.
// TCPStore class methods
TCPStore::TCPStore(
    const std::string& masterAddr,
    PortType masterPort,
    c10::optional<int> numWorkers,
    bool isServer,
    const std::chrono::milliseconds& timeout,
    bool waitWorkers)
    : Store(timeout),
      isServer_(isServer),
      tcpStoreAddr_(masterAddr),
      tcpStorePort_(masterPort),
      numWorkers_(numWorkers),
      initKey_("init/"),
      regularPrefix_("/") {
  tcputil::socketInitialize();
  if (isServer_) { // If server is set, listen on the masterPort
    // Opening up the listening socket
    std::tie(masterListenSocket_, tcpStorePort_) = tcputil::listen(masterPort);
  }
  try {
    if (isServer_) { // If server is set, start tcpStoreMasterDaemon_
      // Now start the daemon
      tcpStoreMasterDaemon_ =
          std::make_unique<TCPStoreMasterDaemon>(masterListenSocket_);
    }
    // Connect to the daemon
    // The worker will establish contact with the master port
    storeSocket_ = tcputil::connect(
        tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
    if (numWorkers.value_or(-1) >= 0 && waitWorkers) {
      waitForWorkers(); // server waiting for worker
    }

    // socket to handle requests from server, because the master will also send messages to the worker
    listenSocket_ = tcputil::connect(
        tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
    // Start the worker daemon
    tcpStoreWorkerDaemon_ =
        std::make_unique<TCPStoreWorkerDaemon>(listenSocket_);
  } catch (const std::exception&) {
    if (isServer_) {
      tcpStoreMasterDaemon_ = nullptr;
      tcputil::closeSocket(masterListenSocket_);
    }
    tcpStoreWorkerDaemon_ = nullptr;
    if (listenSocket_ != -1) {
      tcputil::closeSocket(listenSocket_);
    }
    if (storeSocket_ != -1) {
      tcputil::closeSocket(storeSocket_);
    }
    throw;
  }
}

The server will use the following function to wait for the worker

void TCPStore::waitForWorkers() {
  addHelper_(initKey_, 1);
  // Let server block until all workers have completed, this ensures that
  // the server daemon thread is always running until the very end
  if (isServer_) {
    const auto start = std::chrono::steady_clock::now();
    while (true) {
      std::vector<uint8_t> value = getHelper_(initKey_);
      auto buf = reinterpret_cast<const char*>(value.data());
      auto len = value.size();
      int numWorkersCompleted = std::stoi(std::string(buf, len));
      if (numWorkersCompleted >= numWorkers_.value_or(-1)) {
        break;
      }
      const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
          std::chrono::steady_clock::now() - start);
      if (timeout_ != kNoTimeout && elapsed > timeout_) {
        break;
      }
      /* sleep override */
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
    }
  }
}

4.2.5 TCPStoreWorkerDaemon

The daemon process is only used to handle watchKey.

// Separate thread that is launched on all instances (including master)
// Right now only handles callbacks registered from watchKey()
class TCPStoreWorkerDaemon : public BackgroundThread {
 public:
  explicit TCPStoreWorkerDaemon(int listenSocket);
  // Set the callback to run key change
  void setCallback(std::string key, WatchKeyCallback cb);
  void waitForCallbackRegistration() {
    // Block until callback has been registered successfully
    std::unique_lock<std::mutex> callbackRegistrationLock(
        callbackRegistrationMutex_);
    callbackRegisteredCV_.wait(
        callbackRegistrationLock, [&] { return callbackRegisteredData_; });

    // Reset payload for next callback
    callbackRegisteredData_ = false;
  }
  void setCallbackRegistered() {
    callbackRegisteredData_ = true;
    callbackRegisteredCV_.notify_one();
  }

 private:
  void run();
  void callbackHandler(int socket);
  // List of callbacks map each watched key
  std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_;
  std::mutex keyToCallbacksMutex_;
  std::mutex callbackRegistrationMutex_;
  std::condition_variable callbackRegisteredCV_;
  bool callbackRegisteredData_ = false;
};


Its build function just creates a thread.

// TCPStoreListener class methods
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(int listenSocket)
    : BackgroundThread(listenSocket) {
  daemonThread_ = std::thread(&TCPStoreWorkerDaemon::run, this);
}
4.2.5.1 watchKey

The Client Store uses the watchkey (const STD:: String & key, watchkeycallback callback) to register the listening key with the master:

  • Worker requested registration. Using tcpStoreWorkerDaemon_ -> Setcallback (regKey, callback) is used to create tcpStoreWorkerDaemon_ STD:: unordered_ map<std::string, WatchKeyCallback> keyToCallbacks_ Add a callback above.
  • The Worker sends a request. Through listenSocket_ Send a message (key, WATCH_KEY) to the master and tell the master to call this callback if the value of the key changes.
  • Then use waitForCallbackRegistration to wait for the registration to complete.
void TCPStore::watchKey(const std::string& key, WatchKeyCallback callback) {
  // Only allow one thread to perform watchKey() at a time
  const std::lock_guard<std::mutex> watchKeyLock(watchKeyMutex_);

  // Register callback with TCPStoreMasterDaemon to call TCPStoreWorkerDaemon on
  // key change
  std::string regKey = regularPrefix_ + key;
  tcpStoreWorkerDaemon_->setCallback(regKey, callback);
  tcputil::sendValue<QueryType>(listenSocket_, QueryType::WATCH_KEY);
  tcputil::sendString(listenSocket_, regKey);

  // Block until callback has been registered successfully
  tcpStoreWorkerDaemon_->waitForCallbackRegistration();
}
4.2.5.2 operation

Its operation is divided into windows and other systems, but it mainly receives the business key and then carries out relevant business processing.

  • Master performs registration. Master receives Watch_ After the KEY message, call watchHandler and use watchedSockets_. [key].push_ Back (socket) and tell yourself that if the key changes, send a message to the socket.
  • The Master notifies the Worker. In TCPStoreMasterDaemon::setHandler, if the new value is set, calling sendKeyUpdatesToClients will traverse watchedSockets_. [key], if there is a socket, send a message change notification to the socket.
  • The Worker executes a callback. So if the key changes, it will be in tcpStoreWorkerDaemon_ This callback is called.
#ifdef _WIN32 
void TCPStoreWorkerDaemon::run() { // This is the windows system
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  while (true) {
    // Check control and exit early if triggered
    int res;
    SYSCHECK_ERR_RETURN_NEG1(
        res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
    if (res == 0) {
      auto rvPoll = WaitForSingleObject(ghStopEvent_, 0);
      if (rvPoll != WAIT_TIMEOUT) {
        break;
      }
      continue;
    }

    // if connection is closed gracefully by master, peeked data will return 0
    char data;
    int ret = recv(fds[0].fd, &data, 1, MSG_PEEK);
    if (ret == 0) {
      auto rvData = WaitForSingleObject(ghStopEvent_, 0);
      if (rvData != WAIT_TIMEOUT) {
        break;
      }
      continue;
    }

    // valid request, perform callback logic
    callbackHandler(fds[0].fd); // Business processing
  }
}
#else
void TCPStoreWorkerDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  while (true) {
    SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));

    // Check control and exit early if triggered
    // The pipe receives an event which tells us to shutdown the listener thread
    if (fds[0].revents != 0) {
      // Will be POLLUP when the pipe is closed
      if (fds[0].revents ^ POLLHUP) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the control pipe's reading fd: " +
                std::to_string(fds[0].revents));
      }
      break;
    }

    // if connection is closed gracefully by master, peeked data will return 0
    char data;
    int ret = recv(fds[1].fd, &data, 1, MSG_PEEK);
    if (ret == 0) {
      continue;
    }

    // valid request, perform callback logic
    callbackHandler(fds[1].fd); // Business processing
  }
}
#endif

4.2.6 TCPStoreMasterDaemon

STD:: unordered here_ map<std::string, std::vector<uint8_ t>> tcpStore_; It's real kv.

Therefore, TCPStoreMasterDaemon is responsible for kv operations, such as access.

// Separate thread that is only launched on master
class TCPStoreMasterDaemon : public BackgroundThread {
 public:
  explicit TCPStoreMasterDaemon(int storeListenSocket);

 private:
  void run();
  void queryFds(std::vector<struct pollfd>& fds);
  void query(int socket);

  // The master runs on a single thread so only
  // one handler can be executed at a time
  void setHandler(int socket);
  void compareSetHandler(int socket);
  void addHandler(int socket);
  void getHandler(int socket) const;
  void checkHandler(int socket) const;
  void getNumKeysHandler(int socket) const;
  void deleteHandler(int socket);
  void waitHandler(int socket);
  void watchHandler(int socket);

  bool checkKeys(const std::vector<std::string>& keys) const;
  // Helper function to alerts waiting workers, used in setHandler, getHandler
  void wakeupWaitingClients(const std::string& key);
  // Helper function used when the key is changed
  // used in setHandler, addHandler, getHandler, deleteHandler
  void sendKeyUpdatesToClients(
      const std::string& key,
      const enum WatchResponseType& type,
      std::vector<uint8_t>& oldData,
      std::vector<uint8_t>& newData);
  std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
  // From key -> the list of sockets waiting on the key
  std::unordered_map<std::string, std::vector<int>> waitingSockets_;
  // From socket -> number of keys awaited
  std::unordered_map<int, size_t> keysAwaited_;
  // From key -> the list of sockets watching the key
  std::unordered_map<std::string, std::vector<int>> watchedSockets_;
};
4.2.6.1 operation

TCPStoreMasterDaemon is waiting on the socket, that is, masterListenSocket_ listen is above the masterPort.

  • tcpStoreMasterDaemon_ Use tcputil::addPollfd(fds, storeListenSocket_, POLLIN) to listen for masterListenSocket_.
  • tcpStoreMasterDaemon_ As a master, it is a server that provides services for the whole TCPStore.
  • Key value is STD:: unordered_ map<std::string, std::vector<uint8_ t>> tcpStore.
#ifdef _WIN32
void TCPStoreMasterDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  // receive the queries
  bool finished = false;
  while (!finished) {
    for (size_t i = 0; i < sockets_.size(); i++) {
      fds[i].revents = 0;
    }

    int res;
    SYSCHECK_ERR_RETURN_NEG1(
        res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
    if (res == 0) {
      auto rv = WaitForSingleObject(ghStopEvent_, 0);
      if (rv != WAIT_TIMEOUT) {
        finished = true;
        break;
      }
      continue;
    }

    // TCPStore's listening socket has an event and it should now be able to
    // accept new connections.
    if (fds[0].revents != 0) { // Got the message
      if (!(fds[0].revents & POLLIN)) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the master's listening socket: " +
                std::to_string(fds[0].revents));
      }
      int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
      sockets_.push_back(sockFd);
      tcputil::addPollfd(fds, sockFd, POLLIN);
    }
    queryFds(fds); // Business processing
  }
}
#else

void TCPStoreMasterDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);
  // Push the read end of the pipe to signal the stopping of the daemon run
  tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);

  // receive the queries
  bool finished = false;
  while (!finished) {
    for (size_t i = 0; i < sockets_.size(); i++) {
      fds[i].revents = 0;
    }

    SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));

    // TCPStore's listening socket has an event and it should now be able to
    // accept new connections.
    if (fds[0].revents != 0) {
      if (fds[0].revents ^ POLLIN) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the master's listening socket: " +
                std::to_string(fds[0].revents));
      }
      int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
      sockets_.push_back(sockFd);
      tcputil::addPollfd(fds, sockFd, POLLIN);
    }

    // The pipe receives an event which tells us to shutdown the daemon
    if (fds[1].revents != 0) { // Got the message
      // Will be POLLUP when the pipe is closed
      if (fds[1].revents ^ POLLHUP) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the control pipe's reading fd: " +
                std::to_string(fds[1].revents));
      }
      finished = true;
      break;
    }
    queryFds(fds); // Business processing
  }
}
#endif
4.2.6.2 calling business

queryFds will call different services according to the socket listening results.

void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
  // Skipping the fds[0] and fds[1],
  // fds[0] is master's listening socket
  // fds[1] is control pipe's reading fd, it is not for Windows platform
  for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
    if (fds[fdIdx].revents == 0) {
      continue;
    }

    // Now query the socket that has the event
    try {
      query(fds[fdIdx].fd); // Processing business
    } catch (...) {
      tcputil::closeSocket(fds[fdIdx].fd);

      // Remove all the tracking state of the close FD
      for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
        for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
          if (*vecIt == fds[fdIdx].fd) {
            vecIt = it->second.erase(vecIt);
          } else {
            ++vecIt;
          }
        }
        if (it->second.size() == 0) {
          it = waitingSockets_.erase(it);
        } else {
          ++it;
        }
      }
      for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
        if (it->first == fds[fdIdx].fd) {
          it = keysAwaited_.erase(it);
        } else {
          ++it;
        }
      }
      fds.erase(fds.begin() + fdIdx);
      sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
      --fdIdx;
      continue;
    }
  }
}

4.2.6.4 processing business

Read the message from the socket and carry out relevant business processing according to the message content.

// query communicates with the worker. The format
// of the query is as follows:
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
// or, in the case of wait
// type of query | number of args | size of arg1 | arg1 | ...
void TCPStoreMasterDaemon::query(int socket) {
  QueryType qt;
  tcputil::recvBytes<QueryType>(socket, &qt, 1);
  if (qt == QueryType::SET) {
    setHandler(socket);

  } else if (qt == QueryType::COMPARE_SET) {
    compareSetHandler(socket);

  } else if (qt == QueryType::ADD) {
    addHandler(socket);

  } else if (qt == QueryType::GET) {
    getHandler(socket);

  } else if (qt == QueryType::CHECK) {
    checkHandler(socket);

  } else if (qt == QueryType::WAIT) {
    waitHandler(socket);

  } else if (qt == QueryType::GETNUMKEYS) {
    getNumKeysHandler(socket);

  } else if (qt == QueryType::DELETE_KEY) {
    deleteHandler(socket);

  } else if (qt == QueryType::WATCH_KEY) {
    watchHandler(socket);

  } else {
    throw std::runtime_error("Unexpected query type");
  }
}

add to

Here is the business of adding value.

void TCPStoreMasterDaemon::setHandler(int socket) {
  std::string key = tcputil::recvString(socket);
  std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
  std::vector<uint8_t> oldData;
  bool newKey = true;
  auto it = tcpStore_.find(key);
  if (it != tcpStore_.end()) {
    oldData = it->second;
    newKey = false;
  }
  tcpStore_[key] = newData;
  // On "set", wake up all clients that have been waiting
  wakeupWaitingClients(key);
  // Send key update to all watching clients
  newKey ? sendKeyUpdatesToClients(
               key, WatchResponseType::KEY_CREATED, oldData, newData)
         : sendKeyUpdatesToClients(
               key, WatchResponseType::KEY_UPDATED, oldData, newData);
}
obtain

The source handles the business of obtaining value.

void TCPStoreMasterDaemon::getHandler(int socket) const {
  std::string key = tcputil::recvString(socket);
  auto data = tcpStore_.at(key);
  tcputil::sendVector<uint8_t>(socket, data);
}

watchKey

The key you want to monitor is added here.

For WATCH_KEY, add a socket to the corresponding key as the object for sending notifications in the future.

void TCPStoreMasterDaemon::watchHandler(int socket) {
  std::string key = tcputil::recvString(socket);

  // Record the socket to respond to when the key is updated
  watchedSockets_[key].push_back(socket);

  // Send update to TCPStoreWorkerDaemon on client
  tcputil::sendValue<WatchResponseType>(
      socket, WatchResponseType::KEY_CALLBACK_REGISTERED);
}
notice

If the key changes, the client will be notified.

void TCPStoreMasterDaemon::sendKeyUpdatesToClients(
    const std::string& key,
    const enum WatchResponseType& type,
    std::vector<uint8_t>& oldData,
    std::vector<uint8_t>& newData) {
  for (int socket : watchedSockets_[key]) {
    tcputil::sendValue<WatchResponseType>(socket, type);
    tcputil::sendString(socket, key, true);
    tcputil::sendVector<uint8_t>(socket, oldData);
    tcputil::sendVector<uint8_t>(socket, newData);
  }
}

4.2.7 summary

We summarize the legend as follows:

  • The Master uses the MasterPort to listen for requests.
  • About accessing value.
    • In the Worker, storeSocket_ It is used to store / obtain value, corresponding to the number 1 in the figure below.
    • The Master corresponds to tcpStore_.
  • About monitoring.
    • In the worker, listen socket_ It is used to inform the Master that I need to listen to this key, corresponding to the number 2 in the figure below. At the same time, the worker sets callback for this key, which corresponds to the number 3 in the figure below.
    • The monitoring Master corresponds to watchedSockets_[key] = socket_ .
    • In the Master, if a monitored key is found when setting value, watchedsockets will be notified_ [key] corresponds to the number 4 in the figure below.
    • Relevant business calls will be made in the Worker.
                                                                          +
+----------------------------------------------------------------------+  |  +------------------------------------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                                                        Worker |
|                                                                      |  |  |                                                                        |
|   storeSocket_                                                       |  |  |                                                                        |
|                                                                      |  |  |                                                                        |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |  1   +---------------------------------+                               |
|   |                                                            | <--------------+ | set(key, value)                 |                               |
|   |   unordered_map<string, vector<uint8_t> > tcpStore_+---+   |     |  |  |      |                                 |                               |
|   |                                                        |   |     |  |  |      |    storeSocket_                 |                               |
|   |   TCPStore.masterListenSocket_                         |   |     |  |  |      |                                 |                               |
|   |                                                        |   |     |  |  |      +---------------------------------+                               |
|   |   +-----------------------------------------------+    |   |     |  |  |                                                                        |
|   |   |  run                                          |    |   |     |  |  |  2   +---------------------------------+                               |
|   |   |                                               |    |   | <--------------+ |                                 |                               |
|   |   |    queryFds     query                         |    |   |     |  |  |      | watchKey(key, callback) +-------------------------------+       |
|   |   |                                               |    |   |     |  |  |      |                                 |        3              |       |
|   |   |    setHandler   getHandler                    |    |   |     |  |  |      |    listenSocket_                |                       |       |
|   |   |                                               |    |   |     |  |  |      |                                 |                       |       |
|   |   +-----------------------------------------------+    |   |     |  |  |      |                                 |                       |       |
|   |                                                        |   |     |  |  |      +---------------------------------+                       |       |
|   +------------------------------------------------------------+     |  |  |                                                                |       |
|                                                            |         |  |  |                                                                |       |
|                                                            |         |  |  |                                                                |       |
|                                                            |         |  |  |   +----------------------------------------------------------------+   |
|                                                            |         |  |  |   | TCPStoreWorkerDaemon                                       |   |   |
|                                                            |         |  |  |   |                                                            |   |   |
|                                                            |         |  |  |   |   unordered_map<string, WatchKeyCallback> keyToCallbacks_  |   |   |
|                                                            |         |  |  |   |                                                            |   |   |
|                                                            |         |  |  |   |   TCPStore.listenSocket_                              +----+   |   |
|                                                            |         |  |  |   |                                                       |        |   |
|                                                            |         |  |  |   |  +----------------------------------------------------------+  |   |
|                                                            |         |  |  |   |  | run                                                |     |  |   |
|                                                            |     4   |  |  |   |  |                                                    |     |  |   |
|                                                            +--------------------->+                                                    v     |  |   |
|                                                                      |  |  |   |  |       callbackHandler +-----> keyToCallbacks_(callback)  |  |   |
|                                                                      |  |  |   |  |                                                          |  |   |
|                                                                      |  |  |   |  +----------------------------------------------------------+  |   |
|                                                                      |  |  |   +----------------------------------------------------------------+   |
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+

Mobile phones are as follows:

So far, we have combed the two concepts of initialization method and Store. Finally, the concept of Store plays a role in the initialization process. Through the analysis of TCPStore, we also know the functions that a Store should have, such as setting KV, monitoring the change of a key, etc. it is these functions that enable several processes to know each other's existence.

Next, we introduce the concept of process group. Please look forward to it.

Other PyTorch distributed articles are as follows:

[source code analysis] PyTorch pipeline parallel implementation (1) – basic knowledge

[ Source code analysis] PyTorch pipeline parallel implementation (2) – how to divide the model

[source code analysis] PyTorch pipeline parallel implementation (3) – segmentation of data and runtime system

[ Source code analysis] PyTorch pipeline parallel implementation (4) – forward computing

[source code analysis] PyTorch pipelined parallel implementation (5) – computing dependency

Source code analysis] PyTorch pipelined parallel implementation (6) – parallel computing

Automatic differentiation of deep learning tools (1)

Automatic differentiation of deep learning tools (2)

Source code analysis] automatic differentiation of deep learning tools (3) - example interpretation

[ Source code analysis] how PyTorch implements forward propagation (1) - basic class (1)

[ Source code analysis] how PyTorch implements forward propagation (2) - basic classes (2)

[ Source code analysis] how PyTorch implements forward propagation (3) - specific implementation

[ Source code analysis] how pytoch implements backward propagation (1) -- call engine

[ Source code analysis] how pytoch implements backward propagation (2) -- engine static structure

[Source code analysis] how pytoch implements backward propagation (3) -- engine dynamic logic

Source code analysis] how PyTorch implements backward propagation (4) -- specific algorithm

0xEE personal information

★★★★★★★ thinking about life and technology ★★★★★★

Wechat public account: Rossi's thinking

If you want to get the news push of personal articles in time, or want to see the technical materials recommended by yourself, please pay attention.

Topics: Machine Learning