Skip to content

Latest commit

 

History

History
161 lines (124 loc) · 8.89 KB

KAN函数详解.md

File metadata and controls

161 lines (124 loc) · 8.89 KB

class Kan()

以下是对这个类的详细解释:

  • biases:一个包含 nn.Linear() 对象的列表,用于在节点上添加偏置项(原则上偏置可以被吸收到激活函数中,但这里保留是为了更好地优化)。
  • act_fun:一个包含 KANLayer 的列表,即 KAN 层。
  • depth:表示 KAN 的深度。
  • width:一个列表,指定了每层的神经元数量,例如 [2,5,5,3] 表示 2 维输入,5 维输出,中间有 2 层各有 5 个隐藏神经元。
  • grid:网格的数量。
  • k:分段多项式的阶数。
  • base_fun:残差函数 b(x),一个激活函数 phi(x) = sb_scale * b(x) + sp_scale * spline(x)
  • symbolic_fun:一个包含 Symbolic_KANLayer 的列表,即符号化的 KAN 层。
  • symbolic_enabled:布尔值,如果为 False,则不计算符号化前端(以节省时间),默认是 True

方法部分:

  • __init__():用于初始化一个 KAN 类实例。
  • initialize_from_another_model():从另一个 KAN(具有相同形状但可能不同网格)初始化当前 KAN
  • update_grid_from_samples():基于样本更新样条网格。
  • initialize_grid_from_another_model():从另一个 KAN 初始化网格。
  • forward():前向传播方法。
  • set_mode():设置激活函数的模式,如 'n' 表示数值模式,'s' 表示符号模式,'ns' 表示组合模式(注意在绘制时它们的可视化方式不同,'n' 为黑色,'s' 为红色,'ns' 为紫色)。
  • fix_symbolic():将一个激活函数固定为符号模式。
  • suggest_symbolic():为数值样条型激活函数建议符号化候选。
  • lock():锁定激活函数以共享参数。
  • unlock():解锁锁定的激活函数。
  • get_range():获取激活函数的输入和输出范围。
  • plot():绘制 KAN 的图。
  • train():训练 KAN
  • prune():对 KAN 进行剪枝。
  • remove_edge():移除 KAN 的某些边。
  • remove_node():移除 KAN 的某些节点。
  • auto_symbolic():自动将所有样条拟合为符号函数。
  • symbolic_formula():获取 KAN 网络的符号公式。

model= KAN(width=[4,5,2], grid=3, k=3, seed=0)

以下是对 __init__ 方法的详细解释:

  • width:指定每层神经元数量的列表。
  • grid:网格的数量,默认为 3。
  • k:分段多项式的阶数,默认为 3。
  • noise_scale:初始注入到样条中的噪声大小,默认为 0.1。
  • noise_scale_base:可能与基础函数相关的噪声规模。
  • base_fun:残差函数,默认为 torch.nn.SiLU()
  • symbolic_enabled:布尔值,决定是否计算或跳过符号计算(为了效率),默认为真。
  • bias_trainable:布尔值,指示偏置参数是否可更新,默认为真。
  • grid_eps:用于控制网格划分方式,当为 0 时是均匀网格,为 1 时根据样本分位数划分,介于 0 到 1 之间则是两者的插值,默认为 0.02。
  • grid_range:指定网格范围的列表(形状为 (2,)),默认为 [-1, 1]。
  • sp_trainable:如果为真,则对应的缩放参数 scale_sp 可训练,默认为真。
  • sb_trainable:如果为真,则 scale_base 可训练,默认为真。
  • device:模型运行的设备,默认为 'cpu'。
  • seed:随机种子。

这个初始化方法用于设置 KAN 模型的各种参数和属性,为模型的构建和后续操作奠定基础。

model.plot()

以下是对这个 plot 函数的详细解释:

  • folder:指定用于存储 PNG 图像的文件夹。
  • beta:一个正数,用于控制每个激活的透明度,透明度通过 tanh(beta*l1) 计算。
  • mask:布尔值,如果为真,则使用掩码进行绘制(需要先运行 prune() 函数以获得掩码),默认为假,即绘制所有激活函数。
  • mode:模式,可以是“有监督”或“无监督”,这决定了如何测量 l1(有监督时通过绝对值,无监督时通过标准差减去均值)。
  • scale:控制图形的大小。
  • in_vars:可以为 None 或输入变量名称的列表。
  • out_vars:可以为 None 或输出变量名称的列表。
  • title:可以为 None 或图形的标题。

这个函数的主要目的是用于绘制与模型相关的图形,根据指定的参数来展示激活函数等信息。通过设置不同的参数,可以控制绘制的细节,如存储位置、透明度、模式、大小以及输入和输出变量的标注等。最后返回绘制的图形对象。

示例部分展示了如何创建一个 KAN 模型,进行前向传播获取相关激活数据,然后调用 plot 函数进行绘制。

modle.train()

以下是对这个 train 函数的详细解释:

这个函数的文档解释了training函数的各个参数。以下是对每个参数的详细解释:

  1. dataset : dict

    • 包含训练和测试数据集,字典结构如下:
      • dataset['train_input']: 训练输入数据
      • dataset['train_label']: 训练标签数据
      • dataset['test_input']: 测试输入数据
      • dataset['test_label']: 测试标签数据
  2. opt : str

    • 优化器类型,值可以是 "LBFGS" 或 "Adam",决定了使用的优化算法。
  3. steps : int

    • 训练步骤数,表示训练过程中的总步数。
  4. log : int

    • 日志记录频率,表示每隔多少步记录一次日志信息。
  5. lamb : float

    • 总体惩罚强度系数,用于正则化项的总体惩罚。
  6. lamb_l1 : float

    • L1惩罚强度系数,用于稀疏化正则化。
  7. lamb_entropy : float

    • 熵惩罚强度系数,用于熵正则化。
  8. lamb_coef : float

    • 系数大小惩罚强度系数,用于控制系数大小的正则化。
  9. lamb_coefdiff : float

    • 相邻系数差异(平滑性)惩罚强度系数,用于控制相邻系数差异的正则化。
  10. update_grid : bool

    • 是否定期更新网格。如果为 True,则在特定步骤前定期更新网格。
  11. grid_update_num : int

    • 在停止网格更新步骤之前,网格更新的次数。
  12. stop_grid_update_step : int

    • 在此步骤之后不再更新网格。
  13. batch : int

    • 批量大小。如果值为 -1,则使用整个数据集进行训练(即全量批次)。
  14. small_mag_threshold : float

    • 确定大小数的阈值(可能对小数施加较大的惩罚)。
  15. small_reg_factor : float

    • 对小因子相对于大因子的惩罚强度。
  16. device : str

    • 计算设备,例如 "cpu" 或 "cuda"(用于指定运行训练的硬件设备)。
  17. save_fig_freq : int

    • 每隔多少步保存一次图像。

这些参数共同定义了训练过程中的数据、优化策略、正则化方法、批量处理方式以及日志记录和图像保存的频率。

model.prune()

以下是对上述 prune 函数的详细解释:

函数功能: 这个函数用于在节点级别对 KAN 模型进行剪枝操作。如果一个节点的输入或输出连接较小,就会被剪枝掉。

参数解释

  • threshold:这是一个浮点数阈值。用于在自动模式(mode="auto")下,判断一个节点是否足够小从而决定是否剪枝。
  • mode:一个字符串参数,取值为 "auto""manual""auto" 表示使用阈值自动剪枝节点;"manual" 表示需要通过 active_neurons_id 来指定保留哪些神经元(其他的被丢弃)。
  • active_neurons_id:这是一个列表的列表。例如 [[0,1],[0,2,3]] 表示在第一个隐藏层保留 0/1 神经元,在第二个隐藏层保留 0/2/3 神经元。目前不支持对输入和输出神经元进行剪枝。

函数内部实现

  • 首先,初始化一个 mask 列表和 active_neurons 列表。mask 用于记录每个层中节点的保留情况,active_neurons 记录每个层中活跃神经元的索引。
  • 然后,通过一个循环处理每一层(除了输入层)。
    • 在自动模式下,根据当前层和下一层的激活尺度计算哪些节点是重要的(大于阈值),通过逻辑与操作得到整体重要的节点。
    • 在手动模式下,根据提供的 active_neurons_id 来确定哪些节点是重要的。
  • 将重要节点的信息添加到 maskactive_neurons 列表中。
  • mask 赋值给模型的属性。
  • 通过另一个循环,对不重要的节点进行移除操作。
  • 创建一个新的 KAN 模型 model2 ,并复制原模型的状态字典。
  • 对新模型进行一些更新操作,包括调整偏置的权重、更新激活函数、调整宽度和符号函数等。

返回值: 返回经过剪枝操作后的新模型 model2

示例: 示例中首先创建了一个 KAN 模型,进行训练,然后调用 prune 函数进行剪枝,并绘制剪枝后的模型。

希望以上解释对您理解这个函数有所帮助!如果您还有其他疑问,请随时提问。