Skip to content

Commit

Permalink
Merge branch 'main' into duration-proposal
Browse files Browse the repository at this point in the history
  • Loading branch information
ematejska authored Mar 3, 2025
2 parents cfdacd3 + a504dec commit 0fa2b7d
Show file tree
Hide file tree
Showing 38 changed files with 1,787 additions and 1,789 deletions.
16 changes: 12 additions & 4 deletions examples/custom_ops/fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@ def main():
path = Path(__file__).parent / "kernels.mojopkg"

dtype = DType.float32
N = 8
D = 8
BD = 4
BN = 4

if accelerator_count() == 0:
N = 8
D = 8
BD = 4
BN = 4
else:
N = 32
D = 32
BD = 8
BN = 16

with Graph(
"fused_attention",
input_types=[
Expand Down
48 changes: 16 additions & 32 deletions examples/custom_ops/kernels/mandelbrot.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,6 @@ from runtime.asyncrt import DeviceContextPtr
from utils.index import IndexList


@always_inline
fn mandelbrot_inner_simd[
float_type: DType, int_type: DType, simd_width: Int
](
c: ComplexSIMD[float_type, simd_width], max_iterations: SIMD[int_type, 1]
) -> SIMD[int_type, simd_width]:
"""A vectorized implementation of the inner Mandelbrot computation."""
var z = ComplexSIMD[float_type, simd_width](0, 0)
var iters = SIMD[int_type, simd_width](0)

var in_set_mask: SIMD[DType.bool, simd_width] = True
for _ in range(max_iterations):
if not any(in_set_mask):
break
in_set_mask = z.squared_norm() <= 4
iters = in_set_mask.select(iters + 1, iters)
z = z.squared_add(c)

return iters


alias float_dtype = DType.float32


Expand All @@ -68,25 +47,30 @@ struct Mandelbrot:
fn elementwise_mandelbrot[
width: Int
](idx: IndexList[out.rank]) -> SIMD[out.type, width]:
# Obtain the position in the grid from the X, Y thread locations.
var row = idx[0]
var col = idx[1]

# Calculate the complex C corresponding to that grid location.
var cx = min_x.cast[float_dtype]() + (
col + iota[float_dtype, width]()
) * scale_x.cast[float_dtype]()
var cy = min_y.cast[float_dtype]() + row * SIMD[float_dtype, width](
scale_y.cast[float_dtype]()
)
var c = ComplexSIMD[float_dtype, width](cx, cy)
return mandelbrot_inner_simd[cx.type, out.type, width](
c, max_iterations.cast[out.type]()
)
var z = ComplexSIMD[float_dtype, width](0, 0)

foreach[elementwise_mandelbrot, target=target](out, ctx)
# Perform the Mandelbrot iteration loop calculation.
var iters = SIMD[out.type, width](0)
var in_set_mask: SIMD[DType.bool, width] = True
for _ in range(max_iterations):
if not any(in_set_mask):
break
in_set_mask = z.squared_norm() <= 4
iters = in_set_mask.select(iters + 1, iters)
z = z.squared_add(c)

# You only need to implement this if you do not manually annotate
# output shapes in the graph.
@staticmethod
fn shape(
x: ManagedTensorSlice,
) raises -> IndexList[x.rank]:
raise "NotImplemented"
return iters

foreach[elementwise_mandelbrot, target=target](out, ctx)
2 changes: 1 addition & 1 deletion examples/custom_ops/kernels/matrix_multiplication.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from gpu import WARP_SIZE, block_dim, block_idx, thread_idx
from gpu import WARP_SIZE, barrier, block_dim, block_idx, thread_idx
from gpu.host import DeviceBuffer, DeviceContext
from gpu.memory import async_copy_wait_all
from layout.layout_tensor import (
Expand Down
142 changes: 71 additions & 71 deletions examples/custom_ops/magic.lock
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/linux-64/max-python-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/linux-64/max-python-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030305-release.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py312heda63a1_0.conda
Expand Down Expand Up @@ -107,11 +107,11 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/libuuid-2.38.1-hb4cce97_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/libxcrypt-4.4.36-h31becfc_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/libzlib-1.3.1-h86ecc28_2.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-core-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-python-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-core-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-python-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030305-release.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/ncurses-6.5-ha32ae93_3.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/numpy-1.26.4-py312h470d778_0.conda
Expand Down Expand Up @@ -163,11 +163,11 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.49.1-h3f77e49_1.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libzlib-1.3.1-h8359307_2.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/llvm-openmp-19.1.7-hdb05f8b_0.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-python-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-python-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030305-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030305-release.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_3.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-1.26.4-py312h8442bc7_0.conda
Expand Down Expand Up @@ -1301,48 +1301,48 @@ packages:
license_family: APACHE
size: 280830
timestamp: 1736986295869
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030305-release.conda
noarch: python
sha256: aa0c1db61615cc9eeb8c3ff93563f190169d44c9842a229dc22356b3fb39c1c1
md5: ff398daef8f5c7807b6a423a03ff69a5
sha256: b1e1d327b440cc4410242695b71af4107c55371ffa45ea6afc451b926b75ad06
md5: bc042a54567c146308baa503ecf64b53
depends:
- max-core ==25.2.0.dev2025030106 release
- max-python ==25.2.0.dev2025030106 release
- mojo-jupyter ==25.2.0.dev2025030106 release
- mblack ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030305 release
- max-python ==25.2.0.dev2025030305 release
- mojo-jupyter ==25.2.0.dev2025030305 release
- mblack ==25.2.0.dev2025030305 release
license: LicenseRef-Modular-Proprietary
size: 9904
timestamp: 1740812805176
- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.2.0.dev2025030106-release.conda
sha256: 6cd9c8b031f6a8e6a42561a3b68b5d5d893e8b57cd608111cdebf5d27fd8073a
md5: 8d5deae1a12f695bb1b313fa10ba4332
size: 9912
timestamp: 1740979061163
- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.2.0.dev2025030305-release.conda
sha256: 3636ae685e256479172339fe8c5d674a843fa2fc686b576e3223034ec24c2632
md5: ff81362bc96cbd58884f73487150bba4
depends:
- mblack ==25.2.0.dev2025030106 release
- mblack ==25.2.0.dev2025030305 release
license: LicenseRef-Modular-Proprietary
size: 249844103
timestamp: 1740812805175
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-core-25.2.0.dev2025030106-release.conda
sha256: 05eca10ca12e8bc08ee92fbacbbd0283df899af3b527af77b9d184fb512a3783
md5: bbca6a0604e0d5591b3167b4203e72ee
size: 249912659
timestamp: 1740979075048
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-core-25.2.0.dev2025030305-release.conda
sha256: 8992d4b65fbb37be38b79d0b013b997af8ce4d297e9bfdf5967444e911b7fc09
md5: 8c2cf14006dda0791b1c82f6528e9b8a
depends:
- mblack ==25.2.0.dev2025030106 release
- mblack ==25.2.0.dev2025030305 release
license: LicenseRef-Modular-Proprietary
size: 252111155
timestamp: 1740812763755
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.2.0.dev2025030106-release.conda
sha256: 63c02e7c8e0430951e9cd4b4505afe9e191e5b0bbf5343f4fdedc39ddeeb63d0
md5: 257c1c4a080f1d9117b0d2d44dda3b06
size: 252209414
timestamp: 1740979061163
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.2.0.dev2025030305-release.conda
sha256: abf6e1d66d53fdeb5ee40b9981b68f412ea5b0ca4d8e6401695e6f3e40577ba3
md5: 74e56cd46e352f72a2706cb18b3f57f8
depends:
- mblack ==25.2.0.dev2025030106 release
- mblack ==25.2.0.dev2025030305 release
license: LicenseRef-Modular-Proprietary
size: 217288777
timestamp: 1740813876536
- conda: https://conda.modular.com/max-nightly/linux-64/max-python-25.2.0.dev2025030106-release.conda
size: 217419570
timestamp: 1740980182916
- conda: https://conda.modular.com/max-nightly/linux-64/max-python-25.2.0.dev2025030305-release.conda
noarch: python
sha256: fef4cdeded5a65511ad025d286cfcdeeba0dbda5a325ba05ab34e3da281256fc
md5: a89e5884fc3955acb191409c11c82d7f
sha256: 1eaa96322c64db1ea143c529a5ce3cb28bdc74abaedbb35011646fe4a8f5e746
md5: 5fc2738c9c7aa3b7a848825f4ab5fbd9
depends:
- max-core ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030305 release
- click >=8.0.0
- numpy >=1.18,<2.0
- sentencepiece >=0.2.0
Expand Down Expand Up @@ -1379,14 +1379,14 @@ packages:
- uvloop >=0.21.0
- xgrammar ==0.1.11
license: LicenseRef-Modular-Proprietary
size: 123614388
timestamp: 1740812805176
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-python-25.2.0.dev2025030106-release.conda
size: 123616168
timestamp: 1740979075048
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-python-25.2.0.dev2025030305-release.conda
noarch: python
sha256: 35ca6e41ac83be72027aa30ab8e24874b790f205b6d1dfd2cbc6e36ea8ea9537
md5: 1ed03583fcda964b93e193808e643111
sha256: e7114c9fec921c82ccc6a41313dbfabb77ea5dfdf06b3e764cc3f8219dee942d
md5: 7965896da871ef6f5ec6d3e63754a7de
depends:
- max-core ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030305 release
- click >=8.0.0
- numpy >=1.18,<2.0
- sentencepiece >=0.2.0
Expand Down Expand Up @@ -1423,14 +1423,14 @@ packages:
- uvloop >=0.21.0
- xgrammar ==0.1.11
license: LicenseRef-Modular-Proprietary
size: 125987433
timestamp: 1740812763755
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-python-25.2.0.dev2025030106-release.conda
size: 125978148
timestamp: 1740979061163
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-python-25.2.0.dev2025030305-release.conda
noarch: python
sha256: ce716702ae9c7f2a5d194cbb19f8ffa70cd9956fe243dbb70985fc3679277c06
md5: 58f0a87442d47428ba3ee68b9e839a44
sha256: fec8e20e96fe5dca33a7f269e401121869b15fd2ba4b5ecd8eef9a2fe22d07c1
md5: 1e733edfaa00254c1149ee717540fa28
depends:
- max-core ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030305 release
- click >=8.0.0
- numpy >=1.18,<2.0
- sentencepiece >=0.2.0
Expand Down Expand Up @@ -1467,12 +1467,12 @@ packages:
- uvloop >=0.21.0
- xgrammar ==0.1.11
license: LicenseRef-Modular-Proprietary
size: 112588416
timestamp: 1740813876537
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030106-release.conda
size: 112590648
timestamp: 1740980182917
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030305-release.conda
noarch: python
sha256: 46eeb957d71a44b69341001b0c4a7718b8ac0ef195bf8c716e7ead02b15baba4
md5: 4f031fd6f2bf81cfba322178106b315e
sha256: 0f45f5a5c9949198034a22dc02747a815db7614ac0b081cd6be081c08e2820f3
md5: 91b18174cd7cf314753c91bf1e6ba7b4
depends:
- python >=3.9,<3.13
- click >=8.0.0
Expand All @@ -1483,20 +1483,20 @@ packages:
- typing_extensions >=v4.12.2
- python
license: MIT
size: 130843
timestamp: 1740812805175
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030106-release.conda
size: 130864
timestamp: 1740979061163
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030305-release.conda
noarch: python
sha256: 1453f8559b8ada824083fc236f73a1ad4922ac904daa15b0d574dbfc4669867a
md5: 6e064b52057bbd756183ef214c0c0100
sha256: 32f162a17a82318ec55fb9ff05158828c53cf850532b79e20fc01c4cb342acd3
md5: 0539b71f182f98f2ac274a3ba9420f63
depends:
- max-core ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030305 release
- python >=3.9,<3.13
- jupyter_client >=8.6.2,<8.7
- python
license: LicenseRef-Modular-Proprietary
size: 22986
timestamp: 1740812805175
size: 22994
timestamp: 1740979061163
- conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda
sha256: 1895f47b7d68581a6facde5cb13ab8c2764c2e53a76bd746f8f98910dc4e08fe
md5: 29097e7ea634a45cc5386b95cac6568f
Expand Down
25 changes: 19 additions & 6 deletions examples/custom_ops/mandelbrot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@
from max.graph import Graph, TensorType, ops


def draw_mandelbrot(tensor: Tensor, width: int, height: int, iterations: int):
"""A helper function to visualize the Mandelbrot set in ASCII art."""
sr = "....,c8M@jawrpogOQEPGJ"
for row in range(height):
for col in range(width):
v = tensor[row, col].item()
if v < iterations:
idx = int(v % len(sr))
p = sr[idx]
print(p, end="")
else:
print(" ", end="")
print("")


def create_mandelbrot_graph(
width: int,
height: int,
Expand Down Expand Up @@ -61,10 +76,10 @@ def create_mandelbrot_graph(
path = Path(__file__).parent / "kernels.mojopkg"

# Establish Mandelbrot set ranges.
WIDTH = 15
HEIGHT = 15
WIDTH = 60
HEIGHT = 25
MAX_ITERATIONS = 100
MIN_X = -1.5
MIN_X = -2.0
MAX_X = 0.7
MIN_Y = -1.12
MAX_Y = 1.12
Expand Down Expand Up @@ -94,6 +109,4 @@ def create_mandelbrot_graph(
assert isinstance(result, Tensor)
result = result.to(CPU())

print("Iterations to escape:")
print(result.to_numpy())
print()
draw_mandelbrot(result, WIDTH, HEIGHT, MAX_ITERATIONS)
22 changes: 18 additions & 4 deletions examples/custom_ops/mojoproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,27 @@ version = "0.1.0"
package = "mojo package kernels/ -o kernels.mojopkg"
addition = { cmd = "python addition.py", depends-on = ["package"] }
mandelbrot = { cmd = "python mandelbrot.py", depends-on = ["package"] }
vector_addition = { cmd = "python vector_addition.py", depends-on = ["package"] }
vector_addition = { cmd = "python vector_addition.py", depends-on = [
"package",
] }
top_k = { cmd = "python top_k.py", depends-on = ["package"] }
fused_attention = { cmd = "python fused_attention.py", depends-on = ["package"] }
matrix_multiplication = { cmd = "python matrix_multiplication.py", depends-on = ["package"] }
fused_attention = { cmd = "python fused_attention.py", depends-on = [
"package",
] }
matrix_multiplication = { cmd = "python matrix_multiplication.py", depends-on = [
"package",
] }
histogram = { cmd = "python histogram.py", depends-on = ["package"] }
benchmark = { cmd = "mojo benchmarks.mojo", depends-on = ["package"] }
test = { depends-on = ["addition", "mandelbrot", "vector_addition", "top_k", "matrix_multiplication", "benchmark"] }
test = { depends-on = [
"addition",
"mandelbrot",
"vector_addition",
"top_k",
"fused_attention",
"matrix_multiplication",
"benchmark",
] }

[dependencies]
python = ">=3.9,<3.13"
Expand Down
Loading

0 comments on commit 0fa2b7d

Please sign in to comment.