引擎#

工作流#

Workflow#

class monai.engines.Workflow(device, max_epochs, data_loader, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, postprocessing=None, key_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, handlers=None, amp=False, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None)[source]#

Workflow 定义了继承自 Ignite 引擎的核心工作流程。所有的训练器(trainer)、验证器(validator)和评估器(evaluator)都共享此工作流作为基类,因为它们都可以被视为相同的 Ignite 引擎循环。它在 Ignite engine.state 中初始化所有可共享的数据。并基于事件处理(Event-Handler)机制将附加处理逻辑附加到 Ignite 引擎。

用户应考虑继承 trainerevaluator 来开发更多的训练器或评估器。

参数:
  • device – 表示运行设备的类或对象。

  • max_epochs – 引擎运行的总轮次(epoch)数,验证器和评估器只有 1 个轮次。

  • data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable 或 torch.DataLoader。

  • epoch_length – 一个轮次的迭代次数,默认为 len(data_loader)

  • non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。

  • prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 imagelabel 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html

  • iteration_update – 每次迭代的可调用函数,预期接受 engineengine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html

  • postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。

  • key_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_metric 是用于比较和将检查点保存到文件中的主要指标。

  • additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。

  • metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metricbest_metric_epoch,默认为 大于

  • handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。

  • amp – 是否启用自动混合精度训练或推理,默认为 False。

  • event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum

  • event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。

  • decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True

  • to_kwargsprepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking

  • amp_kwargstorch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast

抛出:
  • TypeError – 当 data_loader 不是 torch.utils.data.DataLoader 时。

  • TypeError – 当 key_metric 不是 Optional[dict] 时。

  • TypeError – 当 additional_metrics 不是 Optional[dict] 时。

get_stats(*vars)[source]#

获取工作流程过程的统计信息。

参数:

varsself.state 中的变量名,将使用变量名作为键,状态内容作为值。如果变量不存在,默认值为 None

run()[source]#

基于 Ignite 引擎执行训练、验证或评估。

返回类型:

None

Trainer#

class monai.engines.Trainer(device, max_epochs, data_loader, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, postprocessing=None, key_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, handlers=None, amp=False, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None)[source]#

所有训练器的基类,继承自 Workflow。

get_stats(*vars)[source]#

获取训练过程的统计信息。默认返回 rankcurrent_epochcurrent_iterationtotal_epochstotal_iterations

参数:

vars – 除了默认统计信息外,self.state 中要返回的其他变量名,将使用变量名作为键,状态内容作为值。如果变量不存在,默认值为 None

run()[source]#

基于 Ignite 引擎执行训练。如果多次调用此函数,它将从先前的状态继续运行。

返回类型:

None

SupervisedTrainer#

class monai.engines.SupervisedTrainer(device, max_epochs, train_data_loader, network, optimizer, loss_function, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, inferer=None, postprocessing=None, key_train_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, train_handlers=None, amp=False, event_names=None, event_to_attr=None, decollate=True, optim_set_to_none=False, to_kwargs=None, amp_kwargs=None, compile=False, compile_kwargs=None)[source]#

标准的有监督训练方法,使用图像和标签,继承自 TrainerWorkflow

参数:
  • device – 表示运行设备的类或对象。

  • max_epochs – 训练器运行的总轮次(epoch)数。

  • train_data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable 或 torch.DataLoader。

  • network – 在训练器中训练的网络,应为标准的 PyTorch torch.nn.Module

  • optimizer – 与网络相关的优化器,应为 torch.optim 或其子类中的标准 PyTorch 优化器。

  • loss_function – 与优化器相关的损失函数,应为标准的 PyTorch 损失函数,继承自 torch.nn.modules.loss

  • epoch_length – 一个轮次的迭代次数,默认为 len(train_data_loader)

  • non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。

  • prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 imagelabel 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html

  • iteration_update – 每次迭代的可调用函数,预期接受 engineengine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html

  • inferer – 在输入数据上执行模型前向计算的推理方法,例如:SlidingWindow 等。

  • postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。

  • key_train_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_train_metric 是用于比较和将检查点保存到文件中的主要指标。

  • additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。

  • metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metricbest_metric_epoch,默认为 大于

  • train_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。

  • amp – 是否启用自动混合精度训练,默认为 False。

  • event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum

  • event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。

  • decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True

  • optim_set_to_none – 调用 optimizer.zero_grad() 时,将梯度设置为 None 而不是零。更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.optim.Optimizer.zero_grad.html

  • to_kwargsprepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking

  • amp_kwargstorch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast

  • compile – 是否使用 torch.compile,默认为 False。如果为 True,MetaTensor 输入将在前向传播之前转换为 torch.Tensor,然后携带复制的元信息转换回去。

  • compile_kwargstorch.compile() API 的参数字典,更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.compile.html#torch-compile

GanTrainer#

class monai.engines.GanTrainer(device, max_epochs, train_data_loader, g_network, g_optimizer, g_loss_function, d_network, d_optimizer, d_loss_function, epoch_length=None, g_inferer=None, d_inferer=None, d_train_steps=1, latent_shape=64, non_blocking=False, d_prepare_batch=<function default_prepare_batch>, g_prepare_batch=<function default_make_latent>, g_update_latents=True, iteration_update=None, postprocessing=None, key_train_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, train_handlers=None, decollate=True, optim_set_to_none=False, to_kwargs=None, amp_kwargs=None)[source]#

基于 Goodfellow 等人 2014 年论文 https://arxiv.org/abs/1406.266 的生成对抗网络训练,继承自 TrainerWorkflow

训练循环:对于每个数据批次大小为 m 的批次
  1. 从随机潜在代码生成 m 个假样本。

  2. 使用这些假样本和当前批次的真实样本更新判别器,重复 d_train_steps 次。

  3. 如果 g_update_latents 为 True,从新的随机潜在代码生成 m 个假样本。

  4. 使用这些假样本和判别器的反馈更新生成器。

参数:
  • device – 表示运行设备的类或对象。

  • max_epochs – 引擎运行的总轮次(epoch)数。

  • train_data_loader – 核心 ignite 引擎使用 DataLoader 进行训练循环批次数据加载。

  • g_network – 生成器 (G) 网络架构。

  • g_optimizer – G 优化器函数。

  • g_loss_function – G 优化器的损失函数。

  • d_network – 判别器 (D) 网络架构。

  • d_optimizer – D 优化器函数。

  • d_loss_function – D 优化器的损失函数。

  • epoch_length – 一个轮次的迭代次数,默认为 len(train_data_loader)

  • g_inferer – 执行 G 模型前向计算的推理方法。默认为 SimpleInferer()

  • d_inferer – 执行 D 模型前向计算的推理方法。默认为 SimpleInferer()

  • d_train_steps – 使用真实数据小批次更新 D 的次数。默认为 1

  • latent_shape – G 输入潜在代码的大小。默认为 64

  • non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。

  • d_prepare_batch – 为 D 推理器准备批次数据的回调函数。默认为在批次数据字典中返回 GanKeys.REALS。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html

  • g_prepare_batch – 为 G 推理器创建潜在输入批次的回调函数。默认为返回随机潜在代码。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html

  • g_update_latents – 使用新的潜在代码计算 G 损失。默认为 True

  • iteration_update – 每次迭代的可调用函数,预期接受 engineengine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html

  • postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。

  • key_train_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_train_metric 是用于比较和将检查点保存到文件中的主要指标。

  • additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。

  • metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metricbest_metric_epoch,默认为 大于

  • train_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。

  • decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True

  • optim_set_to_none – 调用 optimizer.zero_grad() 时,将梯度设置为 None 而不是零。更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.optim.Optimizer.zero_grad.html

  • to_kwargsprepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking

  • amp_kwargstorch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast

AdversarialTrainer#

class monai.engines.AdversarialTrainer(device, max_epochs, train_data_loader, g_network, g_optimizer, g_loss_function, recon_loss_function, d_network, d_optimizer, d_loss_function, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, g_inferer=None, d_inferer=None, postprocessing=None, key_train_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, train_handlers=None, amp=False, event_names=None, event_to_attr=None, decollate=True, optim_set_to_none=False, to_kwargs=None, amp_kwargs=None)[source]#

用于启用对抗性损失的神经网络的标准有监督训练工作流。

参数:
  • device – 表示运行设备的类或对象。

  • max_epochs – 引擎运行的总轮次(epoch)数。

  • train_data_loader – 核心 ignite 引擎使用 DataLoader 进行训练循环批次数据加载。

  • g_network – “生成器”(G)网络架构。

  • g_optimizer – G 优化器函数。

  • g_loss_function – 用于对抗性训练的 G 损失函数。

  • recon_loss_function – 用于重构的 G 损失函数。

  • d_network – 判别器 (D) 网络架构。

  • d_optimizer – D 优化器函数。

  • d_loss_function – 用于对抗性训练的 D 损失函数。

  • epoch_length – 一个轮次的迭代次数,默认为 len(train_data_loader)

  • non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。

  • prepare_batch – 用于解析当前迭代的图像和标签的函数。

  • iteration_update – 每次迭代的可调用函数,预期接受 enginebatchdata 作为输入参数。如果未提供,则使用 self._iteration() 代替。

  • g_inferer – 执行 G 模型前向计算的推理方法。默认为 SimpleInferer()

  • d_inferer – 执行 D 模型前向计算的推理方法。默认为 SimpleInferer()

  • postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。默认为 None

  • key_train_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_train_metric 是用于比较和将检查点保存到文件中的主要指标。

  • additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。

  • metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 ‘best_metric` 和 best_metric_epoch,默认为 大于

  • train_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。

  • amp – 是否启用自动混合精度训练,默认为 False。

  • event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum

  • event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。

  • decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True

  • optim_set_to_none – 调用 optimizer.zero_grad() 时,将梯度设置为 None 而不是零。更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.optim.Optimizer.zero_grad.html

  • to_kwargsprepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking

  • amp_kwargstorch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast

Evaluator#

class monai.engines.Evaluator(device, val_data_loader, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, postprocessing=None, key_val_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, val_handlers=None, amp=False, mode=eval, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None)[source]#

所有评估器的基类,继承自 Workflow。

参数:
  • device – 表示运行设备的类或对象。

  • val_data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable,通常是 torch.DataLoader。

  • epoch_length – 一个轮次的迭代次数,默认为 len(val_data_loader)

  • non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。

  • prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 imagelabel 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html

  • iteration_update – 每次迭代的可调用函数,预期接受 engineengine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html

  • postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。

  • key_val_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_val_metric 是用于比较和将检查点保存到文件中的主要指标。

  • additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。

  • metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metricbest_metric_epoch,默认为 大于

  • val_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。

  • amp – 是否启用自动混合精度评估,默认为 False。

  • mode – 评估期间模型的向前传播模式,应为 ‘eval’ 或 ‘train’,分别映射到 model.eval()model.train(),默认为 ‘eval’。

  • event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum

  • event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。

  • decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True

  • to_kwargsprepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking

  • amp_kwargstorch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast

get_stats(*vars)[source]#

获取验证过程的统计信息。默认返回 rankbest_validation_epochbest_validation_metric

参数:

vars – 除了默认统计信息外,self.state 中要返回的其他变量名,将使用变量名作为键,状态内容作为值。如果变量不存在,默认值为 None

run(global_epoch=1)[source]#

基于 Ignite 引擎执行验证/评估。

参数:

global_epoch (int) – 如果在训练期间,表示总的轮次。评估器引擎可以从训练器中获取它。

返回类型:

None

SupervisedEvaluator#

class monai.engines.SupervisedEvaluator(device, val_data_loader, network, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, inferer=None, postprocessing=None, key_val_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, val_handlers=None, amp=False, mode=eval, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None, compile=False, compile_kwargs=None)[source]#

标准的有监督评估方法,使用图像和(可选的)标签,继承自 evaluator 和 Workflow。

参数:
  • device – 表示运行设备的类或对象。

  • val_data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable,通常是 torch.DataLoader。

  • network – 在评估器中评估的网络,应为标准的 PyTorch torch.nn.Module

  • epoch_length – 一个轮次的迭代次数,默认为 len(val_data_loader)

  • non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。

  • prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 imagelabel 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html

  • iteration_update – 每次迭代的可调用函数,预期接受 engineengine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html

  • inferer – 在输入数据上执行模型前向计算的推理方法,例如:SlidingWindow 等。

  • postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。

  • key_val_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_val_metric 是用于比较和将检查点保存到文件中的主要指标。

  • additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。

  • metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metricbest_metric_epoch,默认为 大于

  • val_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。

  • amp – 是否启用自动混合精度评估,默认为 False。

  • mode – 评估期间模型的向前传播模式,应为 ‘eval’ 或 ‘train’,分别映射到 model.eval()model.train(),默认为 ‘eval’。

  • event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum

  • event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。

  • decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True

  • to_kwargsprepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking

  • amp_kwargstorch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast

  • compile – 是否使用 torch.compile,默认为 False。如果为 True,MetaTensor 输入将在前向传播之前转换为 torch.Tensor,然后携带复制的元信息转换回去。

  • compile_kwargstorch.compile() API 的参数字典,更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.compile.html#torch-compile

EnsembleEvaluator#

class monai.engines.EnsembleEvaluator(device, val_data_loader, networks, pred_keys=None, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, inferer=None, postprocessing=None, key_val_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, val_handlers=None, amp=False, mode=eval, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None)[source]#

用于多个模型的集成评估,继承自 evaluator 和 Workflow。它接受模型列表进行推理,并输出预测列表以进行进一步操作。

参数:
  • device – 表示运行设备的类或对象。

  • val_data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable,通常是 torch.DataLoader。

  • epoch_length – 一个轮次的迭代次数,默认为 len(val_data_loader)

  • networks – 在评估器中按顺序评估的网络,应为标准的 PyTorch torch.nn.Module

  • pred_keys – 存储每个预测数据的键。长度必须与网络数量完全匹配。如果为 None,则使用 “pred_{index}” 作为对应 N 个网络的键,索引从 0N-1

  • non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。

  • prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 imagelabel 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html

  • iteration_update – 每次迭代的可调用函数,预期接受 engineengine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html

  • inferer – 在输入数据上执行模型前向计算的推理方法,例如:SlidingWindow 等。

  • postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。

  • key_val_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_val_metric 是用于比较和将检查点保存到文件中的主要指标。

  • additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。

  • metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metricbest_metric_epoch,默认为 大于

  • val_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。

  • amp – 是否启用自动混合精度评估,默认为 False。

  • mode – 评估期间模型的向前传播模式,应为 ‘eval’ 或 ‘train’,分别映射到 model.eval()model.train(),默认为 ‘eval’。

  • event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum

  • event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。

  • decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True

  • to_kwargsprepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking

  • amp_kwargstorch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast

工具类#

class monai.engines.utils.DiffusionPrepareBatch(num_train_timesteps, condition_name=None)[source]#

此类用作扩散训练引擎类中 prepare_batch 参数的可调用对象。

假设是监督训练过程,它将使用 get_noise 为输入图像生成噪声场,并将图像和噪声场作为图像/目标对返回,同时在 kwargs 中以键 “noise” 存储噪声场。这假设与此类一起使用的推理器期望提供一个 “noise” 参数。

如果提供了 condition_name,这必须指向输入字典中包含要传递给推理器的条件字段的键。这将在关键字参数中以键 “condition” 出现。

get_noise(images)[source]#

返回输入张量 images 的噪声张量,对于不同的噪声分布请覆盖此方法。

返回类型:

Tensor

get_target(images, noise, timesteps)[source]#

返回损失函数的目标,默认为 noise 值。

返回类型:

Tensor

get_timesteps(images)[source]#

获取一个时间步长,默认为 0 和 self.num_train_timesteps 之间的随机整数。

返回类型:

Tensor

class monai.engines.utils.IterationEvents(value, event_filter=None, name=None)[source]#

引擎可以在迭代过程中注册和触发的附加事件。参考 ignite 中的示例:https://pytorch.ac.cn/ignite/generated/ignite.engine.events.EventEnum.html。这些事件可以在训练迭代期间触发:FORWARD_COMPLETEDnetwork(image, label) 完成时的事件。LOSS_COMPLETEDloss(pred, label) 完成时的事件。BACKWARD_COMPLETEDloss.backward() 完成时的事件。MODEL_COMPLETED 是所有模型相关操作完成时的事件。INNER_ITERATION_STARTED 是当迭代具有内循环且内循环开始时的事件。INNER_ITERATION_COMPLETED 是当迭代具有内循环且内循环完成时的事件。

class monai.engines.utils.PrepareBatch[source]#

训练器或评估器工作流中自定义 prepare_batch 的接口。它接受当前批次的数据、目标设备和 non_blocking 标志作为输入。参数 batchdatadevicenon_blocking 参考 ignite API:https://pytorch.ac.cn/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.htmlkwargs 支持 Tensor.to() API 的其他参数。

class monai.engines.utils.PrepareBatchDefault[source]#

这封装了 default_prepare_batch 以仅返回 imagelabel,因此与其 API 一致。

class monai.engines.utils.PrepareBatchExtraInput(extra_keys)[source]#

支持网络额外输入数据的训练器或评估器的自定义 prepare batch 可调用对象。额外项由 extra_keys 参数指定,并从输入字典(即批次)中提取。这使用 default_prepare_batch 但需要字典输入。

参数:

extra_keys – 如果提供字符串或字符串序列,将从输入字典中提取这些键的值并作为额外的位置参数传递给网络。如果提供字典,则该字典中的每对 (k, v) 将成为一个新的关键字参数,将输入字典中键为 v 的值赋给 k

class monai.engines.utils.VPredictionPrepareBatch(scheduler, num_train_timesteps, condition_name=None)[source]#

此类用作扩散训练引擎类中 prepare_batch 参数的可调用对象。

假设是监督训练过程,它将使用 get_noise 为输入图像生成噪声场,并由此使用提供的调度器计算速度。此值用作目标,取代噪声场本身,尽管噪声场在 kwargs 中以键 “noise” 存储。这假设与此类一起使用的推理器期望提供一个 “noise” 参数。

如果提供了 condition_name,这必须指向输入字典中包含要传递给推理器的条件字段的键。这将在关键字参数中以键 “condition” 出现。

get_target(images, noise, timesteps)[source]#

返回损失函数的目标,默认为 noise 值。

monai.engines.utils.default_metric_cmp_fn(current_metric, prev_best)[source]#

用于比较当前指标值与先前最佳指标值的默认函数。

参数:
  • current_metric (float) – 当前轮次计算的指标值。

  • prev_best (float) – 要比较的先前轮次的最佳指标值。

返回类型:

bool

monai.engines.utils.default_prepare_batch(batchdata, device=None, non_blocking=False, **kwargs)[source]#

为当前迭代准备数据的默认函数。

输入 batchdata 可以是单个张量、一对张量或数据字典。在第一种情况下,返回值为张量和 None,在第二种情况下,返回值为两个张量,在字典情况下,返回值取决于存在哪些键。如果存在 CommonKeys.IMAGECommonKeys.LABEL,则返回它们对应的张量,如果仅存在 CommonKeys.IMAGE,则返回该张量和 None。如果存在 CommonKeys.REALS,则返回此项和 None。所有返回的张量在返回之前都会使用给定的 non-blocking 参数移动到指定设备上。

此函数实现了 Ignite 中 prepare_batch 可调用对象的预期 API:https://pytorch.ac.cn/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html

参数:
  • batchdata – 输入批次数据,可以是单个张量、一对张量或字典

  • device – 将每个返回的张量移动到的设备

  • non_blockingTensor.to 的等效参数

  • kwargsTensor.to 的其他参数

返回:

图像,标签(可选)。

monai.engines.utils.engine_apply_transform(batch, output, transform)[source]#

batchoutput 应用变换。如果 batchoutput 是字典,则暂时组合它们进行变换,否则,仅对 output 数据应用变换。

返回类型:

tuple[Any, Any]

monai.engines.utils.get_devices_spec(devices=None)[source]#

获取一个或多个设备的有效规范。如果 devices 为 None,则获取所有可用 CUDA 设备。如果 devices 是长度为零的结构,则返回单个 CPU 计算设备。在任何其他情况下,devices 将保持不变返回。

参数:

devices – 要请求的设备列表,None 表示所有 GPU 设备,[] 表示 CPU。

抛出:

RuntimeError – 当选择了所有 GPU(devices=None)但没有可用的 GPU 时。

返回:

设备列表。

返回类型:

list of torch.device