PoseNet深度网络进行6D位姿估计的训练,python3实现

news2025/2/13 0:12:06

0.相关github网址

  • 原版github代码-caffe实现
  • tensorflow实现,相关版本较低,python2,本文根据此代码迁移到python3上面。
  • pytorch实现,但将骨干模型从goglenet改成了resnet,实验效果得到提升,但没公布预训练权重
  • 注意,有一个人体姿态检测的网络也叫做PoseNet,检索的时候注意不要弄混二者
  • PoseNet论文总结

1.任务背景

  • Image Matching Challenge 2023
  • 大致任务:根据图片推算相机位姿,包括3*3的旋转矩阵和3维的位置矩阵
  • 数据描述:train_labels.csv
    • dataset:数据集名字
    • scene:场景
    • image_path:图像路径
    • rotation_matrix:3*3的旋转矩阵
    • translation_vector:3维的位置矩阵
      结构如下:
      在这里插入图片描述

2.四元数与旋转矩阵的转换

  • 参考这里
  • 相互转换所需的函数:
#change to 四元数,https://zhuanlan.zhihu.com/p/45404840
def matrix2quaternion(m):
    #m:array
    w = ((np.trace(m) + 1) ** 0.5) / 2
    x = (m[2][1] - m[1][2]) / (4 * w)
    y = (m[0][2] - m[2][0]) / (4 * w)
    z = (m[1][0] - m[0][1]) / (4 * w)
    return w,x,y,z
def quaternion2matrix(q):
    #q:list
    w,x,y,z = q
    return np.array([[1-2*y*y-2*z*z, 2*x*y-2*z*w, 2*x*z+2*y*w],
             [2*x*y+2*z*w, 1-2*x*x-2*z*z, 2*y*z-2*x*w],
             [2*x*z-2*y*w, 2*y*z+2*x*w, 1-2*x*x-2*y*y]])

3.train_labels.csv文件处理

  • 需要根据rotation_matrix的数据计算出对应的四元数并存储到新列rotation_matrix_quaternion中:(使用列表推导式和map实现)
def m(a):
    a = a.split(';')
    a = [float(i) for i in a]
    A = np.array([[a[0],a[1],a[2]],
                [a[3],a[4],a[5]],
                [a[6],a[7],a[8]]])
    return matrix2quaternion(A)

change_train_labels = 1
if change_train_labels:
    train_labels = pd.read_csv('/kaggle/input/image-matching-challenge-2023/train/train_labels.csv')
    train_labels['rotation_matrix_quaternion'] = [i for i in map(m,train_labels['rotation_matrix'])]
    train_labels.to_csv('/kaggle/working/my_train_labels.csv')
  • '/kaggle/input/image-matching-challenge-2023/train/train_labels.csv'读入,写入'/kaggle/working/my_train_labels.csv'

4.网络模型相关

  • 构建神经网络类代码
DEFAULT_PADDING = 'SAME'


def layer(op):
    '''Decorator for composable network layers.'''

    def layer_decorated(self, *args, **kwargs):
        # Automatically set a name if not provided.
        name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
        # Figure out the layer inputs.
        if len(self.terminals) == 0:
            raise RuntimeError('No input variables found for layer %s.' % name)
        elif len(self.terminals) == 1:
            layer_input = self.terminals[0]
        else:
            layer_input = list(self.terminals)
        # Perform the operation and get the output.
        layer_output = op(self, layer_input, *args, **kwargs)
        # Add to layer LUT.
        self.layers[name] = layer_output
        # This output is now the input for the next layer.
        self.feed(layer_output)
        # Return self for chained calls.
        return self

    return layer_decorated


class Network(object):

    def __init__(self, inputs, trainable=True):
        # The input nodes for this network
        self.inputs = inputs
        # The current list of terminal nodes
        self.terminals = []
        # Mapping from layer names to layers
        self.layers = dict(inputs)
        # If true, the resulting variables are set as trainable
        self.trainable = trainable
        # Switch variable for dropout
        self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
                                                       shape=[],
                                                       name='use_dropout')
        self.setup()

    def setup(self):
        '''Construct the network. '''
        raise NotImplementedError('Must be implemented by the subclass.')

    def load(self, data_path, session, ignore_missing=False):
        '''Load network weights.
        data_path: The path to the numpy-serialized network weights
        session: The current TensorFlow session
        ignore_missing: If true, serialized weights for missing layers are ignored.
        '''
        data_dict = np.load(data_path,allow_pickle=True,encoding="latin1").item()
        for op_name in data_dict:
            with tf.variable_scope(op_name, reuse=True):
                for param_name, data in data_dict[op_name].items():
                    try:
                        var = tf.get_variable(param_name)
                        session.run(var.assign(data))
                    except ValueError:
                        if not ignore_missing:
                            raise

    def feed(self, *args):
        '''Set the input(s) for the next operation by replacing the terminal nodes.
        The arguments can be either layer names or the actual layers.
        '''
        assert len(args) != 0
        self.terminals = []
        for fed_layer in args:
            if isinstance(fed_layer, str):
                try:
                    fed_layer = self.layers[fed_layer]
                except KeyError:
                    raise KeyError('Unknown layer name fed: %s' % fed_layer)
            self.terminals.append(fed_layer)
        return self

    def get_output(self):
        '''Returns the current network output.'''
        return self.terminals[-1]

    def get_unique_name(self, prefix):
        '''Returns an index-suffixed unique name for the given prefix.
        This is used for auto-generating layer names based on the type-prefix.
        '''
        ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
        return '%s_%d' % (prefix, ident)

    def make_var(self, name, shape):
        '''Creates a new TensorFlow variable.'''
        return tf.get_variable(name, shape, trainable=self.trainable)

    def validate_padding(self, padding):
        '''Verifies that the padding is one of the supported ones.'''
        assert padding in ('SAME', 'VALID')

    @layer
    def conv(self,
             input,
             k_h,
             k_w,
             c_o,
             s_h,
             s_w,
             name,
             relu=True,
             padding=DEFAULT_PADDING,
             group=1,
             biased=True):
        # Verify that the padding is acceptable
        self.validate_padding(padding)
        # Get the number of channels in the input
        c_i = input.get_shape()[-1]
        # Verify that the grouping parameter is valid
        assert c_i % group == 0
        assert c_o % group == 0
        # Convolution for a given input and kernel
        convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
        with tf.variable_scope(name) as scope:
            kernel = self.make_var('weights', shape=[k_h, k_w, int(int(c_i) / group), c_o])
            if group == 1:
                # This is the common-case. Convolve the input without any further complications.
                output = convolve(input, kernel)
            else:
                # Split the input into groups and then convolve each of them independently
                input_groups = tf.split(3, group, input)
                kernel_groups = tf.split(3, group, kernel)
                output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
                # Concatenate the groups
                output = tf.concat(3, output_groups)
            # Add the biases
            if biased:
                biases = self.make_var('biases', [c_o])
                output = tf.nn.bias_add(output, biases)
            if relu:
                # ReLU non-linearity
                output = tf.nn.relu(output, name=scope.name)
            return output

    @layer
    def relu(self, input, name):
        return tf.nn.relu(input, name=name)

    @layer
    def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.max_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.avg_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def lrn(self, input, radius, alpha, beta, name, bias=1.0):
        return tf.nn.local_response_normalization(input,
                                                  depth_radius=radius,
                                                  alpha=alpha,
                                                  beta=beta,
                                                  bias=bias,
                                                  name=name)

    @layer
    def concat(self, inputs, axis, name):
        return tf.concat(values=inputs, axis=axis, name=name)

    @layer
    def add(self, inputs, name):
        return tf.add_n(inputs, name=name)

    @layer
    def fc(self, input, num_out, name, relu=True):
        with tf.variable_scope(name) as scope:
            input_shape = input.get_shape()
            if input_shape.ndims == 4:
                # The input is spatial. Vectorize it first.
                dim = 1
                for d in input_shape[1:].as_list():
                    dim *= d
                feed_in = tf.reshape(input, [-1, dim])
            else:
                feed_in, dim = (input, input_shape[-1].value)
            weights = self.make_var('weights', shape=[dim, num_out])
            biases = self.make_var('biases', [num_out])
            op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
            fc = op(feed_in, weights, biases, name=scope.name)
            return fc

    @layer
    def softmax(self, input, name):
        input_shape = map(lambda v: v.value, input.get_shape())
        if len(input_shape) > 2:
            # For certain models (like NiN), the singleton spatial dimensions
            # need to be explicitly squeezed, since they're not broadcast-able
            # in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
            if input_shape[1] == 1 and input_shape[2] == 1:
                input = tf.squeeze(input, squeeze_dims=[1, 2])
            else:
                raise ValueError('Rank 2 tensor input expected for softmax!')
        return tf.nn.softmax(input, name)

    @layer
    def batch_normalization(self, input, name, scale_offset=True, relu=False):
        # NOTE: Currently, only inference is supported
        with tf.variable_scope(name) as scope:
            shape = [input.get_shape()[-1]]
            if scale_offset:
                scale = self.make_var('scale', shape=shape)
                offset = self.make_var('offset', shape=shape)
            else:
                scale, offset = (None, None)
            output = tf.nn.batch_normalization(
                input,
                mean=self.make_var('mean', shape=shape),
                variance=self.make_var('variance', shape=shape),
                offset=offset,
                scale=scale,
                # TODO: This is the default Caffe batch norm eps
                # Get the actual eps from parameters
                variance_epsilon=1e-5,
                name=name)
            if relu:
                output = tf.nn.relu(output)
            return output

    @layer
    def dropout(self, input, keep_prob, name):
        keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
        return tf.nn.dropout(input, keep, name=name)
  • 构建骨干网络GoogLeNet:
class GoogLeNet(Network):
    def setup(self):
        (self.feed('data')
             .conv(7, 7, 64, 2, 2, name='conv1')
             .max_pool(3, 3, 2, 2, name='pool1')
             .lrn(2, 2e-05, 0.75, name='norm1')
             .conv(1, 1, 64, 1, 1, name='reduction2')
             .conv(3, 3, 192, 1, 1, name='conv2')
             .lrn(2, 2e-05, 0.75, name='norm2')
             .max_pool(3, 3, 2, 2, name='pool2')
             .conv(1, 1, 96, 1, 1, name='icp1_reduction1')
             .conv(3, 3, 128, 1, 1, name='icp1_out1'))

        (self.feed('pool2')
             .conv(1, 1, 16, 1, 1, name='icp1_reduction2')
             .conv(5, 5, 32, 1, 1, name='icp1_out2'))

        (self.feed('pool2')
             .max_pool(3, 3, 1, 1, name='icp1_pool')
             .conv(1, 1, 32, 1, 1, name='icp1_out3'))

        (self.feed('pool2')
             .conv(1, 1, 64, 1, 1, name='icp1_out0'))

        (self.feed('icp1_out0', 
                   'icp1_out1', 
                   'icp1_out2', 
                   'icp1_out3')
             .concat(3, name='icp2_in')
             .conv(1, 1, 128, 1, 1, name='icp2_reduction1')
             .conv(3, 3, 192, 1, 1, name='icp2_out1'))

        (self.feed('icp2_in')
             .conv(1, 1, 32, 1, 1, name='icp2_reduction2')
             .conv(5, 5, 96, 1, 1, name='icp2_out2'))

        (self.feed('icp2_in')
             .max_pool(3, 3, 1, 1, name='icp2_pool')
             .conv(1, 1, 64, 1, 1, name='icp2_out3'))

        (self.feed('icp2_in')
             .conv(1, 1, 128, 1, 1, name='icp2_out0'))

        (self.feed('icp2_out0', 
                   'icp2_out1', 
                   'icp2_out2', 
                   'icp2_out3')
             .concat(3, name='icp2_out')
             .max_pool(3, 3, 2, 2, name='icp3_in')
             .conv(1, 1, 96, 1, 1, name='icp3_reduction1')
             .conv(3, 3, 208, 1, 1, name='icp3_out1'))

        (self.feed('icp3_in')
             .conv(1, 1, 16, 1, 1, name='icp3_reduction2')
             .conv(5, 5, 48, 1, 1, name='icp3_out2'))

        (self.feed('icp3_in')
             .max_pool(3, 3, 1, 1, name='icp3_pool')
             .conv(1, 1, 64, 1, 1, name='icp3_out3'))

        (self.feed('icp3_in')
             .conv(1, 1, 192, 1, 1, name='icp3_out0'))

        (self.feed('icp3_out0', 
                   'icp3_out1', 
                   'icp3_out2', 
                   'icp3_out3')
             .concat(3, name='icp3_out')
             .avg_pool(5, 5, 3, 3, padding='VALID', name='cls1_pool')
             .conv(1, 1, 128, 1, 1, name='cls1_reduction_pose')
             .fc(1024, name='cls1_fc1_pose')
             .fc(3, relu=False, name='cls1_fc_pose_xyz'))

        (self.feed('cls1_fc1_pose')
             .fc(4, relu=False, name='cls1_fc_pose_wpqr'))

        (self.feed('icp3_out')
             .conv(1, 1, 112, 1, 1, name='icp4_reduction1')
             .conv(3, 3, 224, 1, 1, name='icp4_out1'))

        (self.feed('icp3_out')
             .conv(1, 1, 24, 1, 1, name='icp4_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp4_out2'))

        (self.feed('icp3_out')
             .max_pool(3, 3, 1, 1, name='icp4_pool')
             .conv(1, 1, 64, 1, 1, name='icp4_out3'))

        (self.feed('icp3_out')
             .conv(1, 1, 160, 1, 1, name='icp4_out0'))

        (self.feed('icp4_out0', 
                   'icp4_out1', 
                   'icp4_out2', 
                   'icp4_out3')
             .concat(3, name='icp4_out')
             .conv(1, 1, 128, 1, 1, name='icp5_reduction1')
             .conv(3, 3, 256, 1, 1, name='icp5_out1'))

        (self.feed('icp4_out')
             .conv(1, 1, 24, 1, 1, name='icp5_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp5_out2'))

        (self.feed('icp4_out')
             .max_pool(3, 3, 1, 1, name='icp5_pool')
             .conv(1, 1, 64, 1, 1, name='icp5_out3'))

        (self.feed('icp4_out')
             .conv(1, 1, 128, 1, 1, name='icp5_out0'))

        (self.feed('icp5_out0', 
                   'icp5_out1', 
                   'icp5_out2', 
                   'icp5_out3')
             .concat(3, name='icp5_out')
             .conv(1, 1, 144, 1, 1, name='icp6_reduction1')
             .conv(3, 3, 288, 1, 1, name='icp6_out1'))

        (self.feed('icp5_out')
             .conv(1, 1, 32, 1, 1, name='icp6_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp6_out2'))

        (self.feed('icp5_out')
             .max_pool(3, 3, 1, 1, name='icp6_pool')
             .conv(1, 1, 64, 1, 1, name='icp6_out3'))

        (self.feed('icp5_out')
             .conv(1, 1, 112, 1, 1, name='icp6_out0'))

        (self.feed('icp6_out0', 
                   'icp6_out1', 
                   'icp6_out2', 
                   'icp6_out3')
             .concat(3, name='icp6_out')
             .avg_pool(5, 5, 3, 3, padding='VALID', name='cls2_pool')
             .conv(1, 1, 128, 1, 1, name='cls2_reduction_pose')
             .fc(1024, name='cls2_fc1')
             .fc(3, relu=False, name='cls2_fc_pose_xyz'))

        (self.feed('cls2_fc1')
             .fc(4, relu=False, name='cls2_fc_pose_wpqr'))

        (self.feed('icp6_out')
             .conv(1, 1, 160, 1, 1, name='icp7_reduction1')
             .conv(3, 3, 320, 1, 1, name='icp7_out1'))

        (self.feed('icp6_out')
             .conv(1, 1, 32, 1, 1, name='icp7_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp7_out2'))

        (self.feed('icp6_out')
             .max_pool(3, 3, 1, 1, name='icp7_pool')
             .conv(1, 1, 128, 1, 1, name='icp7_out3'))

        (self.feed('icp6_out')
             .conv(1, 1, 256, 1, 1, name='icp7_out0'))

        (self.feed('icp7_out0', 
                   'icp7_out1', 
                   'icp7_out2', 
                   'icp7_out3')
             .concat(3, name='icp7_out')
             .max_pool(3, 3, 2, 2, name='icp8_in')
             .conv(1, 1, 160, 1, 1, name='icp8_reduction1')
             .conv(3, 3, 320, 1, 1, name='icp8_out1'))

        (self.feed('icp8_in')
             .conv(1, 1, 32, 1, 1, name='icp8_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp8_out2'))

        (self.feed('icp8_in')
             .max_pool(3, 3, 1, 1, name='icp8_pool')
             .conv(1, 1, 128, 1, 1, name='icp8_out3'))

        (self.feed('icp8_in')
             .conv(1, 1, 256, 1, 1, name='icp8_out0'))

        (self.feed('icp8_out0', 
                   'icp8_out1', 
                   'icp8_out2', 
                   'icp8_out3')
             .concat(3, name='icp8_out')
             .conv(1, 1, 192, 1, 1, name='icp9_reduction1')
             .conv(3, 3, 384, 1, 1, name='icp9_out1'))

        (self.feed('icp8_out')
             .conv(1, 1, 48, 1, 1, name='icp9_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp9_out2'))

        (self.feed('icp8_out')
             .max_pool(3, 3, 1, 1, name='icp9_pool')
             .conv(1, 1, 128, 1, 1, name='icp9_out3'))

        (self.feed('icp8_out')
             .conv(1, 1, 384, 1, 1, name='icp9_out0'))

        (self.feed('icp9_out0', 
                   'icp9_out1', 
                   'icp9_out2', 
                   'icp9_out3')
             .concat(3, name='icp9_out')
             .avg_pool(7, 7, 1, 1, padding='VALID', name='cls3_pool')
             .fc(2048, name='cls3_fc1_pose')
             .fc(3, relu=False, name='cls3_fc_pose_xyz'))

        (self.feed('cls3_fc1_pose')
             .fc(4, relu=False, name='cls3_fc_pose_wpqr'))

5.图像预处理部分

  • PoseNet的输入图像是224*224分辨率的,加上本任务对图像尺寸视角等敏感,不适合直接放缩,所以采用中心裁剪的办法,中心裁剪函数如下:
def centeredCrop(img, output_side_length):
    height, width, depth = img.shape
    new_height = output_side_length
    new_width = output_side_length
    if height > width:
        new_height = output_side_length * height / width
    else:
        new_width = output_side_length * width / height
    height_offset = (new_height - output_side_length) / 2
    width_offset = (new_width - output_side_length) / 2
    cropped_img = img[height_offset:height_offset + output_side_length,
                        width_offset:width_offset + output_side_length]
    return cropped_img
  • 预处理函数,在这个函数中会调用上面的中心裁剪函数,并对图像的每个通道进行归一化,并完成维度转换,方便送入PyTorch的网络:
def preprocess(images):
    images_out = [] #final result
    #Resize and crop and compute mean!
    images_cropped = []
    for i in tqdm(range(len(images)):
        #print(images[i])
        X = cv2.imread(images[i])
        #X = cv2.resize(X, (455, 256))
        X = centeredCrop(X, 224)
        images_cropped.append(X)
    #compute images mean
    N = 0
    mean = np.zeros((1, 3, 224, 224))
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        #print(X.shape)#3,224,224
        #print(X[0,:,:].shape)#3,224
        #print(mean[0][0].shape)#224,224
        mean[0][0] += X[0,:,:]
        mean[0][1] += X[1,:,:]
        mean[0][2] += X[2,:,:]
        N += 1
    mean[0] /= N
    #Subtract mean from all images
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        X = X - mean
        X = np.squeeze(X)
        X = np.transpose(X, (1,2,0))
        images_out.append(X)
    return images_out
  • 如果调试的时候,为了快速验证,可以不处理全部函数,比如len(images)*0+2


#network.py
DEFAULT_PADDING = 'SAME'


def layer(op):
    '''Decorator for composable network layers.'''

    def layer_decorated(self, *args, **kwargs):
        # Automatically set a name if not provided.
        name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
        # Figure out the layer inputs.
        if len(self.terminals) == 0:
            raise RuntimeError('No input variables found for layer %s.' % name)
        elif len(self.terminals) == 1:
            layer_input = self.terminals[0]
        else:
            layer_input = list(self.terminals)
        # Perform the operation and get the output.
        layer_output = op(self, layer_input, *args, **kwargs)
        # Add to layer LUT.
        self.layers[name] = layer_output
        # This output is now the input for the next layer.
        self.feed(layer_output)
        # Return self for chained calls.
        return self

    return layer_decorated


class Network(object):

    def __init__(self, inputs, trainable=True):
        # The input nodes for this network
        self.inputs = inputs
        # The current list of terminal nodes
        self.terminals = []
        # Mapping from layer names to layers
        self.layers = dict(inputs)
        # If true, the resulting variables are set as trainable
        self.trainable = trainable
        # Switch variable for dropout
        self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
                                                       shape=[],
                                                       name='use_dropout')
        self.setup()

    def setup(self):
        '''Construct the network. '''
        raise NotImplementedError('Must be implemented by the subclass.')

    def load(self, data_path, session, ignore_missing=False):
        '''Load network weights.
        data_path: The path to the numpy-serialized network weights
        session: The current TensorFlow session
        ignore_missing: If true, serialized weights for missing layers are ignored.
        '''
        data_dict = np.load(data_path,allow_pickle=True,encoding="latin1").item()
        for op_name in data_dict:
            with tf.variable_scope(op_name, reuse=True):
                for param_name, data in data_dict[op_name].items():
                    try:
                        var = tf.get_variable(param_name)
                        session.run(var.assign(data))
                    except ValueError:
                        if not ignore_missing:
                            raise

    def feed(self, *args):
        '''Set the input(s) for the next operation by replacing the terminal nodes.
        The arguments can be either layer names or the actual layers.
        '''
        assert len(args) != 0
        self.terminals = []
        for fed_layer in args:
            if isinstance(fed_layer, str):
                try:
                    fed_layer = self.layers[fed_layer]
                except KeyError:
                    raise KeyError('Unknown layer name fed: %s' % fed_layer)
            self.terminals.append(fed_layer)
        return self

    def get_output(self):
        '''Returns the current network output.'''
        return self.terminals[-1]

    def get_unique_name(self, prefix):
        '''Returns an index-suffixed unique name for the given prefix.
        This is used for auto-generating layer names based on the type-prefix.
        '''
        ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
        return '%s_%d' % (prefix, ident)

    def make_var(self, name, shape):
        '''Creates a new TensorFlow variable.'''
        return tf.get_variable(name, shape, trainable=self.trainable)

    def validate_padding(self, padding):
        '''Verifies that the padding is one of the supported ones.'''
        assert padding in ('SAME', 'VALID')

    @layer
    def conv(self,
             input,
             k_h,
             k_w,
             c_o,
             s_h,
             s_w,
             name,
             relu=True,
             padding=DEFAULT_PADDING,
             group=1,
             biased=True):
        # Verify that the padding is acceptable
        self.validate_padding(padding)
        # Get the number of channels in the input
        c_i = input.get_shape()[-1]
        # Verify that the grouping parameter is valid
        assert c_i % group == 0
        assert c_o % group == 0
        # Convolution for a given input and kernel
        convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
        with tf.variable_scope(name) as scope:
            kernel = self.make_var('weights', shape=[k_h, k_w, int(int(c_i) / group), c_o])
            if group == 1:
                # This is the common-case. Convolve the input without any further complications.
                output = convolve(input, kernel)
            else:
                # Split the input into groups and then convolve each of them independently
                input_groups = tf.split(3, group, input)
                kernel_groups = tf.split(3, group, kernel)
                output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
                # Concatenate the groups
                output = tf.concat(3, output_groups)
            # Add the biases
            if biased:
                biases = self.make_var('biases', [c_o])
                output = tf.nn.bias_add(output, biases)
            if relu:
                # ReLU non-linearity
                output = tf.nn.relu(output, name=scope.name)
            return output

    @layer
    def relu(self, input, name):
        return tf.nn.relu(input, name=name)

    @layer
    def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.max_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.avg_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def lrn(self, input, radius, alpha, beta, name, bias=1.0):
        return tf.nn.local_response_normalization(input,
                                                  depth_radius=radius,
                                                  alpha=alpha,
                                                  beta=beta,
                                                  bias=bias,
                                                  name=name)

    @layer
    def concat(self, inputs, axis, name):
        return tf.concat(values=inputs, axis=axis, name=name)

    @layer
    def add(self, inputs, name):
        return tf.add_n(inputs, name=name)

    @layer
    def fc(self, input, num_out, name, relu=True):
        with tf.variable_scope(name) as scope:
            input_shape = input.get_shape()
            if input_shape.ndims == 4:
                # The input is spatial. Vectorize it first.
                dim = 1
                for d in input_shape[1:].as_list():
                    dim *= d
                feed_in = tf.reshape(input, [-1, dim])
            else:
                feed_in, dim = (input, input_shape[-1].value)
            weights = self.make_var('weights', shape=[dim, num_out])
            biases = self.make_var('biases', [num_out])
            op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
            fc = op(feed_in, weights, biases, name=scope.name)
            return fc

    @layer
    def softmax(self, input, name):
        input_shape = map(lambda v: v.value, input.get_shape())
        if len(input_shape) > 2:
            # For certain models (like NiN), the singleton spatial dimensions
            # need to be explicitly squeezed, since they're not broadcast-able
            # in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
            if input_shape[1] == 1 and input_shape[2] == 1:
                input = tf.squeeze(input, squeeze_dims=[1, 2])
            else:
                raise ValueError('Rank 2 tensor input expected for softmax!')
        return tf.nn.softmax(input, name)

    @layer
    def batch_normalization(self, input, name, scale_offset=True, relu=False):
        # NOTE: Currently, only inference is supported
        with tf.variable_scope(name) as scope:
            shape = [input.get_shape()[-1]]
            if scale_offset:
                scale = self.make_var('scale', shape=shape)
                offset = self.make_var('offset', shape=shape)
            else:
                scale, offset = (None, None)
            output = tf.nn.batch_normalization(
                input,
                mean=self.make_var('mean', shape=shape),
                variance=self.make_var('variance', shape=shape),
                offset=offset,
                scale=scale,
                # TODO: This is the default Caffe batch norm eps
                # Get the actual eps from parameters
                variance_epsilon=1e-5,
                name=name)
            if relu:
                output = tf.nn.relu(output)
            return output

    @layer
    def dropout(self, input, keep_prob, name):
        keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
        return tf.nn.dropout(input, keep, name=name)

def centeredCrop(img, output_side_length):
    height, width, depth = img.shape
    new_height = output_side_length
    new_width = output_side_length
    if height > width:
        new_height = output_side_length * height / width
    else:
        new_width = output_side_length * width / height
    height_offset = (new_height - output_side_length) / 2
    width_offset = (new_width - output_side_length) / 2
    cropped_img = img[height_offset:height_offset + output_side_length,
                        width_offset:width_offset + output_side_length]
    return cropped_img
def preprocess(images):
    images_out = [] #final result
    #Resize and crop and compute mean!
    images_cropped = []
    for i in tqdm(range(len(images)*0+2)):
        #print(images[i])
        X = cv2.imread(images[i])
        #X = cv2.resize(X, (455, 256))
        X = centeredCrop(X, 224)
        images_cropped.append(X)
    #compute images mean
    N = 0
    mean = np.zeros((1, 3, 224, 224))
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        #print(X.shape)#3,224,224
        #print(X[0,:,:].shape)#3,224
        #print(mean[0][0].shape)#224,224
        mean[0][0] += X[0,:,:]
        mean[0][1] += X[1,:,:]
        mean[0][2] += X[2,:,:]
        N += 1
    mean[0] /= N
    #Subtract mean from all images
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        X = X - mean
        X = np.squeeze(X)
        X = np.transpose(X, (1,2,0))
        images_out.append(X)
    return images_out


6.数据加载

  • 基本配置设置(my_train_labels是之前文件处理写入的路径)
batch_size = 75
max_iterations = 30000
# Set this path to your dataset directory
my_train_labels = '/kaggle/working/my_train_labels.csv'
  • 创建datasource类,成对存储数据
class datasource(object):
    def __init__(self, images, poses):
        self.images = images
        self.poses = poses
  • 获取单一数据
def gen_data(source):
    while True:
        indices = list(range(len(source.images)))
        random.shuffle(indices)
        for i in indices:
            image = source.images[i]
            pose_x = source.poses[i][0:3]
            pose_q = source.poses[i][3:7]
            yield image, pose_x, pose_q
  • 批量获取数据
def gen_data_batch(source):
    data_gen = gen_data(source)
    while True:
        image_batch = []
        pose_x_batch = []
        pose_q_batch = []
        for _ in range(batch_size):
            image, pose_x, pose_q = next(data_gen)
            image_batch.append(image)
            pose_x_batch.append(pose_x)
            pose_q_batch.append(pose_q)
        yield np.array(image_batch), np.array(pose_x_batch), np.array(pose_q_batch)
  • 获取数据的最终函数:(中间的路径需要根据训练图像所处位置进行更改,按照任务背景的目录结构,是train_labels.csv所处文件夹)
def get_data():
    poses = []
    images = []
    for i in pd.read_csv(my_train_labels).itertuples():
        #i[4]:image_path
        #i[6]:xyz需要根据
        #i[7]:四元数
        #print(i[4])
        #print(i[7])
        p0,p1,p2 = i[6].split(';')
        p3,p4,p5,p6 = i[7].split('(')[1].split(')')[0].split(',')
        p0 = float(p0)
        p1 = float(p1)
        p2 = float(p2)
        p3 = float(p3)
        p4 = float(p4)
        p5 = float(p5)
        p6 = float(p6)
        poses.append((p0,p1,p2,p3,p4,p5,p6))
        images.append('/kaggle/input/image-matching-challenge-2023/train/' + i[4])
        #print(poses,images)
    images = preprocess(images)
    return datasource(images, poses)

7.网络和数据容器的准备

images = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])
poses_x = tf.placeholder(tf.float32, [batch_size, 3])
poses_q = tf.placeholder(tf.float32, [batch_size, 4])
datasource = get_data()

net = GoogLeNet({'data': images})

p1_x = net.layers['cls1_fc_pose_xyz']
p1_q = net.layers['cls1_fc_pose_wpqr']
p2_x = net.layers['cls2_fc_pose_xyz']
p2_q = net.layers['cls2_fc_pose_wpqr']
p3_x = net.layers['cls3_fc_pose_xyz']
p3_q = net.layers['cls3_fc_pose_wpqr']

l1_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p1_x, poses_x)))) * 0.3
l1_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p1_q, poses_q)))) * 150
l2_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p2_x, poses_x)))) * 0.3
l2_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p2_q, poses_q)))) * 150
l3_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p3_x, poses_x)))) * 1
l3_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p3_q, poses_q)))) * 500

loss = l1_x + l1_q + l2_x + l2_q + l3_x + l3_q
opt = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=0.00000001, use_locking=False, name='Adam').minimize(loss)

# Set GPU options
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6833)

init = tf.global_variables_initializer()
saver = tf.train.Saver()
outputFile = "PoseNet.ckpt"
  • 这里运行第二遍会报错,因为网络在内存中已经构建起来了

8.开始训练

  • 下面的代码cpu、gpu环境都可以,每20轮打印一下损失,每500轮保存一下权重
  • 加载的预训练权重放在了'/kaggle/input/tensorflow-posenet-master/tensorflow-posenet-master/posenet.npy',是官方caffe权重转换过来的,下载链接
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    # Load the data
    sess.run(init)
    net.load('/kaggle/input/tensorflow-posenet-master/tensorflow-posenet-master/posenet.npy', sess)

    data_gen = gen_data_batch(datasource)
    for i in range(max_iterations):
        np_images, np_poses_x, np_poses_q = next(data_gen)
        feed = {images: np_images, poses_x: np_poses_x, poses_q: np_poses_q}

        sess.run(opt, feed_dict=feed)
        np_loss = sess.run(loss, feed_dict=feed)
        if i % 20 == 0:
            print("iteration: " + str(i) + "\n\t" + "Loss is: " + str(np_loss))
        if i % 500 == 0:
            saver.save(sess, outputFile)
            print("Intermediate file saved at: " + outputFile)
    saver.save(sess, outputFile)
    print("Intermediate file saved at: " + outputFile)
  • 效果如下:
    在这里插入图片描述## 9.完整代码如下
  • 完整notebook下载链接
  • tensorflow权重文件下载地址
  • 完整代码如下:
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import random
import cv2
from tqdm import tqdm

change_train_labels = 1
def matrix2quaternion(m):
    #m:array
    w = ((np.trace(m) + 1) ** 0.5) / 2
    x = (m[2][1] - m[1][2]) / (4 * w)
    y = (m[0][2] - m[2][0]) / (4 * w)
    z = (m[1][0] - m[0][1]) / (4 * w)
    return w,x,y,z
def quaternion2matrix(q):
    #q:list
    w,x,y,z = q
    return np.array([[1-2*y*y-2*z*z, 2*x*y-2*z*w, 2*x*z+2*y*w],
             [2*x*y+2*z*w, 1-2*x*x-2*z*z, 2*y*z-2*x*w],
             [2*x*z-2*y*w, 2*y*z+2*x*w, 1-2*x*x-2*y*y]])

def m(a):
    a = a.split(';')
    a = [float(i) for i in a]
    A = np.array([[a[0],a[1],a[2]],
                [a[3],a[4],a[5]],
                [a[6],a[7],a[8]]])
    return matrix2quaternion(A)

if change_train_labels:
    train_labels = pd.read_csv('/kaggle/input/image-matching-challenge-2023/train/train_labels.csv')
    train_labels['rotation_matrix_quaternion'] = [i for i in map(m,train_labels['rotation_matrix'])]
    train_labels.to_csv('/kaggle/working/my_train_labels.csv')

DEFAULT_PADDING = 'SAME'


def layer(op):
    '''Decorator for composable network layers.'''

    def layer_decorated(self, *args, **kwargs):
        # Automatically set a name if not vided.
        name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
        # Figure out the layer inputs.
        if len(self.terminals) == 0:
            raise RuntimeError('No input variables found for layer %s.' % name)
        elif len(self.terminals) == 1:
            layer_input = self.terminals[0]
        else:
            layer_input = list(self.terminals)
        # Perform the operation and get the output.
        layer_output = op(self, layer_input, *args, **kwargs)
        # Add to layer LUT.
        self.layers[name] = layer_output
        # This output is now the input for the next layer.
        self.feed(layer_output)
        # Return self for chained calls.
        return self

    return layer_decorated


class Network(object):

    def __init__(self, inputs, trainable=True):
        # The input nodes for this network
        self.inputs = inputs
        # The current list of terminal nodes
        self.terminals = []
        # Mapping from layer names to layers
        self.layers = dict(inputs)
        # If true, the resulting variables are set as trainable
        self.trainable = trainable
        # Switch variable for dropout
        self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
                                                       shape=[],
                                                       name='use_dropout')
        self.setup()

    def setup(self):
        '''Construct the network. '''
        raise NotImplementedError('Must be implemented by the subclass.')

    def load(self, data_path, session, ignore_missing=False):
        '''Load network weights.
        data_path: The path to the numpy-serialized network weights
        session: The current TensorFlow session
        ignore_missing: If true, serialized weights for missing layers are ignored.
        '''
        data_dict = np.load(data_path,allow_pickle=True,encoding="latin1").item()
        for op_name in data_dict:
            with tf.variable_scope(op_name, reuse=True):
                for param_name, data in data_dict[op_name].items():
                    try:
                        var = tf.get_variable(param_name)
                        session.run(var.assign(data))
                    except ValueError:
                        if not ignore_missing:
                            raise

    def feed(self, *args):
        '''Set the input(s) for the next operation by replacing the terminal nodes.
        The arguments can be either layer names or the actual layers.
        '''
        assert len(args) != 0
        self.terminals = []
        for fed_layer in args:
            if isinstance(fed_layer, str):
                try:
                    fed_layer = self.layers[fed_layer]
                except KeyError:
                    raise KeyError('Unknown layer name fed: %s' % fed_layer)
            self.terminals.append(fed_layer)
        return self

    def get_output(self):
        '''Returns the current network output.'''
        return self.terminals[-1]

    def get_unique_name(self, prefix):
        '''Returns an index-suffixed unique name for the given prefix.
        This is used for auto-generating layer names based on the type-prefix.
        '''
        ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
        return '%s_%d' % (prefix, ident)

    def make_var(self, name, shape):
        '''Creates a new TensorFlow variable.'''
        return tf.get_variable(name, shape, trainable=self.trainable)

    def validate_padding(self, padding):
        '''Verifies that the padding is one of the supported ones.'''
        assert padding in ('SAME', 'VALID')

    @layer
    def conv(self,
             input,
             k_h,
             k_w,
             c_o,
             s_h,
             s_w,
             name,
             relu=True,
             padding=DEFAULT_PADDING,
             group=1,
             biased=True):
        # Verify that the padding is acceptable
        self.validate_padding(padding)
        # Get the number of channels in the input
        c_i = input.get_shape()[-1]
        # Verify that the grouping parameter is valid
        assert c_i % group == 0
        assert c_o % group == 0
        # Convolution for a given input and kernel
        convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
        with tf.variable_scope(name) as scope:
            kernel = self.make_var('weights', shape=[k_h, k_w, int(int(c_i) / group), c_o])
            if group == 1:
                # This is the common-case. Convolve the input without any further complications.
                output = convolve(input, kernel)
            else:
                # Split the input into groups and then convolve each of them independently
                input_groups = tf.split(3, group, input)
                kernel_groups = tf.split(3, group, kernel)
                output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
                # Concatenate the groups
                output = tf.concat(3, output_groups)
            # Add the biases
            if biased:
                biases = self.make_var('biases', [c_o])
                output = tf.nn.bias_add(output, biases)
            if relu:
                # ReLU non-linearity
                output = tf.nn.relu(output, name=scope.name)
            return output

    @layer
    def relu(self, input, name):
        return tf.nn.relu(input, name=name)

    @layer
    def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.max_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.avg_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def lrn(self, input, radius, alpha, beta, name, bias=1.0):
        return tf.nn.local_response_normalization(input,
                                                  depth_radius=radius,
                                                  alpha=alpha,
                                                  beta=beta,
                                                  bias=bias,
                                                  name=name)

    @layer
    def concat(self, inputs, axis, name):
        return tf.concat(values=inputs, axis=axis, name=name)

    @layer
    def add(self, inputs, name):
        return tf.add_n(inputs, name=name)

    @layer
    def fc(self, input, num_out, name, relu=True):
        with tf.variable_scope(name) as scope:
            input_shape = input.get_shape()
            if input_shape.ndims == 4:
                # The input is spatial. Vectorize it first.
                dim = 1
                for d in input_shape[1:].as_list():
                    dim *= d
                feed_in = tf.reshape(input, [-1, dim])
            else:
                feed_in, dim = (input, input_shape[-1].value)
            weights = self.make_var('weights', shape=[dim, num_out])
            biases = self.make_var('biases', [num_out])
            op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
            fc = op(feed_in, weights, biases, name=scope.name)
            return fc

    @layer
    def softmax(self, input, name):
        input_shape = map(lambda v: v.value, input.get_shape())
        if len(input_shape) > 2:
            # For certain models (like NiN), the singleton spatial dimensions
            # need to be explicitly squeezed, since they're not broadcast-able
            # in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
            if input_shape[1] == 1 and input_shape[2] == 1:
                input = tf.squeeze(input, squeeze_dims=[1, 2])
            else:
                raise ValueError('Rank 2 tensor input expected for softmax!')
        return tf.nn.softmax(input, name)

    @layer
    def batch_normalization(self, input, name, scale_offset=True, relu=False):
        # NOTE: Currently, only inference is supported
        with tf.variable_scope(name) as scope:
            shape = [input.get_shape()[-1]]
            if scale_offset:
                scale = self.make_var('scale', shape=shape)
                offset = self.make_var('offset', shape=shape)
            else:
                scale, offset = (None, None)
            output = tf.nn.batch_normalization(
                input,
                mean=self.make_var('mean', shape=shape),
                variance=self.make_var('variance', shape=shape),
                offset=offset,
                scale=scale,
                # TODO: This is the default Caffe batch norm eps
                # Get the actual eps from parameters
                variance_epsilon=1e-5,
                name=name)
            if relu:
                output = tf.nn.relu(output)
            return output

    @layer
    def dropout(self, input, keep_prob, name):
        keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
        return tf.nn.dropout(input, keep, name=name)

def centeredCrop(img, output_side_length):
    height, width, depth = img.shape
    new_height = output_side_length
    new_width = output_side_length
    if height > width:
        new_height = output_side_length * height / width
    else:
        new_width = output_side_length * width / height
    height_offset = (new_height - output_side_length) / 2
    width_offset = (new_width - output_side_length) / 2
    cropped_img = img[height_offset:height_offset + output_side_length,
                        width_offset:width_offset + output_side_length]
    return cropped_img
def preprocess(images):
    images_out = [] #final result
    #Resize and crop and compute mean!
    images_cropped = []
    for i in tqdm(range(len(images))):
        #print(images[i])
        X = cv2.imread(images[i])
        #X = cv2.resize(X, (455, 256))
        X = centeredCrop(X, 224)
        images_cropped.append(X)
    #compute images mean
    N = 0
    mean = np.zeros((1, 3, 224, 224))
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        #print(X.shape)#3,224,224
        #print(X[0,:,:].shape)#3,224
        #print(mean[0][0].shape)#224,224
        mean[0][0] += X[0,:,:]
        mean[0][1] += X[1,:,:]
        mean[0][2] += X[2,:,:]
        N += 1
    mean[0] /= N
    #Subtract mean from all images
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        X = X - mean
        X = np.squeeze(X)
        X = np.transpose(X, (1,2,0))
        images_out.append(X)
    return images_out

class GoogLeNet(Network):
    def setup(self):
        (self.feed('data')
             .conv(7, 7, 64, 2, 2, name='conv1')
             .max_pool(3, 3, 2, 2, name='pool1')
             .lrn(2, 2e-05, 0.75, name='norm1')
             .conv(1, 1, 64, 1, 1, name='reduction2')
             .conv(3, 3, 192, 1, 1, name='conv2')
             .lrn(2, 2e-05, 0.75, name='norm2')
             .max_pool(3, 3, 2, 2, name='pool2')
             .conv(1, 1, 96, 1, 1, name='icp1_reduction1')
             .conv(3, 3, 128, 1, 1, name='icp1_out1'))

        (self.feed('pool2')
             .conv(1, 1, 16, 1, 1, name='icp1_reduction2')
             .conv(5, 5, 32, 1, 1, name='icp1_out2'))

        (self.feed('pool2')
             .max_pool(3, 3, 1, 1, name='icp1_pool')
             .conv(1, 1, 32, 1, 1, name='icp1_out3'))

        (self.feed('pool2')
             .conv(1, 1, 64, 1, 1, name='icp1_out0'))

        (self.feed('icp1_out0', 
                   'icp1_out1', 
                   'icp1_out2', 
                   'icp1_out3')
             .concat(3, name='icp2_in')
             .conv(1, 1, 128, 1, 1, name='icp2_reduction1')
             .conv(3, 3, 192, 1, 1, name='icp2_out1'))

        (self.feed('icp2_in')
             .conv(1, 1, 32, 1, 1, name='icp2_reduction2')
             .conv(5, 5, 96, 1, 1, name='icp2_out2'))

        (self.feed('icp2_in')
             .max_pool(3, 3, 1, 1, name='icp2_pool')
             .conv(1, 1, 64, 1, 1, name='icp2_out3'))

        (self.feed('icp2_in')
             .conv(1, 1, 128, 1, 1, name='icp2_out0'))

        (self.feed('icp2_out0', 
                   'icp2_out1', 
                   'icp2_out2', 
                   'icp2_out3')
             .concat(3, name='icp2_out')
             .max_pool(3, 3, 2, 2, name='icp3_in')
             .conv(1, 1, 96, 1, 1, name='icp3_reduction1')
             .conv(3, 3, 208, 1, 1, name='icp3_out1'))

        (self.feed('icp3_in')
             .conv(1, 1, 16, 1, 1, name='icp3_reduction2')
             .conv(5, 5, 48, 1, 1, name='icp3_out2'))

        (self.feed('icp3_in')
             .max_pool(3, 3, 1, 1, name='icp3_pool')
             .conv(1, 1, 64, 1, 1, name='icp3_out3'))

        (self.feed('icp3_in')
             .conv(1, 1, 192, 1, 1, name='icp3_out0'))

        (self.feed('icp3_out0', 
                   'icp3_out1', 
                   'icp3_out2', 
                   'icp3_out3')
             .concat(3, name='icp3_out')
             .avg_pool(5, 5, 3, 3, padding='VALID', name='cls1_pool')
             .conv(1, 1, 128, 1, 1, name='cls1_reduction_pose')
             .fc(1024, name='cls1_fc1_pose')
             .fc(3, relu=False, name='cls1_fc_pose_xyz'))

        (self.feed('cls1_fc1_pose')
             .fc(4, relu=False, name='cls1_fc_pose_wpqr'))

        (self.feed('icp3_out')
             .conv(1, 1, 112, 1, 1, name='icp4_reduction1')
             .conv(3, 3, 224, 1, 1, name='icp4_out1'))

        (self.feed('icp3_out')
             .conv(1, 1, 24, 1, 1, name='icp4_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp4_out2'))

        (self.feed('icp3_out')
             .max_pool(3, 3, 1, 1, name='icp4_pool')
             .conv(1, 1, 64, 1, 1, name='icp4_out3'))

        (self.feed('icp3_out')
             .conv(1, 1, 160, 1, 1, name='icp4_out0'))

        (self.feed('icp4_out0', 
                   'icp4_out1', 
                   'icp4_out2', 
                   'icp4_out3')
             .concat(3, name='icp4_out')
             .conv(1, 1, 128, 1, 1, name='icp5_reduction1')
             .conv(3, 3, 256, 1, 1, name='icp5_out1'))

        (self.feed('icp4_out')
             .conv(1, 1, 24, 1, 1, name='icp5_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp5_out2'))

        (self.feed('icp4_out')
             .max_pool(3, 3, 1, 1, name='icp5_pool')
             .conv(1, 1, 64, 1, 1, name='icp5_out3'))

        (self.feed('icp4_out')
             .conv(1, 1, 128, 1, 1, name='icp5_out0'))

        (self.feed('icp5_out0', 
                   'icp5_out1', 
                   'icp5_out2', 
                   'icp5_out3')
             .concat(3, name='icp5_out')
             .conv(1, 1, 144, 1, 1, name='icp6_reduction1')
             .conv(3, 3, 288, 1, 1, name='icp6_out1'))

        (self.feed('icp5_out')
             .conv(1, 1, 32, 1, 1, name='icp6_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp6_out2'))

        (self.feed('icp5_out')
             .max_pool(3, 3, 1, 1, name='icp6_pool')
             .conv(1, 1, 64, 1, 1, name='icp6_out3'))

        (self.feed('icp5_out')
             .conv(1, 1, 112, 1, 1, name='icp6_out0'))

        (self.feed('icp6_out0', 
                   'icp6_out1', 
                   'icp6_out2', 
                   'icp6_out3')
             .concat(3, name='icp6_out')
             .avg_pool(5, 5, 3, 3, padding='VALID', name='cls2_pool')
             .conv(1, 1, 128, 1, 1, name='cls2_reduction_pose')
             .fc(1024, name='cls2_fc1')
             .fc(3, relu=False, name='cls2_fc_pose_xyz'))

        (self.feed('cls2_fc1')
             .fc(4, relu=False, name='cls2_fc_pose_wpqr'))

        (self.feed('icp6_out')
             .conv(1, 1, 160, 1, 1, name='icp7_reduction1')
             .conv(3, 3, 320, 1, 1, name='icp7_out1'))

        (self.feed('icp6_out')
             .conv(1, 1, 32, 1, 1, name='icp7_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp7_out2'))

        (self.feed('icp6_out')
             .max_pool(3, 3, 1, 1, name='icp7_pool')
             .conv(1, 1, 128, 1, 1, name='icp7_out3'))

        (self.feed('icp6_out')
             .conv(1, 1, 256, 1, 1, name='icp7_out0'))

        (self.feed('icp7_out0', 
                   'icp7_out1', 
                   'icp7_out2', 
                   'icp7_out3')
             .concat(3, name='icp7_out')
             .max_pool(3, 3, 2, 2, name='icp8_in')
             .conv(1, 1, 160, 1, 1, name='icp8_reduction1')
             .conv(3, 3, 320, 1, 1, name='icp8_out1'))

        (self.feed('icp8_in')
             .conv(1, 1, 32, 1, 1, name='icp8_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp8_out2'))

        (self.feed('icp8_in')
             .max_pool(3, 3, 1, 1, name='icp8_pool')
             .conv(1, 1, 128, 1, 1, name='icp8_out3'))

        (self.feed('icp8_in')
             .conv(1, 1, 256, 1, 1, name='icp8_out0'))

        (self.feed('icp8_out0', 
                   'icp8_out1', 
                   'icp8_out2', 
                   'icp8_out3')
             .concat(3, name='icp8_out')
             .conv(1, 1, 192, 1, 1, name='icp9_reduction1')
             .conv(3, 3, 384, 1, 1, name='icp9_out1'))

        (self.feed('icp8_out')
             .conv(1, 1, 48, 1, 1, name='icp9_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp9_out2'))

        (self.feed('icp8_out')
             .max_pool(3, 3, 1, 1, name='icp9_pool')
             .conv(1, 1, 128, 1, 1, name='icp9_out3'))

        (self.feed('icp8_out')
             .conv(1, 1, 384, 1, 1, name='icp9_out0'))

        (self.feed('icp9_out0', 
                   'icp9_out1', 
                   'icp9_out2', 
                   'icp9_out3')
             .concat(3, name='icp9_out')
             .avg_pool(7, 7, 1, 1, padding='VALID', name='cls3_pool')
             .fc(2048, name='cls3_fc1_pose')
             .fc(3, relu=False, name='cls3_fc_pose_xyz'))

        (self.feed('cls3_fc1_pose')
             .fc(4, relu=False, name='cls3_fc_pose_wpqr'))

batch_size = 75
max_iterations = 30000
# Set this path to your dataset directory
directory = 'path_to_datasets/KingsCollege/'
dataset = 'dataset_train.txt'
my_train_labels = '/kaggle/working/my_train_labels.csv'

class datasource(object):
    def __init__(self, images, poses):
        self.images = images
        self.poses = poses

def centeredCrop(img, output_side_length):
    height, width, depth = img.shape
    new_height = output_side_length
    new_width = output_side_length
    if height > width:
        new_height = output_side_length * height / width
    else:
        new_width = output_side_length * width / height
    height_offset = (new_height - output_side_length) / 2
    width_offset = (new_width - output_side_length) / 2
    cropped_img = img[int(height_offset):int(height_offset + output_side_length),
                        int(width_offset):int(width_offset + output_side_length)]
    return cropped_img

def gen_data(source):
    while True:
        indices = list(range(len(source.images)))
        random.shuffle(indices)
        for i in indices:
            image = source.images[i]
            pose_x = source.poses[i][0:3]
            pose_q = source.poses[i][3:7]
            yield image, pose_x, pose_q

def gen_data_batch(source):
    data_gen = gen_data(source)
    while True:
        image_batch = []
        pose_x_batch = []
        pose_q_batch = []
        for _ in range(batch_size):
            image, pose_x, pose_q = next(data_gen)
            image_batch.append(image)
            pose_x_batch.append(pose_x)
            pose_q_batch.append(pose_q)
        yield np.array(image_batch), np.array(pose_x_batch), np.array(pose_q_batch)

def get_data():
    poses = []
    images = []
    for i in pd.read_csv(my_train_labels).itertuples():
        #i[4]:image_path
        #i[6]:xyz
        #i[7]:四元数
        #print(i[4])
        #print(i[7])
        p0,p1,p2 = i[6].split(';')
        p3,p4,p5,p6 = i[7].split('(')[1].split(')')[0].split(',')
        p0 = float(p0)
        p1 = float(p1)
        p2 = float(p2)
        p3 = float(p3)
        p4 = float(p4)
        p5 = float(p5)
        p6 = float(p6)
        poses.append((p0,p1,p2,p3,p4,p5,p6))
        images.append('/kaggle/input/image-matching-challenge-2023/train/' + i[4])
        #print(poses,images)
    images = preprocess(images)
    return datasource(images, poses)


images = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])
poses_x = tf.placeholder(tf.float32, [batch_size, 3])
poses_q = tf.placeholder(tf.float32, [batch_size, 4])
datasource = get_data()

net = GoogLeNet({'data': images})

p1_x = net.layers['cls1_fc_pose_xyz']
p1_q = net.layers['cls1_fc_pose_wpqr']
p2_x = net.layers['cls2_fc_pose_xyz']
p2_q = net.layers['cls2_fc_pose_wpqr']
p3_x = net.layers['cls3_fc_pose_xyz']
p3_q = net.layers['cls3_fc_pose_wpqr']

l1_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p1_x, poses_x)))) * 0.3
l1_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p1_q, poses_q)))) * 150
l2_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p2_x, poses_x)))) * 0.3
l2_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p2_q, poses_q)))) * 150
l3_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p3_x, poses_x)))) * 1
l3_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p3_q, poses_q)))) * 500

loss = l1_x + l1_q + l2_x + l2_q + l3_x + l3_q
opt = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=0.00000001, use_locking=False, name='Adam').minimize(loss)

# Set GPU options
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6833)

init = tf.global_variables_initializer()
saver = tf.train.Saver()
outputFile = "PoseNet.ckpt"


with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    # Load the data
    sess.run(init)
    net.load('/kaggle/input/tensorflow-posenet-master/tensorflow-posenet-master/posenet.npy', sess)

    data_gen = gen_data_batch(datasource)
    for i in range(max_iterations):
        np_images, np_poses_x, np_poses_q = next(data_gen)
        feed = {images: np_images, poses_x: np_poses_x, poses_q: np_poses_q}

        sess.run(opt, feed_dict=feed)
        np_loss = sess.run(loss, feed_dict=feed)
        if i % 20 == 0:
            print("iteration: " + str(i) + "\n\t" + "Loss is: " + str(np_loss))
        if i % 500 == 0:
            saver.save(sess, outputFile)
            print("Intermediate file saved at: " + outputFile)
    saver.save(sess, outputFile)
    print("Intermediate file saved at: " + outputFile)

9.环境版本

  • !pip list
Package                                Version              Editable project location
-------------------------------------- -------------------- -------------------------
absl-py                                1.4.0
accelerate                             0.12.0
access                                 1.1.9
affine                                 2.4.0
aiobotocore                            2.5.0
aiofiles                               22.1.0
aiohttp                                3.8.4
aiohttp-cors                           0.7.0
aioitertools                           0.11.0
aiorwlock                              1.3.0
aiosignal                              1.3.1
aiosqlite                              0.19.0
albumentations                         1.3.0
alembic                                1.11.1
altair                                 5.0.0
annoy                                  1.17.2
ansiwrap                               0.8.4
anyio                                  3.6.2
apache-beam                            2.46.0
aplus                                  0.11.0
appdirs                                1.4.4
argon2-cffi                            21.3.0
argon2-cffi-bindings                   21.2.0
array-record                           0.2.0
arrow                                  1.2.3
arviz                                  0.12.1
astroid                                2.15.5
astropy                                5.3
asttokens                              2.2.1
astunparse                             1.6.3
async-timeout                          4.0.2
atpublic                               3.1.1
attrs                                  23.1.0
audioread                              3.0.0
autopep8                               2.0.2
Babel                                  2.12.1
backcall                               0.2.0
backoff                                2.2.1
backports.functools-lru-cache          1.6.4
bayesian-optimization                  1.4.3
bayespy                                0.5.25
beatrix-jupyterlab                     2023.58.190319
beautifulsoup4                         4.12.2
bidict                                 0.22.1
biopython                              1.81
blake3                                 0.2.1
bleach                                 6.0.0
blessed                                1.20.0
blinker                                1.6.2
blis                                   0.7.9
blosc2                                 2.0.0
bokeh                                  2.4.3
boltons                                23.0.0
Boruta                                 0.3
boto3                                  1.26.100
botocore                               1.29.76
bq-helper                              0.4.1                /src/bq-helper
bqplot                                 0.12.39
branca                                 0.6.0
brewer2mpl                             1.4.1
brotlipy                               0.7.0
cached-property                        1.5.2
cachetools                             4.2.4
Cartopy                                0.21.1
catalogue                              2.0.8
catalyst                               22.4
catboost                               1.2
category-encoders                      2.6.1
certifi                                2023.5.7
cesium                                 0.12.1
cffi                                   1.15.1
cftime                                 1.6.2
charset-normalizer                     2.1.1
chex                                   0.1.7
cleverhans                             4.0.0
click                                  8.1.3
click-plugins                          1.1.1
cligj                                  0.7.2
cloud-tpu-client                       0.10
cloud-tpu-profiler                     2.4.0
cloudpickle                            2.2.1
cmaes                                  0.9.1
cmdstanpy                              1.1.0
cmudict                                1.0.13
colorama                               0.4.6
colorcet                               3.0.1
colorful                               0.5.5
colorlog                               6.7.0
colorlover                             0.3.0
comm                                   0.1.3
commonmark                             0.9.1
conda                                  23.3.1
conda-content-trust                    0+unknown
conda-package-handling                 2.0.2
conda_package_streaming                0.7.0
confection                             0.0.4
contextily                             1.3.0
contourpy                              1.0.7
convertdate                            2.4.0
crcmod                                 1.7
cryptography                           40.0.2
cubinlinker                            0.2.2
cuda-python                            11.8.1
cudf                                   23.4.1
cufflinks                              0.17.3
cuml                                   23.4.1
cupy                                   11.6.0
CVXcanon                               0.1.2
cycler                                 0.11.0
cymem                                  2.0.7
cysignals                              1.11.2
Cython                                 0.29.34
cytoolz                                0.12.0
daal                                   2023.1.1
daal4py                                2023.1.1
dask                                   2023.5.0
dask-cuda                              23.4.0
dask-cudf                              23.4.1
dataclasses                            0.8
dataclasses-json                       0.5.7
datasets                               2.1.0
datashader                             0.14.4
datashape                              0.5.2
datatile                               1.0.3
db-dtypes                              1.1.1
deap                                   1.3.3
debugpy                                1.6.7
decorator                              5.1.1
defusedxml                             0.7.1
Delorean                               1.0.0
deprecat                               2.1.1
Deprecated                             1.2.13
deprecation                            2.1.0
descartes                              1.1.0
dill                                   0.3.6
dipy                                   1.7.0
distlib                                0.3.6
distributed                            2023.3.2.1
dm-tree                                0.1.8
docker                                 6.1.1
docker-pycreds                         0.4.0
docopt                                 0.6.2
docstring-parser                       0.15
docstring-to-markdown                  0.12
docutils                               0.20.1
earthengine-api                        0.1.354
easydict                               1.10
easyocr                                1.6.2
ecos                                   2.0.12
eli5                                   0.13.0
emoji                                  2.2.0
en-core-web-lg                         3.5.0
en-core-web-sm                         3.5.0
entrypoints                            0.4
ephem                                  4.1.4
esda                                   2.4.3
essentia                               2.1b6.dev1034
et-xmlfile                             1.1.0
etils                                  1.2.0
executing                              1.2.0
explainable-ai-sdk                     1.3.3
fastai                                 2.7.12
fastapi                                0.95.1
fastavro                               1.7.4
fastcore                               1.5.29
fastdownload                           0.0.7
fasteners                              0.18
fastjsonschema                         2.16.3
fastprogress                           1.0.3
fastrlock                              0.8
fasttext                               0.9.2
fbpca                                  1.0
feather-format                         0.4.1
featuretools                           1.26.0
filelock                               3.12.0
Fiona                                  1.8.22
fire                                   0.5.0
fitter                                 1.5.2
flake8                                 6.0.0
flashtext                              2.7
Flask                                  2.3.2
flatbuffers                            23.3.3
flax                                   0.6.10
flit_core                              3.8.0
folium                                 0.14.0
fonttools                              4.39.3
fqdn                                   1.5.1
frozendict                             2.3.8
frozenlist                             1.3.3
fsspec                                 2023.5.0
funcy                                  2.0
fury                                   0.9.0
future                                 0.18.3
fuzzywuzzy                             0.18.0
gast                                   0.4.0
gatspy                                 0.3
gcsfs                                  2023.5.0
gensim                                 4.3.1
geographiclib                          2.0
Geohash                                1.0
geojson                                3.0.1
geopandas                              0.13.0
geoplot                                0.5.1
geopy                                  2.3.0
geoviews                               1.9.6
ggplot                                 0.11.5
giddy                                  2.3.4
gitdb                                  4.0.10
GitPython                              3.1.31
google-api-core                        1.33.2
google-api-python-client               2.86.0
google-apitools                        0.5.31
google-auth                            2.17.3
google-auth-httplib2                   0.1.0
google-auth-oauthlib                   1.0.0
google-cloud-aiplatform                0.6.0a1
google-cloud-artifact-registry         1.8.1
google-cloud-automl                    1.0.1
google-cloud-bigquery                  2.34.4
google-cloud-bigtable                  1.7.3
google-cloud-core                      2.3.2
google-cloud-datastore                 2.15.2
google-cloud-dlp                       3.12.1
google-cloud-language                  2.6.1
google-cloud-monitoring                2.14.2
google-cloud-pubsub                    2.16.1
google-cloud-pubsublite                1.8.1
google-cloud-recommendations-ai        0.7.1
google-cloud-resource-manager          1.10.0
google-cloud-spanner                   3.33.0
google-cloud-storage                   1.44.0
google-cloud-translate                 3.8.4
google-cloud-videointelligence         2.8.3
google-cloud-vision                    2.8.0
google-crc32c                          1.5.0
google-pasta                           0.2.0
google-resumable-media                 2.5.0
googleapis-common-protos               1.57.1
gplearn                                0.4.2
gpustat                                1.0.0
gpxpy                                  1.5.0
graphviz                               0.20.1
greenlet                               2.0.2
grpc-google-iam-v1                     0.12.6
grpcio                                 1.51.1
grpcio-status                          1.48.1
gviz-api                               1.10.0
gym                                    0.26.2
gym-notices                            0.0.8
Gymnasium                              0.26.3
gymnasium-notices                      0.0.1
h11                                    0.14.0
h2o                                    3.40.0.4
h5py                                   3.8.0
haversine                              2.8.0
hdfs                                   2.7.0
hep-ml                                 0.7.2
hijri-converter                        2.3.1
hmmlearn                               0.3.0
holidays                               0.24
holoviews                              1.16.0
hpsklearn                              0.1.0
html5lib                               1.1
htmlmin                                0.1.12
httplib2                               0.21.0
httptools                              0.5.0
huggingface-hub                        0.14.1
humanize                               4.6.0
hunspell                               0.5.5
husl                                   4.0.3
hydra-slayer                           0.4.1
hyperopt                               0.2.7
hypertools                             0.8.0
ibis-framework                         5.1.0
idna                                   3.4
igraph                                 0.10.4
imagecodecs                            2023.3.16
ImageHash                              4.3.1
imageio                                2.28.1
imbalanced-learn                       0.10.1
imgaug                                 0.4.0
implicit                               0.5.2
importlib-metadata                     5.2.0
importlib-resources                    5.12.0
inequality                             1.0.0
ipydatawidgets                         4.3.3
ipykernel                              6.23.0
ipyleaflet                             0.17.2
ipympl                                 0.7.0
ipython                                8.13.2
ipython-genutils                       0.2.0
ipython-sql                            0.5.0
ipyvolume                              0.6.1
ipyvue                                 1.9.0
ipyvuetify                             1.8.10
ipywebrtc                              0.6.0
ipywidgets                             7.7.1
isoduration                            20.11.0
isort                                  5.12.0
isoweek                                1.3.3
itsdangerous                           2.1.2
Janome                                 0.4.2
jaraco.classes                         3.2.3
jax                                    0.4.10
jaxlib                                 0.4.7+cuda11.cudnn86
jedi                                   0.18.2
jeepney                                0.8.0
jieba                                  0.42.1
Jinja2                                 3.1.2
jmespath                               1.0.1
joblib                                 1.2.0
json5                                  0.9.11
jsonpatch                              1.32
jsonpointer                            2.0
jsonschema                             4.17.3
jupyter_client                         7.4.9
jupyter-console                        6.6.3
jupyter_core                           5.3.0
jupyter-events                         0.6.3
jupyter-http-over-ws                   0.0.8
jupyter-lsp                            1.5.1
jupyter_server                         2.5.0
jupyter_server_fileid                  0.9.0
jupyter-server-mathjax                 0.2.6
jupyter_server_proxy                   4.0.0
jupyter_server_terminals               0.4.4
jupyter_server_ydoc                    0.8.0
jupyter-ydoc                           0.2.4
jupyterlab                             3.6.3
jupyterlab-git                         0.41.0
jupyterlab-lsp                         4.1.0
jupyterlab-pygments                    0.2.2
jupyterlab_server                      2.22.1
jupyterlab-widgets                     3.0.7
jupytext                               1.14.5
kaggle                                 1.5.13
kaggle-environments                    1.12.0
keras                                  2.12.0
keras-tuner                            1.3.5
keyring                                23.13.1
keyrings.google-artifactregistry-auth  1.1.2
kfp                                    1.8.21
kfp-pipeline-spec                      0.1.16
kfp-server-api                         1.8.5
kiwisolver                             1.4.4
kmapper                                2.0.1
kmodes                                 0.12.2
korean-lunar-calendar                  0.3.1
kornia                                 0.6.12
kt-legacy                              1.0.5
kubernetes                             25.3.0
langcodes                              3.3.0
langid                                 1.1.6
lazy_loader                            0.2
lazy-object-proxy                      1.9.0
learntools                             0.3.4
leven                                  1.0.4
Levenshtein                            0.21.0
libclang                               16.0.0
libmambapy                             1.4.2
libpysal                               4.7.0
librosa                                0.10.0.post2
lightgbm                               3.3.2
lightning-utilities                    0.8.0
lime                                   0.2.0.1
line-profiler                          4.0.3
llvmlite                               0.39.1
lml                                    0.1.0
locket                                 1.0.0
LunarCalendar                          0.0.9
lxml                                   4.9.2
lz4                                    4.3.2
Mako                                   1.2.4
mamba                                  1.4.2
mapclassify                            2.5.0
marisa-trie                            0.8.0
Markdown                               3.4.3
markdown-it-py                         2.2.0
markovify                              0.9.4
MarkupSafe                             2.1.2
marshmallow                            3.19.0
marshmallow-enum                       1.5.1
matplotlib                             3.6.3
matplotlib-inline                      0.1.6
matplotlib-venn                        0.11.9
mccabe                                 0.7.0
mdit-py-plugins                        0.3.5
mdurl                                  0.1.2
memory-profiler                        0.61.0
mercantile                             1.2.1
mgwr                                   2.1.2
missingno                              0.5.2
mistune                                0.8.4
mizani                                 0.9.1
ml-dtypes                              0.1.0
mlcrate                                0.2.0
mlens                                  0.2.3
mlxtend                                0.22.0
mmh3                                   4.0.0
mne                                    1.4.0
mnist                                  0.2.2
mock                                   5.0.2
momepy                                 0.6.0
more-itertools                         9.1.0
mpld3                                  0.5.9
mpmath                                 1.3.0
msgpack                                1.0.5
msgpack-numpy                          0.4.8
multidict                              6.0.4
multimethod                            1.9.1
multipledispatch                       0.6.0
multiprocess                           0.70.14
munch                                  3.0.0
munkres                                1.1.4
murmurhash                             1.0.9
mypy-extensions                        1.0.0
nb-conda                               2.2.1
nb-conda-kernels                       2.3.1
nbclassic                              1.0.0
nbclient                               0.5.13
nbconvert                              6.4.5
nbdime                                 3.2.0
nbformat                               5.8.0
nest-asyncio                           1.5.6
netCDF4                                1.6.3
networkx                               3.1
nibabel                                5.1.0
nilearn                                0.10.1
ninja                                  1.11.1
nltk                                   3.2.4
nose                                   1.3.7
notebook                               6.5.4
notebook-executor                      0.2
notebook_shim                          0.2.3
numba                                  0.56.4
numexpr                                2.8.4
numpy                                  1.23.5
nvidia-ml-py                           11.495.46
nvtx                                   0.2.5
oauth2client                           4.1.3
oauthlib                               3.2.2
objsize                                0.6.1
odfpy                                  1.4.1
olefile                                0.46
onnx                                   1.14.0
opencensus                             0.11.2
opencensus-context                     0.1.3
opencv-contrib-python                  4.5.4.60
opencv-python                          4.5.4.60
opencv-python-headless                 4.5.4.60
openpyxl                               3.1.2
openslide-python                       1.2.0
opentelemetry-api                      1.17.0
opentelemetry-exporter-otlp            1.17.0
opentelemetry-exporter-otlp-proto-grpc 1.17.0
opentelemetry-exporter-otlp-proto-http 1.17.0
opentelemetry-proto                    1.17.0
opentelemetry-sdk                      1.17.0
opentelemetry-semantic-conventions     0.38b0
opt-einsum                             3.3.0
optax                                  0.1.5
optuna                                 3.1.1
orbax-checkpoint                       0.2.3
orderedmultidict                       1.0.1
orjson                                 3.8.12
ortools                                9.4.1874
osmnx                                  1.1.1
overrides                              6.5.0
packaging                              21.3
pandas                                 1.5.3
pandas-datareader                      0.10.0
pandas-profiling                       3.6.6
pandas-summary                         0.2.0
pandasql                               0.7.3
pandocfilters                          1.5.0
panel                                  0.14.4
papermill                              2.4.0
param                                  1.13.0
parso                                  0.8.3
parsy                                  2.1
partd                                  1.4.0
path                                   16.6.0
path.py                                12.5.0
pathos                                 0.3.0
pathtools                              0.1.2
pathy                                  0.10.1
patsy                                  0.5.3
pdf2image                              1.16.3
pexpect                                4.8.0
phik                                   0.12.3
pickleshare                            0.7.5
Pillow                                 9.5.0
pip                                    23.1.2
pkgutil_resolve_name                   1.3.10
platformdirs                           3.5.0
plotly                                 5.14.1
plotly-express                         0.4.1
plotnine                               0.10.1
pluggy                                 1.0.0
pointpats                              2.3.0
polars                                 0.17.15
polyglot                               16.7.4
pooch                                  1.6.0
pox                                    0.3.2
ppca                                   0.0.4
ppft                                   1.7.6.6
preprocessing                          0.1.13
preshed                                3.0.8
prettytable                            3.7.0
progressbar2                           4.2.0
prometheus-client                      0.16.0
promise                                2.3
prompt-toolkit                         3.0.38
pronouncing                            0.2.0
prophet                                1.1.1
proto-plus                             1.22.2
protobuf                               3.20.3
psutil                                 5.9.3
ptxcompiler                            0.8.1
ptyprocess                             0.7.0
pudb                                   2022.1.3
PuLP                                   2.7.0
pure-eval                              0.2.2
py-cpuinfo                             9.0.0
py-lz4framed                           0.14.0
py-spy                                 0.3.14
py4j                                   0.10.9.7
pyaml                                  23.5.9
PyArabic                               0.6.15
pyarrow                                10.0.1
pyasn1                                 0.4.8
pyasn1-modules                         0.2.7
PyAstronomy                            0.19.0
pybind11                               2.10.4
pyclipper                              1.3.0.post4
pycodestyle                            2.10.0
pycolmap                               0.4.0
pycosat                                0.6.4
pycparser                              2.21
pycryptodome                           3.18.0
pyct                                   0.5.0
pycuda                                 2022.2.2
pydantic                               1.10.7
pydegensac                             0.1.2
pydicom                                2.3.1
pydocstyle                             6.3.0
pydot                                  1.4.2
pydub                                  0.25.1
pyemd                                  1.0.0
pyerfa                                 2.0.0.3
pyexcel-io                             0.6.6
pyexcel-ods                            0.6.0
pyfasttext                             0.4.6
pyflakes                               3.0.1
pygltflib                              1.15.6
Pygments                               2.15.1
PyJWT                                  2.6.0
pykalman                               0.9.5
pyLDAvis                               3.2.2
pylibraft                              23.4.1
pylint                                 2.17.4
pymc3                                  3.11.5
PyMeeus                                0.5.12
pymongo                                3.13.0
Pympler                                1.0.1
pynndescent                            0.5.10
pynvml                                 11.4.1
pynvrtc                                9.2
pyocr                                  0.8.3
pyOpenSSL                              23.1.1
pyparsing                              3.0.9
pypdf                                  3.9.0
pyproj                                 3.5.0
pyrsistent                             0.19.3
pysal                                  23.1
pyshp                                  2.3.1
PySocks                                1.7.1
pytesseract                            0.3.10
python-bidi                            0.4.2
python-dateutil                        2.8.2
python-dotenv                          1.0.0
python-igraph                          0.10.4
python-json-logger                     2.0.7
python-Levenshtein                     0.21.0
python-louvain                         0.16
python-lsp-jsonrpc                     1.0.0
python-lsp-server                      1.7.3
python-slugify                         8.0.1
python-utils                           3.5.2
pythreejs                              2.4.2
pytoolconfig                           1.2.5
pytools                                2022.1.14
pytorch-ignite                         0.4.12
pytorch-lightning                      2.0.2
pytz                                   2023.3
pyu2f                                  0.1.5
PyUpSet                                0.1.1.post7
pyviz-comms                            2.2.1
PyWavelets                             1.4.1
PyYAML                                 5.4.1
pyzmq                                  25.0.2
qgrid                                  1.3.1
qtconsole                              5.4.3
QtPy                                   2.3.1
quantecon                              0.7.0
quantities                             0.14.1
qudida                                 0.0.4
raft-dask                              23.4.1
randomgen                              1.23.1
rapidfuzz                              3.0.0
rasterio                               1.3.7
rasterstats                            0.18.0
ray                                    2.4.0
ray-cpp                                2.4.0
regex                                  2023.5.5
requests                               2.28.2
requests-oauthlib                      1.3.1
requests-toolbelt                      0.10.1
responses                              0.18.0
retrying                               1.3.3
rfc3339-validator                      0.1.4
rfc3986-validator                      0.1.1
rgf-python                             3.12.0
rich                                   12.6.0
rmm                                    23.4.1
rope                                   1.8.0
rsa                                    4.9
Rtree                                  1.0.1
ruamel.yaml                            0.17.24
ruamel.yaml.clib                       0.2.7
ruamel-yaml-conda                      0.15.100
s2sphere                               0.2.5
s3fs                                   2023.5.0
s3transfer                             0.6.1
safetensors                            0.3.1
scattertext                            0.1.19
scikit-image                           0.20.0
scikit-learn                           1.2.2
scikit-learn-intelex                   2023.1.1
scikit-multilearn                      0.2.0
scikit-optimize                        0.9.0
scikit-plot                            0.3.7
scikit-surprise                        1.1.3
scipy                                  1.10.1
seaborn                                0.12.2
SecretStorage                          3.3.3
segment-anything                       1.0
segregation                            2.4.2
semver                                 3.0.0
Send2Trash                             1.8.2
sentencepiece                          0.1.99
sentry-sdk                             1.24.0
setproctitle                           1.3.2
setuptools                             59.8.0
setuptools-git                         1.2
setuptools-scm                         7.1.0
shap                                   0.41.0
Shapely                                1.8.5.post1
shellingham                            1.5.1
simpervisor                            0.4
SimpleITK                              2.2.1
simplejson                             3.19.1
six                                    1.16.0
sklearn-pandas                         2.2.0
slicer                                 0.0.7
smart-open                             6.3.0
smhasher                               0.150.1
smmap                                  5.0.0
sniffio                                1.3.0
snowballstemmer                        2.2.0
snuggs                                 1.4.7
sortedcontainers                       2.4.0
soundfile                              0.12.1
soupsieve                              2.3.2.post1
soxr                                   0.3.5
spacy                                  3.5.3
spacy-legacy                           3.0.12
spacy-loggers                          1.0.4
spaghetti                              1.7.2
spectral                               0.23.1
spglm                                  1.0.8
sphinx-rtd-theme                       0.2.4
spint                                  1.0.7
splot                                  1.1.5.post1
spopt                                  0.5.0
spreg                                  1.3.2
spvcm                                  0.3.0
SQLAlchemy                             2.0.12
sqlglot                                11.7.1
sqlparse                               0.4.4
squarify                               0.4.3
srsly                                  2.4.6
stack-data                             0.6.2
starlette                              0.26.1
statsmodels                            0.13.5
stemming                               1.0.1
stop-words                             2018.7.23
stopit                                 1.1.2
strip-hints                            0.1.10
stumpy                                 1.11.1
sympy                                  1.12
tables                                 3.8.0
tabulate                               0.9.0
tangled-up-in-unicode                  0.2.0
tbb                                    2021.9.0
tblib                                  1.7.0
tenacity                               8.2.2
tensorboard                            2.12.3
tensorboard-data-server                0.7.0
tensorboard-plugin-profile             2.11.2
tensorboardX                           2.6
tensorflow                             2.12.0
tensorflow-addons                      0.20.0
tensorflow-cloud                       0.1.16
tensorflow-datasets                    4.9.2
tensorflow-decision-forests            1.3.0
tensorflow-estimator                   2.12.0
tensorflow-gcs-config                  2.12.0
tensorflow-hub                         0.12.0
tensorflow-io                          0.31.0
tensorflow-io-gcs-filesystem           0.31.0
tensorflow-metadata                    0.14.0
tensorflow-probability                 0.20.0
tensorflow-serving-api                 2.12.1
tensorflow-text                        2.12.1
tensorflow-transform                   0.14.0
tensorflowjs                           3.15.0
tensorpack                             0.11
tensorstore                            0.1.36
termcolor                              2.3.0
terminado                              0.17.1
testpath                               0.6.0
text-unidecode                         1.3
textblob                               0.17.1
texttable                              1.6.7
textwrap3                              0.9.2
Theano                                 1.0.5
Theano-PyMC                            1.1.2
thinc                                  8.1.10
threadpoolctl                          3.1.0
tifffile                               2023.4.12
timm                                   0.9.2
tinycss2                               1.2.1
tobler                                 0.10
tokenizers                             0.13.3
toml                                   0.10.2
tomli                                  2.0.1
tomlkit                                0.11.8
toolz                                  0.12.0
torch                                  2.0.0
torchaudio                             2.0.1
torchdata                              0.6.0
torchinfo                              1.8.0
torchmetrics                           0.11.4
torchtext                              0.15.1
torchvision                            0.15.1
tornado                                6.3.1
TPOT                                   0.11.7
tqdm                                   4.64.1
traceml                                1.0.8
traitlets                              5.9.0
traittypes                             0.2.1
transformers                           4.29.2
treelite                               3.2.0
treelite-runtime                       3.2.0
trueskill                              0.4.5
tsfresh                                0.20.0
typeguard                              2.13.3
typer                                  0.7.0
typing_extensions                      4.5.0
typing-inspect                         0.8.0
tzlocal                                5.0.1
ucx-py                                 0.31.0
ujson                                  5.7.0
umap-learn                             0.5.3
unicodedata2                           15.0.0
Unidecode                              1.3.6
update-checker                         0.18.0
uri-template                           1.2.0
uritemplate                            3.0.1
urllib3                                1.26.15
urwid                                  2.1.2
urwid-readline                         0.13
uvicorn                                0.22.0
uvloop                                 0.17.0
vaex                                   4.16.0
vaex-astro                             0.9.3
vaex-core                              4.16.1
vaex-hdf5                              0.14.1
vaex-jupyter                           0.8.1
vaex-ml                                0.18.1
vaex-server                            0.8.1
vaex-viz                               0.5.4
vecstack                               0.4.0
virtualenv                             20.21.0
visions                                0.7.5
vowpalwabbit                           9.8.0
vtk                                    9.2.6
Wand                                   0.6.11
wandb                                  0.15.3
wasabi                                 1.1.1
watchfiles                             0.19.0
wavio                                  0.0.7
wcwidth                                0.2.6
webcolors                              1.13
webencodings                           0.5.1
websocket-client                       1.5.1
websockets                             11.0.3
Werkzeug                               2.3.4
wfdb                                   4.1.1
whatthepatch                           1.0.5
wheel                                  0.40.0
widgetsnbextension                     3.6.4
witwidget                              1.8.1
woodwork                               0.23.0
Wordbatch                              1.4.9
wordcloud                              1.9.2
wordsegment                            1.3.1
wrapt                                  1.14.1
wurlitzer                              3.0.3
xarray                                 2023.5.0
xarray-einstats                        0.5.1
xgboost                                1.7.5
xvfbwrapper                            0.2.9
xxhash                                 3.2.0
xyzservices                            2023.5.0
y-py                                   0.5.9
yapf                                   0.33.0
yarl                                   1.9.1
ydata-profiling                        4.1.2
yellowbrick                            1.5
ypy-websocket                          0.8.2
zict                                   3.0.0
zipp                                   3.15.0
zstandard                              0.19.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/639679.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

快递单号一键批量查询的具体操作方法和步骤

最近做电商的朋友对一个话题很感兴趣:如何批量查询快递单号?今天,小编给你安利一款软件:固乔快递查询助手,支持大量快递单号的批量查询。下面我们来看看批量查询的具体操作方法。 小伙伴们需要先在“固乔科技”的官网上…

session与cookie的来源与区别

目录 1.什么是HTTP? 2.cookie 3.session 4.cookie和session的区别 如果你对于session 和cookie 只有一点模糊理解,那么此文章能帮你更深入理解session和cookie ,依旧和上篇文章一样,我们采用问题的方式来一步步探索&#xff0…

第七章 测试

文章目录 第七章 测试7.1 编码7.1.1 选择程序设计语言1. 计算机程序设计语言基本上可以分为汇编语言和高级语言2. 从应用特点看,高级语言可分为基础语言、结构化语言、专用语言 7.1.2 编码风格 7.2 软件测试基础7.2.1 软件测试的目标7.2.2 软件测试准则7.2.3 测试方…

JVM基础面试题及原理讲解

基本问题 介绍下 Java 内存区域(运行时数据区)Java 对象的创建过程(五步,建议能默写出来并且要知道每一步虚拟机做了什么)对象的访问定位的两种方式(句柄和直接指针两种方式) 拓展问题 Strin…

Flutter Widget 生命周期 key探究

Widget 在Flutter中,一切皆是Widget(组件),Widget的功能是“描述一个UI元素的配置数据”,它就是说,Widget其实并不是表示最终绘制在设备屏幕上的显示元素,它只是描述显示元素的一个配置数据。 …

分库分表 21 招

(一)好好的系统,为什么要分库分表? 咱们先介绍下在分库分表架构实施过程中,会接触到的一些通用概念,了解这些概念能够帮助理解市面上其他的分库分表工具,尽管它们的实现方法可能存在差异&#…

自动化测试框架seldom

创建项目 | seldom文档 这个框架还是不错的,一直在优化,测试框架里的功能这里都有了。 seldom继承unittest单元测试框架,可以用来做UI和接口自动化项目。 安装 pip install seldom 创建项目 > seldom -P mypro 创建测试用例 # tes…

第8章 维护

文章目录 第8章 维护一、软件交付使用的工作二、软件交付使用的方式1) 直接方式2) 并行方式3) 逐步方式 8.1 软件维护的定义1、软件维护的定义2、软件维护的原因3、软件维护的类型1、改正性维护2、适应性维护3、完善性维护4、预防性维护 8.2 软件维护的特点8.2.1结构化维护和非…

12.异常-Exception|Java学习笔记

文章目录 异常介绍异常体系图一览运行时异常编译异常异常处理异常处理的方式try-catch 异常处理throws 异常处理注意事项和使用细节 自定义异常自定义异常的步骤 throw和throws的区别 异常介绍 基本概念:Java语言中,将程序执行中发生的不正常情况称为“…

【TCP/IP】多进程服务器的实现(进阶) - 多进程服务器模型及代码实现

经过前面的铺垫,我们已经具备实现并发服务器的基础了,接下来让我们尝试将之前的单任务回声服务器改装成多任务并发模式吧! 多任务回声服务器模型 在编写代码前,先让我们大致将多任务(回声)服务器的模型抽象…

mac下部署和访问 Kubernetes 仪表板(Dashboard)

简介 Dashboard 是基于网页的 Kubernetes 用户界面。 你可以使用 Dashboard 将容器应用部署到 Kubernetes 集群中,也可以对容器应用排错,还能管理集群资源。 你可以使用 Dashboard 获取运行在集群中的应用的概览信息,也可以创建或者修改 Kub…

技术分享——隐私计算简介

随着数据规模的不断扩大和网络技术的快速发展,数据安全和隐私保护成为了热门的话题。隐私计算作为一种新兴的数据安全和隐私保护技术,为数据安全和隐私泄露问题提供了新的思路和方法。 2020年10月19日,Gartner发布2021年前沿战略科技趋势&am…

VTK源码编译安装记录与教程(VS2019+QT5.15.2+PCL1.12.1+VTK9.1.0配置,超详细)

因为PCL库(傻瓜式安装)中自动安装的VTK库并不完整,不支持QT环境的UI界面开发,于是,想用QT在VS2019上开发图形界面程序,需要单独自己通过VTK源码编译安装,再进行配置。本人安装时开发环境已经装好…

2023拒绝行业内卷!八年软件测试月薪30K*16薪行业心得 想入行必看

目前工作做软件测试工作8年,属于高级测试员那个级别吧!现在看到各行各业的人都在转行学习软件测试,想给大家一些学习建议和忠告。 很多粉丝都跟我说今年行情很差,找不到工资,真的找不到工作了吗? 我们常在网…

simhash原理以及用python3实现simhash算法详解(附python3源码)

1. 为什么需要Simhash? 传统相似度算法:文本相似度的计算,一般使用向量空间模型(VSM),先对文本分词,提取特征,根据特征建立文本向量,把文本之间相似度的计算转化为特征向量距离的计算,如欧式距离、余弦夹角等。 缺点:大数据情况下复杂度会很高。 Simhash应用场景:…

Graph Learning笔记 - 长尾分布问题

Graph Learning笔记 - 长尾分布问题 分享四篇论文入门图神经网络时的学习笔记。 SL-DSGCN 论文:Investigating and Mitigating Degree-Related Biases in Graph Convolutional Networks 来源:2020CIKM 概要 GCN在图的半监督学习上能取得良好表现&a…

lora,固定模特+固定衣服,怎么实现?

在电商行业,经常会有一个需求,就是把固定的衣服让模型穿上,然后拍很多的图片,放在商品主图、详情页、买家秀...... 人工智能发展到现在,最近aigc也挺热门的,有没有办法用“人工智能”算法就实现这个功能&a…

逆向汇编反汇编——函数分析

add esp,8就是把esp调整到函数调用之前的状态,用以平衡堆栈 默认采用的是cdcall:外平栈 stdcall:内平栈 什么是堆栈平衡? 》原来的堆栈是什么样的,函数调用之后堆栈还是什么样的(即,esp和ebp的值保持不变)&#xf…

掌握Scala数据结构(2)MAP、TUPLE、SET

一、映射 (Map) (一)不可变映射 1、创建不可变映射 创建不可变映射mp,用键->值的形式 创建不可变映射mp,用(键, 值)的形式 注意:Map是特质(Scala里的trait,相当于Java里的interface&#…

git选择指定分支中的指定目录进行合并

指定路径合并 先进入branch A : git checkout branchA 将dir2中的变更转移至branchA: git checkout branchB dir2 所有变更将出现在branchA中的dir2中,检查后提交即可。 git commit -m "sync branchB dir2 to branchA" 也可以…