-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathasal_metrics.py
79 lines (68 loc) · 2.81 KB
/
asal_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import jax
import jax.numpy as jnp
from einops import repeat
def calc_supervised_target_score(z, z_txt):
"""
Calculates the supervisted target score from ASAL.
The returned score should be minimized, since we add a minus sign here.
Parameters
----------
z : jnp.ndarray of shape (T, D)
The latent representation of the images over time.
z_txt : jnp.ndarray of shape (T2, D)
The latent representation of the text prompts over time.
"""
T, T2 = z.shape[0], z_txt.shape[0]
assert T % T2 == 0
z_txt = repeat(z_txt, "T2 D -> (k T2) D", k=T//T2) # repeat to match shape, creating even intervals for each prompt
kernel = z_txt @ z.T # T, T
return -jnp.diag(kernel).mean()
def calc_supervised_target_softmax_score(z, z_txt, temperature_softmax=0.01):
"""
Calculates the supervisted target score from ASAL with softmax.
This isn't part of the original ASAL, but it's a useful extension.
This score helps incentivize the simulation to find unique images for each prompt rather than one static image satisfying all prompts.
The returned score should be minimized.
Parameters
----------
z : jnp.ndarray of shape (T, D)
The latent representation of the images over time.
z_txt : jnp.ndarray of shape (T2, D)
The latent representation of the text prompts over time.
temperature_softmax : float
The temperature for the softmax function. For CLIP, leave it at 0.01, since that is default CLIP softmax temperature.
"""
T, T2 = z.shape[0], z_txt.shape[0]
assert T % T2 == 0
z_txt = repeat(z_txt, "T2 D -> (k T2) D", k=T//T2) # repeat to match shape, creating even intervals for each prompt
kernel = z_txt @ z.T # T, T
loss_sm1 = jax.nn.softmax(kernel/temperature_softmax, axis=-1)
loss_sm2 = jax.nn.softmax(kernel/temperature_softmax, axis=-2)
loss_sm1 = -jnp.log(jnp.diag(loss_sm1))
loss_sm2 = -jnp.log(jnp.diag(loss_sm2))
return (loss_sm1.mean() + loss_sm2.mean())/2.
def calc_open_endedness_score(z):
"""
Calculates the open-endedness score from ASAL.
The returned score should be minimized.
Parameters
----------
z : jnp.ndarray of shape (T, D)
The latent representation of the images over time.
"""
kernel = (z @ z.T) # T, T
kernel = jnp.tril(kernel, k=-1)
return kernel.max(axis=-1).mean()
def calc_illumination_score(zs):
"""
Calculates the illumination score from ASAL.
The returned score should be minimized.
Parameters
----------
zs : jnp.ndarray of shape (N, D)
The latent representation of the images from different simulation parameters.
"""
N, D = zs.shape
kernel = (zs @ zs.T) # N, N
kernel = jnp.where(jnp.eye(N, dtype=bool), -jnp.inf, kernel) # set diagonal to -inf
return kernel.max(axis=-1).mean()