Skip to content

Commit

Permalink
Update Neuron initializations (#7952)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws committed Sep 6, 2024
1 parent 900296a commit 901c3a3
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 18 deletions.
4 changes: 4 additions & 0 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
set -xue

python3 test/neuron/test_neuron_utils.py
57 changes: 57 additions & 0 deletions test/neuron/test_neuron_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import pytest
import unittest
from torch_xla._internal.neuron_utils import *


class NeuronUtilsTest(unittest.TestCase):

def test_get_visible_cores_list(self):
os.environ["NEURON_RT_VISIBLE_CORES"] = "1"
assert (get_visible_cores_list() == [1])
os.environ["NEURON_RT_VISIBLE_CORES"] = "1,2,3"
assert (get_visible_cores_list() == [1, 2, 3])
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3"
assert (get_visible_cores_list() == [1, 2, 3])
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8"
assert (get_visible_cores_list() == [1, 2, 3, 5, 6, 7, 8])
os.environ["NEURON_RT_VISIBLE_CORES"] = "1,3,5-8"
assert (get_visible_cores_list() == [1, 3, 5, 6, 7, 8])
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8,3-5"
with pytest.raises(ValueError):
get_visible_cores_list()
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8-5"
with pytest.raises(ValueError):
get_visible_cores_list()
os.environ["NEURON_RT_VISIBLE_CORES"] = "a-b,5-8-5"
with pytest.raises(Exception):
get_visible_cores_list()
os.environ["NEURON_RT_VISIBLE_CORES"] = "a"
with pytest.raises(Exception):
get_visible_cores_list()

def test_remap_visible_cores(self):
os.environ["NEURON_RT_VISIBLE_CORES"] = "1"
remap_visible_cores(0, 1)
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "1")
os.environ["NEURON_RT_VISIBLE_CORES"] = "1,2,3"
remap_visible_cores(1, 3)
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "2")
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3"
remap_visible_cores(2, 3)
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "3")
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8"
remap_visible_cores(5, 7)
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "7")
os.environ["NEURON_RT_VISIBLE_CORES"] = "1,3,5-8"
remap_visible_cores(5, 6)
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "8")
with pytest.raises(ValueError):
remap_visible_cores(5, 9)
with pytest.raises(ValueError):
remap_visible_cores(6, 6)


if __name__ == "__main__":
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/pjrt/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
reload(torch_xla)
logs_context = contextlib.nullcontext()
if expect_using_pjrt:
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU'])
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'NEURON'])
else:
self.assertIsNone(xr.device_type())

Expand Down
28 changes: 24 additions & 4 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,31 @@ def _summarize_fn_tracker():

def _aws_ec2_inf_trn_init():
try:
from torch_neuronx import xla
from libneuronxla.libneuronpjrt_path import libneuronpjrt_path
except ImportError:
return
pass
else:
xla.init()
# Need to set NEURON_LIBRARY_PATH here for proper Neuron Cache behavior
os.environ.setdefault('NEURON_LIBRARY_PATH', libneuronpjrt_path())
# Enable addition features and overrides
try:
from torch_neuronx import xla
except ImportError:
pass
else:
xla.init()

# Basic initializations if torch-neuronx is not available
from ._internal import neuron
if os.path.basename(sys.argv[0]) != 'neuron_parallel_compile':
import libneuronxla
libneuronxla.configure_environment()
neuron.set_envvar_defaults()
neuron.configure_pjrt_environment()
# Found libneuronxla
return True
# Did not find libneuronxla
return False


def _setup_tpu_vm_library_path() -> bool:
Expand Down Expand Up @@ -179,7 +199,7 @@ def _check_deprecated_env_var():
_found_libtpu = _setup_tpu_vm_library_path()

# Setup Neuron library for AWS EC2 inf/trn instances.
_aws_ec2_inf_trn_init()
_found_libneuronxla = _aws_ec2_inf_trn_init()


def _prepare_to_exit():
Expand Down
108 changes: 95 additions & 13 deletions torch_xla/_internal/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,115 @@

from torch_xla.experimental import plugins

import sys
import torch.distributed as dist

from .neuron_utils import get_visible_cores_list, remap_visible_cores

logging.basicConfig()
logger = logging.getLogger(__name__)


# Set root communication address/port
def set_rt_root_comm_id():
if os.environ.get('NEURON_RT_ROOT_COMM_ID', None) is None:
if 'MASTER_ADDR' not in os.environ:
logging.warning(
"MASTER_ADDR environment variable is not set, defaulting to localhost"
)
root_port = 62182
root_addr = os.environ.get('MASTER_ADDR', 'localhost')
is_ipv6 = len(root_addr.split(":")) >= 3
if is_ipv6:
modified = False
if not root_addr.startswith("["):
root_addr = "[" + root_addr
modified = True
if not root_addr.endswith("]"):
root_addr = root_addr + "]"
modified = True
if modified:
logger.warning(
"IPv6 address detected for MASTER_ADDR and missing brackets added: {}"
.format(root_addr))
os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format(root_addr, root_port)


def set_envvar_defaults():
os.environ.setdefault('ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB', '50')


def configure_pjrt_environment():
"""
Setting all necessary PJRT default environment variables.
"""
from torch.distributed import is_torchelastic_launched

# Set root communication address/port
set_rt_root_comm_id()

# Set env variables if we don't use GSPMD, using PJRT, and using torchrun
if os.environ.get('XLA_USE_SPMD', '0') != '1' \
and is_torchelastic_launched():
# Env variables that only need to be set once
# NEURON_PJRT_PROCESSES_NUM_DEVICES is a list of core counts and is too long for very large cluster,
# so use NEURON_PJRT_WORLD_SIZE to pass world size and use core count of 1 per process in PJRT client.
if 'NEURON_PJRT_PROCESSES_NUM_DEVICES' not in os.environ and 'NEURON_PJRT_WORLD_SIZE' not in os.environ:
if 'WORLD_SIZE' not in os.environ:
logger.warning(
'WORLD_SIZE environment variable not set, defaulting to 1.')
os.environ["NEURON_PJRT_WORLD_SIZE"] = os.environ.get("WORLD_SIZE", "1")
if 'LOCAL_WORLD_SIZE' not in os.environ:
logger.warning(
'LOCAL_WORLD_SIZE environment variable not set, defaulting to 1.')
os.environ['PJRT_LOCAL_PROCESS_COUNT'] = os.environ.get(
'LOCAL_WORLD_SIZE', '1')

# Env variables that need to be set once per process
if not os.environ.get('NEURON_RT_VISIBLE_CORES', None):
os.environ['NEURON_RT_VISIBLE_CORES'] = os.environ.get('LOCAL_RANK', '0')
else:
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', '1'))
remap_visible_cores(local_rank, local_world_size)

if 'RANK' not in os.environ:
logger.warning('RANK environment variable is not set, defaulting to 0.')
os.environ['NEURON_PJRT_PROCESS_INDEX'] = os.environ.get('RANK', '0')
if 'LOCAL_RANK' not in os.environ:
logger.warning(
'LOCAL RANK environment variable is not set, defaulting to 0.')
os.environ['PJRT_LOCAL_PROCESS_RANK'] = os.environ.get('LOCAL_RANK', '0')


def num_local_processes() -> int:
if 'MASTER_ADDR' not in os.environ:
logging.warning("MASTER_ADDR not setting, defaulting to localhost")
os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format(
os.environ.get('MASTER_ADDR', 'localhost'), '62182')
if "NEURONCORE_NUM_DEVICES" not in os.environ:
logging.warning("NEURONCORE_NUM_DEVICES not set, defaulting to 1")
set_rt_root_comm_id()
num_processes = int(os.environ.get("NEURONCORE_NUM_DEVICES", "1"))
os.environ['NEURON_PJRT_PROCESSES_NUM_DEVICES'] = ','.join(
['1' for _ in range(num_processes)])

return num_processes


# When torchrun is used, setting these environments causes the
# second instance in 2-node cluster to think it is node 0 instead of node 1.
# Need to skip these settings and let configure_pjrt_environment to
# set the distributed PJRT environment variables.
# If NEURONCORE_NUM_DEVICES is used, then go ahead and set the environments.
def initialize_env(local_rank, local_world_size):
os.environ["NEURON_PJRT_PROCESS_INDEX"] = str(local_rank)
assert (
local_rank < local_world_size
), "ERROR in initialize_env: PJRT_LOCAL_PROCESS_RANK is not less than PJRT_LOCAL_PROCESS_COUNT"
os.environ["NEURON_RT_VISIBLE_CORES"] = str(local_rank)
from torch.distributed import is_torchelastic_launched
if not is_torchelastic_launched():
os.environ["NEURON_PJRT_PROCESS_INDEX"] = str(local_rank)
if not os.environ.get('NEURON_RT_VISIBLE_CORES', None):
os.environ["NEURON_RT_VISIBLE_CORES"] = str(local_rank)
else:
remap_visible_cores(local_rank, local_world_size)


class NeuronPlugin(plugins.DevicePlugin):

def library_path(self):
return os.environ.get("NEURON_LIBRARY_PATH", "libneuronpjrt.so")
from libneuronxla.libneuronpjrt_path import libneuronpjrt_path
return os.environ.get("NEURON_LIBRARY_PATH", libneuronpjrt_path())

def configure_multiprocess(self, local_rank, local_world_size):
initialize_env(local_rank, local_world_size)
Expand Down
66 changes: 66 additions & 0 deletions torch_xla/_internal/neuron_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)


def convert_range(range_spec):
try:
lowerupper = list(map(int, range_spec.split("-")))
except Exception as e:
print(f"ERROR: Malformed range specs in NEURON_RT_VISIBLE_CORES;" +
f"expecting <int> or <lower int>-<upper int> (got {range_spec})")
raise e
if len(lowerupper) > 2:
raise ValueError(
f"ERROR: Range specs in NEURON_RT_VISIBLE_CORES should be of " +
f"the form <int> or <lower int>-<upper int> (got {range_spec})")
if len(lowerupper) == 2:
if lowerupper[0] > lowerupper[1]:
raise ValueError(
f"ERROR: Range specs in NEURON_RT_VISIBLE_CORES should " +
f"be of the form <int> or <lower int>-<upper int> (got {range_spec})")
lowerupper = range(lowerupper[0], lowerupper[1] + 1)
return lowerupper


def get_visible_cores_list():
import os

range_list = os.environ.get("NEURON_RT_VISIBLE_CORES")
cores_list = None
if range_list:
range_list = range_list.split(",")
cores_list = []
for i in range_list:
new = convert_range(i)
if (set(cores_list) & set(new)) != set():
raise ValueError(
"ERROR: Please ensure the ranges in NEURON_RT_VISIBLE_CORES are mutually exclusive."
)
cores_list.extend(new)
return cores_list


def remap_visible_cores(local_rank, local_world_size):
cores_list = get_visible_cores_list()
count = len(cores_list)
assert (local_world_size > 0), "Local world size should be non-zero"
if count <= 1 and local_world_size == 1:
# Allow user to pass NEURON_RT_VISIBLE_CORES for sinlge-core workload
pass
elif local_world_size != count:
raise ValueError(
f"LOCAL_WORLD_SIZE (torchrun) or PJRT_LOCAL_PROCESS_COUNT (xmp.spawn) value of {local_world_size} "
+
f"is not equal to count {count} from NEURON_RT_VISIBLE_CORES {cores_list}"
)
elif local_rank >= count:
raise ValueError(
f"LOCAL_RANK (torchrun) or PJRT_LOCAL_PROCESS_RANK (xmp.spawn) value of {local_rank} is higher than "
+ f"count {count} from NEURON_RT_VISIBLE_CORES {cores_list}")
else:
remapped_core = cores_list[local_rank]
logger.warning(f"Remapping NEURON_RT_VISIBLE_CORES {cores_list} to " +
f"NEURON_RT_VISIBLE_CORES[LOCAL_RANK]={remapped_core}")
os.environ['NEURON_RT_VISIBLE_CORES'] = str(remapped_core)
3 changes: 3 additions & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def _maybe_select_default_device():
+ num_devices_str)
os.environ[xenv.PJRT_DEVICE] = 'CUDA'
os.environ[xenv.GPU_NUM_DEVICES] = num_devices_str
elif torch_xla._found_libneuronxla:
logging.warning('Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.')
os.environ[xenv.PJRT_DEVICE] = 'NEURON'
else:
logging.warning('Defaulting to PJRT_DEVICE=CPU')
os.environ[xenv.PJRT_DEVICE] = 'CPU'
Expand Down

0 comments on commit 901c3a3

Please sign in to comment.