Skip to content

Commit

Permalink
[BUG] Fix databricks-aws user profiling tool error with `--gpu_cluste…
Browse files Browse the repository at this point in the history
…r` argument (#707)


Fixes #695
---------

Signed-off-by: cindyyuanjiang <cindyj@nvidia.com>
  • Loading branch information
cindyyuanjiang authored Jan 18, 2024
1 parent 2f1c7ee commit 5b49ced
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions data_validation/src/spark_rapids_validation_tool/utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -110,7 +110,7 @@ def gen_random_string(str_length: int) -> str:


def get_gpu_device_list():
return ['T4', 'V100', 'K80', 'A100', 'P100']
return ['T4', 'V100', 'K80', 'A100', 'P100', 'A10', 'A10G']


def is_valid_gpu_device(val):
Expand Down
6 changes: 4 additions & 2 deletions user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,6 +50,7 @@ class GpuDevice(EnumeratedType):
P4 = 'P4'
L4 = 'l4'
A10 = 'a10'
A10G = 'a10g'

@classmethod
def get_default_gpu(cls):
Expand All @@ -64,7 +65,8 @@ def get_gpu_mem(self) -> list:
self.K80: [12288],
self.V100: [16384],
self.P100: [16384],
self.A10: [24576]
self.A10: [24576],
self.A10G: [24576]
}
return memory_hash.get(self)

Expand Down

0 comments on commit 5b49ced

Please sign in to comment.