Skip to content

Commit

Permalink
Merge branch 'main' into inlinearray-compiler-slowness
Browse files Browse the repository at this point in the history
  • Loading branch information
msaelices authored Mar 2, 2025
2 parents b5bdf12 + bf1929c commit 4bcd693
Show file tree
Hide file tree
Showing 29 changed files with 1,663 additions and 1,618 deletions.
48 changes: 16 additions & 32 deletions examples/custom_ops/kernels/mandelbrot.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,6 @@ from runtime.asyncrt import DeviceContextPtr
from utils.index import IndexList


@always_inline
fn mandelbrot_inner_simd[
float_type: DType, int_type: DType, simd_width: Int
](
c: ComplexSIMD[float_type, simd_width], max_iterations: SIMD[int_type, 1]
) -> SIMD[int_type, simd_width]:
"""A vectorized implementation of the inner Mandelbrot computation."""
var z = ComplexSIMD[float_type, simd_width](0, 0)
var iters = SIMD[int_type, simd_width](0)

var in_set_mask: SIMD[DType.bool, simd_width] = True
for _ in range(max_iterations):
if not any(in_set_mask):
break
in_set_mask = z.squared_norm() <= 4
iters = in_set_mask.select(iters + 1, iters)
z = z.squared_add(c)

return iters


alias float_dtype = DType.float32


Expand All @@ -68,25 +47,30 @@ struct Mandelbrot:
fn elementwise_mandelbrot[
width: Int
](idx: IndexList[out.rank]) -> SIMD[out.type, width]:
# Obtain the position in the grid from the X, Y thread locations.
var row = idx[0]
var col = idx[1]

# Calculate the complex C corresponding to that grid location.
var cx = min_x.cast[float_dtype]() + (
col + iota[float_dtype, width]()
) * scale_x.cast[float_dtype]()
var cy = min_y.cast[float_dtype]() + row * SIMD[float_dtype, width](
scale_y.cast[float_dtype]()
)
var c = ComplexSIMD[float_dtype, width](cx, cy)
return mandelbrot_inner_simd[cx.type, out.type, width](
c, max_iterations.cast[out.type]()
)
var z = ComplexSIMD[float_dtype, width](0, 0)

foreach[elementwise_mandelbrot, target=target](out, ctx)
# Perform the Mandelbrot iteration loop calculation.
var iters = SIMD[out.type, width](0)
var in_set_mask: SIMD[DType.bool, width] = True
for _ in range(max_iterations):
if not any(in_set_mask):
break
in_set_mask = z.squared_norm() <= 4
iters = in_set_mask.select(iters + 1, iters)
z = z.squared_add(c)

# You only need to implement this if you do not manually annotate
# output shapes in the graph.
@staticmethod
fn shape(
x: ManagedTensorSlice,
) raises -> IndexList[x.rank]:
raise "NotImplemented"
return iters

foreach[elementwise_mandelbrot, target=target](out, ctx)
2 changes: 1 addition & 1 deletion examples/custom_ops/kernels/matrix_multiplication.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from gpu import WARP_SIZE, block_dim, block_idx, thread_idx
from gpu import WARP_SIZE, barrier, block_dim, block_idx, thread_idx
from gpu.host import DeviceBuffer, DeviceContext
from gpu.memory import async_copy_wait_all
from layout.layout_tensor import (
Expand Down
142 changes: 71 additions & 71 deletions examples/custom_ops/magic.lock
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/linux-64/max-python-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/linux-64/max-python-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030205-release.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py312heda63a1_0.conda
Expand Down Expand Up @@ -107,11 +107,11 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/libuuid-2.38.1-hb4cce97_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/libxcrypt-4.4.36-h31becfc_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/libzlib-1.3.1-h86ecc28_2.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-core-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-python-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-core-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-python-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030205-release.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/ncurses-6.5-ha32ae93_3.conda
- conda: https://conda.anaconda.org/conda-forge/linux-aarch64/numpy-1.26.4-py312h470d778_0.conda
Expand Down Expand Up @@ -163,11 +163,11 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.49.1-h3f77e49_1.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libzlib-1.3.1-h8359307_2.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/llvm-openmp-19.1.7-hdb05f8b_0.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-python-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-python-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030205-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030205-release.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_3.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-1.26.4-py312h8442bc7_0.conda
Expand Down Expand Up @@ -1301,48 +1301,48 @@ packages:
license_family: APACHE
size: 280830
timestamp: 1736986295869
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030106-release.conda
- conda: https://conda.modular.com/max-nightly/noarch/max-25.2.0.dev2025030205-release.conda
noarch: python
sha256: aa0c1db61615cc9eeb8c3ff93563f190169d44c9842a229dc22356b3fb39c1c1
md5: ff398daef8f5c7807b6a423a03ff69a5
sha256: ab31bd8b9eef4c57467d6c151398a9d461776c6b563ec3bb5e0d60290196bb27
md5: 6614f831f5ed728170fa194ec1624f00
depends:
- max-core ==25.2.0.dev2025030106 release
- max-python ==25.2.0.dev2025030106 release
- mojo-jupyter ==25.2.0.dev2025030106 release
- mblack ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030205 release
- max-python ==25.2.0.dev2025030205 release
- mojo-jupyter ==25.2.0.dev2025030205 release
- mblack ==25.2.0.dev2025030205 release
license: LicenseRef-Modular-Proprietary
size: 9904
timestamp: 1740812805176
- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.2.0.dev2025030106-release.conda
sha256: 6cd9c8b031f6a8e6a42561a3b68b5d5d893e8b57cd608111cdebf5d27fd8073a
md5: 8d5deae1a12f695bb1b313fa10ba4332
size: 9900
timestamp: 1740892623813
- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.2.0.dev2025030205-release.conda
sha256: 4ce97457fc052c19b4b4570fddca5275d0ac90d7bd5764b1e6addbc70922da2e
md5: 911b344e6f297830312effec03ac9177
depends:
- mblack ==25.2.0.dev2025030106 release
- mblack ==25.2.0.dev2025030205 release
license: LicenseRef-Modular-Proprietary
size: 249844103
timestamp: 1740812805175
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-core-25.2.0.dev2025030106-release.conda
sha256: 05eca10ca12e8bc08ee92fbacbbd0283df899af3b527af77b9d184fb512a3783
md5: bbca6a0604e0d5591b3167b4203e72ee
size: 249844800
timestamp: 1740892767853
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-core-25.2.0.dev2025030205-release.conda
sha256: 3f3624e96eb63bdcbe5315fbb5eef1017e3b78f742b930a81aae3928aa8e5476
md5: 5c731c6714e2994a177c29f68c356705
depends:
- mblack ==25.2.0.dev2025030106 release
- mblack ==25.2.0.dev2025030205 release
license: LicenseRef-Modular-Proprietary
size: 252111155
timestamp: 1740812763755
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.2.0.dev2025030106-release.conda
sha256: 63c02e7c8e0430951e9cd4b4505afe9e191e5b0bbf5343f4fdedc39ddeeb63d0
md5: 257c1c4a080f1d9117b0d2d44dda3b06
size: 252139076
timestamp: 1740892623813
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.2.0.dev2025030205-release.conda
sha256: 7e797baa59c9884b627250ff524a2fe4a4ae4e3152dd8c812cc46117143e758b
md5: a9fa47a9307c18b36d289cc9d905d75f
depends:
- mblack ==25.2.0.dev2025030106 release
- mblack ==25.2.0.dev2025030205 release
license: LicenseRef-Modular-Proprietary
size: 217288777
timestamp: 1740813876536
- conda: https://conda.modular.com/max-nightly/linux-64/max-python-25.2.0.dev2025030106-release.conda
size: 217343460
timestamp: 1740893672847
- conda: https://conda.modular.com/max-nightly/linux-64/max-python-25.2.0.dev2025030205-release.conda
noarch: python
sha256: fef4cdeded5a65511ad025d286cfcdeeba0dbda5a325ba05ab34e3da281256fc
md5: a89e5884fc3955acb191409c11c82d7f
sha256: 0cfc1563d861755f9ddf7c3fad5232b92e0120b46b3a5251aa4876e8133c7963
md5: 676e0dc9d02168c97b58922b7da48f80
depends:
- max-core ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030205 release
- click >=8.0.0
- numpy >=1.18,<2.0
- sentencepiece >=0.2.0
Expand Down Expand Up @@ -1379,14 +1379,14 @@ packages:
- uvloop >=0.21.0
- xgrammar ==0.1.11
license: LicenseRef-Modular-Proprietary
size: 123614388
timestamp: 1740812805176
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-python-25.2.0.dev2025030106-release.conda
size: 123622714
timestamp: 1740892767853
- conda: https://conda.modular.com/max-nightly/linux-aarch64/max-python-25.2.0.dev2025030205-release.conda
noarch: python
sha256: 35ca6e41ac83be72027aa30ab8e24874b790f205b6d1dfd2cbc6e36ea8ea9537
md5: 1ed03583fcda964b93e193808e643111
sha256: d77e85d078817b894732219a0a3836abc55da25f894cd84f6239814b45a82598
md5: 31366b531c4d4ad9643e2860d8267f39
depends:
- max-core ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030205 release
- click >=8.0.0
- numpy >=1.18,<2.0
- sentencepiece >=0.2.0
Expand Down Expand Up @@ -1423,14 +1423,14 @@ packages:
- uvloop >=0.21.0
- xgrammar ==0.1.11
license: LicenseRef-Modular-Proprietary
size: 125987433
timestamp: 1740812763755
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-python-25.2.0.dev2025030106-release.conda
size: 125978958
timestamp: 1740892623813
- conda: https://conda.modular.com/max-nightly/osx-arm64/max-python-25.2.0.dev2025030205-release.conda
noarch: python
sha256: ce716702ae9c7f2a5d194cbb19f8ffa70cd9956fe243dbb70985fc3679277c06
md5: 58f0a87442d47428ba3ee68b9e839a44
sha256: bc485be7dbfd9a9d245853ca86fdb86a81b72199245021bfecb7c950acb23dbd
md5: 5fe672a8c8c68068c2190257d76ac333
depends:
- max-core ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030205 release
- click >=8.0.0
- numpy >=1.18,<2.0
- sentencepiece >=0.2.0
Expand Down Expand Up @@ -1467,12 +1467,12 @@ packages:
- uvloop >=0.21.0
- xgrammar ==0.1.11
license: LicenseRef-Modular-Proprietary
size: 112588416
timestamp: 1740813876537
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030106-release.conda
size: 112609556
timestamp: 1740893672847
- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.2.0.dev2025030205-release.conda
noarch: python
sha256: 46eeb957d71a44b69341001b0c4a7718b8ac0ef195bf8c716e7ead02b15baba4
md5: 4f031fd6f2bf81cfba322178106b315e
sha256: 738a6870f7f9538e5007d331182cffadf3ed31d245823d008ee851e2b2614d1f
md5: 341f0c6fcc18aa9fdf01b5032d1600f6
depends:
- python >=3.9,<3.13
- click >=8.0.0
Expand All @@ -1483,20 +1483,20 @@ packages:
- typing_extensions >=v4.12.2
- python
license: MIT
size: 130843
timestamp: 1740812805175
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030106-release.conda
size: 130850
timestamp: 1740892623813
- conda: https://conda.modular.com/max-nightly/noarch/mojo-jupyter-25.2.0.dev2025030205-release.conda
noarch: python
sha256: 1453f8559b8ada824083fc236f73a1ad4922ac904daa15b0d574dbfc4669867a
md5: 6e064b52057bbd756183ef214c0c0100
sha256: aad1b9c6cfd56e2e6d8336e52ef3c99dcf98b37bb2570222415629bfbc747284
md5: 5c81a8345de7d762d194990da7556384
depends:
- max-core ==25.2.0.dev2025030106 release
- max-core ==25.2.0.dev2025030205 release
- python >=3.9,<3.13
- jupyter_client >=8.6.2,<8.7
- python
license: LicenseRef-Modular-Proprietary
size: 22986
timestamp: 1740812805175
size: 22994
timestamp: 1740892623813
- conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda
sha256: 1895f47b7d68581a6facde5cb13ab8c2764c2e53a76bd746f8f98910dc4e08fe
md5: 29097e7ea634a45cc5386b95cac6568f
Expand Down
25 changes: 19 additions & 6 deletions examples/custom_ops/mandelbrot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@
from max.graph import Graph, TensorType, ops


def draw_mandelbrot(tensor: Tensor, width: int, height: int, iterations: int):
"""A helper function to visualize the Mandelbrot set in ASCII art."""
sr = "....,c8M@jawrpogOQEPGJ"
for row in range(height):
for col in range(width):
v = tensor[row, col].item()
if v < iterations:
idx = int(v % len(sr))
p = sr[idx]
print(p, end="")
else:
print(" ", end="")
print("")


def create_mandelbrot_graph(
width: int,
height: int,
Expand Down Expand Up @@ -61,10 +76,10 @@ def create_mandelbrot_graph(
path = Path(__file__).parent / "kernels.mojopkg"

# Establish Mandelbrot set ranges.
WIDTH = 15
HEIGHT = 15
WIDTH = 60
HEIGHT = 25
MAX_ITERATIONS = 100
MIN_X = -1.5
MIN_X = -2.0
MAX_X = 0.7
MIN_Y = -1.12
MAX_Y = 1.12
Expand Down Expand Up @@ -94,6 +109,4 @@ def create_mandelbrot_graph(
assert isinstance(result, Tensor)
result = result.to(CPU())

print("Iterations to escape:")
print(result.to_numpy())
print()
draw_mandelbrot(result, WIDTH, HEIGHT, MAX_ITERATIONS)
Loading

0 comments on commit 4bcd693

Please sign in to comment.