Skip to content

Latest commit

 

History

History
76 lines (61 loc) · 1.61 KB

elixir_MNIST_nb.livemd

File metadata and controls

76 lines (61 loc) · 1.61 KB

Simple MNIST

Mix.install([
  {:vega_lite, "~> 0.1.6"},
  {:kino_vega_lite, "~> 0.1.10"},
  {:axon, "~> 0.6"},
  {:exla, "~> 0.6"},
  {:req, "~> 0.3.1"}
  # {:torchx, "~> 0.6"},
])

alias VegaLite, as: Vl

Section

base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
%{body: train_images} = Req.get!(base_url <> "train-images-idx3-ubyte.gz")
%{body: train_labels} = Req.get!(base_url <> "train-labels-idx1-ubyte.gz")

<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = train_images
<<_::32, n_labels::32, labels::binary>> = train_labels
images =
  images
  |> Nx.from_binary({:u, 8})
  |> Nx.reshape({n_images, 1, n_rows, n_cols}, names: [:images, :channels, :height, :width])
  |> Nx.divide(255)
images[[images: 0..4]] |> Nx.to_heatmap()
images = Nx.to_batched(images, 32)
targets =
  labels
  |> Nx.from_binary({:u, 8})
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  |> Nx.to_batched(32)
model =
  Axon.input("input", shape: {nil, 1, 28, 28})
  |> Axon.flatten()
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

train_loop = fn ->
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(Stream.zip(images, targets), %{}, epochs: 10, compiler: EXLA)
end

{time_in_microseconds, model_state} = :timer.tc(train_loop)

"Seconds #{time_in_microseconds / 10 ** 6}"
first_batch = Enum.at(images, 0)

output = Axon.predict(model, model_state, first_batch)
Nx.argmax(output, axis: 1)