anomalib1.0学习纪实-续3:结合python lightning理思路

news2024/11/28 16:38:25

一、python lightning

python lightning是个好东西,但不见得那么友好。

GPT4给我讲解了他的用法:

 

 

二、anomalib的思路

 1、 创建一个Lightning Module。

首先,在src\anomalib\models\components\base\anomaly_module.py中, 

class AnomalyModule(pl.LightningModule, ABC):
    """AnomalyModule to train, validate, predict and test images.

    Acts as a base class for all the Anomaly Modules in the library.
    """

    def __init__(self) -> None:
        super().__init__()
        logger.info("Initializing %s model.", self.__class__.__name__)

        self.save_hyperparameters()
        self.model: nn.Module
        self.loss: nn.Module
        self.callbacks: list[Callback]

        self.image_threshold: BaseThreshold
        self.pixel_threshold: BaseThreshold

        self.normalization_metrics: Metric

        self.image_metrics: AnomalibMetricCollection
        self.pixel_metrics: AnomalibMetricCollection

    def forward(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> Any:  # noqa: ANN401
        """Perform the forward-pass by passing input tensor to the module.

        Args:
            batch (dict[str, str | torch.Tensor]): Input batch.
            *args: Arguments.
            **kwargs: Keyword arguments.

        Returns:
            Tensor: Output tensor from the model.
        """
        del args, kwargs  # These variables are not used.

        return self.model(batch)

    def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
        """To be implemented in the subclasses."""
        raise NotImplementedError

    def predict_step(
        self,
        batch: dict[str, str | torch.Tensor],
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> STEP_OUTPUT:
        """Step function called during :meth:`~lightning.pytorch.trainer.Trainer.predict`.

        By default, it calls :meth:`~lightning.pytorch.core.lightning.LightningModule.forward`.
        Override to add any processing logic.

        Args:
            batch (Any): Current batch
            batch_idx (int): Index of current batch
            dataloader_idx (int): Index of the current dataloader

        Return:
            Predicted output
        """
        del batch_idx, dataloader_idx  # These variables are not used.

        return self.validation_step(batch)
。。。以下省略

定义了一堆类似 def forward的虚函数,都有待于他的之类去实现。

这里就要说一下,在python中,也有类似c++中虚函数的概念吗?

GPT给了我回答,是的。只不过,在c++中,虚函数需要明确指出,但是在python中,在Python中实现类似C++虚函数的行为,主要依靠方法重写(Override)。当子类重写了父类的方法时,无论是通过对象直接调用该方法,还是通过父类的接口调用,实际执行的都是子类中重写的方法。这使得我们可以在子类中改变或扩展在父类中定义的行为,这与C++中虚函数的目的是一致的。

所以我们看一下,在我们自己搞的Ddad类中,如何重写了AnomalyModule类的一些方法。

在src\anomalib\models\image\ddad\lightning_model.py中,

class Ddad(MemoryBankMixin, AnomalyModule):
    """Ddad: a Patch Distribution Modeling Framework for Anomaly Detection and Localization.

    

。。。省略

    @staticmethod
    def configure_optimizers() -> None:
        """Ddad doesn't require optimization, therefore returns no optimizers."""
        return
    

    def on_train_epoch_start (self)-> None:
        print("----------------------------------------on_train_epoch_start")


    def prepare_data(self) -> None:
         print("----------------------------------------prepare_data")

    def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None:
        print("---------------training_step")
        """Perform the training step of Ddad. For each batch, hierarchical features are extracted from the CNN.

        Args:
            batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask
            args: Additional arguments.
            kwargs: Additional keyword arguments.

        Returns:
            Hierarchical feature map
        """
        
        
        del args, kwargs  # These variables are not used.
        
        self.model.feature_extractor.eval()
        embedding = self.model(batch["image"])

        self.embeddings.append(embedding.cpu())
        

    def fit(self) -> None:

        """Fit a Gaussian to the embedding collected from the training set."""
        
        logger.info("Aggregating the embedding extracted from the training set.")
        embeddings = torch.vstack(self.embeddings)

        logger.info("Fitting a Gaussian to the embedding collected from the training set.")
        self.stats = self.model.gaussian.fit(embeddings)
        

    def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
        """Perform a validation step of PADIM.

        Similar to the training step, hierarchical features are extracted from the CNN for each batch.

        Args:
            batch (dict[str, str | torch.Tensor]): Input batch
            args: Additional arguments.
            kwargs: Additional keyword arguments.

        Returns:
            Dictionary containing images, features, true labels and masks.
            These are required in `validation_epoch_end` for feature concatenation.
        """

        
        del args, kwargs  # These variables are not used.

        batch["anomaly_maps"] = self.model(batch["image"])
        return batch
        

        return 0

你看,重写了training_step、fit、validation_step等重要函数。

 2、 准备数据

在src\anomalib\data\base\dataset.py中,定义了Dataset的一个之类AnomalibDataset,如下:

 

class AnomalibDataset(Dataset, ABC):
    """Anomalib dataset.

    Args:
        task (str): Task type, either 'classification' or 'segmentation'
        transform (A.Compose): Albumentations Compose object describing the transforms that are applied to the inputs.
    """

    def __init__(self, task: TaskType, transform: A.Compose) -> None:
        super().__init__()
        self.task = task
        self.transform = transform
        self._samples: DataFrame

    def __len__(self) -> int:
        """Get length of the dataset."""
        return len(self.samples)

    def subsample(self, indices: Sequence[int], inplace: bool = False) -> "AnomalibDataset":
        """Subsamples the dataset at the provided indices.

        Args:
            indices (Sequence[int]): Indices at which the dataset is to be subsampled.
            inplace (bool): When true, the subsampling will be performed on the instance itself.
                Defaults to ``False``.
        """
        assert len(set(indices)) == len(indices), "No duplicates allowed in indices."
        dataset = self if inplace else copy.deepcopy(self)
        dataset.samples = self.samples.iloc[indices].reset_index(drop=True)
        return dataset

    @property
    def is_setup(self) -> bool:
        """Checks if setup() been called."""
        return hasattr(self, "_samples")

    @property
    def samples(self) -> DataFrame:
        """Get the samples dataframe."""
        if not self.is_setup:
            msg = "Dataset is not setup yet. Call setup() first."
            raise RuntimeError(msg)
        return self._samples

    @samples.setter
    def samples(self, samples: DataFrame) -> None:
        """Overwrite the samples with a new dataframe.

        Args:
            samples (DataFrame): DataFrame with new samples.
        """
        # validate the passed samples by checking the
        assert isinstance(samples, DataFrame), f"samples must be a pandas.DataFrame, found {type(samples)}"
        expected_columns = _EXPECTED_COLUMNS_PERTASK[self.task]
        assert all(
            col in samples.columns for col in expected_columns
        ), f"samples must have (at least) columns {expected_columns}, found {samples.columns}"
        assert samples["image_path"].apply(lambda p: Path(p).exists()).all(), "missing file path(s) in samples"

        self._samples = samples.sort_values(by="image_path", ignore_index=True)

    @property
    def has_normal(self) -> bool:
        """Check if the dataset contains any normal samples."""
        return 0 in list(self.samples.label_index)

    @property
    def has_anomalous(self) -> bool:
        """Check if the dataset contains any anomalous samples."""
        return 1 in list(self.samples.label_index)

    def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
        """Get dataset item for the index ``index``.

        Args:
            index (int): Index to get the item.

        Returns:
            dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path,
                target path, image tensor, label and transformed bounding box.
        """
        image_path = self._samples.iloc[index].image_path
        mask_path = self._samples.iloc[index].mask_path
        label_index = self._samples.iloc[index].label_index

        image = read_image(image_path)
        item = {"image_path": image_path, "label": label_index}

        if self.task == TaskType.CLASSIFICATION:
            transformed = self.transform(image=image)
            item["image"] = transformed["image"]
        elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
            # Only Anomalous (1) images have masks in anomaly datasets
            # Therefore, create empty mask for Normal (0) images.

            mask = np.zeros(shape=image.shape[:2]) if label_index == 0 else cv2.imread(mask_path, flags=0) / 255.0
            mask = mask.astype(np.single)

            transformed = self.transform(image=image, mask=mask)

            item["image"] = transformed["image"]
            item["mask_path"] = mask_path
            item["mask"] = transformed["mask"]

            if self.task == TaskType.DETECTION:
                # create boxes from masks for detection task
                boxes, _ = masks_to_boxes(item["mask"])
                item["boxes"] = boxes[0]
        else:
            msg = f"Unknown task type: {self.task}"
            raise ValueError(msg)

        return item

    def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset":
        """Concatenate this dataset with another dataset.

        Args:
            other_dataset (AnomalibDataset): Dataset to concatenate with.

        Returns:
            AnomalibDataset: Concatenated dataset.
        """
        assert isinstance(other_dataset, self.__class__), "Cannot concatenate datasets that are not of the same type."
        assert self.is_setup, "Cannot concatenate uninitialized datasets. Call setup first."
        assert other_dataset.is_setup, "Cannot concatenate uninitialized datasets. Call setup first."
        dataset = copy.deepcopy(self)
        dataset.samples = pd.concat([self.samples, other_dataset.samples], ignore_index=True)
        return dataset

    def setup(self) -> None:
        """Load data/metadata into memory."""
        if not self.is_setup:
            self._setup()
        assert self.is_setup, "setup() should set self._samples"

    @abstractmethod
    def _setup(self) -> DataFrame:
        """Set up the data module.

        This method should return a dataframe that contains the information needed by the dataloader to load each of
        the dataset items into memory.

        The DataFrame must, at least, include the following columns:
            - `split` (str): The subset to which the dataset item is assigned (e.g., 'train', 'test').
            - `image_path` (str): Path to the file system location where the image is stored.
            - `label_index` (int): Index of the anomaly label, typically 0 for 'normal' and 1 for 'anomalous'.
            - `mask_path` (str, optional): Path to the ground truth masks (for the anomalous images only).
            Required if task is 'segmentation'.

        Example DataFrame:
            +---+-------------------+-----------+-------------+------------------+-------+
            |   | image_path        | label     | label_index | mask_path        | split |
            +---+-------------------+-----------+-------------+------------------+-------+
            | 0 | path/to/image.png | anomalous | 1           | path/to/mask.png | train |
            +---+-------------------+-----------+-------------+------------------+-------+

        Note:
            The example above is illustrative and may need to be adjusted based on the specific dataset structure.
        """
        raise NotImplementedError

重写了__len__、__getitem__等重要函数。最终,通过def _setup(self) -> DataFrame:

获得了DataFrame如下

最后,在src\anomalib\data\base\datamodule.py的AnomalibDataModule 类中,

 后来,这个train_dataloader就被自动调用了。至于怎么被自动调用的,我还没看明白

欢迎留言指点一下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1456022.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Java SSM框架实现电影售票系统项目【项目源码】

基于java的SSM框架实现电影售票系统演示 SSM框架 当今流行的“SSM组合框架”是Spring SpringMVC MyBatis的缩写,受到很多的追捧,“组合SSM框架”是强强联手、各司其职、协调互补的团队精神。web项目的框架,通常更简单的数据源。Spring属于…

生成式 AI - Diffusion 模型的数学原理(3)

来自 论文《 Denoising Diffusion Probabilistic Model》(DDPM) 论文链接: https://arxiv.org/abs/2006.11239 Hung-yi Lee 课件整理 文章目录 一、图像生成模型本质上的共同目标二、最大似然估计三、和VAE的关联四、概率计算 一、图像生成模…

LeetCode.590. N 叉树的后序遍历

题目 590. N 叉树的后序遍历 分析 我们之前有做过LeetCode的 145. 二叉树的后序遍历,其实对于 N 叉树来说和二叉树的思路是一模一样的。 二叉树的后序遍历是【左 右 根】 N叉树的后序遍历顺序是【孩子 根】,你可以把二叉树的【左 右 根】想象成【孩子…

MySQL为什么改进LRU算法?

普通LRU算法 LRU = Least Recently Used(最近最少使用): 就是末尾淘汰法,新数据从链表头部加入,释放空间时从末尾淘汰. 当要访问某个页时,如果不在Buffer Pool,需要把该页加载到缓冲池,并且把该缓冲页对应的控制块作为节点添加到LRU链表的头部。当要访问某个页时,如果在…

js设计模式:代理模式

作用: 创建代理的数据来复刻对原有数据的操作,并且可以添加自己的逻辑 vue中的data就是用的代理模式,比较经典 示例: let proxyFun (obj)>{return new Proxy(obj,{get:(obj,prop,value)>{return obj[prop]},set:(obj,prop,value)>{obj[prop] valuereturn true}})…

从阿里宜搭到吉客云通过接口配置打通数据

从阿里宜搭到吉客云通过接口配置打通数据 来源系统:阿里宜搭 宜搭是阿里巴巴自研的低代码应用搭建平台,传统情况下需要2周才能搭建完成的应用,用宜搭2小时就可完成。宜搭于2019年3月上线,用户可以在可视化界面上以拖拉拽的方式编辑和配置页面…

【漏洞复现-通达OA】通达OA report_bi存在前台SQL注入漏洞

一、漏洞简介 通达OA(Office Anywhere网络智能办公系统)是由北京通达信科科技有限公司自主研发的协同办公自动化软件,是与中国企业管理实践相结合形成的综合管理办公平台。通达OA为各行业不同规模的众多用户提供信息化管理能力,包括流程审批、行政办公、日常事务、数据统计…

二叉树和N叉数的遍历合集

二叉树和N叉数的遍历合集 二叉树的前序遍历 前序遍历的顺序是 根 -> 左儿子 -> 右儿子&#xff0c;所以我们直接按照这个顺序 dfs 就行 dfs class Solution { public:vector<int> preorderTraversal(TreeNode* root) {vector<int> res;function<void(…

如何在极低成本硬件上落地人工智能算法 —— 分布式AI

一、背景 分布式AI的发展前景非常广阔&#xff0c;随着5G、6G等高速网络通信技术的普及和边缘计算能力的提升&#xff0c;以及AI算法和硬件的不断优化进步&#xff0c;分布式AI将在多个领域展现出强大的应用潜力和市场价值&#xff1a; 1. **物联网&#xff08;IoT&#xff0…

unity学习(20)——客户端与服务器合力完成注册功能(2)调试注册逻辑

接着上一节的问题&#xff0c;想办法升级成具备数据库功能的服务器&#xff0c;这个是必须的。 至少在初始化要学会把文件转换为session&#xff0c;新知识&#xff0c;有挑战的。 现在是从LoginHandler.cs跳到了AccountBiz.cs的create&#xff0c;跳度还是很大的。 create函…

宝塔安装MySQL、设置MySQL密码、设置navicat连接

1、登录宝塔面板进行安装 2、设置MySQL连接密码 3、安装好了设置navicat连接 登录MySQL [roothecs-394544 ~]# mysql -uroot -p Enter password: 切换到MySQL数据 mysql> use mysql Database changed mysql> 查询用户信息 mysql> select host,user from user; ---…

一起玩儿物联网人工智能小车(ESP32)——63 SD和TF卡模块的使用

摘要&#xff1a;本文介绍SD和TF卡模块的使用方法 前面介绍了非易失性存储的使用方法&#xff0c;由于空间和本身只支持键值对的限制&#xff0c;非易失性存储只适用于少量数据的记录。而不适用于各种声音、图片、大量数据等情况的使用。这时候就需要有文件系统或者更大容量存…

无人机数据链技术,无人机数据链路系统技术详解,无人机数传技术

早期的无人机更多的为军事应用服务&#xff0c;如军事任务侦查等&#xff0c;随着技术和社会的发展&#xff0c;工业级无人机和民用无人机得到快速的发展&#xff0c;工业级无人机用于农业植保、地理测绘、电力巡检、救灾援助等&#xff1b;民用无人机用于航拍、物流等等领域。…

Unity之闪电侠大战蓝毒兽

目录 &#x1f3a8;一、创建地形 &#x1f3ae;二、创建角色 &#x1f3c3;2.1 动画 &#x1f3c3;2.2 拖尾 &#x1f3c3;2.3 角色控制 ​&#x1f3c3;2.4 技能释放 &#x1f3c3;2.5 准星 &#x1f4f1;三、创建敌人 &#x1f432;3.1 选择模型 &#x1f432;3.…

GitLab安装配置

一、GitLab的简介 GitLab是开源的代码托管平台&#xff0c;提供版本控制功能、代码审查、持续集成等工具&#xff0c;帮助团队协作开发软件项目。用户可以创建仓库存储代码&#xff0c;管理问题追踪&#xff0c;部署自动化流程等。 二、GitLab的安装 1、Rocky_Linux 下载安装 …

使用RK3588开发板使用 SFTP 互传-windows与开发板互传

MobaXterm 软件网盘下载路径&#xff1a;“iTOP-3588 开发板\02_【iTOP-RK3588 开发板】开发资料\04_iTOP-3588 开发板所需 PC 软件&#xff08;工具&#xff09;\02-MobaXterm”。 打开 MobaXterm 创建一个 SFTP 会话&#xff0c;如下图所示&#xff1a; 输入密码 topeet 进入…

《UE5_C++多人TPS完整教程》学习笔记19 ——《P20 我们子系统的回调函数(Callbacks to Our Subsystem)》

本文为B站系列教学视频 《UE5_C多人TPS完整教程》 —— 《P20 我们子系统的回调函数&#xff08;Callbacks to Our Subsystem&#xff09;》 的学习笔记&#xff0c;该系列教学视频为 Udemy 课程 《Unreal Engine 5 C Multiplayer Shooter》 的中文字幕翻译版&#xff0c;UP主&…

外汇天眼:小白开始实盘之前,必须知道的7件事

在进行外汇交易时&#xff0c;保持松弛的心态和学习外汇知识是一件很重要的事情&#xff0c;但对于缺乏交易经验的交易小白来说&#xff0c;想保持松弛的心态和学习外汇知识比较困难&#xff0c;考虑到这一点&#xff0c;天眼给大家总结了7件在交易前必须知道的事情。 1、实现财…

基于PSO优化的GRU多输入分类(Matlab)粒子群优化门控循环单元神经网络分类预测

目录 一、程序及算法内容介绍&#xff1a; 基本内容&#xff1a; 亮点与优势&#xff1a; 二、实际运行效果&#xff1a; 三、部分程序&#xff1a; 四、完整代码数据分享下载&#xff1a; 一、程序及算法内容介绍&#xff1a; 基本内容&#xff1a; 本代码基于Matlab平台…

167基于matlab的根据《液体动静压轴承》编写的有回油槽径向静压轴承的程序

基于matlab的根据《液体动静压轴承》编写的有回油槽径向静压轴承的程序&#xff0c;可显示承载能力、压强、刚度及温升等图谱.程序已调通&#xff0c;可直接运行。 167 显示承载能力、压强、刚度及温升 (xiaohongshu.com)https://www.xiaohongshu.com/explore/65d212b200000000…