SegmentAnything官网demo使用vue+python实现

news2024/11/19 10:24:19

一、效果&准备工作

1.效果

没啥好说的,低质量复刻SAM官网 https://segment-anything.com/

需要提一点:所有生成embedding和mask的操作都是python后端做的,计算mask不是onnxruntime-web实现的,前端只负责了把rle编码的mask解码后画到canvas上,会有几十毫秒的网络传输延迟。我不会react和typescript,官网F12里的源代码太难懂了,生成的svg总是与期望的不一样

主页

在这里插入图片描述

鼠标移动动态分割(Hover)

throttle了一下,修改代码里的throttle delay,反应更快些,我觉得没必要已经够了,设置的150ms

在这里插入图片描述

点选前景背景(Click)

蓝色前景,红色背景,对应clickType分别为1和0

在这里插入图片描述

分割(Cut out object)

同官网,分割出该区域需要的最小矩形框部分
在这里插入图片描述

分割所有(Everything)

随便做了下,实在做不出官网的效果,可能模型也有问题 ,我用的vit_b,懒得试了,这功能对我来说没卵用

在这里插入图片描述

2.准备工作

安装依赖

前端使用了Vue3+ElementPlus(https://element-plus.org/zh-CN/#/zh-CN)+axios+lz-string,npm安装一下。

后端是fastapi(https://fastapi.tiangolo.com/),FastAPI 依赖 Python 3.8 及更高版本。

安装 FastAPI

pip install fastapi

另外我们还需要一个 ASGI 服务器,生产环境可以使用 Uvicorn 或者 Hypercorn:

pip install "uvicorn[standard]"
要用的js文件

@/util/request.js

import axios from "axios";
import { ElMessage } from "element-plus";

axios.interceptors.request.use(
    config => {
        return config;
    },
    error => {
        return Promise.reject(error);
    }
);

axios.interceptors.response.use(
    response => {
        if (response.data.success != null && !response.data.success) {
            return Promise.reject(response.data)
        }
        return response.data;
    },
    error => {
        console.log('error: ', error)
        ElMessage.error(' ');
        return Promise.reject(error);
    }
);

export default axios;

然后在main.js中绑定

import axios from './util/request.js'
axios.defaults.baseURL = 'http://localhost:9000'
axios.defaults.headers.post['Content-Type'] = 'application/x-www-form-urlencoded';
app.config.globalProperties.$http = axios

@/util/throttle.js

function throttle(func, delay) {
    let timer = null; // 定时器变量

    return function() {
        const context = this; // 保存this指向
        const args = arguments; // 保存参数列表

        if (!timer) {
            timer = setTimeout(() => {
                func.apply(context, args); // 调用原始函数并传入上下文和参数
                clearTimeout(timer); // 清除计时器
                timer = null; // 重置计时器为null
            }, delay);
        }
    };
}
export default throttle

@/util/mask_utils.js

/**
 * Parses RLE from compressed string
 * @param {Array<number>} input
 * @returns array of integers
 */
export const rleFrString = (input) => {
    let result = [];
    let charIndex = 0;
    while (charIndex < input.length) {
        let value = 0,
            k = 0,
            more = 1;
        while (more) {
            let c = input.charCodeAt(charIndex) - 48;
            value |= (c & 0x1f) << (5 * k);
            more = c & 0x20;
            charIndex++;
            k++;
            if (!more && c & 0x10) value |= -1 << (5 * k);
        }
        if (result.length > 2) value += result[result.length - 2];
        result.push(value);
    }
    return result;
};

/**
 * Parse RLE to mask array
 * @param rows
 * @param cols
 * @param counts
 * @returns {Uint8Array}
 */
export const decodeRleCounts = ([rows, cols], counts) => {
    let arr = new Uint8Array(rows * cols)
    let i = 0
    let flag = 0
    for (let k of counts) {
        while (k-- > 0) {
            arr[i++] = flag
        }
        flag = (flag + 1) % 2
    }
    return arr
};

/**
 * Parse Everything mode counts array to mask array
 * @param rows
 * @param cols
 * @param counts
 * @returns {Uint8Array}
 */
export const decodeEverythingMask = ([rows, cols], counts) => {
    let arr = new Uint8Array(rows * cols)
    let k = 0;
    for (let i = 0; i < counts.length; i += 2) {
        for (let j = 0; j < counts[i]; j++) {
            arr[k++] = counts[i + 1]
        }
    }
    return arr;
};

/**
 * Get globally unique color in the mask
 * @param category
 * @param colorMap
 * @returns {*}
 */
export const getUniqueColor = (category, colorMap) => {
    // 该种类没有颜色
    if (!colorMap.hasOwnProperty(category)) {
        // 生成唯一的颜色
        while (true) {
            const color = {
                r: Math.floor(Math.random() * 256),
                g: Math.floor(Math.random() * 256),
                b: Math.floor(Math.random() * 256)
            }
            // 检查颜色映射中是否已存在相同的颜色
            const existingColors = Object.values(colorMap);
            const isDuplicateColor = existingColors.some((existingColor) => {
                return color.r === existingColor.r && color.g === existingColor.g && color.b === existingColor.b;
            });
            // 如果不存在相同颜色,结束循环
            if (!isDuplicateColor) {
                colorMap[category] = color;
                break
            }
        }
        console.log("生成唯一颜色", category, colorMap[category])
        return colorMap[category]
    } else {
        return colorMap[category]
    }
}

/**
 * Cut out specific area of image uncovered by mask
 * @param w image's natural width
 * @param h image's natural height
 * @param image source image
 * @param canvas mask canvas
 * @param callback function to solve the image blob
 */
export const cutOutImage = ({w, h}, image, canvas, callback) => {
    const resultCanvas = document.createElement('canvas'),
        resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true}),
        originalCtx = canvas.getContext('2d', {willReadFrequently: true});
    resultCanvas.width = w;
    resultCanvas.height = h;
    resultCtx.drawImage(image, 0, 0, w, h)
    const maskDataArray = originalCtx.getImageData(0, 0, w, h).data;
    const imageData = resultCtx.getImageData(0, 0, w, h);
    const imageDataArray = imageData.data
    // 将mask的部分去掉
    for (let i = 0; i < maskDataArray.length; i += 4) {
        const alpha = maskDataArray[i + 3];
        if (alpha !== 0) { // 不等于0,是mask区域
            imageDataArray[i + 3] = 0;
        }
    }
    // 计算被分割出来的部分的矩形框
    let minX = w;
    let minY = h;
    let maxX = 0;
    let maxY = 0;
    for (let y = 0; y < h; y++) {
        for (let x = 0; x < w; x++) {
            const alpha = imageDataArray[(y * w + x) * 4 + 3];
            if (alpha !== 0) {
                minX = Math.min(minX, x);
                minY = Math.min(minY, y);
                maxX = Math.max(maxX, x);
                maxY = Math.max(maxY, y);
            }
        }
    }
    const width = maxX - minX + 1;
    const height = maxY - minY + 1;
    const startX = minX;
    const startY = minY;
    resultCtx.putImageData(imageData, 0, 0)
    // 创建一个新的canvas来存储特定区域的图像
    const croppedCanvas = document.createElement("canvas");
    const croppedContext = croppedCanvas.getContext("2d");
    croppedCanvas.width = width;
    croppedCanvas.height = height;
    // 将特定区域绘制到新canvas上
    croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
    croppedCanvas.toBlob(blob => {
        if (callback) {
            callback(blob)
        }
    }, "image/png");
}

/**
 * Cut out specific area of image covered by target color mask
 * PS: 我写的这代码有问题,比较color的时候tmd明明mask canvas中有这个颜色,
 * 就是说不存在这颜色,所以不用这个函数,改成下面的了
 * @param w image's natural width
 * @param h image's natural height
 * @param image source image
 * @param canvas mask canvas
 * @param color target color
 * @param callback function to solve the image blob
 */
export const cutOutImageWithMaskColor = ({w, h}, image, canvas, color, callback) => {
    const resultCanvas = document.createElement('canvas'),
        resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true}),
        originalCtx = canvas.getContext('2d', {willReadFrequently: true});
    resultCanvas.width = w;
    resultCanvas.height = h;
    resultCtx.drawImage(image, 0, 0, w, h)
    const maskDataArray = originalCtx.getImageData(0, 0, w, h).data;
    const imageData = resultCtx.getImageData(0, 0, w, h);
    const imageDataArray = imageData.data

    let find = false

    // 比较mask的color和目标color
    for (let i = 0; i < maskDataArray.length; i += 4) {
        const r = maskDataArray[i],
            g = maskDataArray[i + 1],
            b = maskDataArray[i + 2];
        if (r != color.r || g != color.g || b != color.b) { // 颜色与目标颜色不相同,是mask区域
            // 设置alpha为0
            imageDataArray[i + 3] = 0;
        } else {
            find = true
        }
    }
    // 计算被分割出来的部分的矩形框
    let minX = w;
    let minY = h;
    let maxX = 0;
    let maxY = 0;
    for (let y = 0; y < h; y++) {
        for (let x = 0; x < w; x++) {
            const alpha = imageDataArray[(y * w + x) * 4 + 3];
            if (alpha !== 0) {
                minX = Math.min(minX, x);
                minY = Math.min(minY, y);
                maxX = Math.max(maxX, x);
                maxY = Math.max(maxY, y);
            }
        }
    }
    const width = maxX - minX + 1;
    const height = maxY - minY + 1;
    const startX = minX;
    const startY = minY;
    // console.log(`矩形宽度:${width}`);
    // console.log(`矩形高度:${height}`);
    // console.log(`起点坐标:(${startX}, ${startY})`);
    resultCtx.putImageData(imageData, 0, 0)
    // 创建一个新的canvas来存储特定区域的图像
    const croppedCanvas = document.createElement("canvas");
    const croppedContext = croppedCanvas.getContext("2d");
    croppedCanvas.width = width;
    croppedCanvas.height = height;
    // 将特定区域绘制到新canvas上
    croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
    croppedCanvas.toBlob(blob => {
        if (callback) {
            callback(blob)
        }
    }, "image/png");
}

/**
 * Cut out specific area whose category is target category
 * @param w image's natural width
 * @param h image's natural height
 * @param image source image
 * @param arr original mask array that stores all pixel's category
 * @param category target category
 * @param callback function to solve the image blob
 */
export const cutOutImageWithCategory = ({w, h}, image, arr, category, callback) => {
    const resultCanvas = document.createElement('canvas'),
        resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true});
    resultCanvas.width = w;
    resultCanvas.height = h;
    resultCtx.drawImage(image, 0, 0, w, h)
    const imageData = resultCtx.getImageData(0, 0, w, h);
    const imageDataArray = imageData.data
    // 比较mask的类别和目标类别
    let i = 0
    for(let y = 0; y < h; y++){
        for(let x = 0; x < w; x++){
            if (category != arr[i++]) { // 类别不相同,是mask区域
                // 设置alpha为0
                imageDataArray[3 + (w * y + x) * 4] = 0;
            }
        }
    }
    // 计算被分割出来的部分的矩形框
    let minX = w;
    let minY = h;
    let maxX = 0;
    let maxY = 0;
    for (let y = 0; y < h; y++) {
        for (let x = 0; x < w; x++) {
            const alpha = imageDataArray[(y * w + x) * 4 + 3];
            if (alpha !== 0) {
                minX = Math.min(minX, x);
                minY = Math.min(minY, y);
                maxX = Math.max(maxX, x);
                maxY = Math.max(maxY, y);
            }
        }
    }
    const width = maxX - minX + 1;
    const height = maxY - minY + 1;
    const startX = minX;
    const startY = minY;
    resultCtx.putImageData(imageData, 0, 0)
    // 创建一个新的canvas来存储特定区域的图像
    const croppedCanvas = document.createElement("canvas");
    const croppedContext = croppedCanvas.getContext("2d");
    croppedCanvas.width = width;
    croppedCanvas.height = height;
    // 将特定区域绘制到新canvas上
    croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
    croppedCanvas.toBlob(blob => {
        if (callback) {
            callback(blob)
        }
    }, "image/png");
}

二、后端代码

1.SAM下载

首先从github上下载SAM的代码https://github.com/facebookresearch/segment-anything

然后下载模型文件,保存到项目根目录/checkpoints中,

  • default or vit_h: ViT-H SAM model.
  • vit_l: ViT-L SAM model.
  • vit_b: ViT-B SAM model.

2.后端代码

在项目根目录下创建main.py

main.py

import os
import time

from PIL import Image
import numpy as np
import io
import base64
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
from pycocotools import mask as mask_utils
import lzstring


def init():
    # your model path
    checkpoint = "checkpoints/sam_vit_b_01ec64.pth"
    model_type = "vit_b"
    sam = sam_model_registry[model_type](checkpoint=checkpoint)
    sam.to(device='cuda')
    predictor = SamPredictor(sam)
    mask_generator = SamAutomaticMaskGenerator(sam)
    return predictor, mask_generator


predictor, mask_generator = init()

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins="*",
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

last_image = ""
last_logit = None


@app.post("/segment")
def process_image(body: dict):
    global last_image, last_logit
    print("start processing image", time.time())
    path = body["path"]
    is_first_segment = False
    # 看上次分割的图片是不是该图片
    if path != last_image:  # 不是该图片,重新生成图像embedding
        pil_image = Image.open(path)
        np_image = np.array(pil_image)
        predictor.set_image(np_image)
        last_image = path
        is_first_segment = True
        print("第一次识别该图片,获取embedding")
    # 获取mask
    clicks = body["clicks"]
    input_points = []
    input_labels = []
    for click in clicks:
        input_points.append([click["x"], click["y"]])
        input_labels.append(click["clickType"])
    print("input_points:{}, input_labels:{}".format(input_points, input_labels))
    input_points = np.array(input_points)
    input_labels = np.array(input_labels)
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        mask_input=last_logit[None, :, :] if not is_first_segment else None,
        multimask_output=is_first_segment  # 第一次产生3个结果,选择最优的
    )
    # 设置mask_input,为下一次做准备
    best = np.argmax(scores)
    last_logit = logits[best, :, :]
    masks = masks[best, :, :]
    # print(mask_utils.encode(np.asfortranarray(masks))["counts"])
    # numpy_array = np.frombuffer(mask_utils.encode(np.asfortranarray(masks))["counts"], dtype=np.uint8)
    # print("Uint8Array([" + ", ".join(map(str, numpy_array)) + "])")
    source_mask = mask_utils.encode(np.asfortranarray(masks))["counts"].decode("utf-8")
    # print(source_mask)
    lzs = lzstring.LZString()
    encoded = lzs.compressToEncodedURIComponent(source_mask)
    print("process finished", time.time())
    return {"shape": masks.shape, "mask": encoded}


@app.get("/everything")
def segment_everything(path: str):
    start_time = time.time()
    print("start segment_everything", start_time)
    pil_image = Image.open(path)
    np_image = np.array(pil_image)
    masks = mask_generator.generate(np_image)
    sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
    img = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]), dtype=np.uint8)
    for idx, ann in enumerate(sorted_anns, 0):
        img[ann['segmentation']] = idx
    #看一下mask是什么样
    #plt.figure(figsize=(10,10))
	#plt.imshow(img) 
	#plt.show()
    # 压缩数组
    result = my_compress(img)
    end_time = time.time()
    print("finished segment_everything", end_time)
    print("time cost", end_time - start_time)
    return {"shape": img.shape, "mask": result}


@app.get('/automatic_masks')
def automatic_masks(path: str):
    pil_image = Image.open(path)
    np_image = np.array(pil_image)
    mask = mask_generator.generate(np_image)
    sorted_anns = sorted(mask, key=(lambda x: x['area']), reverse=True)
    lzs = lzstring.LZString()
    res = []
    for ann in sorted_anns:
        m = ann['segmentation']
        source_mask = mask_utils.encode(m)['counts'].decode("utf-8")
        encoded = lzs.compressToEncodedURIComponent(source_mask)
        r = {
            "encodedMask": encoded,
            "point_coord": ann['point_coords'][0],
        }
        res.append(r)
    return res


# 就是将连续的数字统计个数,然后把[个数,数字]放到result中,类似rle算法
# 比如[[1,1,1,2,3,2,2,4,4],[3,3,4...]]
# result是[3,1,  1,2,  1,3,  2,2,  2,4,  2,3,...]
def my_compress(img):
    result = []
    last_pixel = img[0][0]
    count = 0
    for line in img:
        for pixel in line:
            if pixel == last_pixel:
                count += 1
            else:
                result.append(count)
                result.append(int(last_pixel))
                last_pixel = pixel
                count = 1
    result.append(count)
    result.append(int(last_pixel))
    return result

3.原神启动

在cmd或者pycharm终端,cd到项目根目录下,输入uvicorn main:app --port 8006,启动服务器

三、前端代码

1.页面代码

template
<template>
  <div class="segment-container">
    <ElScrollbar class="tool-box">
      <div class="image-section">
        <div class="title">
          <div style="padding-left:15px">
            <el-icon><Picture /></el-icon><span style="font-size: 18px;font-weight: 550;">展示图像</span>
            <el-icon class="header-icon"></el-icon>
          </div>
        </div>
        <ElScrollbar height="350px">
          <div v-if="cutOuts.length === 0">
            <p>未进行抠图</p>
            <p>左键设置区域为前景</p>
            <p>右键设置区域为背景</p>
          </div>
          <img v-for="src in cutOuts" :src="src" alt="加载中"
               @click="openInNewTab(src)"/>
        </ElScrollbar>
      </div>
      <div class="options-section">
        <span class="option" @click="reset">重置</span>
        <span :class="'option'+(clicks.length===0?' disabled':'')" @click="undo">撤销</span>
        <span :class="'option'+(clickHistory.length===0?' disabled':'')" @click="redo">恢复</span>
      </div>
      <button :class="'segmentation-button'+(lock||clicks.length===0?' disabled':'')"
              @click="cutImage">分割</button>
      <button :class="'segmentation-button'+(lock||isEverything?' disabled':'')"
              @click="segmentEverything">分割所有</button>
    </ElScrollbar>
    <div class="segment-box">
      <div class="segment-wrapper" :style="{'left': left + 'px'}">
        <img v-show="path" id="segment-image" :src="url" :style="{width:w, height:h}" alt="加载失败" crossorigin="anonymous"
             @mousedown="handleMouseDown" @mouseenter="canvasVisible = true"
             @mouseout="() => {if (!this.clicks.length&&!this.isEverything) this.canvasVisible = false}"/>
        <canvas v-show="path && canvasVisible" id="segment-canvas" :width="originalSize.w" :height="originalSize.h"></canvas>
        <div id="point-box" :style="{width:w, height:h}"></div>
      </div>

    </div>
  </div>
</template>
script
<script>
import throttle from "@/util/throttle";
import LZString from "lz-string";
import {
  rleFrString,
  decodeRleCounts,
  decodeEverythingMask,
  getUniqueColor,
  cutOutImage,
  cutOutImageWithMaskColor, cutOutImageWithCategory
} from "@/util/mask_utils";
import {ElCollapse, ElCollapseItem, ElScrollbar} from "element-plus";
import {Picture} from '@element-plus/icons-vue'
export default {
  name: "segment",
  components: {
    ElCollapse, ElCollapseItem, ElScrollbar, Picture
  },
  data() {
    return {
      image: null,
      clicks: [],
      clickHistory: [],
      originalSize: {w: 0, h: 0},
      w: 0,
      h: 0,
      left: 0,
      scale: 1,
      url: null, // url用来设置成img的src展示
      path: null, // path是该图片在文件系统中的绝对路径
      loading: false,
      lock: false,
      canvasVisible: true,
      // cutOuts: ['http://localhost:9000/p/2024/01/19/112ce48bd76e47c7900863a3a0147853.jpg', 'http://localhost:9000/p/2024/01/19/112ce48bd76e47c7900863a3a0147853.jpg'],
      cutOuts: [],
      isEverything: false
    }
  },
  mounted() {
    this.init()
  },
  methods: {
    async init() {
      this.loading = true
      // 从路由获取id
      let id = this.$route.params.id
      if (!id) {
        this.$message.error('未选择图片')
        return
      }
      this.id = id
      // 获取图片信息
      try {
        const { path, url } = await this.getPathAndUrl()
        this.loadImage(path, url)
      } catch (e) {
        console.error(e)
        this.$message.error(e)
      }
    },
    async getPathAndUrl() {
      let res = await this.$http.get("/photo/path/" + this.id)
      console.log(res)
      return res.data
    },
    loadImage(path, url) {
      let image = new Image();
      image.src = this.$photo_base + url;
      image.onload = () => {
        let w = image.width, h = image.height
        let nw, nh
        let body = document.querySelector('.segment-box')
        let mw = body.clientWidth, mh = body.clientHeight
        let ratio = w / h
        if (ratio * mh > mw) {
          nw = mw
          nh = mw / ratio
        } else {
          nh = mh
          nw = ratio * mh
        }
        this.originalSize = {w, h}
        nw = parseInt(nw)
        nh = parseInt(nh)
        this.w = nw + 'px'
        this.h = nh + 'px'
        this.left = (mw - nw) / 2
        this.scale = nw / w
        this.url = this.$photo_base + url
        this.path = path
        console.log((this.scale > 1 ? '放大' : '缩小') + w + ' --> ' + nw)
        const img = document.getElementById('segment-image')
        img.addEventListener('contextmenu', e => e.preventDefault())
        img.addEventListener('mousemove', throttle(this.handleMouseMove, 150))
        const canvas = document.getElementById('segment-canvas')
        canvas.style.transform = `scale(${this.scale})`
      }
    },
    getClick(e) {
      let click = {
        x: e.offsetX,
        y: e.offsetY,
      }
      const imageScale = this.scale
      click.x /= imageScale;
      click.y /= imageScale;
      if(e.which === 3){ // 右键
        click.clickType = 0
      } else if(e.which === 1 || e.which === 0) { // 左键
        click.clickType = 1
      }
      return click
    },
    handleMouseMove(e) {
      if (this.isEverything) { // 分割所有模式,返回
        return;
      }
      if (this.clicks.length !== 0) { // 选择了点
        return;
      }
      if (this.lock) {
        return;
      }
      this.lock = true;
      let click = this.getClick(e);
      requestIdleCallback(() => {
        this.getMask([click])
      })
    },
    handleMouseDown(e) {
      e.preventDefault();
      e.stopPropagation();
      if (e.button === 1) {
        return;
      }
      // 如果是“分割所有”模式,返回
      if (this.isEverything) {
        return;
      }
      if (this.lock) {
        return;
      }
      this.lock = true
      let click = this.getClick(e);
      this.placePoint(e.offsetX, e.offsetY, click.clickType)
      this.clicks.push(click);
      requestIdleCallback(() => {
        this.getMask()
      })
    },
    placePoint(x, y, clickType) {
      let box = document.getElementById('point-box')
      let point = document.createElement('div')
      point.className = 'segment-point' + (clickType ? '' : ' negative')
      point.style = `position: absolute;
                      width: 10px;
                      height: 10px;
                      border-radius: 50%;
                      background-color: ${clickType?'#409EFF':'#F56C6C '};
                      left: ${x-5}px;
                      top: ${y-5}px`
      // 点的id是在clicks数组中的下标索引
      point.id = 'point-' + this.clicks.length
      box.appendChild(point)
    },
    removePoint(i) {
      const selector = 'point-' + i
      let point = document.getElementById(selector)
      if (point != null) {
        point.remove()
      }
    },
    getMask(clicks) {
      // 如果clicks为空,则是mouse move产生的click
      if (clicks == null) {
        clicks = this.clicks
      }
      const data = {
        path: this.path,
        clicks: clicks
      }
      console.log(data)
      this.$http.post('http://localhost:8006/segment', data, {
        headers: {
          "Content-Type": "application/json"
        }
      }).then(res => {
        const shape = res.shape
        const maskenc = LZString.decompressFromEncodedURIComponent(res.mask);
        const decoded = rleFrString(maskenc)
        this.drawCanvas(shape, decodeRleCounts(shape, decoded))
        this.lock = false
      }).catch(err => {
        console.error(err)
        this.$message.error("生成失败")
        this.lock = false
      })
    },
    segmentEverything() {
      if (this.isEverything) { // 上一次刚点过了
        return;
      }
      if (this.lock) {
        return;
      }
      this.lock = true
      this.reset()
      this.isEverything = true
      this.canvasVisible = true
      this.$http.get("http://localhost:8006/everything?path=" + this.path).then(res => {
        const shape = res.shape
        const counts = res.mask
        this.drawEverythingCanvas(shape, decodeEverythingMask(shape, counts))
      }).catch(err => {
        console.error(err)
        this.$message.error("生成失败")
      })
    },
    drawCanvas(shape, arr) {
      let height = shape[0],
          width = shape[1]
      console.log("height: ", height, " width: ", width)
      let canvas = document.getElementById('segment-canvas'),
          canvasCtx = canvas.getContext("2d"),
          imgData = canvasCtx.getImageData(0, 0, width, height),
          pixelData = imgData.data
      let i = 0
      for(let x = 0; x < width; x++){
        for(let y = 0; y < height; y++){
          if (arr[i++] === 0) { // 如果是0,是背景,遮住
            pixelData[0 + (width * y + x) * 4] = 40;
            pixelData[1 + (width * y + x) * 4] = 40;
            pixelData[2 + (width * y + x) * 4] = 40;
            pixelData[3 + (width * y + x) * 4] = 190;
          } else {
            pixelData[3 + (width * y + x) * 4] = 0;
          }
        }
      }
      canvasCtx.putImageData(imgData, 0, 0)
    },
    drawEverythingCanvas(shape, arr) {
      const height = shape[0],
          width = shape[1]
      console.log("height: ", height, " width: ", width)
      let canvas = document.getElementById('segment-canvas'),
          canvasCtx = canvas.getContext("2d"),
          imgData = canvasCtx.getImageData(0, 0, width, height),
          pixelData = imgData.data;
      const colorMap = {}
      let i = 0
      for(let y = 0; y < height; y++){
        for(let x = 0; x < width; x++){
          const category = arr[i++]
          const color = getUniqueColor(category, colorMap)
          pixelData[0 + (width * y + x) * 4] = color.r;
          pixelData[1 + (width * y + x) * 4] = color.g;
          pixelData[2 + (width * y + x) * 4] = color.b;
          pixelData[3 + (width * y + x) * 4] = 150;
        }
      }
      // 显示在图片上
      canvasCtx.putImageData(imgData, 0, 0)
      // 开始分割每一个mask的图片
      const image = document.getElementById('segment-image')
      Object.keys(colorMap).forEach(category => {
        cutOutImageWithCategory(this.originalSize, image, arr, category, blob => {
          const url = URL.createObjectURL(blob);
          this.cutOuts = [url, ...this.cutOuts]
        })
      })
    },
    reset() {
      for (let i = 0; i < this.clicks.length; i++) {
        this.removePoint(i)
      }
      this.clicks = []
      this.clickHistory = []
      this.isEverything = false
      this.clearCanvas()
    },
    undo() {
      if (this.clicks.length === 0)
        return
      const idx = this.clicks.length - 1
      const click = this.clicks[idx]
      this.clickHistory.push(click)
      this.clicks.splice(idx, 1)
      this.removePoint(idx)
      if (this.clicks.length) {
        this.getMask()
      } else {
        this.clearCanvas()
      }
    },
    redo() {
      if (this.clickHistory.length === 0)
        return
      const idx = this.clickHistory.length - 1
      const click = this.clickHistory[idx]
      console.log(this.clicks, this.clickHistory, click)
      this.placePoint(click.x * this.scale, click.y * this.scale, click.clickType)
      this.clicks.push(click)
      this.clickHistory.splice(idx, 1)
      this.getMask()
    },
    clearCanvas() {
      let canvas = document.getElementById('segment-canvas')
      canvas.getContext('2d').clearRect(0, 0, canvas.width, canvas.height)
    },
    cutImage() {
      if (this.lock || this.clicks.length === 0) {
        return;
      }
      const canvas = document.getElementById('segment-canvas'),
          image = document.getElementById('segment-image')
      const {w, h} = this.originalSize
      cutOutImage(this.originalSize, image, canvas, blob => {
        const url = URL.createObjectURL(blob);
        this.cutOuts = [url, ...this.cutOuts]
        // 不需要之后用下面的清除文件
        // URL.revokeObjectURL(url);
      })
    },
    openInNewTab(src) {
      window.open(src, '_blank')
    }
  }
}
</script>
style
<style scoped lang="scss">
.segment-container {
  position: relative;
}

.tool-box {
  position: absolute;
  left: 20px;
  top: 20px;
  width: 200px;
  height: 600px;
  border-radius: 20px;
  //background: pink;
  overflow: auto;
  box-shadow: 0 0 5px rgb(150, 150, 150);
  box-sizing: border-box;
  padding: 10px;

  .image-section {
    height: fit-content;
    width: 100%;
    .title {
      height: 48px;
      line-height: 48px;
      border-bottom: 1px solid lightgray;
      margin-bottom: 15px;
    }
  }

  .image-section img {
    max-width: 85%;
    max-height: 140px;
    margin: 10px auto;
    padding: 10px;
    box-sizing: border-box;
    object-fit: contain;
    display: block;
    transition: .3s;
    cursor: pointer;
  }
  .image-section img:hover {
    background: rgba(0, 30, 160, 0.3);
  }

  .image-section p {
    text-align: center;
  }

  .options-section {
    margin-top: 5px;
    display: flex;
    justify-content: space-between;
    align-items: center;
    padding: 10px;
    box-sizing: border-box;
    border: 3px solid lightgray;
    border-radius: 20px;
  }
  .options-section:hover {
    border: 3px solid #59ACFF;
  }

  .option {
    font-size: 15px;
    padding: 5px 10px;
    cursor: pointer;
  }
  .option:hover {
    color: #59ACFF;
  }
  .option.disabled {
    color: gray;
    cursor: not-allowed;
  }

  .segmentation-button {
    margin-top: 5px;
    width: 100%;
    height: 40px;
    background-color: white;
    color: rgb(40, 40, 40);
    font-size: 17px;
    cursor: pointer;
    border: 3px solid lightgray;
    border-radius: 20px;
  }
  .segmentation-button:hover {
    border: 3px solid #59ACFF;
  }
  .segmentation-button.disabled {
    color: lightgray;
    cursor: not-allowed;
  }
}

.segment-box {
  position: relative;
  margin-left: calc(220px);
  width: calc(100% - 220px);
  height: calc(100vh - 80px);
  //background: #42b983;
  .segment-wrapper {
    position: absolute;
    left: 0;
    top: 0;
  }
  #segment-canvas {
    position: absolute;
    left: 0;
    top: 0;
    pointer-events: none;
    transform-origin: left top;
    z-index: 1;
  }
  #point-box {
    position: absolute;
    left: 0;
    top: 0;
    z-index: 2;
    pointer-events: none;
  }
  .segment-point {
    position: absolute;
    width: 10px;
    height: 10px;
    border-radius: 50%;
    background-color: #409EFF;
  }
  .segment-point.negative {
    background-color: #F56C6C;
  }
}
</style>

2.代码说明

  • 本项目没做上传图片分割,就是简单的选择本地图片分割,data中url是img的src,path是绝对路径用来传给python后端进行分割,我是从我项目的系统获取的,请自行修改代码成你的图片路径,如src: “/assets/test.jpg”, path:“D:/project/segment/assets/test.jpg”
  • 由于pycocotools的rle encode是从上到下进行统计连续的0和1,为了方便,我在【@/util/mask_utils.js:decodeRleCounts】解码Click点选产生的mask时将(H,W)的矩阵转成了(W,H)顺序存储的Uint8array;而在Everything分割所有时,我没有使用pycocotools的encode,而是main.py中的my_compress函数编码的,是从左到右进行压缩,因此矩阵解码后仍然是(H,W)的矩阵,所以在drawCanvasdrawEverythingCanvas中的二层循环xy的顺序不一样,我实在懒得改了,就这样就可以了。

关于上面所提rle,可以在项目根目录/notebooks/predictor_example.ipynb中产生mask的位置添加代码自行观察他编码的rle,他只支持矩阵元素为0或1,result的第一个位置是0的个数,不管矩阵是不是0开头。

  • [0,0,1,1,0,1,0],rle counts是[2(两个0), 2(两个1), 1(一个0), 1(一个1), 1(一个0)];

  • [1,1,1,1,1,0],rle counts是[0(零个0),5(五个1),1(一个0)]

def decode_rle(rle_string): # 这是将pycocotools的counts编码的字符串转成counts数组,而非转成原矩阵
    result = []
    char_index = 0
    
    while char_index < len(rle_string):
        value = 0
        k = 0
        more = 1
        
        while more:
            c = ord(rle_string[char_index]) - 48
            value |= (c & 0x1f) << (5 * k)
            more = c & 0x20
            char_index += 1
            k += 1
            if not more and c & 0x10:
                value |= -1 << (5 * k)
        
        if len(result) > 2:
            value += result[-2]
        result.append(value)
    return result

from pycocotools import mask as mask_utils
import numpy as np
mask = np.array([[1,1,0,1,1,0],[1,1,1,1,1,1],[0,1,1,1,0,0],[1,1,1,1,1,1]])
mask = np.asfortranarray(mask, dtype=np.uint8)
print("原mask:\n{}".format(mask))
res = mask_utils.encode(mask)
print("encode:{}".format(res))
print("rle counts:{}".format(decode_rle(res["counts"].decode("utf-8"))))
# 转置后好看
print("转置:{}".format(mask.transpose()))
# flatten后更好看
print("flatten:{}".format(mask.transpose().flatten()))
#numpy_array = np.frombuffer(res["counts"], dtype=np.uint8)
# 打印numpy数组作为uint8array的格式
#print("Uint8Array([" + ", ".join(map(str, numpy_array)) + "])")

输出:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1442886.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[架构之路-275]:五张图向你展现软件开发不仅仅是编码,而是一个庞大的系统工程

目录 一、软件开发是组织架构的一部分&#xff0c;是为业务服务的 二、软件开发是一个系统工程&#xff0c;需要组织各种组织内的资源 三、目标软件是一个复杂的系统 四、软件开发过程本身是一个系统工程 五、目标软件的测试验证是一个系统工程 一、软件开发是组织架构的一…

MYSQL笔记:约束条件

MYSQL笔记&#xff1a;约束条件 主键约束 不能为空&#xff0c;值必须是不同的&#xff08;唯一性&#xff09; 一个表只能修饰一个主键 PRIMARY KEY自增约束 AUTO_INCREMENT唯一键约束 可以为空 unique非空约束 not null 默认值约束 default 外键约束 foreign key …

基于图像掩膜和深度学习的花生豆分拣(附源码)

目录 项目介绍 图像分类网络构建 处理花生豆图片完成预测 项目介绍 这是一个使用图像掩膜技术和深度学习技术实现的一个花生豆分拣系统 我们有大量的花生豆图片&#xff0c;并以及打好了标签&#xff0c;可以看一下目录结构和几张具体的图片 同时我们也有几张大的图片&…

Qt网络编程-ZMQ的使用

不同主机或者相同主机中不同进程之间可以借助网络通信相互进行数据交互&#xff0c;网络通信实现了进程之间的通信。比如两个进程之间需要借助UDP进行单播通信&#xff0c;则双方需要知道对方的IP和端口&#xff0c;假设两者不在同一主机中&#xff0c;如下示意图&#xff1a; …

【C语言】SYSCALL_DEFINE3(socket, int, family, int, type, int, protocol)

一、SYSCALL_DEFINE3与系统调用 在Linux操作系统中&#xff0c;为了从用户空间跳转到内核空间执行特定的内核级操作&#xff0c;使用了一种机制叫做"系统调用"&#xff08;System Call&#xff09;。系统调用是操作系统提供给程序员访问和使用内核功能的接口。例如&…

OnlyOffice-8.0版本深度测评

OnlyOffice 是一套全面的开源办公协作软件&#xff0c;不断演进的 OnlyOffice 8.0 版本为用户带来了一系列引人瞩目的新特性和功能改进。OnlyOffice 8.0 版本在功能丰富性、安全性和用户友好性上都有显著提升&#xff0c;为用户提供了更为强大、便捷和安全的文档处理和协作环境…

【Docker】02 镜像管理

文章目录 一、Images镜像二、管理操作2.1 搜索镜像2.1.1 命令行搜索2.1.2 页面搜索2.1.3 搜索条件 2.2 下载镜像2.3 查看本地镜像2.3.1 docker images2.3.2 --help2.3.3 repository name2.3.4 --filter2.3.5 -q2.3.6 --format 2.4 给镜像打标签2.5 推送镜像2.6 删除镜像2.7 导出…

React18原理: 渲染与更新时的重点关注事项

概述 react 在渲染过程中要做很多事情&#xff0c;所以不可能直接通过初始元素直接渲染还需要一个东西&#xff0c;就是虚拟节点&#xff0c;暂不涉及React Fiber的概念&#xff0c;将vDom树和Fiber 树统称为虚拟节点有了初始元素后&#xff0c;React 就会根据初始元素和其他可…

1g的视频怎么压缩到200m?3个步骤解决~

把1G的文件压缩到200M&#xff0c;可以有效节省存储空间&#xff0c;加快传输速度&#xff0c;在某些情况下&#xff0c;压缩文件可以提供更好的安全性&#xff0c;例如通过加密或压缩算法保护文件内容。下面就向大家介绍3个好用的方法。 方法一&#xff1a;使用嗨格式压缩大师…

立体感十足的地图组件,如何设计出来的?

以下是一些设计可视化界面中的地图组件更具备立体感的建议&#xff1a; 使用渐变色&#xff1a; 可以使用不同的渐变色来表现地图的高低差异&#xff0c;例如使用深蓝色或深紫色来表示海底&#xff0c;使用浅绿色或黄色来表示低地&#xff0c;使用橙色或红色来表示高地。 添加…

springboot173疫苗发布和接种预约系统

简介 【毕设源码推荐 javaweb 项目】基于springbootvue 的 适用于计算机类毕业设计&#xff0c;课程设计参考与学习用途。仅供学习参考&#xff0c; 不得用于商业或者非法用途&#xff0c;否则&#xff0c;一切后果请用户自负。 看运行截图看 第五章 第四章 获取资料方式 **项…

数据分析基础之《pandas(7)—高级处理2》

四、合并 如果数据由多张表组成&#xff0c;那么有时候需要将不同的内容合并在一起分析 1、先回忆下numpy中如何合并 水平拼接 np.hstack() 竖直拼接 np.vstack() 两个都能实现 np.concatenate((a, b), axis) 2、pd.concat([data1, data2], axis1) 按照行或者列…

[超分辨率重建]ESRGAN算法训练自己的数据集过程

一、下载数据集及项目包 1. 数据集 1.1 文件夹框架的介绍&#xff0c;如下图所示&#xff1a;主要有train和val&#xff0c;分别有高清&#xff08;HR&#xff09;和低清&#xff08;LR&#xff09;的图像。 1.2 原图先通过分割尺寸的脚本先将数据集图片处理成两个相同的图像…

政安晨:示例演绎机器学习中(深度学习)神经网络的数学基础——快速理解核心概念(一){两篇文章讲清楚}

进入人工智能领域免不了与算法打交道&#xff0c;算法依托数学基础&#xff0c;很多小伙伴可能新生畏惧&#xff0c;不用怕&#xff0c;算法没那么难&#xff0c;也没那么玄乎&#xff0c;未来人工智能时代说不得人人都要了解算法、应用算法。 本文试图以一篇文章&#xff0c;…

分享76个表单按钮JS特效,总有一款适合您

分享76个表单按钮JS特效&#xff0c;总有一款适合您 76个表单按钮JS特效下载链接&#xff1a;https://pan.baidu.com/s/1CW9aoh23UIwj9zdJGNVb5w?pwd8888 提取码&#xff1a;8888 Python采集代码下载链接&#xff1a;采集代码.zip - 蓝奏云 学习知识费力气&#xff0c;收集…

(坑点!!!)给定n条过原点的直线和m条抛物线(y=ax^2+bx+c,a>0),对于每一条抛物线,是否存在一条直线与它没有交点,若有,输出直线斜率

题目 思路: 1、区间端点可能是小数的时候,不能直接利用加减1将 < 转化为 <=,例如,x < 1.5 不等价于 x <= 2.5 2、该题中k在(b - sqrt(4 * a * c), b + sqrt(4 * a * c) 中,注意是开区间,那么可以将左端点向上取整,右端点向下取整,即sqrt(4 * a * c)向下取…

Netty中的常用组件(三)

ChannelPipeline 基于Netty的网路应用程序中根据业务需求会使用Netty已经提供的Channelhandler 或者自行开发ChannelHandler&#xff0c;这些ChannelHandler都放在ChannelPipeline中统一 管理&#xff0c;事件就会在ChannelPipeline中流动&#xff0c;并被其中一个或者多个Chan…

Mysql-Explain-使用说明

Explain 说明 explain SELECT * FROM tb_category_report;id&#xff1a;SELECT识别符&#xff0c;这是SELECT查询序列号。select_type&#xff1a;表示单位查询的查询类型&#xff0c;比如&#xff1a;普通查询、联合查询(union、union all)、子查询等复杂查询。table&#x…

房屋租赁系统的Java实战开发之旅

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

【Java多线程案例】实现阻塞队列

1. 阻塞队列简介 1.1 阻塞队列概念 阻塞队列&#xff1a;是一种特殊的队列&#xff0c;具有队列"先进先出"的特性&#xff0c;同时相较于普通队列&#xff0c;阻塞队列是线程安全的&#xff0c;并且带有阻塞功能&#xff0c;表现形式如下&#xff1a; 当队列满时&…