Skip to content

Commit

Permalink
Update Readme and Jax_Utils
Browse files Browse the repository at this point in the history
  • Loading branch information
longtanle committed May 24, 2023
1 parent fa494de commit adafdb4
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,22 @@ Please extract the zip file and replace the original ``./data/`` folder by the e
- Jaxlib == 0.3.25+cuda11.cudnn82
- jaxopt == 0.5.5
- optax == 0.1.4
- chex == 0.1.5
- dm-haiku == 0.0.9

To install the dependencies: `pip3 install -r requirements.txt`
Please follow the installation guidelines in [http://github.com/google/jax](https://github.com/google/jax#pip-installation-gpu-cuda) to install compatible version Jax and Jaxlib version for your machine. The version of packages related to Jax (Jaxopt, Optax, Chex) may also need to be adjusted for compatible with Jax and Jaxlib.

To install other dependencies: `pip3 install -r requirements.txt`

## Experiments

The general template commands for running an experiment are:

```bash
bash runners/<dataset>/run_<algorithm>.sh --repeat 1 [other flags]
bash runners/<dataset>/run_<algorithm>.sh [other flags]
```

### Basic Usage
### Flags

| params | full params | description | default value | options |
|--------|--------------------|-----------------------------------------------------------|---------------|---------|
Expand Down Expand Up @@ -109,10 +111,10 @@ bash runners/shakespeare/run_fedeq.sh -r 1 -g 0
Run algorithms to evaluate the generalization to unseen clients

```bash
bash runners/femnist/run_fedeq.sh -r 1 -fu 0.1 -g 0
bash runners/cifar10/run_fedeq.sh -r 1 -fu 0.1 -g 0
bash runners/cifar100/run_fedeq.sh -r 1 -fu 0.1 -g 0
bash runners/shakespeare/run_fedeq.sh -r 1 -fu 0.1 -g 0
bash runners/femnist/run_fedeq.sh -fu 0.1 -r 1 -g 0
bash runners/cifar10/run_fedeq.sh -fu 0.1 -r 1 -g 0
bash runners/cifar100/run_fedeq.sh -fu 0.1 -r 1 -g 0
bash runners/shakespeare/run_fedeq.sh -fu 0.1 -r 1 -g 0
```

---
Expand Down
14 changes: 9 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
jaxopt == 0.5.5
dm-haiku == 0.0.9
optax == 0.1.4
chex == 0.1.5
numpy
tqdm
scikit-learn
pandas
jax==0.3.25
jaxlib==0.3.25+cuda11.cudnn82
jaxopt==0.5.5
dm-haiku==0.0.9
optax==0.1.4
torch
torchvision
seaborn
matplotlib
tqdm
2 changes: 1 addition & 1 deletion runners/shakespeare/run_fedeq.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
bash runners/shakespeare/main.sh \
--trainer fedeq_seq --model deq_transformer --rho 0.01 --lam_admm 0.01 \
--trainer fedeq_seq --model deq_transformer --rho 0.01 --lam_admm 0.001 \
$@
3 changes: 2 additions & 1 deletion utils/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def rdeq_nll_loss_fn(model,
"""Computes the (scalar) LM loss on `data` w.r.t. params."""
inputs, targets = batch

params = hk.data_structures.merge(shared_params, personalized_params)
params = hk.data_structures.merge(deq_params, personalized_params)
logits, embedding = model.apply(params, rng, inputs)
targets = jax.nn.one_hot(targets, vocab_size)
assert logits.shape == targets.shape
Expand All @@ -222,6 +222,7 @@ def rdeq_nll_loss_fn(model,
quadratic_term = 0.5 * rho * tree_l2_norm(diff_params, squared=True)

mask = jnp.greater(inputs, 0)
#print(jnp.sum(mask))
log_likelihood = jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
return -jnp.sum(log_likelihood * mask)/jnp.sum(mask) + linear_term/jnp.sum(mask) + quadratic_term/jnp.sum(mask) # NLL per token.

Expand Down

0 comments on commit adafdb4

Please sign in to comment.