引擎#
工作流#
Workflow#
- class monai.engines.Workflow(device, max_epochs, data_loader, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, postprocessing=None, key_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, handlers=None, amp=False, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None)[source]#
Workflow 定义了继承自 Ignite 引擎的核心工作流程。所有的训练器(trainer)、验证器(validator)和评估器(evaluator)都共享此工作流作为基类,因为它们都可以被视为相同的 Ignite 引擎循环。它在 Ignite engine.state 中初始化所有可共享的数据。并基于事件处理(Event-Handler)机制将附加处理逻辑附加到 Ignite 引擎。
用户应考虑继承 trainer 或 evaluator 来开发更多的训练器或评估器。
- 参数:
device – 表示运行设备的类或对象。
max_epochs – 引擎运行的总轮次(epoch)数,验证器和评估器只有 1 个轮次。
data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable 或 torch.DataLoader。
epoch_length – 一个轮次的迭代次数,默认为 len(data_loader)。
non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。
prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 image、label 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html。
iteration_update – 每次迭代的可调用函数,预期接受 engine 和 engine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html。
postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。
key_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_metric 是用于比较和将检查点保存到文件中的主要指标。
additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。
metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metric 和 best_metric_epoch,默认为 大于。
handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。
amp – 是否启用自动混合精度训练或推理,默认为 False。
event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum。
event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。
decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True。
to_kwargs – prepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking。
amp_kwargs – torch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast。
- 抛出:
TypeError – 当
data_loader
不是torch.utils.data.DataLoader
时。TypeError – 当
key_metric
不是Optional[dict]
时。TypeError – 当
additional_metrics
不是Optional[dict]
时。
Trainer#
- class monai.engines.Trainer(device, max_epochs, data_loader, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, postprocessing=None, key_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, handlers=None, amp=False, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None)[source]#
所有训练器的基类,继承自 Workflow。
SupervisedTrainer#
- class monai.engines.SupervisedTrainer(device, max_epochs, train_data_loader, network, optimizer, loss_function, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, inferer=None, postprocessing=None, key_train_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, train_handlers=None, amp=False, event_names=None, event_to_attr=None, decollate=True, optim_set_to_none=False, to_kwargs=None, amp_kwargs=None, compile=False, compile_kwargs=None)[source]#
标准的有监督训练方法,使用图像和标签,继承自
Trainer
和Workflow
。- 参数:
device – 表示运行设备的类或对象。
max_epochs – 训练器运行的总轮次(epoch)数。
train_data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable 或 torch.DataLoader。
network – 在训练器中训练的网络,应为标准的 PyTorch torch.nn.Module。
optimizer – 与网络相关的优化器,应为 torch.optim 或其子类中的标准 PyTorch 优化器。
loss_function – 与优化器相关的损失函数,应为标准的 PyTorch 损失函数,继承自 torch.nn.modules.loss。
epoch_length – 一个轮次的迭代次数,默认为 len(train_data_loader)。
non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。
prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 image、label 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html。
iteration_update – 每次迭代的可调用函数,预期接受 engine 和 engine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html。
inferer – 在输入数据上执行模型前向计算的推理方法,例如:SlidingWindow 等。
postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。
key_train_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_train_metric 是用于比较和将检查点保存到文件中的主要指标。
additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。
metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metric 和 best_metric_epoch,默认为 大于。
train_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。
amp – 是否启用自动混合精度训练,默认为 False。
event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum。
event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。
decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True。
optim_set_to_none – 调用 optimizer.zero_grad() 时,将梯度设置为 None 而不是零。更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.optim.Optimizer.zero_grad.html。
to_kwargs – prepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking。
amp_kwargs – torch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast。
compile – 是否使用 torch.compile,默认为 False。如果为 True,MetaTensor 输入将在前向传播之前转换为 torch.Tensor,然后携带复制的元信息转换回去。
compile_kwargs – torch.compile() API 的参数字典,更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.compile.html#torch-compile。
GanTrainer#
- class monai.engines.GanTrainer(device, max_epochs, train_data_loader, g_network, g_optimizer, g_loss_function, d_network, d_optimizer, d_loss_function, epoch_length=None, g_inferer=None, d_inferer=None, d_train_steps=1, latent_shape=64, non_blocking=False, d_prepare_batch=<function default_prepare_batch>, g_prepare_batch=<function default_make_latent>, g_update_latents=True, iteration_update=None, postprocessing=None, key_train_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, train_handlers=None, decollate=True, optim_set_to_none=False, to_kwargs=None, amp_kwargs=None)[source]#
基于 Goodfellow 等人 2014 年论文 https://arxiv.org/abs/1406.266 的生成对抗网络训练,继承自
Trainer
和Workflow
。- 训练循环:对于每个数据批次大小为 m 的批次
从随机潜在代码生成 m 个假样本。
使用这些假样本和当前批次的真实样本更新判别器,重复 d_train_steps 次。
如果 g_update_latents 为 True,从新的随机潜在代码生成 m 个假样本。
使用这些假样本和判别器的反馈更新生成器。
- 参数:
device – 表示运行设备的类或对象。
max_epochs – 引擎运行的总轮次(epoch)数。
train_data_loader – 核心 ignite 引擎使用 DataLoader 进行训练循环批次数据加载。
g_network – 生成器 (G) 网络架构。
g_optimizer – G 优化器函数。
g_loss_function – G 优化器的损失函数。
d_network – 判别器 (D) 网络架构。
d_optimizer – D 优化器函数。
d_loss_function – D 优化器的损失函数。
epoch_length – 一个轮次的迭代次数,默认为 len(train_data_loader)。
g_inferer – 执行 G 模型前向计算的推理方法。默认为
SimpleInferer()
。d_inferer – 执行 D 模型前向计算的推理方法。默认为
SimpleInferer()
。d_train_steps – 使用真实数据小批次更新 D 的次数。默认为
1
。latent_shape – G 输入潜在代码的大小。默认为
64
。non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。
d_prepare_batch – 为 D 推理器准备批次数据的回调函数。默认为在批次数据字典中返回
GanKeys.REALS
。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html。g_prepare_batch – 为 G 推理器创建潜在输入批次的回调函数。默认为返回随机潜在代码。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html。
g_update_latents – 使用新的潜在代码计算 G 损失。默认为
True
。iteration_update – 每次迭代的可调用函数,预期接受 engine 和 engine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html。
postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。
key_train_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_train_metric 是用于比较和将检查点保存到文件中的主要指标。
additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。
metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metric 和 best_metric_epoch,默认为 大于。
train_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。
decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True。
optim_set_to_none – 调用 optimizer.zero_grad() 时,将梯度设置为 None 而不是零。更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.optim.Optimizer.zero_grad.html。
to_kwargs – prepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking。
amp_kwargs – torch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast。
AdversarialTrainer#
- class monai.engines.AdversarialTrainer(device, max_epochs, train_data_loader, g_network, g_optimizer, g_loss_function, recon_loss_function, d_network, d_optimizer, d_loss_function, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, g_inferer=None, d_inferer=None, postprocessing=None, key_train_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, train_handlers=None, amp=False, event_names=None, event_to_attr=None, decollate=True, optim_set_to_none=False, to_kwargs=None, amp_kwargs=None)[source]#
用于启用对抗性损失的神经网络的标准有监督训练工作流。
- 参数:
device – 表示运行设备的类或对象。
max_epochs – 引擎运行的总轮次(epoch)数。
train_data_loader – 核心 ignite 引擎使用 DataLoader 进行训练循环批次数据加载。
g_network – “生成器”(G)网络架构。
g_optimizer – G 优化器函数。
g_loss_function – 用于对抗性训练的 G 损失函数。
recon_loss_function – 用于重构的 G 损失函数。
d_network – 判别器 (D) 网络架构。
d_optimizer – D 优化器函数。
d_loss_function – 用于对抗性训练的 D 损失函数。
epoch_length – 一个轮次的迭代次数,默认为 len(train_data_loader)。
non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。
prepare_batch – 用于解析当前迭代的图像和标签的函数。
iteration_update – 每次迭代的可调用函数,预期接受 engine 和 batchdata 作为输入参数。如果未提供,则使用 self._iteration() 代替。
g_inferer – 执行 G 模型前向计算的推理方法。默认为
SimpleInferer()
。d_inferer – 执行 D 模型前向计算的推理方法。默认为
SimpleInferer()
。postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。默认为 None
key_train_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_train_metric 是用于比较和将检查点保存到文件中的主要指标。
additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。
metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 ‘best_metric` 和 best_metric_epoch,默认为 大于。
train_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。
amp – 是否启用自动混合精度训练,默认为 False。
event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum。
event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。
decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True。
optim_set_to_none – 调用 optimizer.zero_grad() 时,将梯度设置为 None 而不是零。更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.optim.Optimizer.zero_grad.html。
to_kwargs – prepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking。
amp_kwargs – torch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast。
Evaluator#
- class monai.engines.Evaluator(device, val_data_loader, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, postprocessing=None, key_val_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, val_handlers=None, amp=False, mode=eval, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None)[source]#
所有评估器的基类,继承自 Workflow。
- 参数:
device – 表示运行设备的类或对象。
val_data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable,通常是 torch.DataLoader。
epoch_length – 一个轮次的迭代次数,默认为 len(val_data_loader)。
non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。
prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 image、label 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html。
iteration_update – 每次迭代的可调用函数,预期接受 engine 和 engine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html。
postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。
key_val_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_val_metric 是用于比较和将检查点保存到文件中的主要指标。
additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。
metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metric 和 best_metric_epoch,默认为 大于。
val_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。
amp – 是否启用自动混合精度评估,默认为 False。
mode – 评估期间模型的向前传播模式,应为 ‘eval’ 或 ‘train’,分别映射到 model.eval() 或 model.train(),默认为 ‘eval’。
event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum。
event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。
decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True。
to_kwargs – prepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking。
amp_kwargs – torch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast。
SupervisedEvaluator#
- class monai.engines.SupervisedEvaluator(device, val_data_loader, network, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, inferer=None, postprocessing=None, key_val_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, val_handlers=None, amp=False, mode=eval, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None, compile=False, compile_kwargs=None)[source]#
标准的有监督评估方法,使用图像和(可选的)标签,继承自 evaluator 和 Workflow。
- 参数:
device – 表示运行设备的类或对象。
val_data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable,通常是 torch.DataLoader。
network – 在评估器中评估的网络,应为标准的 PyTorch torch.nn.Module。
epoch_length – 一个轮次的迭代次数,默认为 len(val_data_loader)。
non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。
prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 image、label 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html。
iteration_update – 每次迭代的可调用函数,预期接受 engine 和 engine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html。
inferer – 在输入数据上执行模型前向计算的推理方法,例如:SlidingWindow 等。
postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。
key_val_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_val_metric 是用于比较和将检查点保存到文件中的主要指标。
additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。
metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metric 和 best_metric_epoch,默认为 大于。
val_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。
amp – 是否启用自动混合精度评估,默认为 False。
mode – 评估期间模型的向前传播模式,应为 ‘eval’ 或 ‘train’,分别映射到 model.eval() 或 model.train(),默认为 ‘eval’。
event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum。
event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。
decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True。
to_kwargs – prepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking。
amp_kwargs – torch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast。
compile – 是否使用 torch.compile,默认为 False。如果为 True,MetaTensor 输入将在前向传播之前转换为 torch.Tensor,然后携带复制的元信息转换回去。
compile_kwargs – torch.compile() API 的参数字典,更多详情:https://pytorch.ac.cn/docs/stable/generated/torch.compile.html#torch-compile。
EnsembleEvaluator#
- class monai.engines.EnsembleEvaluator(device, val_data_loader, networks, pred_keys=None, epoch_length=None, non_blocking=False, prepare_batch=<function default_prepare_batch>, iteration_update=None, inferer=None, postprocessing=None, key_val_metric=None, additional_metrics=None, metric_cmp_fn=<function default_metric_cmp_fn>, val_handlers=None, amp=False, mode=eval, event_names=None, event_to_attr=None, decollate=True, to_kwargs=None, amp_kwargs=None)[source]#
用于多个模型的集成评估,继承自 evaluator 和 Workflow。它接受模型列表进行推理,并输出预测列表以进行进一步操作。
- 参数:
device – 表示运行设备的类或对象。
val_data_loader – Ignite 引擎用于运行的数据加载器,必须是 Iterable,通常是 torch.DataLoader。
epoch_length – 一个轮次的迭代次数,默认为 len(val_data_loader)。
networks – 在评估器中按顺序评估的网络,应为标准的 PyTorch torch.nn.Module。
pred_keys – 存储每个预测数据的键。长度必须与网络数量完全匹配。如果为 None,则使用 “pred_{index}” 作为对应 N 个网络的键,索引从 0 到 N-1。
non_blocking – 如果为 True 且此复制发生在 CPU 和 GPU 之间,则复制可能相对于主机异步进行。对于其他情况,此参数无效。
prepare_batch – 用于在每次迭代中从 engine.state.batch 中解析预期数据(通常是 image、label 和其他网络参数)的函数,更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.create_supervised_trainer.html。
iteration_update – 每次迭代的可调用函数,预期接受 engine 和 engine.state.batch 作为输入,返回的数据将存储在 engine.state.output 中。如果未提供,则使用 self._iteration() 代替。更多详情请参考:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html。
inferer – 在输入数据上执行模型前向计算的推理方法,例如:SlidingWindow 等。
postprocessing – 对模型输出数据执行附加的变换。通常是由 Compose 组合的几个基于 Tensor 的变换。
key_val_metric – 在每次迭代完成后计算指标,并在轮次完成后将平均值保存到 engine.state.metrics。key_val_metric 是用于比较和将检查点保存到文件中的主要指标。
additional_metrics – 更多也附加到 Ignite 引擎的 Ignite 指标。
metric_cmp_fn – 用于比较当前关键指标与先前最佳关键指标值的函数,它必须接受 2 个参数(current_metric, previous_best)并返回一个布尔结果:如果为 True,将使用当前指标和轮次更新 best_metric 和 best_metric_epoch,默认为 大于。
val_handlers – 每个处理程序都是一组 Ignite 事件处理程序(Event-Handlers),必须具有 attach 函数,例如:CheckpointHandler, StatsHandler 等。
amp – 是否启用自动混合精度评估,默认为 False。
mode – 评估期间模型的向前传播模式,应为 ‘eval’ 或 ‘train’,分别映射到 model.eval() 或 model.train(),默认为 ‘eval’。
event_names – 将注册到引擎的其他自定义 Ignite 事件。新事件可以是字符串列表或 ignite.engine.events.EventEnum。
event_to_attr – 一个将事件映射到状态属性,然后添加到 engine.state 的字典。更多详情请查看:https://pytorch.ac.cn/ignite/generated/ignite.engine.engine.Engine.html #ignite.engine.engine.Engine.register_events。
decollate – 模型计算后是否将 batch-first 数据去批次化(decollate)为数据列表,当 postprocessing 使用 monai.transforms 中的组件时,建议设置 decollate=True。默认为 True。
to_kwargs – prepare_batch API 在转换输入数据时使用的其他参数字典,除了 device, non_blocking。
amp_kwargs – torch.cuda.amp.autocast() API 的参数字典,更多详情请查看:https://pytorch.ac.cn/docs/stable/amp.html#torch.cuda.amp.autocast。
工具类#
- class monai.engines.utils.DiffusionPrepareBatch(num_train_timesteps, condition_name=None)[source]#
此类用作扩散训练引擎类中 prepare_batch 参数的可调用对象。
假设是监督训练过程,它将使用 get_noise 为输入图像生成噪声场,并将图像和噪声场作为图像/目标对返回,同时在 kwargs 中以键 “noise” 存储噪声场。这假设与此类一起使用的推理器期望提供一个 “noise” 参数。
如果提供了 condition_name,这必须指向输入字典中包含要传递给推理器的条件字段的键。这将在关键字参数中以键 “condition” 出现。
- class monai.engines.utils.IterationEvents(value, event_filter=None, name=None)[source]#
引擎可以在迭代过程中注册和触发的附加事件。参考 ignite 中的示例:https://pytorch.ac.cn/ignite/generated/ignite.engine.events.EventEnum.html。这些事件可以在训练迭代期间触发:FORWARD_COMPLETED 是 network(image, label) 完成时的事件。LOSS_COMPLETED 是 loss(pred, label) 完成时的事件。BACKWARD_COMPLETED 是 loss.backward() 完成时的事件。MODEL_COMPLETED 是所有模型相关操作完成时的事件。INNER_ITERATION_STARTED 是当迭代具有内循环且内循环开始时的事件。INNER_ITERATION_COMPLETED 是当迭代具有内循环且内循环完成时的事件。
- class monai.engines.utils.PrepareBatch[source]#
训练器或评估器工作流中自定义 prepare_batch 的接口。它接受当前批次的数据、目标设备和 non_blocking 标志作为输入。参数 batchdata、device、non_blocking 参考 ignite API:https://pytorch.ac.cn/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html。kwargs 支持 Tensor.to() API 的其他参数。
- class monai.engines.utils.PrepareBatchDefault[source]#
这封装了 default_prepare_batch 以仅返回 image 和 label,因此与其 API 一致。
- class monai.engines.utils.PrepareBatchExtraInput(extra_keys)[source]#
支持网络额外输入数据的训练器或评估器的自定义 prepare batch 可调用对象。额外项由 extra_keys 参数指定,并从输入字典(即批次)中提取。这使用 default_prepare_batch 但需要字典输入。
- 参数:
extra_keys – 如果提供字符串或字符串序列,将从输入字典中提取这些键的值并作为额外的位置参数传递给网络。如果提供字典,则该字典中的每对 (k, v) 将成为一个新的关键字参数,将输入字典中键为 v 的值赋给 k。
- class monai.engines.utils.VPredictionPrepareBatch(scheduler, num_train_timesteps, condition_name=None)[source]#
此类用作扩散训练引擎类中 prepare_batch 参数的可调用对象。
假设是监督训练过程,它将使用 get_noise 为输入图像生成噪声场,并由此使用提供的调度器计算速度。此值用作目标,取代噪声场本身,尽管噪声场在 kwargs 中以键 “noise” 存储。这假设与此类一起使用的推理器期望提供一个 “noise” 参数。
如果提供了 condition_name,这必须指向输入字典中包含要传递给推理器的条件字段的键。这将在关键字参数中以键 “condition” 出现。
- monai.engines.utils.default_metric_cmp_fn(current_metric, prev_best)[source]#
用于比较当前指标值与先前最佳指标值的默认函数。
- 参数:
current_metric (
float
) – 当前轮次计算的指标值。prev_best (
float
) – 要比较的先前轮次的最佳指标值。
- 返回类型:
bool
- monai.engines.utils.default_prepare_batch(batchdata, device=None, non_blocking=False, **kwargs)[source]#
为当前迭代准备数据的默认函数。
输入 batchdata 可以是单个张量、一对张量或数据字典。在第一种情况下,返回值为张量和 None,在第二种情况下,返回值为两个张量,在字典情况下,返回值取决于存在哪些键。如果存在 CommonKeys.IMAGE 和 CommonKeys.LABEL,则返回它们对应的张量,如果仅存在 CommonKeys.IMAGE,则返回该张量和 None。如果存在 CommonKeys.REALS,则返回此项和 None。所有返回的张量在返回之前都会使用给定的 non-blocking 参数移动到指定设备上。
此函数实现了 Ignite 中 prepare_batch 可调用对象的预期 API:https://pytorch.ac.cn/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html
- 参数:
batchdata – 输入批次数据,可以是单个张量、一对张量或字典
device – 将每个返回的张量移动到的设备
non_blocking – Tensor.to 的等效参数
kwargs – Tensor.to 的其他参数
- 返回:
图像,标签(可选)。
- monai.engines.utils.engine_apply_transform(batch, output, transform)[source]#
对 batch 和 output 应用变换。如果 batch 和 output 是字典,则暂时组合它们进行变换,否则,仅对 output 数据应用变换。
- 返回类型:
tuple
[Any
,Any
]
- monai.engines.utils.get_devices_spec(devices=None)[source]#
获取一个或多个设备的有效规范。如果 devices 为 None,则获取所有可用 CUDA 设备。如果 devices 是长度为零的结构,则返回单个 CPU 计算设备。在任何其他情况下,devices 将保持不变返回。
- 参数:
devices – 要请求的设备列表,None 表示所有 GPU 设备,[] 表示 CPU。
- 抛出:
RuntimeError – 当选择了所有 GPU(
devices=None
)但没有可用的 GPU 时。- 返回:
设备列表。
- 返回类型:
list of torch.device