Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict dtypes, bugfixes, test basic tensor functions #14

Merged
merged 12 commits into from
Jan 4, 2024
Empty file added applications/__init__.py
Empty file.
43 changes: 23 additions & 20 deletions applications/learn_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@ def parse(file):

# parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
BASE = os.path.dirname(__file__) + "/datasets"

X_train = parse(BASE + "/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28 * 28)).astype(np.float32)
Y_train = parse(BASE + "/mnist/train-labels-idx1-ubyte.gz")[8:]
Y_train = parse(BASE + "/mnist/train-labels-idx1-ubyte.gz")[8:].astype(np.int32)
X_test = parse(BASE + "/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28 * 28)).astype(np.float32)
Y_test = parse(BASE + "/mnist/t10k-labels-idx1-ubyte.gz")[8:]
Y_test = parse(BASE + "/mnist/t10k-labels-idx1-ubyte.gz")[8:].astype(np.int32)
if for_convolution:
X_train = X_train.reshape(-1, 1, 28, 28)
X_test = X_test.reshape(-1, 1, 28, 28)
return X_train, Y_train, X_test, Y_test


class TinyConvNet:
class ConvNet:
def __init__(self):
# https://keras.io/examples/vision/mnist_convnet/
kernel_sz = 3
Expand All @@ -39,21 +40,16 @@ def __call__(self, x: Tensor):
return x.dot(self.l1).log_softmax()


if __name__ == "__main__":
NUM_STEPS = 100
BS = 128
LR = 0.001

def train_and_evaluate_mnist(num_steps=100, batch_size=128, learning_rate=0.001):
X_train, Y_train, X_test, Y_test = fetch_mnist()
model = TinyConvNet()
opt = optimizer.Adam([model.c1, model.c2, model.l1], lr=LR)
model = ConvNet()
opt = optimizer.Adam([model.c1, model.c2, model.l1], lr=learning_rate)

with Tensor.train():
for step in range(NUM_STEPS):
# Get sample batches
samp = np.random.randint(0, X_train.shape[0], size=(BS))
for step in range(num_steps):
samp = np.random.randint(0, X_train.shape[0], size=(batch_size))
xb, yb = Tensor(X_train[samp], requires_grad=False), Tensor(Y_train[samp])
# Train

out = model(xb)
loss = out.sparse_categorical_crossentropy(yb)
opt.zero_grad()
Expand All @@ -66,11 +62,18 @@ def __call__(self, x: Tensor):
print(f"Step {step+1:<3} | Loss: {loss.numpy():.4f} | Train Acc: {acc:.3f}")

# Evaluate Test
acc = 0
for i in range(0, len(Y_test), BS):
xb, yb = Tensor(X_test[i : i + BS], requires_grad=False), Tensor(Y_test[i : i + BS])
test_accuracy = 0
for i in range(0, len(Y_test), batch_size):
xb, yb = Tensor(X_test[i : i + batch_size], requires_grad=False), Tensor(Y_test[i : i + batch_size])
out = model(xb)
preds = out.argmax(axis=-1)
acc += (preds == yb).sum().numpy()
acc /= len(Y_test)
print(f"Test Acc: {acc:.3f}")
test_accuracy += (preds == yb).sum().numpy()
test_accuracy /= len(Y_test)
return test_accuracy



if __name__ == "__main__":
# Only execute if this script is run directly
test_accuracy = train_and_evaluate_mnist()
print(f"Test Acc: {test_accuracy:.3f}")
2 changes: 1 addition & 1 deletion edugrad/_tensor/tensor_broadcasted_binary_mlops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math

from edugrad.helpers import dtypes
from edugrad.dtypes import dtypes
import edugrad.function as function


Expand Down
2 changes: 1 addition & 1 deletion edugrad/_tensor/tensor_combine_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def cat(tensor, *args, dim) -> Tensor:
from edugrad._tensor import Tensor
from edugrad.tensor import Tensor

dim = (dim + len(tensor.shape)) if dim < 0 else dim
assert all(
Expand Down
3 changes: 2 additions & 1 deletion edugrad/_tensor/tensor_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import time
import math

from edugrad.helpers import argfix, DType, prod, shape_int, dtypes
from edugrad.dtypes import DType, dtypes
from edugrad.helpers import argfix, prod, shape_int
from edugrad.data import TensorData
from edugrad.ops import LoadOps

Expand Down
10 changes: 6 additions & 4 deletions edugrad/_tensor/tensor_index_slice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Sequence, Optional, Tuple
from collections import defaultdict

from edugrad.helpers import shape_int, dtypes
from edugrad.dtypes import dtypes
from edugrad.helpers import shape_int
from edugrad._tensor.tensor_reshape import pad, _flatten


Expand Down Expand Up @@ -35,7 +36,7 @@
def __getitem__(
tensor: "Tensor", val
) -> "Tensor": # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
from edugrad._tensor import Tensor
from edugrad.tensor import Tensor

def normalize_int(e, i, dim_sz):
if -dim_sz <= e < dim_sz:
Expand Down Expand Up @@ -141,10 +142,11 @@ def __setitem__(tensor: "Tensor", s, v):


# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(tensor: "Tensor", arg: Sequence[Optional[Tuple[int, shape_int]]], value: float) -> "Tensor":
def tslice(tensor: "Tensor", arg: Sequence[Optional[Tuple[int, shape_int]]], value: float = 0) -> "Tensor":
from edugrad.tensor import Tensor
arg_ = tuple([a if a is not None else (0, s) for s, a in zip(tensor.shape, arg)])
padding = tuple([(max(0, -p[0]), max(0, p[1] - tensor.shape[i])) for i, p in enumerate(arg_)])
return pad(tensor, padding, value=value).shrink(
return tensor.pad(padding, value=value).shrink(
tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i, p in enumerate(arg_)])
)
# FIXME: tensor.pad(padding, value=value)... returns None...
Expand Down
3 changes: 2 additions & 1 deletion edugrad/_tensor/tensor_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from __future__ import annotations
import math

from edugrad.helpers import make_pair, flatten, dtypes, all_int, shape_int
from edugrad.dtypes import dtypes
from edugrad.helpers import make_pair, flatten, all_int, shape_int


# processing ops
Expand Down
12 changes: 6 additions & 6 deletions edugrad/_tensor/tensor_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

from edugrad.helpers import dtypes, prod, all_int

from edugrad.dtypes import dtypes
from edugrad.helpers import prod, all_int
from edugrad.function import Function
import edugrad.function as function

Expand Down Expand Up @@ -44,9 +46,9 @@ def _reduce(self, fxn: type[Function], axis: int | tuple[int, ...] | None, keepd
return ret if keepdim else ret.reshape(shape=shape)


# ----------------------------------------------------------------------------------------------------------------------
# Functions that use the generic _reduce method for specific reduction operations.


def tsum(tensor: Tensor, axis, keepdim):
"""Computes the sum of elements over the specified axis."""
return tensor._reduce(function.Sum, axis, keepdim)
Expand All @@ -59,8 +61,7 @@ def tmax(tensor: Tensor, axis, keepdim):

def tmin(tensor: Tensor, axis, keepdim):
"""Computes the minimum value of elements over the specified axis."""
return -((-tensor).tmax((-tensor), axis=axis, keepdim=keepdim))

return -tmax((-tensor), axis=axis, keepdim=keepdim)

def mean(tensor: Tensor, axis, keepdim):
"""Computes the mean of elements over the specified axis."""
Expand All @@ -75,10 +76,9 @@ def std(tensor: Tensor, axis, keepdim, correction):
square_sum = ((tensor - tensor.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return square_sum.div(prod(tensor.shape) / prod(square_sum.shape) - correction).sqrt()


# ----------------------------------------------------------------------------------------------------------------------
# Functions for softmax and its logarithmic variant, as well as argmax and argmin operations.


def _softmax(tensor: Tensor, axis):
"""Helper function to compute softmax components."""
m = tensor - tensor.max(axis=axis, keepdim=True)
Expand Down
3 changes: 2 additions & 1 deletion edugrad/_tensor/tensor_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def shrink(tensor: Tensor, arg: tuple[tuple[shape_int, shape_int] | None, ...])


def pad(tensor: Tensor, arg: tuple[tuple[int, int] | None, ...], value: float) -> Tensor:
from edugrad.tensor import Tensor
if all(x is None or x == (0, 0) for x in arg):
return tensor
ret = function.Pad.apply(tensor, arg=(narg := tuple(x if x is not None else (0, 0) for x in arg)))
return ret if 0 == value else ret + function.Pad.apply("Tensor".ones_like(tensor), arg=narg).where(0, value)
return ret if 0 == value else ret + function.Pad.apply(Tensor.ones_like(tensor), arg=narg).where(0, value)


# (padding_left, padding_right, padding_top, padding_bottom)
Expand Down
34 changes: 19 additions & 15 deletions edugrad/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Defines the TensorData class, a container for tensor data (tensor.Tensor.data), represented as numpy arrays.

It facilitates direct manipulation of tensor data through a range of basic operation ("low-level ops"). These operations
are building blocks for defining forward and backward passes of differentiable function.Functions. ("mid-level ops")
are building blocks for defining forward and backward passes of differentiable function.Functions. ("mid-level ops").

For simplicity and to ensure that compatible dtypes operate with each other, we enforce two of the three supported
dtypes (bool and float32) with a typecast in each elementwise operation.

The ops are executed immediately on the CPU using numpy. This approach contrasts with deferred computation models that
analyze subsequent delayed operations in order to find an optimized equivalent final optimization at the point where
Expand All @@ -13,7 +16,8 @@
from typing import Tuple
import numpy as np
from edugrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps # consider reading the docs there
from edugrad.helpers import DType, dtypes, DEBUG
from edugrad.helpers import DEBUG
from edugrad.dtypes import DType, dtypes


class TensorData:
Expand Down Expand Up @@ -76,22 +80,22 @@ def cast(self, dtype: DType, bitcast: bool = False) -> "TensorData":
def elementwise(self, op, *srcs: "TensorData"):
"""Perform a unary, binary, or ternary elementwise operation on the data."""
unary_ops = {
UnaryOps.NEG: np.negative,
UnaryOps.EXP2: np.exp2,
UnaryOps.LOG2: np.log2,
UnaryOps.SIN: np.sin,
UnaryOps.SQRT: np.sqrt,
}
UnaryOps.NEG: lambda x: np.negative(x).astype(dtypes.only_float.np),
UnaryOps.EXP2: lambda x: np.exp2(x).astype(dtypes.only_float.np),
UnaryOps.LOG2: lambda x: np.log2(x).astype(dtypes.only_float.np),
UnaryOps.SIN: lambda x: np.sin(x).astype(dtypes.only_float.np),
UnaryOps.SQRT: lambda x: np.sqrt(x).astype(dtypes.only_float.np),
}
binary_ops = {
BinaryOps.ADD: np.add,
BinaryOps.SUB: np.subtract,
BinaryOps.MUL: np.multiply,
BinaryOps.DIV: np.divide,
BinaryOps.MAX: np.maximum,
BinaryOps.CMPLT: np.less,
BinaryOps.ADD: lambda x, y: np.add(x, y).astype(dtypes.only_float.np),
BinaryOps.SUB: lambda x, y: np.subtract(x, y).astype(dtypes.only_float.np),
BinaryOps.MUL: lambda x, y: np.multiply(x, y).astype(dtypes.only_float.np),
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(dtypes.only_float.np),
BinaryOps.MAX: lambda x, y: np.maximum(x, y).astype(dtypes.only_float.np),
BinaryOps.CMPLT: lambda x, y: np.less(x, y).astype(np.bool_),
}
ternary_ops = {
TernaryOps.WHERE: np.where,
TernaryOps.WHERE: lambda x, y, z: np.where(x, y, z).astype(dtypes.only_float.np),
}

if op in unary_ops:
Expand Down
67 changes: 67 additions & 0 deletions edugrad/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import ClassVar, Dict, Optional, Final
import numpy as np
from dataclasses import dataclass


@dataclass(frozen=True, order=True)
class DType:
"""Data type class for managing different data types."""

priority: int # Priority for upcasting
itemsize: int # Size of the data type in bytes
name: str # Name of the data type
np: Optional[type] # Corresponding numpy data type
sz: int = 1 # Size factor

def __repr__(self):
return f"dtypes.{self.name}"


class dtypes:
"""Container for different data types and utility methods.
We need this because some layer operation might use different trade-offs between precision and efficiency. In such
cases, we have to translate b/w dtypes.
"""

@staticmethod
def is_int(x: DType) -> bool:
"""Check if a data type is an integer type."""
return x in (
dtypes.int32,
)

@staticmethod
def is_float(x: DType) -> bool:
"""Check if a data type is a float type."""
return x in (dtypes.float32)

@staticmethod
def from_np(x) -> DType:
"""Convert a numpy data type to a DType."""
return DTYPES_DICT[np.dtype(x).name]

@staticmethod
def fields() -> Dict[str, DType]:
return DTYPES_DICT

@staticmethod # NOTE: isinstance(True, int) is True in python
def from_py(x) -> DType:
return (
dtypes.only_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.only_int
)

# Definition of various data types
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float32: Final[DType] = DType(2, 4, "float", np.float32)
int32: Final[DType] = DType(1, 4, "int", np.int32)

only_float: ClassVar[DType] = float32
only_int: ClassVar[DType] = int32


# Dictionary mapping data type names to DType objects
DTYPES_DICT = {
k: v
for k, v in dtypes.__dict__.items()
if not k.startswith("__") and not k.startswith("only") and not callable(v) and not v.__class__ == staticmethod
}
12 changes: 11 additions & 1 deletion edugrad/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
"""
import math
from typing import Tuple, Optional, cast
from edugrad.helpers import argsort, DType, shape_int
from edugrad.dtypes import DType
from edugrad.helpers import argsort, shape_int
from edugrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from edugrad.data import TensorData

Expand Down Expand Up @@ -355,3 +356,12 @@ def backward(self, grad_output: TensorData) -> TensorData:
), "symbolic shrink does not support backward"
# need this cast because mypy cannot narrow the type even with assert
return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg))


class Flip(Function):
def forward(self, x: TensorData, axis: Tuple[int, ...]) -> TensorData:
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
return x.stride(self.arg)

def backward(self, grad_output: TensorData) -> TensorData:
return grad_output.stride(self.arg)
Loading
Loading