[source code analysis] deep learning distributed training framework horovod --- start on spark

Posted by updwebmaster on Tue, 25 Jan 2022 09:59:13 +0100

[source code analysis] deep learning distributed training framework horovod (9) - start on spark

0x00 summary

Horovod is an easy-to-use high-performance distributed training framework released by Uber in 2017, which has been widely used in the industry.

This series will lead you to understand horovod through source code analysis. These articles introduce how horovod runs on spark. This is the ninth article, which introduces how to start horovod on spark.

Other articles in this series are as follows:

[source code analysis] deep learning distributed training framework Horovod - (1) basic knowledge

[source code analysis] deep learning distributed training framework horovod (2) - from the perspective of users

[source code analysis] deep learning distributed training framework horovod (3) - what is behind Horovodrun

[source code analysis] deep learning distributed training framework horovod (4) - Network Foundation & Driver

[source code analysis] deep learning distributed training framework horovod (5) - Fusion Framework

[source code analysis] deep learning distributed training framework horovod (6) - background thread architecture

[source code analysis] deep learning distributed training framework horovod (7) - distributed optimizer

[source code analysis] deep learning distributed training framework horovod (8) - on spark

0x01 overall architecture diagram

First of all, we still need to offer the architecture diagram, so that we can follow the diagram.

In general, the overall logic of Horovod on Spark is divided into the following stages:

  • Start SparkDriverService service and use_ make_spark_thread starts Spark task, and horovod will wait for the start to end;
  • Multithreading starts spark task s in the spark executor. One SparkTaskService runs in each task. The SparkTaskService will register with the SparkDriverTask in the host process and wait for the next step to run the start instruction;
  • After Horovod receives the information about the end of all tasks, it notifies each task to enter the next stage;
  • Horovod calls mpi_run (also using mpirun_rsh.py) starts the p orted process on each spark executor to start the MPI cluster;
  • orted runs the training code on each executor;

Let's take a look at how to start.

0x02 phase I: Horovod start

The main logic of this part is to start the SparkDriverService service and use_ make_spark_thread starts Spark task, and horovod will wait for the start to end.

2.1 Driver service: SparkDriverService

SparkDriverService inherits driver_ service. Basic driverservice, so a socket server is started inside it for network interaction.

Horovod uses SparkDriverService to interact with Spark executor (through the SparkTaskService running in it), such as collecting information, asking spark to start training job, etc. This is an RPC mechanism.

For specific functions of SparkDriverService, please refer to various requests processed internally, such as

  • CodeRequest: SparkTaskService will be used to request user code;
  • TaskHostHashIndicesRequest: get the task host address;
  • TaskIndexByRankRequest: Get task index from rank;
  • SetLocalRankToRankRequest: get rank information from local rank;
  • WaitForTaskShutdownRequest: wait for shutdown;

It is similar to the HorovodRunDriverService described earlier.

Where, its member variable_ fn is the training function. In the future, when SparkTaskService requests code, it will send it through CodeResponse_ fn send it back directly. This solves the code release problem.

class SparkDriverService(driver_service.BasicDriverService):
    NAME = 'driver service'

    def __init__(self, initial_np, num_proc, fn, args, kwargs, key, nics):
        super(SparkDriverService, self).__init__(num_proc,
                                                 SparkDriverService.NAME,
                                                 key, nics)
        self._initial_np = initial_np
        self._fn = fn # Save user code
        self._args = args # User parameters
        self._kwargs = kwargs 
        self._key = key
        self._nics = nics # Network card information
        self._ranks_to_indices = {}
        self._spark_job_failed = False
        self._lock = threading.Lock()
        self._task_shutdown = threading.Event()

    def _handle(self, req, client_address):

        if isinstance(req, TaskHostHashIndicesRequest): # Get task host address
            return TaskHostHashIndicesResponse(self._task_host_hash_indices[req.host_hash])

        if isinstance(req, SetLocalRankToRankRequest): # Get rank information from local rank
            self._lock.acquire()

            try:
                # get index for host and local_rank
                indices = self._task_host_hash_indices[req.host]
                index = indices[req.local_rank]

                values = list(self._ranks_to_indices.values())
                prev_pos = values.index(index) if index in values else None
                if prev_pos is not None:
                    prev_rank = list(self._ranks_to_indices.keys())[prev_pos]
                    del self._ranks_to_indices[prev_rank]

                # memorize rank's index
                self._ranks_to_indices[req.rank] = index
            finally:
                self._lock.release()
            return SetLocalRankToRankResponse(index)

        if isinstance(req, TaskIndexByRankRequest): # Get task index from rank
            self._lock.acquire()
            try:
                return TaskIndexByRankResponse(self._ranks_to_indices[req.rank])
            finally:
                self._lock.release()

        if isinstance(req, CodeRequest): # SparkTaskService will be used to request user code
            return CodeResponse(self._fn, self._args, self._kwargs)

        if isinstance(req, WaitForTaskShutdownRequest): # Wait for the end of the task
            self._task_shutdown.wait()
            return network.AckResponse()

        return super(SparkDriverService, self)._handle(req, client_address)

2.2 start spark task:_ make_ spark_ thread

In horovod spark. In run_ make_spark_thread establishes a thread. The key codes here are:

mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
result = procs.mapPartitionsWithIndex(mapper).collect()

The code mappartitions with index will cause Spark to run the mapper function among multiple executors and get the running results.

Settings are created num_ Proc contains Spark tasks. Each task will run mapper (_task_fn), and the external run function will wait for these execution results. In fact, if you need to use RDD, you may use foreachPartition, so that each node will hold a partition of RDD in memory.

def _make_spark_thread(spark_context, spark_job_group, driver, result_queue,
                       settings, use_gloo, is_elastic):
    """Creates `settings.num_proc` Spark tasks in a parallel thread."""
    
    def run_spark():
        """Creates `settings.num_proc` Spark tasks, each executing `_task_fn` and waits for them to terminate."""
        try:
            spark_context.setJobGroup(spark_job_group, "Horovod Spark Run", interruptOnCancel=True)
            procs = spark_context.range(0, numSlices=settings.max_np if settings.elastic else settings.num_proc)
            # We assume that folks caring about security will enable Spark RPC encryption,
            # thus ensuring that key that is passed here remains secret.
            mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
            # Make Spark run the mapper function among multiple executors and get the running results
            result = procs.mapPartitionsWithIndex(mapper).collect()
            result_queue.put(result)
        except:
            driver.notify_spark_job_failed()
            raise

    spark_thread = in_thread(target=run_spark, daemon=False)
    return spark_thread

2.3 wait for spark task to start and finish

After the spark task is started, the horovod main process will call and wait for all tasks to be started.

# wait for all tasks to register, notify them and initiate task-to-task address registration
_notify_and_register_task_addresses(driver, settings)

That is, in the run function, when_ make_ spark_ After the thread, the horovod main process calls_ notify_and_register_task_addresses, which calls driver wait_ for_ initial_ Registration (settings. Start_timeout) for overall waiting.

The waiting content is: wait for all num_proc tasks to register. When all spark thread s are ready, the main horovod process will continue to run.

2.3.1 _notify_and_register_task_addresses

In the horovod main process, the_ notify_and_register_task_addresses to wait for these spark task s to register, thus calling driver wait_ for_ initial_ Registration (settings. Start_timeout) for overall waiting.

Note that after sending the registration request at the same time, spark task also calls task wait_ for_ initial_ Registration waits for horovod to notify the start of the next phase.

In the main process of horovod_ notify_and_register_task_addresses are also complicated:

  • Call driver wait_ for_ initial_ Registration: wait for task to register. You need to wait for num_proc task;
  • Using notify_and_register registers tasks and notifies each task to start the next step;

The specific codes are as follows:

def _notify_and_register_task_addresses(driver, settings, notify=True):
    # wait for num_proc tasks to register
    # Wait for task to register. You need to wait for num_proc tasks
    driver.wait_for_initial_registration(settings.start_timeout) 

    def notify_and_register(index): # Register tasks and notify each task to start the next step
        task_client = task_service.SparkTaskClient(index,
                                                   driver.task_addresses_for_driver(index),
                                                   settings.key, settings.verbose)

        if notify:
            task_client.notify_initial_registration_complete()

        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver.all_task_addresses(next_task_index)
        task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
        driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)

    for index in driver.task_indices():
        in_thread(notify_and_register, (index,)) #Start task in thread

    driver.wait_for_task_to_task_address_updates(settings.start_timeout)

We can only look at the first step "waiting for registration".

2.3.2 driver.wait_for_initial_registration

Here, sparkdriverservice first waits for all spark executor s to register.

In class BasicDriverService(network.BasicService): there is the following code. You can see that only all_ num_ After the proc registration is completed, when all spark thread s are ready, the main horovod process will continue to run.

The key here is: while len (self._all_task_addresses) < self_ num_proc is waiting for self_ all_ task_ The number of addresses reached_ num_proc.

class BasicDriverService(network.BasicService):
  
  def wait_for_initial_registration(self, timeout):
      self._wait_cond.acquire()
      try:
          # Wait for self_ all_ task_ The number of addresses reached_ num_proc
          while len(self._all_task_addresses) < self._num_proc:
              self._wait_cond.wait(timeout.remaining())
              timeout.check_time_out_for('tasks to start')
      finally:
          self._wait_cond.release()

2.4 waiting

For the waiting code, we need to make a special description. See the figure for details.

Here are two sets of wait_for_initial_registration. It can be considered as two sets of barrier s.

namely:

  • barrier 1: sparkdriverservice waits for all sparktaskservices to be ready;
  • Barrier 2: all sparktaskservces need to run together, so sparktaskservces are waiting for barrier 2. Sparkdriverservice will notify these sparktaskservces to launch together;

2.3.1 Barrier 1 in Driver

In the run function, when_ make_ spark_ After the thread, the horovod main process calls_ notify_and_register_task_addresses, which calls driver wait_ for_ initial_ Registration (settings. Start_timeout) for overall waiting.

The waiting content is: wait for all num_proc tasks to register. When all spark thread s are ready, the main horovod process will continue to run. The key here is:

while len(self._all_task_addresses) < self._num_proc

Just wait for self_ all_ task_ The number of addresses reached_ num_proc.

def wait_for_initial_registration(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._all_task_addresses) < self._num_proc:
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('tasks to start')
    finally:
        self._wait_cond.release()

In BasicDriverService, if you receive a registration request from spark executor, you can process it. The most important things here are:

self._all_task_addresses[req.index] = req.task_addresses

When all spark executor s are registered, we wait for success.

2.3.2 Barrier 2 in task

Each spark thread is_ task_ Running in FN means running in spark task. It can also be seen here that it is an overall process of spark task:

  • First call register_task;
  • Next, call task wait_ for_ initial_ registration(settings.start_timeout) ;
  • Then call wait_. for_ command_ Termination to wait for the end;

task.wait_for_initial_registration will wait for self_ initial_ registration_ The condition of complete = true is to wait for register_task registration completed.

Each Spark Executor has a SparkTaskService, so each spark task has its own_ initial_registration_complete.

hovorod. The run main process will notify each SparkTaskService one by one_ initial_registration_complete.

That is to say, if a SparkTaskService is ready, it will notify the owner of the SparkTaskService_ initial_registration_complete. In this way, the SparkTaskService can be officially run.

2.3.3 overall waiting process

The overall waiting process is shown in the figure. The number is the execution sequence:

  1. Sparkdriverservice calls driver wait_ for_ initial_ Registration to wait for the registration of sparktaskservice, which is barrier 1;
  2. Sparktaskservice 1 registers, and then sparktaskservice 1 calls task itself wait_ for_ initial_ Registration waits for horovod to notify the start of the next phase, which is barrier 2;
  3. Sparktaskservice 2 registers, and then sparktaskservice 2 calls task itself wait_ for_ initial_ Registration waits for horovod to notify the start of the next phase, which is barrier 2;
  4. hovorod. After the run main process finds that all tasks are registered, barrier 1 waits for the end and will notify each SparkTaskService one by one_ initial_registration_complete. Only after 4 is completed can two sparktaskservces continue to execute 5 and 6;
  5. Sparktaskservice 1 waits for barrier 2 and continues to execute;
  6. Sparktaskservice 2 waits for the completion of barrier 2 and continues to execute;
    SparkTaskSerivce 1          SparkTaskSerivce 2            SparkDriverSerivce

            +                           +                             +
            |                           |                             |
            |                           |                             |
            |                           |                             |
            |                           |                             |   1
            |                           |                             |
            |                           |                             |
            |                           |                             v
            |                           |
            |                           |         +--------------------------------------+
            |                           |         | barrier 1                            |
            |                           |   2     |                                      |
            |          3                +-------> |                                      |
            |                           |         |                                      |
            +-----------------------------------> | driver.wait_for_initial_registration |
            |                           |         |                                      |
            |                           |         |                                      |
            |                           |         |                                      |
            |                           |         +--------------------+-----------------+
            |                           |                              |
            |                           |                              |
+-----------+----------------------+    |                  4           |
|barrier 2                         | <---------------------------------+
|                                  |    |                              |
|task.wait_for_initial_registration|    |                              |
|                                  |    |                              |
+-----------+----------------------+    |                              |
            |                           |                              |
            |             +-------------+----------------------+       |
            |             | barrier 2                          |   4   |
            | 6           |                                    +<------+
            |             | task.wait_for_initial_registration |       |
            |             |                                    |       |
            |             +-------------+----------------------+       |
            |                           |                              |
            |                           |                              |
            |                           |  5                           |
            |                           |                              |
            v                           v                              v

Next, we will introduce the task startup content and the subsequent work of driver in detail.

0x03 phase II: Spark Task start

In this stage, we will introduce the startup process of Spark Task in detail.

The main functions of this part are: multithreading starts spark tasks in spark executor, and each spark task will run_ task_fn function_ task_fn function will run a SparkTaskService, SparkTaskService will register with SparkDriverTask in the main process of hovorod, and wait for the instruction to start in the next step;

At this time, the program (not the training program, but the SparkTaskService) has been running inside the spark Executor. Let's see how to start and run SparkTaskService in spark Executor.

3.1 specific spark start logic:_ task_fn

Horovod passes through the thread_ make_mapper to make Spark run_ task_fn.

def _make_mapper(driver_addresses, settings, use_gloo, is_elastic):

    def _mapper(index, _):
        yield _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic)

    return _mapper

_ task_fn is used to register horovod and enter spark task. That is, start a SparkTaskService in each spark task (executor).

It must be noted that these sparktaskservices run in spark executor and interact with SparkDriverService in horovod through the network.

As you can see_ task_ The overall logic of FN is:

  • Start SparkTaskService;
  • Via driver_service.SparkDriverClient.register_task to register with the Driver in horovod;
  • Through task wait_ for_ initial_ Registration (settings. Start_timeout) to wait for the start indication of the next start;
  • If the next step starts, call task wait_ for_ command_ Termination() wait for the end;

The details are as follows:

def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic):
    settings.key = key
    hosthash = host_hash(salt='{}-{}'.format(index, time.time()) if is_elastic else None)
    os.environ['HOROVOD_HOSTNAME'] = hosthash
    # Start the SparkTaskService. The SparkTaskService itself includes a socket server that can interact with the driver
    task = task_service.SparkTaskService(index, settings.key, settings.nics,...)
    try:
        driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
        # Register with Driver in horovod
        driver_client.register_task(index, task.addresses(), hosthash)

        # It's still running in spark task, but it's not SparkTaskService, so it's just doing assistance work, and finally waiting quietly
        if not is_elastic:
            # Wait for the start indication of the next start
            task.wait_for_initial_registration(settings.start_timeout)
            task_indices_on_this_host = driver_client.task_host_hash_indices(hosthash)
            local_rank_zero_index = task_indices_on_this_host[0]
        else:
            local_rank_zero_index = None

        if is_elastic:
						...... # The following articles will introduce
        elif use_gloo or index == local_rank_zero_index:
            # Either Gloo or first task with MPI.
            # Use Gloo or the first task of MPI to make this task operate
            task.wait_for_command_start(settings.start_timeout)
            # Wait for the end
            task.wait_for_command_termination()
        else:
            # The other tasks with MPI need to wait for the first task to finish.
            # Let other tasks wait for the end of the first task
            first_task_addresses = driver_client.all_task_addresses(local_rank_zero_index)
            first_task_client = \
                task_service.SparkTaskClient(local_rank_zero_index,
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            # Call task wait_ for_ command_ Termination() wait for the end  
            first_task_client.wait_for_command_termination()

        return task.fn_result()
    finally:
        task.shutdown()

3.2 SparkTaskService

Re emphasize the following codes:

task = task_service.SparkTaskService(index, settings.key, settings.nics,...)

Each_ task_fn defines a SparkTaskService, that is, each Spark Executor will generate one (or more) sparktaskservices to run and function in spark task s.

3.2.1 SparkTaskService definition

The definition of SparkTaskService is as follows. Because it inherits BasicTaskService, a socket server will eventually be started inside it to interact with SparkDriverService in horovod:

class SparkTaskService(task_service.BasicTaskService):
    NAME_FORMAT = 'task service #%d'

    def __init__(self, index, key, nics, minimum_command_lifetime_s, verbose=0):
        # on a Spark cluster we need our train function to see the Spark worker environment
        # this includes PYTHONPATH, HADOOP_TOKEN_FILE_LOCATION and _HOROVOD_SECRET_KEY
        env = os.environ.copy()

        # we inject the secret key here
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)

        # we also need to provide the current working dir to mpirun_exec_fn.py
        env['HOROVOD_SPARK_WORK_DIR'] = os.getcwd()

        super(SparkTaskService, self).__init__(SparkTaskService.NAME_FORMAT % index,
                                               index, key, nics, env, verbose)
        self._key = key
        self._minimum_command_lifetime_s = minimum_command_lifetime_s
        self._minimum_command_lifetime = None

3.2.2 basic functions

The basic functions of SparkTaskService are as follows.

  • _ run_command will be used to start the training job in spark;
  • _ handle will process GetTaskToTaskAddressesRequest to obtain the task address, and will also process ResourcesRequest to return resources;
  • _ get_resources will return spark resources;
  • wait_for_command_termination will wait for the end of command execution;

The specific codes are as follows:

def _run_command(self, command, env, event,
                 stdout=None, stderr=None, index=None,
                 prefix_output_with_timestamp=False):
    # Start the training job in spark
    super(SparkTaskService, self)._run_command(command, env, event,
                                               stdout, stderr, index,
                                               prefix_output_with_timestamp)

    if self._minimum_command_lifetime_s is not None:
        self._minimum_command_lifetime = timeout.Timeout(self._minimum_command_lifetime_s,
                                                         message='Just measuring runtime')

def _handle(self, req, client_address):
    # Return resource
    if isinstance(req, ResourcesRequest):
        return ResourcesResponse(self._get_resources())

    # Get task address  
    if isinstance(req, GetTaskToTaskAddressesRequest):
        next_task_index = req.task_index
        next_task_addresses = req.all_task_addresses
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task_client = \
            SparkTaskClient(next_task_index, next_task_addresses,
                            self._key, self._verbose,
                            match_intf=True)
        return GetTaskToTaskAddressesResponse(next_task_client.addresses())

    return super(SparkTaskService, self)._handle(req, client_address)

def _get_resources(self):
    # Return spark resource
    if LooseVersion(pyspark.__version__) >= LooseVersion('3.0.0'):
        task_context = pyspark.TaskContext.get()
        if task_context:
            return task_context.resources()
        else:
            print("Not running inside Spark worker, no resources available")
    return dict()

def wait_for_command_termination(self):
    """
    Waits for command termination. Ensures this method takes at least
    self._minimum_command_lifetime_s seconds to return after command started.
    """
    try:
        # Wait for the end of command execution
        return super(SparkTaskService, self).wait_for_command_termination()
    finally:
        # command terminated, make sure this method takes at least
        # self._minimum_command_lifetime_s seconds after command started
        # the client that started the command needs some time to connect again
        # to wait for the result (see horovod.spark.driver.rsh).
        if self._minimum_command_lifetime is not None:
            time.sleep(self._minimum_command_lifetime.remaining())

3.3 registering tasks

The next step is to register the task with the Driver.

driver_client.register_task(index, task.addresses(), hosthash)

3.3.1 send registration request

The registration is completed as follows. Here we call network PY_ The send function carries out network interaction through socket, spark executor and horovod driver:

class BasicDriverClient(network.BasicClient):

    def register_task(self, index, task_addresses, host_hash):
        self._send(RegisterTaskRequest(index, task_addresses, host_hash))

3.3.2 Driver handling

Let's take a look at the Driver running in Horovod (see the next section in advance).

In the basic driver service, if a registertaskerequest request is received, it will be processed. The most important things here are:

self._all_task_addresses[req.index] = req.task_addresses

So, self_ all_ task_ The number of addresses increases.

As we mentioned earlier, horovod is driving wait_ for_ initial_ The key to waiting for registration is:

while len(self._all_task_addresses) < self._num_proc

If self_ all_ task_ The number of addresses has reached_ num_proc,driver. wait_ for_ initial_ The registration is over and the implementation is smooth.

The specific code for handling registertaskerequest is as follows. There are various member variables in BasicDriverService to maintain various required information. We [original] above [source code analysis] deep learning distributed training framework horovod (4) - Network Foundation & Driver It has been explained in detail in_ The registertaskerequest processing of the handle function is used to update these member variables:

class BasicDriverService(network.BasicService):

    def _handle(self, req, client_address):
        if isinstance(req, RegisterTaskRequest):
            self._wait_cond.acquire()
            try:

                self._all_task_addresses[req.index] = req.task_addresses
                # Just use source address for service for fast probing.
                self._task_addresses_for_driver[req.index] = \
                    self._filter_by_ip(req.task_addresses, client_address[0])
                  
                # Remove host hash earlier registered under this index.
                if req.index in self._task_index_host_hash:
                    earlier_host_hash = self._task_index_host_hash[req.index]
                    if earlier_host_hash != req.host_hash:
                        self._task_host_hash_indices[earlier_host_hash].remove(req.index)

                # Make index -> host hash map.
                self._task_index_host_hash[req.index] = req.host_hash

                # Make host hash -> indices map.
                if req.host_hash not in self._task_host_hash_indices:
                    self._task_host_hash_indices[req.host_hash] = []
                self._task_host_hash_indices[req.host_hash].append(req.index)
                # TODO: this sorting is a problem in elastic horovod
                self._task_host_hash_indices[req.host_hash].sort()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
                
            return network.AckResponse()

3.4 the task waits for the next notification

As mentioned earlier, when Spark task sends a registration request to the driver, Spark task passes task wait_ for_ initial_ Registration (settings. Start_timeout) to wait for the start indication of the next start. The driver thinks that you have completed the registration of one scene and can start to move on to the next step.

task.wait_for_initial_registration will wait for self_ initial_ registration_ The condition of complete = true is to wait for register_task registration completed.

class BasicTaskService(network.BasicService):
  
  def wait_for_initial_registration(self, timeout):
        self._wait_cond.acquire()
        try:
            while not self._initial_registration_complete:
                self._wait_cond.wait(timeout.remaining())
                timeout.check_time_out_for('tasks to start')
        finally:
            self._wait_cond.release()

Each Spark Executor has a SparkTaskService, so each spark task has its own_ initial_registration_complete.

hovorod. The run main process will notify each SparkTaskService one by one_ initial_registration_complete. That is to say, if a SparkTaskService is ready, it will notify the owner of the SparkTaskService_ initial_registration_complete.

hovorod. The run main process completes this step by sending NotifyInitialRegistrationCompleteRequest.

def notify_initial_registration_complete(self):
    self._send(NotifyInitialRegistrationCompleteRequest())

BasicTaskService is waiting for NotifyInitialRegistrationCompleteRequest. If it receives it, it will be set to True so that it can wait_for_initial_registration will wait for the end.

if isinstance(req, NotifyInitialRegistrationCompleteRequest):
    self._wait_cond.acquire()
    try:
        self._initial_registration_complete = True
    finally:
        self._wait_cond.notify_all()
        self._wait_cond.release()
    return network.AckResponse()

This means that after the thread is registered in horovod, even if the spark thread is started successfully.

+-------------------------------------+             +----------------------------------------------------+
| Horovod Main thread                 |             | Spark Executor                                     |
|                                     |             |                     _task_fn                       |
|                                     |             |                        +                           |
|                                     |             |                        |                           |
|                                     |             |                        |                           |
|                                     |             |                        v                           |
| +-------------------------------+   |             |  +---------------------+------------------------+  |
| | SparkDriverService            |   |             |  | SparkTaskService                             |  |
| |                               |   |             |  |               +                              |  |
| |                               |   |  1 register |  |               |                              |  |
| |  self._all_task_addresses <----------------------------------------+                              |  |
| |                               |   |             |  |               |                              |  |
| |              +                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              | 3              |   |             |  |               |                              |  |
| |              |                |   |             |  |               | 2                            |  |
| |              v                |   |             |  |               |                              |  |
| |  self._wait_cond.notify_all() |   |             |  |               |                              |  |
| |              +                |   |             |  |               v                              |  |
| |              |                |   |             |  |     +---------+---------------------------+  |  |
| |              |                |   |             |  |     |                                     |  |  |
| |              |                |   |             |  |     | task.wait_for_initial_registration  |  |  |
| |              |                |   |             |  |     |                                     |  |  |
| |              |                |   |             |  |     +-------------------------------------+  |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              |                |   |             |  |                                              |  |
| |              v                |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| +-------------------------------+   |             |  +----------------------------------------------+  |
+-------------------------------------+             +----------------------------------------------------+

Mobile phones are as follows:

0x04 stage 3: Driver notifies task of successful registration

The function of this stage is: after Horovod receives the completion information of all tasks, it notifies each task to enter the next stage.

4.1 _notify_and_register_task_addresses

Mentioned earlier. In the horovod main process, the_ notify_and_register_task_addresses to wait for these spark task s to register, thus calling driver wait_ for_ initial_ Registration (settings. Start_timeout) for overall waiting.

Note that after sending the registration request at the same time, spark task also calls task wait_ for_ initial_ Registration waits for horovod to notify the start of the next phase.

And_ notify_and_register_task_addresses is also complicated:

  • Call driver wait_ for_ initial_ Registration: wait for the task to register; (this step has been completed)
  • Using notify_and_register registers tasks and notifies each task to start the next step; (let's move on to the next two steps)
  • Using driver wait_ for_ task_ to_ task_ address_ Update again, confirm that all tasks are OK;
def _notify_and_register_task_addresses(driver, settings, notify=True):
    # wait for num_proc tasks to register
    driver.wait_for_initial_registration(settings.start_timeout)

    def notify_and_register(index):
        # Register tasks and notify each task to start the next step
        task_client = task_service.SparkTaskClient(index,
                                                   driver.task_addresses_for_driver(index),
                                                   settings.key, settings.verbose)

        if notify:
            task_client.notify_initial_registration_complete()

        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver.all_task_addresses(next_task_index)
        task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
        driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)

    for index in driver.task_indices():
        in_thread(notify_and_register, (index,)) # Register tasks and notify each task to start the next step

    # Confirm again that all task s are OK    
    driver.wait_for_task_to_task_address_updates(settings.start_timeout)

4.2 notify_and_register

You can see notify_ and_ The function of register is:

  • Call task_client.notify_initial_registration_complete() notifies spark that the task registration is successful, which makes all waiting tasks wait_ for_ initial_ Run the next phase together with the spark executor of registration.
  • Call Driver register_ task_ to_ task_ Addresses (next_task_index, task_to_task_addresses) to let the Driver complete the registration.
def wait_for_task_to_task_address_updates(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._task_addresses_for_tasks) < self._initial_np:
            self.check_for_spark_job_failure()
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
    finally:
        self._wait_cond.release()

4.3 wait_for_task_to_task_address_updates

Here you will confirm again that all spark task s are OK.

def wait_for_task_to_task_address_updates(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._task_addresses_for_tasks) < self._initial_np:
            self.check_for_spark_job_failure()
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
    finally:
        self._wait_cond.release()

4.4 waiting for In Task

In Spark task, if the next start instruction is received, wait will be called_ for_ command_ Wait for termination.

In fact, this step means that the logical task of spark executor itself is over, because in the future, SparkTaskService will complete the action independently, which is responsible for starting the training code. Since_ task_fn's logical task is over, just wait quietly.

4.4.1 wait_for_command_termination

In horovod master / horovod / spark / task / task_ service. py

def wait_for_command_termination(self):
    """
    Waits for command termination. Ensures this method takes at least
    self._minimum_command_lifetime_s seconds to return after command started.
    """
    try:
        return super(SparkTaskService, self).wait_for_command_termination()
    finally:
        # command terminated, make sure this method takes at least
        # self._minimum_command_lifetime_s seconds after command started
        # the client that started the command needs some time to connect again
        # to wait for the result (see horovod.spark.driver.rsh).
        if self._minimum_command_lifetime is not None:
            time.sleep(self._minimum_command_lifetime.remaining())

In horovod master / horovod / Runner / common / service / task_ service. As you can see in py, wait for the thread where the training code is located to end.

def wait_for_command_termination(self):
    self._command_thread.join() # It will be explained soon

4.4.2 _command_thread

Right here_ command_thread for a brief description.

When the SparkTaskService processes the RunCommandRequest, the thread of the running Command is assigned as_ command_thread.

class BasicTaskService(network.BasicService):
    def _handle(self, req, client_address):
      
        if isinstance(req, RunCommandRequest): # Run command request
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:

                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    self._command_abort = threading.Event()
                    self._command_stdout = Pipe() if req.capture_stdout else None
                    self._command_stderr = Pipe() if req.capture_stderr else None
                    # Configure various parameter information
                    args = (req.command, req.env, self._command_abort,
                            self._command_stdout, self._command_stderr,
                            self._index,
                            req.prefix_output_with_timestamp)
                    # Start a new thread to run the command
                    self._command_thread = in_thread(self._run_command, args)
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()  

The logic is as follows:

+-------------------------------------+             +----------------------------------------------------+
| Horovod Main thread                 |             | Spark Executor                                     |
|                                     |             |                     _task_fn                       |
|                                     |             |                        +                           |
|                                     |             |                        |                           |
|                                     |             |                        |                           |
|                                     |             |                        v                           |
| +-------------------------------+   |             |  +---------------------+------------------------+  |
| | SparkDriverService            |   |             |  | SparkTaskService                             |  |
| |                               |   |             |  |               +                              |  |
| |                               |   |  1 register |  |               |                              |  |
| |  self._all_task_addresses <----------------------------------------+                              |  |
| |                               |   |             |  |               |                              |  |
| |              +                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              | 3              |   |             |  |               |                              |  |
| |              |                |   |             |  |               | 2                            |  |
| |              v                |   |             |  |               |                              |  |
| |  self._wait_cond.notify_all() |   |             |  |               |                              |  |
| |              +                |   |             |  |               v                              |  |
| |              |                |   +             +  +     +---------+---------------------------+  |  |
| |              |            4   |  RegistrationComplete    |                                     |  |  |
| |              |  +-----------------+-------------+--+---> | task.wait_for_initial_registration  |  |  |
| |              |                |   |             |  |     |                                     |  |  |
| |              |                |   |             |  |     +---------+---------------------------+  |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               | 5                            |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               |                              |  |
| |              |                |   |             |  |               v                              |  |
| |              |                |   |             |  |        wait_for_command_termination          |  |
| |              |                | 6 |  RunCommand |  |               +                              |  |
| |              |                |   |             |  |               |                              |  |
| |              +----------------------------------------------->     | 7                            |  |
| |              |                |   |             |  |               v                              |  |
| |              v                |   |             |  |        self._command_thread.join()           |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| |                               |   |             |  |                                              |  |
| +-------------------------------+   |             |  +----------------------------------------------+  |
+-------------------------------------+             +----------------------------------------------------+

Mobile phones are as follows:

So far, the first stage is completed. Let's continue with the next one. Please look forward to it.

0x05 summary

In general, the overall logic of Horovod on Spark is divided into the following stages:

  • Start SparkDriverService service and use_ make_spark_thread starts Spark task, and horovod will wait for the start to end;
  • Multithreading starts spark task s in the spark executor. One SparkTaskService runs in each task. The SparkTaskService will register with the SparkDriverTask in the host process and wait for the next step to run the start instruction;
  • After Horovod receives the information about the end of all tasks, it notifies each task to enter the next stage;
  • Horovod calls mpi_run (also using mpirun_rsh.py) starts p orted on each spark executor to start the MPI cluster;
  • orted runs the training code on each executor;

This paper introduces the first three stages, namely the start-up stage. The next two phases are described below. Please look forward to them.

0xEE personal information

★★★★★★★ thinking about life and technology ★★★★★★

Wechat public account: Rossi's thinking

If you want to get the news push of personal articles in time, or want to see the technical materials recommended by yourself, please pay attention.

Topics: Machine Learning Deep Learning