writing

Entropix Sampling

Oct 2024

Entropy sampling approach by xjdr; create a per-token sampling strategy for LLM using information about the distribution of logits 1.

Entropy and Var-entropy

Entropix Sampling Approach

The model’s behavior is determined by the degree of entropy and var-entropy.

  1. (⬇️ entropy, ⬇️ var-entropy) High degree of confidence: the model will return the token with the highest probability.
  2. (⬆️ entropy, ⬇️ var-entropy) Consistently unsure: it will either backspace and resample to get back on track or give an EOT token to prevent hallucination.
  3. (⬇️ entropy, ⬆️ var-entropy) Confident on multiple paths it will branch out and explore, returning the most confident path.
  4. (⬆️ entropy, ⬆️ var-entropy) Randomness needed: the temperature will be very high and top_p will be decreased to prevent gibberish.

Reference Implementation

def calculate_varentropy_logsoftmax(logits):
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    probs = jnp.exp(log_probs)
    entropy = -jnp.sum(probs * log_probs, axis=-1) / LN_2  # Convert to base-2
    varentropy = jnp.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, axis=-1)
    return entropy, varentropy
  1. (and attention heads)