Skip to content

Commit

Permalink
Small improvements/deletions
Browse files Browse the repository at this point in the history
  • Loading branch information
tostenzel committed Jan 4, 2024
1 parent 43016fc commit 8a5a7a6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
5 changes: 4 additions & 1 deletion edugrad/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Defines the TensorData class, a container for tensor data (tensor.Tensor.data), represented as numpy arrays.
It facilitates direct manipulation of tensor data through a range of basic operation ("low-level ops"). These operations
are building blocks for defining forward and backward passes of differentiable function.Functions. ("mid-level ops")
are building blocks for defining forward and backward passes of differentiable function.Functions. ("mid-level ops").
For simplicity and to ensure that compatible dtypes operate with each other, we enforce two of the three supported
dtypes (bool and float32) with a typecast in each elementwise operation.
The ops are executed immediately on the CPU using numpy. This approach contrasts with deferred computation models that
analyze subsequent delayed operations in order to find an optimized equivalent final optimization at the point where
Expand Down
4 changes: 2 additions & 2 deletions edugrad/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def from_py(x) -> DType:

# Definition of various data types
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float32: Final[DType] = DType(10, 4, "float", np.float32)
int32: Final[DType] = DType(5, 4, "int", np.int32)
float32: Final[DType] = DType(2, 4, "float", np.float32)
int32: Final[DType] = DType(1, 4, "int", np.int32)

only_float: ClassVar[DType] = float32
only_int: ClassVar[DType] = int32
Expand Down
7 changes: 2 additions & 5 deletions edugrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations
import time
import math
from typing import ClassVar, Sequence, Any
from typing import ClassVar, Sequence, Any, Type

import numpy as np

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(

# --------------------------------------------------------------------------------------------------------------
# Handles Tensor(x) for x with different data types.
# We cast x = list(y) up to float32 for every case
# We cast x = list(y) up to float32 (default_type) for every type that y can have

if isinstance(data, TensorData):
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
Expand All @@ -87,9 +87,6 @@ def __init__(

elif data is None:
data = TensorData.loadop(LoadOps.EMPTY, (0,), dtype or dtypes.only_float)

elif isinstance(data, bytes):
data = TensorData(np.frombuffer(data, np.uint8))

elif isinstance(data, np.ndarray):
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
Expand Down

0 comments on commit 8a5a7a6

Please sign in to comment.