Xiaobo's Blog
en

← 首页 EN

大模型强化学习训练中KL散度的梯度估计

三种估计量、一个 loss,以及 ∇E[·] 与 E[∇·] 之间那条微妙的缝。

在大语言模型的强化学习训练中,有reward shaping和KL loss两种引入KL散度约束的方式,前者为直接对reward值加上kl估计,作为一个新的reward ;另外一种为将KL估计量放到loss中,计算 ,一起进行反向传播.

Schulman (2020) 给出了三种KL散度的估计方式,其中k3 loss被认为是兼备unbiased和low variance的好估计量。,DeepSeek的GRPO算法 (Guo et al. 2025) 就使用了这种方式:

Fig. 1. GRPO 算法. (来源:Guo et al. 2025)
Fig. 1. GRPO 算法. (来源:Guo et al. 2025)

不过,如果使用KL Loss,我们实际上是通过抽样来对KL散度的梯度做估计,这和Schulman (2020) 的分析有一些区别。在训练中,我们构建了估计量后,将其直接作为loss进行反向传播求导,期望这仍然是一个很好的逼近:

但KL的无偏估计的梯度,未必是KL的梯度的无偏估计,

事实上,他们之间差了一项:

对于policy生成的样本 ,分别取 为k1、k2、k3 loss可知,

  • k1

  • k2

  • k3

可见,虽然k2 loss并不是kl的无偏估计,但其梯度是kl的梯度的无偏估计;而k1和k3则相反。值得注意的是,k1 loss的梯度的期望始终为0,在大的batch size条件下,这是平凡的。这也说明了GRPO算法中采用的KL loss不是一个很好的估计量。

有时大家还会使用逆KL散度,它将policy model和reference model在KL散度的计算中交换位置。由于样本 始终是从policy model中抽出,我们需要乘以一个probability ratio,

于是,如果要在KL loss中估计其梯度,一个无偏估计可取

此时KL的无偏估计的梯度仍然不等于KL的梯度的无偏估计:

实验

考虑 ,参数 为高斯分布的均值和标准差,估计结果与理论分析一致:

Fig. 2. KL估计实验结果
Fig. 2. KL估计实验结果
Average gradient value comparison (mean parameter):
True KL gradient: -0.826446
k1 average gradient: 0.000137
k2 average gradient: -0.826291
k3 average gradient: -0.999618

Average gradient value comparison (standard deviation parameter):
True KL gradient: -0.173554
k1 average gradient: -0.000104
k2 average gradient: -0.173379
k3 average gradient: -1.209091

Reverse KL - Average gradient value comparison (mean parameter):
True Reverse KL gradient: -1.000000
reverse_k1 average gradient: -2.718389
reverse_k2 average gradient: -0.999755

Reverse KL - Average gradient value comparison (standard deviation parameter):
True Reverse KL gradient: -1.210000
reverse_k1 average gradient: -4.496196
reverse_k2 average gradient: -1.208987

KL - Comparison of bias and standard deviation for mean parameter gradients:
╒════╤═════════════╤══════════════╕
│    │   bias/true │   stdev/true │
╞════╪═════════════╪══════════════╡
│ k1 │      1.0002 │       0.0054 │
├────┼─────────────┼──────────────┤
│ k2 │      0.0002 │       0.0065 │
├────┼─────────────┼──────────────┤
│ k3 │      0.2095 │       0.026  │
╘════╧═════════════╧══════════════╛

KL - Comparison of bias and standard deviation for standard deviation parameter gradients:
╒════╤═════════════╤══════════════╕
│    │   bias/true │   stdev/true │
╞════╪═════════════╪══════════════╡
│ k1 │      0.9994 │       0.0359 │
├────┼─────────────┼──────────────┤
│ k2 │      0.001  │       0.0686 │
├────┼─────────────┼──────────────┤
│ k3 │      5.9667 │       0.4311 │
╘════╧═════════════╧══════════════╛

Reverse KL - Comparison of bias and standard deviation for mean parameter gradients:
╒════════════╤═════════════╤══════════════╕
│            │   bias/true │   stdev/true │
╞════════════╪═════════════╪══════════════╡
│ reverse_k1 │      1.7184 │       0.1081 │
├────────────┼─────────────┼──────────────┤
│ reverse_k2 │      0.0002 │       0.0231 │
╘════════════╧═════════════╧══════════════╛

Reverse KL - Comparison of bias and standard deviation for standard deviation parameter gradients:
╒════════════╤═════════════╤══════════════╕
│            │   bias/true │   stdev/true │
╞════════════╪═════════════╪══════════════╡
│ reverse_k1 │      2.7159 │       0.3419 │
├────────────┼─────────────┼──────────────┤
│ reverse_k2 │      0.0008 │       0.0636 │
╘════════════╧═════════════╧══════════════╛

实验代码

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Normal
import seaborn as sns

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define parameters for two normal distributions
def setup_distributions(mu_theta=0.0, sigma_theta=1.0, mu_ref=1.0, sigma_ref=1.1):
    # Create trainable parameters
    mu_theta_param = torch.tensor(mu_theta, requires_grad=True)
    sigma_theta_param = torch.tensor(sigma_theta, requires_grad=True)
    
    # Reference distribution parameters (fixed)
    mu_ref_param = torch.tensor(mu_ref)
    sigma_ref_param = torch.tensor(sigma_ref)
    
    # Create distributions
    pi_theta = Normal(mu_theta_param, sigma_theta_param)
    pi_ref = Normal(mu_ref_param, sigma_ref_param)
    
    return pi_theta, pi_ref, mu_theta_param, sigma_theta_param

# Calculate the true KL divergence (analytical solution for normal distributions)
def true_kl_divergence(pi_theta, pi_ref):
    mu_theta = pi_theta.loc
    sigma_theta = pi_theta.scale
    mu_ref = pi_ref.loc
    sigma_ref = pi_ref.scale
    
    kl = (torch.log(sigma_ref/sigma_theta) + 
          (sigma_theta**2 + (mu_theta - mu_ref)**2)/(2*sigma_ref**2) - 0.5)
    rkl = (torch.log(sigma_theta/sigma_ref) + 
        (sigma_ref**2 + (mu_ref - mu_theta)**2)/(2*sigma_theta**2) - 0.5)
    
    return kl, rkl

# Three different KL divergence estimates
def k1_loss(x, pi_theta, pi_ref):
    return pi_theta.log_prob(x) - pi_ref.log_prob(x)

def k2_loss(x, pi_theta, pi_ref):
    return 0.5 * (pi_theta.log_prob(x) - pi_ref.log_prob(x))**2

def k3_loss(x, pi_theta, pi_ref):
    ratio = torch.exp(pi_ref.log_prob(x) - pi_theta.log_prob(x))
    return ratio - 1 - torch.log(ratio)

# Two different reverse KL divergence estimates
def reverse_k1_loss(x, pi_theta, pi_ref):
    ratio = torch.exp(pi_ref.log_prob(x) - pi_theta.log_prob(x))
    return ratio * (pi_ref.log_prob(x) - pi_theta.log_prob(x))

def reverse_k2_loss(x, pi_theta, pi_ref):
    return torch.exp(pi_ref.log_prob(x) - pi_theta.log_prob(x))

# Sample and compute gradients
def estimate_gradients(sample_size=10000, num_trials=10):
    pi_theta, pi_ref, mu_param, sigma_param = setup_distributions()
    
    # Compute the gradient of the true and reverse KL divergence
    true_kl, true_reverse_kl = true_kl_divergence(pi_theta, pi_ref)
    
    true_kl.backward()
    true_grad_mu = mu_param.grad.item()
    true_grad_sigma = sigma_param.grad.item()
    mu_param.grad.zero_()
    sigma_param.grad.zero_()

    true_reverse_kl.backward()
    true_reverse_grad_mu = mu_param.grad.item()
    true_reverse_grad_sigma = sigma_param.grad.item()
    mu_param.grad.zero_()
    sigma_param.grad.zero_()
    
    # Store gradients from different estimates
    k1_grads_mu = []
    k1_grads_sigma = []
    k2_grads_mu = []
    k2_grads_sigma = []
    k3_grads_mu = []
    k3_grads_sigma = []
    reverse_k1_grads_mu = []
    reverse_k1_grads_sigma = []
    reverse_k2_grads_mu = []
    reverse_k2_grads_sigma = []
    
    for _ in range(num_trials):
        pi_theta, pi_ref, mu_param, sigma_param = setup_distributions()
        
        # Sample from the current policy
        samples = pi_theta.sample((sample_size,))
        
        # Get gradient of KL estimation
        k1_values = k1_loss(samples, pi_theta, pi_ref)
        k1_mean = k1_values.mean()
        k1_mean.backward()
        k1_grads_mu.append(mu_param.grad.item())
        k1_grads_sigma.append(sigma_param.grad.item())
        mu_param.grad.zero_()
        sigma_param.grad.zero_()
        
        k2_values = k2_loss(samples, pi_theta, pi_ref)
        k2_mean = k2_values.mean()
        k2_mean.backward()
        k2_grads_mu.append(mu_param.grad.item())
        k2_grads_sigma.append(sigma_param.grad.item())
        mu_param.grad.zero_()
        sigma_param.grad.zero_()
        
        k3_values = k3_loss(samples, pi_theta, pi_ref)
        k3_mean = k3_values.mean()
        k3_mean.backward()
        k3_grads_mu.append(mu_param.grad.item())
        k3_grads_sigma.append(sigma_param.grad.item())
        mu_param.grad.zero_()
        sigma_param.grad.zero_()

        reverse_k1_values = reverse_k1_loss(samples, pi_theta, pi_ref)
        reverse_k1_mean = reverse_k1_values.mean()
        reverse_k1_mean.backward()
        reverse_k1_grads_mu.append(mu_param.grad.item())
        reverse_k1_grads_sigma.append(sigma_param.grad.item())
        mu_param.grad.zero_()
        sigma_param.grad.zero_()
        
        reverse_k2_values = reverse_k2_loss(samples, pi_theta, pi_ref)
        reverse_k2_mean = reverse_k2_values.mean()
        reverse_k2_mean.backward()
        reverse_k2_grads_mu.append(mu_param.grad.item())
        reverse_k2_grads_sigma.append(sigma_param.grad.item())
        mu_param.grad.zero_()
        sigma_param.grad.zero_()
    
    return {
        'true_grad_mu': true_grad_mu,
        'true_grad_sigma': true_grad_sigma,
        'k1_grads_mu': k1_grads_mu,
        'k1_grads_sigma': k1_grads_sigma,
        'k2_grads_mu': k2_grads_mu,
        'k2_grads_sigma': k2_grads_sigma,
        'k3_grads_mu': k3_grads_mu,
        'k3_grads_sigma': k3_grads_sigma,
        'true_reverse_grad_mu': true_reverse_grad_mu,
        'true_reverse_grad_sigma': true_reverse_grad_sigma,
        'reverse_k1_grads_mu': reverse_k1_grads_mu,
        'reverse_k1_grads_sigma': reverse_k1_grads_sigma,
        'reverse_k2_grads_mu': reverse_k2_grads_mu,
        'reverse_k2_grads_sigma': reverse_k2_grads_sigma
    }

def create_nice_table(results):
    true_grad_mu = results['true_grad_mu']
    true_grad_sigma = results['true_grad_sigma']
    true_reverse_grad_mu = results['true_reverse_grad_mu']
    true_reverse_grad_sigma = results['true_reverse_grad_sigma']
    k1_bias_mu = abs(np.mean(results['k1_grads_mu']) - true_grad_mu) / abs(true_grad_mu)
    k2_bias_mu = abs(np.mean(results['k2_grads_mu']) - true_grad_mu) / abs(true_grad_mu)
    k3_bias_mu = abs(np.mean(results['k3_grads_mu']) - true_grad_mu) / abs(true_grad_mu)
    k1_std_mu = np.std(results['k1_grads_mu']) / abs(true_grad_mu)
    k2_std_mu = np.std(results['k2_grads_mu']) / abs(true_grad_mu)
    k3_std_mu = np.std(results['k3_grads_mu']) / abs(true_grad_mu)
    k1_bias_sigma = abs(np.mean(results['k1_grads_sigma']) - true_grad_sigma) / abs(true_grad_sigma)
    k2_bias_sigma = abs(np.mean(results['k2_grads_sigma']) - true_grad_sigma) / abs(true_grad_sigma)
    k3_bias_sigma = abs(np.mean(results['k3_grads_sigma']) - true_grad_sigma) / abs(true_grad_sigma)
    k1_std_sigma = np.std(results['k1_grads_sigma']) / abs(true_grad_sigma)
    k2_std_sigma = np.std(results['k2_grads_sigma']) / abs(true_grad_sigma)
    k3_std_sigma = np.std(results['k3_grads_sigma']) / abs(true_grad_sigma)
    reverse_k1_bias_mu = abs(np.mean(results['reverse_k1_grads_mu']) - true_reverse_grad_mu) / abs(true_reverse_grad_mu)
    reverse_k2_bias_mu = abs(np.mean(results['reverse_k2_grads_mu']) - true_reverse_grad_mu) / abs(true_reverse_grad_mu)
    reverse_k1_std_mu = np.std(results['reverse_k1_grads_mu']) / abs(true_reverse_grad_mu)
    reverse_k2_std_mu = np.std(results['reverse_k2_grads_mu']) / abs(true_reverse_grad_mu)
    reverse_k1_bias_sigma = abs(np.mean(results['reverse_k1_grads_sigma']) - true_reverse_grad_sigma) / abs(true_reverse_grad_sigma)
    reverse_k2_bias_sigma = abs(np.mean(results['reverse_k2_grads_sigma']) - true_reverse_grad_sigma) / abs(true_reverse_grad_sigma)
    reverse_k1_std_sigma = np.std(results['reverse_k1_grads_sigma']) / abs(true_reverse_grad_sigma)
    reverse_k2_std_sigma = np.std(results['reverse_k2_grads_sigma']) / abs(true_reverse_grad_sigma)
    
    # Create table
    from tabulate import tabulate
    import pandas as pd
    
    df_mu = pd.DataFrame({
        'bias/true': [k1_bias_mu, k2_bias_mu, k3_bias_mu],
        'stdev/true': [k1_std_mu, k2_std_mu, k3_std_mu]
    }, index=['k1', 'k2', 'k3'])
    df_reverse_mu = pd.DataFrame({
        'bias/true': [reverse_k1_bias_mu, reverse_k2_bias_mu],
        'stdev/true': [reverse_k1_std_mu, reverse_k2_std_mu]
    }, index=['reverse_k1', 'reverse_k2'])
    df_sigma = pd.DataFrame({
        'bias/true': [k1_bias_sigma, k2_bias_sigma, k3_bias_sigma],
        'stdev/true': [k1_std_sigma, k2_std_sigma, k3_std_sigma]
    }, index=['k1', 'k2', 'k3'])

    df_reverse_sigma = pd.DataFrame({
        'bias/true': [reverse_k1_bias_sigma, reverse_k2_bias_sigma],
        'stdev/true': [reverse_k1_std_sigma, reverse_k2_std_sigma]
    }, index=['reverse_k1', 'reverse_k2'])
    
    # Format values
    df_mu = df_mu.round(4)
    df_sigma = df_sigma.round(4)
    df_reverse_mu = df_reverse_mu.round(4)
    df_reverse_sigma = df_reverse_sigma.round(4)
    
    # Print nice table
    print("KL - Comparison of bias and standard deviation for mean parameter gradients:")
    print(tabulate(df_mu, headers='keys', tablefmt='fancy_grid'))
    
    print("\nKL - Comparison of bias and standard deviation for standard deviation parameter gradients:")
    print(tabulate(df_sigma, headers='keys', tablefmt='fancy_grid'))

    print("\nReverse KL - Comparison of bias and standard deviation for mean parameter gradients:")
    print(tabulate(df_reverse_mu, headers='keys', tablefmt='fancy_grid'))
    
    print("\nReverse KL - Comparison of bias and standard deviation for standard deviation parameter gradients:")
    print(tabulate(df_reverse_sigma, headers='keys', tablefmt='fancy_grid'))
    
    return df_mu, df_sigma, df_reverse_mu, df_reverse_sigma


def visualize_results(results):
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    axes[0,0].axhline(y=results['true_grad_mu'], color='r', linestyle='-', label='True KL')
    sns.violinplot(data=[results['k1_grads_mu'], results['k2_grads_mu'], results['k3_grads_mu']], ax=axes[0,0])
    axes[0,0].set_title('Gradient of KL Divergence with Respect to Mean Parameter')
    axes[0,0].set_xticks([0,1,2])
    axes[0,0].set_xticklabels(['k1', 'k2', 'k3'])
    axes[0,0].set_ylabel('Gradient Value')
    axes[0,0].legend()
    
    axes[0,1].axhline(y=results['true_grad_sigma'], color='r', linestyle='-', label='True KL')
    sns.violinplot(data=[results['k1_grads_sigma'], results['k2_grads_sigma'], results['k3_grads_sigma']], ax=axes[0,1])
    axes[0,1].set_title('Gradient of KL Divergence with Respect to Standard Deviation Parameter')
    axes[0,1].set_xticks([0,1,2])
    axes[0,1].set_xticklabels(['k1', 'k2', 'k3'])
    axes[0,1].set_ylabel('Gradient Value')
    axes[0,1].legend()
    
    axes[1,0].axhline(y=results['true_reverse_grad_mu'], color='r', linestyle='-', label='True Reverse KL')
    sns.violinplot(data=[results['reverse_k1_grads_mu'], results['reverse_k2_grads_mu']], ax=axes[1,0])
    axes[1,0].set_title('Gradient of Reverse KL Divergence with Respect to Mean Parameter')
    axes[1,0].set_xticks([0,1])
    axes[1,0].set_xticklabels(['reverse_k1', 'reverse_k2'])
    axes[1,0].set_ylabel('Gradient Value')
    axes[1,0].legend()
    
    axes[1,1].axhline(y=results['true_reverse_grad_sigma'], color='r', linestyle='-', label='True Reverse KL')
    sns.violinplot(data=[results['reverse_k1_grads_sigma'], results['reverse_k2_grads_sigma']], ax=axes[1,1])
    axes[1,1].set_title('Gradient of Reverse KL Divergence with Respect to Standard Deviation Parameter')
    axes[1,1].set_xticks([0,1])
    axes[1,1].set_xticklabels(['reverse_k1', 'reverse_k2'])
    axes[1,1].set_ylabel('Gradient Value')
    axes[1,1].legend()
    
    plt.tight_layout()
    plt.show()
    
    # Print mean comparison
    print("Average gradient value comparison (mean parameter):")
    print(f"True KL gradient: {results['true_grad_mu']:.6f}")
    print(f"k1 average gradient: {np.mean(results['k1_grads_mu']):.6f}")
    print(f"k2 average gradient: {np.mean(results['k2_grads_mu']):.6f}")
    print(f"k3 average gradient: {np.mean(results['k3_grads_mu']):.6f}")
    
    print("\nAverage gradient value comparison (standard deviation parameter):")
    print(f"True KL gradient: {results['true_grad_sigma']:.6f}")
    print(f"k1 average gradient: {np.mean(results['k1_grads_sigma']):.6f}")
    print(f"k2 average gradient: {np.mean(results['k2_grads_sigma']):.6f}")
    print(f"k3 average gradient: {np.mean(results['k3_grads_sigma']):.6f}")

    print("\nReverse KL - Average gradient value comparison (mean parameter):")
    print(f"True Reverse KL gradient: {results['true_reverse_grad_mu']:.6f}")
    print(f"reverse_k1 average gradient: {np.mean(results['reverse_k1_grads_mu']):.6f}")
    print(f"reverse_k2 average gradient: {np.mean(results['reverse_k2_grads_mu']):.6f}")
    
    print("\nReverse KL - Average gradient value comparison (standard deviation parameter):")
    print(f"True Reverse KL gradient: {results['true_reverse_grad_sigma']:.6f}")
    print(f"reverse_k1 average gradient: {np.mean(results['reverse_k1_grads_sigma']):.6f}")
    print(f"reverse_k2 average gradient: {np.mean(results['reverse_k2_grads_sigma']):.6f}")

    # Add nice table output
    print()
    df = create_nice_table(results)

if __name__ == "__main__":
    sample_size = 50000
    num_trials = 1000
    results = estimate_gradients(sample_size, num_trials)
    visualize_results(results)

补充

k2 loss方差减小的思路,可以模仿 Schulman (2020) 的分析,代入一个零均值统计量,待定系数法求解。

首先对 逐分量求解,

然后逐分量使用修正后的估计量,

然而这样却很麻烦。首先 很难估计,需要先得到score(log prob梯度)才能计算,故可能需要两次backward;其次,每个分量都要单独估计。

作为替代,可考虑使用上一步的score以及单个 估计量,

不过这样仍然较为复杂,需要使用逐样本的grad norm,并且已经引入了一些bias,未必合算。KL loss本身起到的作用也未可知,是否需要在这里做额外的方差减小有待商榷,或许k2即足够。

引用

引用格式如下:

Yang, Xiaobo. (Mar 2025). Gradient Estimation of KL Divergence in Large Language Model Reinforcement Learning. Xiabo's Blog.
https://xiaobo-yang.github.io/posts/kl_grad/.

或者使用BibTeX格式:

@article{yang2025klgradient,
    title = "Gradient Estimation of KL Divergence in Large Language Model Reinforcement Learning.",
    author = "Yang, Xiaobo",
    journal = "xiaobo-yang.github.io",
    year = "2025",
    month = "Mar",
    url = "https://xiaobo-yang.github.io/posts/kl_grad/"
}

参考

[1] John Schulman “Approximating KL Divergence.” 2020.

[2] Guo et al. "DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning" 2025.

— Xiaobo Yang