Shortcuts

Source code for pytorch_lightning.utilities.distributed

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities that can be used with distributed training."""

from typing import Any, Callable, Dict, Optional

import torch
from torch.nn.parallel.distributed import DistributedDataParallel

from lightning_lite.utilities.distributed import all_gather_ddp_if_available as new_all_gather_ddp_if_available
from lightning_lite.utilities.distributed import distributed_available as new_distributed_available
from lightning_lite.utilities.distributed import gather_all_tensors as new_gather_all_tensors
from lightning_lite.utilities.distributed import (
    get_default_process_group_backend_for_device as new_get_default_process_group_backend_for_device,
)
from lightning_lite.utilities.distributed import init_dist_connection as new_init_dist_connection
from lightning_lite.utilities.distributed import sync_ddp as new_sync_ddp
from lightning_lite.utilities.distributed import sync_ddp_if_available as new_sync_ddp_if_available
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_info


[docs]def register_ddp_comm_hook( model: DistributedDataParallel, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[Callable] = None, ddp_comm_wrapper: Optional[Callable] = None, ) -> None: """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. Args: model: DDP model ddp_comm_state: state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error feedback in gradient compression, peers to communicate with next in GossipGrad etc. ddp_comm_hook: hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future This callable function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn't perform any communication, it can also just return a completed Future. The Future should hold the new value of grad bucket's tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. ddp_comm_wrapper: communication hook wrapper to support a communication hook such as FP16 compression as wrapper, which could be combined with ddp_comm_hook Examples: >>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP ... default_hooks as default, ... powerSGD_hook as powerSGD, ... post_localSGD_hook as post_localSGD, ... ) >>> >>> # fp16_compress_hook for compress gradients >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_hook=default.fp16_compress_hook, ... ) >>> >>> # powerSGD_hook >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, ... matrix_approximation_rank=1, ... start_powerSGD_iter=5000, ... ), ... ddp_comm_hook=powerSGD.powerSGD_hook, ... ) >>> >>> # post_localSGD_hook >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... state=post_localSGD.PostLocalSGDState( ... process_group=None, ... subgroup=subgroup, ... start_localSGD_iter=1_000, ... ), ... ddp_comm_hook=post_localSGD.post_localSGD_hook, ... ) >>> >>> # fp16_compress_wrapper combined with other communication hook >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, ... matrix_approximation_rank=1, ... start_powerSGD_iter=5000, ... ), ... ddp_comm_hook=powerSGD.powerSGD_hook, ... ddp_comm_wrapper=default.fp16_compress_wrapper, ... ) """ if ddp_comm_hook is None: return # inform mypy that ddp_comm_hook is callable ddp_comm_hook: Callable = ddp_comm_hook if ddp_comm_wrapper is not None: rank_zero_info( f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." ) ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator]
def _broadcast_object_list(obj: Any, rank: int) -> Any: objects = [obj if torch.distributed.get_rank() == rank else None] torch.distributed.broadcast_object_list(objects, src=rank) return objects[0] # TODO: Refactor with the Strategy Collectives once finalized. def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]: """This distributed utility collects dictionary state across all processes. Args: state: Dictionary containing the state of the current process Returns: states: On global rank 0, a dictionary where the primary keys are the process rank and the values their associated states. Otherwise, returns None. """ if not new_distributed_available(): return {0: state} return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())} def all_gather_ddp_if_available(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.all_gather_ddp_if_available` has been deprecated in v1.8.0 and will" " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.all_gather_ddp_if_available` instead." ) return new_all_gather_ddp_if_available(*args, **kwargs) def distributed_available() -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.distributed_available` has been deprecated in v1.8.0 and will" " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.distributed_available` instead." ) return new_distributed_available() def gather_all_tensors(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.gather_all_tensors` has been deprecated in v1.8.0 and will" " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.gather_all_tensors` instead." ) return new_gather_all_tensors(*args, **kwargs) def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated" " in v1.8.0 and will be removed in v1.10.0. Please use" " `lightning_lite.utilities.distributed.get_default_process_group_backend_for_device` instead." ) return new_get_default_process_group_backend_for_device(*args, **kwargs) def init_dist_connection(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.init_dist_connection` has been deprecated in v1.8.0 and will" " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.init_dist_connection` instead." ) return new_init_dist_connection(*args, **kwargs) def sync_ddp(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.sync_ddp` has been deprecated in v1.8.0 and will" " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.sync_ddp` instead." ) return new_sync_ddp(*args, **kwargs) def sync_ddp_if_available(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.sync_ddp_if_available` has been deprecated in v1.8.0 and will" " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.sync_ddp_if_available` instead." ) return new_sync_ddp_if_available(*args, **kwargs) def tpu_distributed() -> bool: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.tpu_distributed` has been deprecated in v1.8.0 and will" " be removed in v1.10.0. Please use `lightning_lite.accelerators.tpu.tpu_distributed` instead." ) from lightning_lite.accelerators.tpu import tpu_distributed return tpu_distributed()

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision 59c65db8.

Built with Sphinx using a theme provided by Read the Docs.