使用GCN根据颗粒图像预测对应性能

news2024/12/28 18:46:05

之前做一个小实验写的代码,本想创建个git repo,想了想好像没必要,直接用篇博文记录一下吧。
对应资源 : https://download.csdn.net/download/rayso9898/87865298

0. 大纲

0.1 代码说明

  1. dataGeneration.py ->
    RSA生成n张图像,可以指定颗粒个数为m,半径为r。

  2. exGraph.py ->
    将一张sim仿真图像,转换成一个node_feature matrix, n * 4 : n个节点,3个特征-质心x y,等效直径,面积
    有r12 - r22共 6 种不同粒径的图像,组成字典,每种类别有100张图的node_feature

  3. create_dataset.py ->
    主要是 create_graph函数

  4. gcn.py ->
    gcn回归模型

  5. task_script ->
    提供模型训练 预测保存 loss查看等功能

0.2 数据说明

  1. sim_data.zip -> dataGeneration.py 生成
    r12-r22共6个文件夹,每个文件夹100张图像。

  2. sim_gdata.pt -> exGraph.py 生成
    sim_gdata structure:
    all - r12 r14 r16 r18 r20 r22
    r12 - img1 img2 … img100
    img1 - particle1 … paritle n
    particle1 - [x,y,dalimeter,area]

  3. dataset_img_property.pt -> create_dataset.py 生成
    x = [] # 图所对应的节点 num4 (x,y,r,area)
    pos = []# 图所对应节点的位置 num
    2 (x,y)
    y = stress[i] # 图所对应的力学性能数据
    x = torch.tensor(x,dtype=torch.float32)
    # 构造一张img的一个图
    y = torch.tensor([y],dtype=torch.float32)
    pos = torch.tensor(pos,dtype=torch.float32)
    edge_index = knn_graph(pos, k=5)
    g = Data(x=x, y=y, pos=pos,edge_index=edge_index)
    all.append(g)

  4. gcn.pt ->
    模型 后续加载即可使用

  5. loss_gcn.pt
    训练过程中的训练集和测试集loss变化

1. dataGeneration.py ->随机序列吸附法(RSA)生成颗粒图像

可以指定生成颗粒的个数和颗粒半径,以及生成的图像个数。
在这里插入图片描述
在这里插入图片描述

# -*- coding: utf-8 -*-
# @Time    : 2021/5/1 15:12
# @Author  : Ray_song
# @File    : dataGeneration.py
# @Software: PyCharm

import cv2
import math
import numpy as np

def calcDis(a,b,c,d):
    return math.sqrt(math.pow(a-c,2)+math.pow(b-d,2))

def generate(n,r,filename,save_path):
    '''
    基于RSA算法,生成一张图像
    :param n: 该图像中有n个颗粒
    :param r: 颗粒的半径大小为r
    :param filename: 图像的名字  如 1 2 3 ....
    :param save_path: 图像的保存路径 - 文件夹
    :return: None
    '''
    # 1.创建白色背景图片
    d = 512
    img = np.ones((d, d, 3), np.uint8) * 0
    #testing

    list = []
    center_x = np.random.randint(0, high=d)
    center_y = np.random.randint(0, high=d)
    list.append([center_x,center_y])

    # 随机半径与颜色
    radius = r
    color = (0, 255, 0)
    cv2.circle(img, (center_x, center_y), radius, color, -1)

    # 2.循环随机绘制实心圆
    for i in range(1, n):
        flag = True
        # 随机中心点
        while flag:
            center_x_new = np.random.randint(radius, high=d-radius)
            center_y_new = np.random.randint(radius, high=d-radius)
            panduan = True
            for per in list:
                Dis = calcDis(center_x_new, center_y_new, per[0], per[1])
                if Dis<2*r:
                    panduan = False
                    break
                else:
                    continue
            if panduan:
                list.append([center_x_new,center_y_new])
                cv2.circle(img, (center_x_new, center_y_new), radius, color, -1)
                break
    # 3.显示结果
    # cv2.imshow("img", img)
    # cv2.waitKey()
    # cv2.destroyAllWindows()

    # 4.保存结果
    root = f'{save_path}/{filename}.jpg'
    cv2.imwrite(root,img)

def main():
    # example1 : 随机生成 100张 颗粒个数为80 半径为20 的 图像
    save_path = 'sim_data/r20'
    for i in range(100):
        generate(80,20,i+1,save_path)

if __name__ == '__main__':
    main()

2. exGraph.py

将一张sim仿真图像,转换成一个node_feature matrix, n * 4 : n个节点,3个特征-质心x y,等效直径,面积
有r12 - r22共 6 种不同粒径的图像,组成字典,每种类别有100张图的node_feature
# -*- coding: utf-8 -*-
# @Time    : 2021/11/3 22:18
# @Author  : Ray_song
# @File    : exGraph.py
# @Software: PyCharm

import os
import torch

# 计算面积占比
def countArea(img):
    # 返回面积占比
    area = 0
    size = img.shape
    height,width = size[0],size[1]
    for i in range(height):
        for j in range(width):
            if img[i, j] == 255:
                area += 1
    total = height * width
    ratio = area / total
    return ratio


def distributionFit(img):
    '''
     计算一张照片中所有的 质心坐标xy、颗粒直径、面积,从一张图像中构建graph
    :param img: RSA生成的图像
    :return: 图像对应的graph,n*4, n个node, 4个feature
    '''
    import numpy as np
    import cv2 as cv
    img_color = cv.imread(img,1) # countors现实阶段使用
    img = cv.imread(img,0)

    # 遍历文件夹中所有的图像

    thresh_mode = 'THRESH_BINARY+THRESH_OTSU'
    # 阈值分割内容
    thresh_down = 127
    thresh_up = 256
    if thresh_mode == 'THRESH_BINARY':
        ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_BINARY)
    elif thresh_mode == 'THRESH_BINARY_INV':
        ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_BINARY_INV)
    elif thresh_mode == 'THRESH_TRUNC':
        ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_TRUNC)
    elif thresh_mode == 'THRESH_TOZERO':
        ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_TOZERO)
    elif thresh_mode == 'THRESH_TOZERO_INV':
        ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_TOZERO_INV)
    elif thresh_mode == 'THRESH_BINARY+THRESH_OTSU':
        ret, thresh = cv.threshold(img, thresh_down, thresh_up, cv.THRESH_BINARY + cv.THRESH_OTSU)
    elif thresh_mode == 'No':
        thresh = img

    contours, hierarchy = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)

    # 显示contours的代码,可以取消注释进行显示
    # cv.drawContours(img_color,contours,-1,(255,0,0),1)
    # cv.imshow('1',img_color)
    # cv.waitKey(0)

    graph = []
    for i in range(len(contours)):
        node = []
        cnt = contours[i]
        M = cv.moments(cnt)  # 计算图像矩
        # print(M)
        cx = int(M['m10'] / M['m00'])  # 重心的x坐标
        cy = int(M['m01'] / M['m00'])  # 重心的y坐标
        area = cv.contourArea(cnt)
        equi_diameter = np.sqrt(4 * area / np.pi)
        node.append(cx)
        node.append(cy)
        node.append(equi_diameter)
        node.append(area)
        graph.append(node)
    return graph

def exGraph():
    dirs = os.listdir('./sim_data')
    print(dirs)
    sim_data = {}
    list = ['r12','r14','r16','r18','r20','r22']
    index = 0
    for dir in dirs:
        path = r'sim_data'+'//'+dir
        # path = './sim_data/r12'
        imgs = os.listdir(path)
        print(len(imgs))
        con = []
        for i in range(len(imgs)):
            img_path = path+'/'+imgs[i]
            graph = distributionFit(img_path)
            con.append(graph)
        sim_data[list[index]] = con
        index = index+1
    torch.save(sim_data,'sim_gdata.pt')

if __name__ == '__main__':
    a = torch.load('./sim_gdata.pt')
    print(a)

3. create_dataset.py ->

主要是 create_graph函数
# -*- coding: utf-8 -*-
# @Time    : 2021/11/4 23:01
# @Author  : Ray_song
# @File    : create_dataset.py
# @Software: PyCharm

import torch
from torch_geometric.data import Data
from torch_cluster import knn_graph

'''

sim_gdata structure:
    all - r12  r14  r16  r18 r20 r22
        r12 - img1 img2 ... img100
            img1 - particle1 .. paritle n
                particle1 - [x,y,dalimeter,area]
'''

def create_graph():
    path = r'./sim_gdata.pt'
    data = torch.load(path)

    r = ['r12', 'r14', 'r16', 'r18', 'r20', 'r22']
    stress = [225, 230, 235, 240, 245, 250, 255]

    # 构造图
    all = []
    for i in range(len(data)):
        ri = r[i] # 字典 - 图数据的字典

        imgs = data[ri]

        y = stress[i]  # 图所对应的力学性能数据

        # 遍历ri中的每一张图
        for j in range(len(imgs)):

            x = [] # 图所对应的节点 num*4 (x,y,r,area)

            pos = []# 图所对应节点的位置 num*2  (x,y)

            img = imgs[j]

            # 遍历图中所有的节点
            for k in range(len(img)):
                # 单个节点的特征
                xi = []
                xi.append(img[k][0])
                xi.append(img[k][1])
                xi.append(img[k][2])
                xi.append(img[k][3])
                x.append(xi)

                # 位置信息(x,y)
                posi = []
                posi.append(img[k][0])
                posi.append(img[k][1])
                pos.append(posi)
            x = torch.tensor(x,dtype=torch.float32)
            # 构造一张img的一个图
            y = torch.tensor([y],dtype=torch.float32)
            pos = torch.tensor(pos,dtype=torch.float32)
            g = Data(x=x, y=y, pos=pos)
            g.edge_index = knn_graph(pos, k=5)
            all.append(g)
    torch.save(all,r'dataset_img_property.pt')
    return all

if __name__ == '__main__':
    create_graph()

4. gcn.py ->

gcn回归模型

# -*- coding: utf-8 -*-
# @Time    : 2021/11/5 19:13
# @Author  : Ray_song
# @File    : gcn.py
# @Software: PyCharm

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        # torch.manual_seed(12345)
        self.conv1 = GCNConv(4, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        # self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        # x = x.relu()
        # x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final regressor
        x = self.lin(x)

        return x

5. task_script ->

提供模型训练 预测保存 loss查看等功能
import numpy as np
import pandas as pd
import torch
# from torch_geometric.data import Data
# from torch_cluster import knn_graph
# import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
# from torch_geometric.datasets import GeometricShapes
# from torch_cluster import knn_graph
from torch_geometric.loader import DataLoader
from gcn import GCN,GCNConv

def showDataset():
    path = r'./dataset_img_property.pt'
    dataset = torch.load(path)

    print()
    # print(f'Dataset: {dataset}:')
    print('====================')
    print(f'Number of graphs: {len(dataset)}')
    # print(f'Number of features: {dataset.num_features}')
    # print(f'Number of classes: {dataset.num_classes}')

    data = dataset[0]  # Get the first graph object.

    print()
    print(data)
    print('=============================================================')

    # Gather some statistics about the first graph.
    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
    print(f'Has isolated nodes: {data.has_isolated_nodes()}')
    print(f'Has self-loops: {data.has_self_loops()}')
    # print(f'Is undirected: {data.is_undirected()}')

def split_train_val(path):
    dataset = torch.load(path)
    import random
    random.shuffle(dataset)
    num = len(dataset)
    # print(dataset,num)
    ratio = 0.8
    train_dataset = dataset[:int(num*ratio)]
    val_dataset = dataset[int(num*ratio):]
    print(len(train_dataset),len(val_dataset))
    return train_dataset,val_dataset

def begin():
    import torch
    path = r'./dataset_img_property.pt'
    dataset = torch.load(path)
    train_dataset,test_dataset = split_train_val(path)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = GCN(hidden_channels=64)
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=1e-8)
    criterion = torch.nn.MSELoss()

    def train():
        model.train()
        train_loss = 0
        for data in train_loader:  # Iterate in batches over the training dataset.
            out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
            loss = criterion(out.squeeze(-1), data.y)  # Compute the loss.
            train_loss = train_loss + loss.item()
            # print(f'train_loss:{train_loss/data.batch}')
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
            optimizer.zero_grad()  # Clear gradients.

    def test(loader):
         model.eval()
         loss_ = 0
         count = 0
         for data in loader:  # Iterate in batches over the training/test dataset.
             out = model(data.x, data.edge_index, data.batch)
             loss = criterion(out.squeeze(-1), data.y)
             loss_ += loss.item()
             count += 1
         return loss_/count

    loss_t = []
    loss_v = []

    for epoch in range(1, 300):
        loss = []
        train()
        train_acc = test(train_loader)
        test_acc = test(test_loader)
        if test_acc<45:
            torch.save(model.state_dict(),'./gcn.pt')
        loss_t.append(train_acc)
        loss_v.append(test_acc)
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
        loss.append(loss_t)
        loss.append(loss_v)
        torch.save(loss,'loss_gcn.pt')

def draw_loss():
    import matplotlib.pyplot as plt
    loss = torch.load('loss_gcn.pt')
    plt.plot(loss[0])
    plt.plot(loss[1])
    plt.show()

def prediction_all():
    from sklearn.metrics import mean_squared_error
    prediction_classes = 1
    g_path = r'./dataset_img_property.pt'
    save_path = r'./prediction.csv'
    g_data = torch.load(g_path)

    model = GCN(hidden_channels=64)
    model_weight = r'./gcn.pt'
    model.load_state_dict(torch.load(model_weight))
    model.eval()
    # begin predicting...
    res = []
    val_data = g_data[:600]
    pre_dataloader = DataLoader(val_data, batch_size=1)
    y = []
    with torch.no_grad():
        for item in pre_dataloader:
            # predict class
            y.append(item.y.item())
            output = torch.squeeze(model(item.x,item.edge_index,item.batch))  # 将batch维度压缩掉
            prediction = output.numpy()
            if prediction_classes > 1:
                prediction = list(prediction)
            res.append(prediction)

    res = np.array(res)
    y = np.array(y)
    acc = mean_squared_error(res,y)
    print('acc',acc)
    res = pd.DataFrame(res)
    res.to_csv(save_path,header=None,index=None)

if __name__ == '__main__':
    prediction_all()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/611444.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Springboot】| 阿里云发送短信验证码,你会了吗?

目录 &#x1f981; 题外话&#x1f981; 需要准备的东西&#x1f981; 进入主题1. 添加依赖2. 配置yaml文件3. 创建阿里云客户端4. 编写发送短信方法5. 完整代码展示6. 测试 &#x1f981; 场景实操1. 编写生成验证码工具类2. 保存到redis操作3. 编写发送验证码短信4. 发送登录…

大数据:spark共享广播变量,累加器

大数据&#xff1a;共享变量 2022找工作是学历、能力和运气的超强结合体&#xff0c;遇到寒冬&#xff0c;大厂不招人&#xff0c;可能很多算法学生都得去找开发&#xff0c;测开 测开的话&#xff0c;你就得学数据库&#xff0c;sql&#xff0c;oracle&#xff0c;尤其sql要学…

三面阿里被挂,竟获内推名额,历经 5 面拿下口碑 offer...

每一个互联网人心中都有一个大厂梦&#xff0c;百度、阿里巴巴、腾讯是很多互联网人梦寐以求的地方&#xff0c;而我也不例外。但是&#xff0c;BAT 等一线互联网大厂并不是想进就能够进的&#xff0c;它对人才的技术能力和学历都是有一定要求的&#xff0c;所以除了学历以外&a…

STM32单片机WIFI物联网厨房燃气安全系统超声波人员检测MQ4燃气报警

实践制作DIY- GC0140-WIFI物联网厨房燃气安全系统 基于STM32单片机设计---WIFI物联网厨房燃气安全系统 二、功能介绍&#xff1a; 硬件组成&#xff1a;STM32F103C系列最小系统继电器模拟阀门MQ-4然气传感器HSR04超声波测距LCD1602显示器ESP8266-WIFI模块蜂鸣器多个按键 1.有…

MySQL命令行速查手册(持续更新ing...)

诸神缄默不语-个人CSDN博文目录 最近更新时间&#xff1a;2023.6.5 最早更新时间&#xff1a;2023.6.5 每个命令都以;作为结尾&#xff08;以下localhost都可以替换成实际IP地址&#xff09;&#xff08;和’的区别应该不大&#xff09;用户管理 修改密码&#xff1a;ALTER U…

如何使用Facebook Business Suite来管理你的FB和Ins商业账户

Facebook Business Suite是Facebook推出的一种强大的数字营销工具&#xff0c;可帮助企业轻松管理其在Facebook和Instagram上的商业账户。该工具集成了多种功能&#xff0c;提供了一种简单、直观的方式来管理你的社交媒体营销活动。 在本文中&#xff0c;我们将详细介绍如何优化…

深眸科技基于技术与人才优势,创新研发机器视觉系统赋能工业生产

随着人工智能技术加速进入生产生活&#xff0c;机器视觉系统作为工业发展的刚需&#xff0c;凭借着能够为机器提供视觉&#xff0c;并在众多场景实现柔性化生产应用的能力&#xff0c;逐步被接受和普及&#xff0c;并在工业生产领域发挥巨大作用。 深眸科技作为国家高新技术企…

物流货运车货匹配平台源码

网络货运平台具有较强的信息数据交互和处理能力&#xff0c;能够对托运人&#xff0c;平台运营人&#xff0c;实际承运人&#xff0c;驾驶员的相关方的交易&#xff0c;运输&#xff0c;结算等全过程进行透明&#xff0c;动态的管理&#xff0c;该平台由托运人、实际承运人、司…

ControlNet: Adding Conditional Control to Text-to-Image Diffusion Models

Adding Conditional Control to Text-to-Image Diffusion Models (Paper reading) Lvmin Zhang and Maneesh Agrawala, Stanford University, arXiv, Cited:113, Code, Paper 1. 前言 我们提出了一种名为ControlNet的神经网络结构&#xff0c;用于控制预训练的大规模扩散模型…

element中table的列标题自定义

一、需求 工作中要求表格table中的某一列标题为红色如图 二、方案一 使用el-table-column自带的:render-header"renderHeader"函数 render-header列标题 Label 区域渲染使用的 FunctionFunction(h, { column, $index })—— 使用有点像v-html插入代码片段&#xf…

PubChem介绍及API及PubChempy

PubChem 【官网 https://pubchem.ncbi.nlm.nih.gov/】 简介 PubChem is the world’s largest collection of freely accessible chemical information. Search chemicals by name, molecular formula, structure, and other identifiers. Find chemical and physical proper…

casbin基于RBAC的权限管理案例

在RBAC模型中新定义了角色和继承关系&#xff0c;用户可以通过角色区分不同的权限&#xff0c;继承不同的角色时用户有多个权限。 [role_definition] g _, _ g2 _, _g 是一个 RBAC系统, g2 是另一个 RBAC 系统。 _, _表示角色继承关系的前项和后项&#xff0c;即前项继承后项…

局部探索测试的要素

局部探索测试的要素 局部探索测试是软件测试过程中的一种方法&#xff0c;旨在发现一个系统、软件或应用程序的局部缺陷和问题。局部探索测试不是全面测试&#xff0c;而是通过对特定功能、模块或环节进行测试来检查其中潜在的缺陷&#xff0c;从而提高软件的质量和可靠性。 局…

【白话机器学习系列】白话Broadcasting

白话 Broadcasting 文章目录 什么是 BroadcastingBroadcasting 的规则逐元素操作向量与标量运算矩阵与向量运算行向量列向量 张量与向量运算张量与矩阵运算 矩阵与张量的点积总结 什么是 Broadcasting 在 《白话张量》 中我们讲过&#xff0c;张量之间进行运算需要满足一定的…

Hadoop之MapReduce概述

MapReduce概述 MapReduce定义MapReduce优缺点MapReduce核心思想MapReduce进程MapReduce编程规范MapTask并行度决定机制ReduceTask并行度决定机制mapreduce中job的提交流程MapReduce工作流程shuffle机制分区partition数据清洗&#xff08;ETL&#xff09;进一步分析MapTask和Red…

Jenkins+RF持续集成测试(二) 定时更新SVN完成构建

在上一篇中讲了Jenkins的安装&#xff0c;这篇将介绍 定时从SVN库中&#xff08;git库与之类似&#xff0c;这里就不具体介绍了&#xff0c;有需要自己折腾&#xff09;拉取最新的测试脚本&#xff0c;完成jenkins的定时构建。这是我们做自动化测试最基本的环节&#xff0c;每天…

【Linux】还在用top命令?可以试试atop工具,信息一目了然,运维工程师的新选择

atop使用 Linux以其稳定性&#xff0c;越来越多地被用作服务器的操作系统(当然&#xff0c;有人会较真地说一句&#xff1a;Linux只是操作系统内核:)。但使用了Linux作为底层的操作系统&#xff0c;是否我们就能保证我们的服务做到7*24地稳定呢&#xff1f;非也&#xff0c;要…

06.05

1.二进制求和 给你两个二进制字符串 a 和 b &#xff0c;以二进制字符串的形式返回它们的和。 考虑一个最朴素的方法&#xff1a;先将 aaa 和 bbb 转化成十进制数&#xff0c;求和后再转化为二进制数。利用 Python 和 Java 自带的高精度运算&#xff0c;我们可以很简单地写出这…

发现问题更全面,减少测试成本:WEB自动化测试的价值分析!

目录 前言&#xff1a; 一、WEB自动化测试的价值 1. 提高测试效率 2. 提高软件的质量 3. 减少测试成本 二、WEB自动化测试的瓶颈 1. 可维护性差 2. 兼容性问题 3. 比手工测试慢 三、代码示例 四、总结 前言&#xff1a; 自动化测试是软件开发中必不可少的一环&…

shell简单命令

命令入门&#xff1a; [rootlocalhost ~]# #/root [jinxflocalhost ~]$ #/home/jinxf 用户名主机名 当前目录 #系统权限 $普通权限 命令格式 命令 选项 参数&#xff08;三者之间要有空格&#xff0c;区分大小写&#xff09; command [-options] [args]…