Skip to content

Sample weighting¤

Open in Colab

This example illustrates how to weigh individual training samples differently with klax.fit by using a custom loss function.

To run it locally install klax with plotting capability via pip install 'klax[plot]'.

We'll start by importing the required packages for model creation, optimization and plotting.

import jax
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt

import klax

key = jr.key(0)

First, we generate some dummy data along with sample weights.

Assume our data comes from the function \(f(x) = \sin(x) + \mathcal{N}(0, 0.2)\). In practice, we don’t have direct access to the underlying function and only observe sampled points.
Importantly, our samples are not uniformly distributed: we have many more data points in the region \(x \in [0, 4]\) than in \(x \in [4, 10]\).
To prevent the model from focusing disproportionately on the dense region, we assign larger sample weights to points in the sparser region.

def f(x):
    return jnp.sin(x)


x_samples = jnp.concat([jnp.linspace(0, 4, 100), jnp.linspace(4, 10, 5)])
y_samples = f(x_samples) + 0.2 * jr.normal(key, shape=x_samples.shape)
sample_weights = jnp.where(x_samples > 4, 20, 1)

x_dense = jnp.linspace(0, 10, 1000)
y_dense = f(x_dense)

# Plot the data
plt.plot(x_dense, y_dense, c="grey", label="True function", alpha=0.5)
plt.scatter(
    x_samples, y_samples, c="k", s=sample_weights, alpha=0.8, label="Samples"
)
plt.gca().set(
    xlabel="x",
    ylabel="f(x)",
    title="Data samples with size indicating the sample weight",
)
plt.legend()
plt.show()

img

Let's fit a simple klax.nn.MLP to this data.
Let's first create a custom loss function that computes a weighted mean squared error. To pass the weights to the loss function we integrate the sample weights into the dataset.
For comparison we also train an identical model without the sample weighting.

@klax.loss
def custom_loss(model, data, run_state):
    inputs, targets, weights = data
    predictions = jax.vmap(model)(inputs)
    weighted_mse = jnp.mean(weights * (predictions - targets) ** 2)
    return weighted_mse


model_key, training_key = jr.split(key)

model = klax.nn.MLP(
    in_size="scalar", out_size="scalar", width_sizes=[8, 8], key=model_key
)

model, history = klax.fit(
    model,
    data=(x_samples, y_samples, sample_weights),
    loss=custom_loss,
    batch_size=32,
    steps=20_000,
    log_every=100,
    key=training_key,
)

baseline_model = klax.nn.MLP(
    in_size="scalar", out_size="scalar", width_sizes=[8, 8], key=model_key
)

baseline_model, baseline_history = klax.fit(
    baseline_model,
    data=(x_samples, y_samples),
    batch_size=32,
    steps=20_000,
    log_every=100,
    key=training_key,
)
0%|          | 0/20000 [00:00<?, ?it/s]
0%|          | 0/20000 [00:00<?, ?it/s]
ax = history.plot(c="red")
baseline_history.plot(ax=ax, ls="--", c="grey")
ax.set(
    xlabel="Training step",
    ylabel="Loss",
    title="Training loss with and without sample weighting",
    yscale="log",
)
plt.legend(["weighted loss", "baseline loss"])
plt.show()

img

Let's plot the model predictions.

# Plot the data
y_pred = jax.vmap(model)(x_dense)
y_pred_baseline = jax.vmap(baseline_model)(x_dense)

plt.plot(x_dense, y_dense, c="grey", label="True function", alpha=0.5)
plt.scatter(
    x_samples, y_samples, c="k", s=sample_weights, alpha=0.8, label="Samples"
)
plt.plot(x_dense, y_pred, c="red", label="Model with sample weights")
plt.plot(
    x_dense,
    y_pred_baseline,
    c="red",
    label="Model without sample weights",
    ls="--",
)
plt.gca().set(
    xlabel="x",
    ylabel="f(x)",
    title="Model predictions with and without sample weighting",
)
plt.legend(loc=3)
plt.show()

img