Skip to content

Commit ae76aad

Browse files
jmoralezStrikerRUS
andauthored
[python-package] do not copy column-major numpy arrays when creating Dataset (#6721)
* do not copy column-major numpy arrays when creating Dataset * fix logic * lint * code review * update test * move dataset test to basic * increase features * assert single layout --------- Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
1 parent 33764e1 commit ae76aad

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

python-package/lightgbm/basic.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,23 @@ def _get_sample_count(total_nrow: int, params: str) -> int:
188188
return sample_cnt.value
189189

190190

191+
def _np2d_to_np1d(mat: np.ndarray) -> Tuple[np.ndarray, int]:
192+
if mat.dtype in (np.float32, np.float64):
193+
dtype = mat.dtype
194+
else:
195+
dtype = np.float32
196+
if mat.flags["F_CONTIGUOUS"]:
197+
order = "F"
198+
layout = _C_API_IS_COL_MAJOR
199+
else:
200+
order = "C"
201+
layout = _C_API_IS_ROW_MAJOR
202+
# ensure dtype and order, copies if either do not match
203+
data = np.asarray(mat, dtype=dtype, order=order)
204+
# flatten array without copying
205+
return data.ravel(order=order), layout
206+
207+
191208
class _MissingType(Enum):
192209
NONE = "None"
193210
NAN = "NaN"
@@ -684,7 +701,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
684701
_C_API_DTYPE_INT32 = 2
685702
_C_API_DTYPE_INT64 = 3
686703

687-
"""Matrix is row major in Python"""
704+
"""Macro definition of data order in matrix"""
705+
_C_API_IS_COL_MAJOR = 0
688706
_C_API_IS_ROW_MAJOR = 1
689707

690708
"""Macro definition of prediction type in C API of LightGBM"""
@@ -2297,19 +2315,15 @@ def __init_from_np2d(
22972315
raise ValueError("Input numpy.ndarray must be 2 dimensional")
22982316

22992317
self._handle = ctypes.c_void_p()
2300-
if mat.dtype == np.float32 or mat.dtype == np.float64:
2301-
data = np.asarray(mat.reshape(mat.size), dtype=mat.dtype)
2302-
else: # change non-float data to float data, need to copy
2303-
data = np.asarray(mat.reshape(mat.size), dtype=np.float32)
2304-
2318+
data, layout = _np2d_to_np1d(mat)
23052319
ptr_data, type_ptr_data, _ = _c_float_array(data)
23062320
_safe_call(
23072321
_LIB.LGBM_DatasetCreateFromMat(
23082322
ptr_data,
23092323
ctypes.c_int(type_ptr_data),
23102324
ctypes.c_int32(mat.shape[0]),
23112325
ctypes.c_int32(mat.shape[1]),
2312-
ctypes.c_int(_C_API_IS_ROW_MAJOR),
2326+
ctypes.c_int(layout),
23132327
_c_str(params_str),
23142328
ref_dataset,
23152329
ctypes.byref(self._handle),

tests/python_package_test/test_basic.py

+36
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,39 @@ def test_max_depth_warning_is_raised_if_max_depth_gte_5_and_num_leaves_omitted(c
947947
"in params. Alternatively, pass (max_depth=-1) and just use 'num_leaves' to constrain model complexity."
948948
)
949949
assert expected_warning in capsys.readouterr().out
950+
951+
952+
@pytest.mark.parametrize("order", ["C", "F"])
953+
@pytest.mark.parametrize("dtype", ["float32", "int64"])
954+
def test_no_copy_in_dataset_from_numpy_2d(rng, order, dtype):
955+
X = rng.random(size=(100, 3))
956+
X = np.require(X, dtype=dtype, requirements=order)
957+
X1d, layout = lgb.basic._np2d_to_np1d(X)
958+
if order == "F":
959+
assert layout == lgb.basic._C_API_IS_COL_MAJOR
960+
else:
961+
assert layout == lgb.basic._C_API_IS_ROW_MAJOR
962+
if dtype == "float32":
963+
assert np.shares_memory(X, X1d)
964+
else:
965+
# makes a copy
966+
assert not np.shares_memory(X, X1d)
967+
968+
969+
def test_equal_datasets_from_row_major_and_col_major_data(tmp_path):
970+
# row-major dataset
971+
X_row, y = make_blobs(n_samples=1_000, n_features=3, centers=2)
972+
assert X_row.flags["C_CONTIGUOUS"] and not X_row.flags["F_CONTIGUOUS"]
973+
ds_row = lgb.Dataset(X_row, y)
974+
ds_row_path = tmp_path / "ds_row.txt"
975+
ds_row._dump_text(ds_row_path)
976+
977+
# col-major dataset
978+
X_col = np.asfortranarray(X_row)
979+
assert X_col.flags["F_CONTIGUOUS"] and not X_col.flags["C_CONTIGUOUS"]
980+
ds_col = lgb.Dataset(X_col, y)
981+
ds_col_path = tmp_path / "ds_col.txt"
982+
ds_col._dump_text(ds_col_path)
983+
984+
# check datasets are equal
985+
assert filecmp.cmp(ds_row_path, ds_col_path)

0 commit comments

Comments
 (0)