手写svm primal form形式

news2025/3/15 3:35:30

svm.py

import numpy as np

class SVM:
    def __init__(self,C=1.0,lr=0.01,batch_size=32,epochs=100):
        self.C=C
        self.lr=lr
        self.batch_size=batch_size
        self.epochs=epochs
        self.w=None
        self.b=0.0
        self.epoch=0

    #计算最高得分和对应w,b
    def fit(self,X,y,X_val=None,y_val=None):
        sample,feature=X.shape
        self.w=np.zeros(feature)
        self.b=0.0
        best_score=-np.inf
        #best_w=self.w  错误
        best_w=self.w.copy()
        best_b=self.b
        best_epoch=0

        for epoch in range(self.epochs):
            #打乱顺序
            shu_index=np.random.permutation(sample)
            shu_X=X[shu_index]
            shu_y=y[shu_index]
            for i in range(0,sample,self.batch_size):
                end=i+self.batch_size
                #第x个批量
                x_batch=shu_X[i:end]
                y_batch=shu_y[i:end]
                dw,db=self.com_gradient(x_batch,y_batch)
                self.w-=self.lr*dw
                self.b-=self.lr*db
            if X_val is not None and y_val is not None:
                y_pred=self.predict(X_val)
                #np.mean(x,y)错误
                score=np.mean(y_pred==y_val)
                if score>best_score:
                    best_score=score
                    best_w=self.w.copy()
                    best_b=self.b
                    #best_epoch=self.epoch 错误
                    best_epoch=epoch
                print(f"第{epoch+1}轮训练,准确率为:{score:.4f}")
        if X_val is not None and y_val is not None:
            self.w=best_w
            self.b=best_b
            self.epoch=best_epoch
        
    def com_gradient(self,X_batch,y_batch):
        n=X_batch.shape[0]
        dw_hinge=np.zeros_like(self.w)
        db_hinge=0.0
        for i in range(n):
            xi=X_batch[i]
            yi=y_batch[i]
            #margin=yi*np.dot(xi,self.w)+self.b 注意是xi
            margin=yi*np.dot(xi,self.w)+self.b
            if margin<1:
                dw_hinge+=-yi*xi
                db_hinge+=-yi
            #注意 是计算完n个样本的dw_hinge才算dw
        dw=self.w+(self.C/n)*dw_hinge
        db=(self.C/n)*db_hinge
        return dw,db

    def predict(self,X):
        linear=np.dot(X,self.w)+self.b
        return np.sign(linear)
    
    def evaluate(self,X,y):
        y_true=y
        y_pre=self.predict(X)
        #注意是标签是-1和1,而非0,1
        tp=np.sum((y_pre==1)&(y_true==1))
        fp=np.sum((y_pre==1)&(y_true==-1))
        tn=np.sum((y_pre==-1)&(y_true==-1))
        fn=np.sum((y_pre==-1)&(y_true==1))
        accuracy=(tp+tn)/(tp+tn+fp+fn)
        precision=tp/(tp+fp) if tp+fp!=0 else 0
        recall=tp/(tp+fn) if tp+fn!=0 else 0
        f1=(2*precision*recall)/(precision+recall) if precision+recall!=0 else 0
        #注意字典的键值对xx:xx
        return{
            'accuracy':accuracy,
            'precision':precision,
            'recall':recall,
            'f1':f1
        }
    

    def save_weight(self,filename):
        #注意w和b要保存进文件
        np.savez(filename,w=self.w,b=self.b,epoch=self.epoch,C=self.C,lr=self.lr,batch_size=self.batch_size,epochs=self.epochs)

    @classmethod
    def load_weight(cls,filename):
        data=np.load(filename)
        svm=cls(C=data['C'],lr=data['lr'],batch_size=data['batch_size'],epochs=data['epochs'])
        svm.w=data['w']
        svm.b=data['b']
        svm.epoch=data['epoch']
        return svm


    

train.py

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from joblib import dump
from svm import SVM

data=datasets.load_breast_cancer()
X=data.data
y=data.target
y=np.where(y==0,-1,1)

X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)
X_train,X_val,y_train,y_val=train_test_split(X_train,y_train,test_size=0.25,random_state=42)

scaler=StandardScaler()
X_train=scaler.fit_transform(X_train)
X_val=scaler.transform(X_val)
X_test=scaler.transform(X_test)
dump(scaler,'scaler.joblib')



#最佳准确率以及最佳模型
best_accu=-np.inf
best_model=None

C_values=[0.1,1,10,100]

for C in C_values:
    print(f"开始C:{C}")
    model=SVM(C=C,lr=0.01,batch_size=32,epochs=100)
    model.fit(X_train,y_train,X_val,y_val)
    #注意要评估X_val,y_val的得分,传参
    m_metrics=model.evaluate(X_val,y_val)
    if m_metrics['accuracy']>best_accu:
        #注意m_metrics['accuracy']传参
        best_accu=m_metrics['accuracy']
        best_model=model
best_model.save_weight("best_weight.npz")
print(f"最优C:{best_model.C}") 
print(f"最优C对应的epoch:{best_model.epoch+1}")   

test.py

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from svm import SVM
from joblib import load

data=datasets.load_breast_cancer()
X=data.data
y=data.target
y=np.where(y==0,-1,1)

_,X_test,_,y_test=train_test_split(X,y,test_size=0.2,random_state=42)

scaler=load('scaler.joblib')
X_test=scaler.transform(X_test)

model=SVM.load_weight('best_weight.npz')
print(f"C:{model.C}")
print(f"最优C的epoch:{model.epoch+1}")
t_metrics=model.evaluate(X_test,y_test)

print(f"Accuracy:{t_metrics['accuracy']:.4f},Precision:{t_metrics['precision']:.4f},Recall:{t_metrics['recall']:.4f},f1分数:{t_metrics['f1']:.4f}")

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2315207.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VBA+FreePic2Pdf 找出没有放入PDF组合的单个PDF工艺文件

设计部门针对某个项目做了一个工艺汇总报告&#xff0c;原先只要几十个工艺文件&#xff0c;组合成一个PDF&#xff0c;但后来要求要多放点PDF进去&#xff0c;但工艺文件都混在一起又不知道哪些是重复的&#xff0c;找上我让我帮忙处理一下&#xff0c;我开始建议让她重新再组…

计网面试准备

正确理解网络数据传输过程 同一路由器的不同接口属于不同局域网&#xff0c;广播只能在同一个局域网

【数据分享】1999—2023年我国地级市社会消费品零售总额和年末金融机构存贷款余额(Shp/Excel格式)

在之前的文章中&#xff0c;我们分享过基于2000-2024年《中国城市统计年鉴》整理的1999-2023年地级市的人口相关数据、染物排放和环境治理相关数据和房地产投资情况和商品房销售面积相关指标数据&#xff08;均可查看之前的文章获悉详情&#xff09;&#xff01; 本次我们分享…

PHP批量去除Bom头的方法

检查的代码&#xff1a; <?php$dir __DIR__; $files new RecursiveIteratorIterator(new RecursiveDirectoryIterator($dir));foreach ($files as $file) {if ($file->isFile() && pathinfo($file, PATHINFO_EXTENSION) php) {$content file_get_contents(…

字节攻克关键技术,大模型训练效率提升1.7倍,成本节省40%

近日&#xff0c;字节豆包大模型团队开源针对 MoE 架构的关键优化技术COMET&#xff0c;该技术可将大模型训练效率提升1.7倍&#xff0c;成本节省40%。据悉&#xff0c;该技术已实际应用于字节的万卡集群训练&#xff0c;累计帮助节省了数百万 GPU 小时训练算力。 MoE&#xff…

[Pytorch报错问题解决]AttributeError: ‘nn.Sequential‘ object has no attribute ‘append‘

问题 运行深度学习代码的时候遇到了以下报错问题&#xff1a; Traceback (most recent call last):File "/home/anaconda3/envs/Text2HOI/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_contextreturn func(*args, **kwargs)Fi…

基于威胁的安全测试值得关注,RASP将大放异彩

2‍021年7月21日&#xff0c;由中国信息通信研究院&#xff08;CAICT&#xff09;指导、悬镜安全主办、腾讯安全协办的中国首届DevSecOps敏捷安全大会&#xff08;DSO 2021&#xff09;在北京圆满举办。大会以“安全从供应链开始”为主题&#xff0c;寓意安全基础决定“上层建筑…

AGI大模型(2):GPT:Generative Pre-trained Transformer

1 Generative Pre-trained Transformer 1.1 Generative生成式 GPT中的“生成式”指的是该模型能够根据输入自动生成文本内容&#xff0c;而不仅仅是从已有的文本库中检索答案。 具体来说&#xff1a; 生成&#xff08;Generative&#xff09;&#xff1a;GPT是一个生成…

DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)之添加列宽调整功能,示例Table14_06带搜索功能的固定表头表格

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…

MySQL再次基础 向初级工程师迈进

作者&#xff1a;在计算机行业找不到工作的大四失业者 Run run run ! ! ! 1、MySQL概述 1.1数据库相关概念 1.2MySQL数据库 2、SQL 2.1SQL通用语法 SQL语句可以单行或多行书写&#xff0c;以分号结尾。SQL语句可以使用空格/缩进来增强语句的可读性。MySQL数据库的SQL语句不区…

使用 Doris 和 Hudi

作为一种全新的开放式的数据管理架构&#xff0c;湖仓一体&#xff08;Data Lakehouse&#xff09;融合了数据仓库的高性能、实时性以及数据湖的低成本、灵活性等优势&#xff0c;帮助用户更加便捷地满足各种数据处理分析的需求&#xff0c;在企业的大数据体系中已经得到越来越…

城市林业的无声革命:人工智能与古老生态学如何重新设计城市

城市林业的无声革命&#xff1a;人工智能与古老生态学如何重新设计城市 在摩天大楼的阴影下&#xff0c;一场静悄悄的变革正在发生——它融合了硅芯片与古老根系&#xff0c;算法与原住民智慧。 作者&#xff1a;保罗桑杜 作者利用 PicLumen 创建的图像 城市森林不再只是城市…

Linux第七讲:基础IO

Linux第七讲&#xff1a;基础IO 1.什么是文件2.文件操作的复习2.1文件基本操作复习2.2将信息输出到显示器&#xff0c;你有哪种方法2.3stdin、stdout、stderror2.4细节问题讲解 3.系统文件IO3.1open函数使用3.1.1理解标志位3.1.2权限问题3.1.3write和read接口介绍3.1.4谈谈fd以…

力扣热题 100:多维动态规划专题经典题解析

系列文章目录 力扣热题 100&#xff1a;哈希专题三道题详细解析(JAVA) 力扣热题 100&#xff1a;双指针专题四道题详细解析(JAVA) 力扣热题 100&#xff1a;滑动窗口专题两道题详细解析&#xff08;JAVA&#xff09; 力扣热题 100&#xff1a;子串专题三道题详细解析(JAVA) 力…

【Unity】在项目中使用VisualScripting

1. 在packagemanager添加插件 2. 在设置中进行初始化。 Edit > Project Settings > Visual Scripting Initialize Visual Scripting You must select Initialize Visual Scripting the first time you use Visual Scripting in a project. Initialize Visual Scripting …

Pytest自动化测试框架pytest-xdist分布式测试插件

平常我们功能测试用例非常多时&#xff0c;比如有1千条用例&#xff0c;假设每个用例执行需要1分钟&#xff0c;如果单个测试人员执行需要1000分钟才能跑完&#xff1b; 当项目非常紧急时&#xff0c;会需要协调多个测试资源来把任务分成两部分&#xff0c;于是执行时间缩短一…

文件解析漏洞靶场解析全集详解

lls解析漏洞 目录解析 在网站的下面将一个1.asp文件夹&#xff0c;在里面建一个2.txt文件在里面写入<% -now()%>这个显示时间的代码&#xff0c;再将文件名改为2.jpg。 发现2.jpg文件以asp形式执行 畸形文件解析 将2.jpg文件移到网站的下面与1.asp并列&#xff0c;将名…

【一次成功】Win10本地化单机部署k8s v1.31.2版本及可视化看板

【一次成功】Win10本地化单机部署k8s v1.31.2版本及可视化看板 零、安装清单一、安装Docker Desktop软件1.1 安装前<启用或关闭Windows功能> 中的描红的三项1.2 查看软件版本1.3 配置Docker镜像 二、更新装Docker Desktop三、安装 k8s3.1 点击启动安装3.2 查看状态3.3 查…

Vue项目搜索引擎优化(SEO)终极指南:从原理到实战

文章目录 1. SEO基础与Vue项目的挑战1.1 为什么Vue项目需要特殊SEO处理&#xff1f;1.2 搜索引擎爬虫工作原理 2. 服务端渲染&#xff08;SSR&#xff09;解决方案2.1 Nuxt.js框架实战原理代码实现流程图 2.2 自定义SSR实现 3. 静态站点生成&#xff08;SSG&#xff09;技术3.1…