1.问题缘由:
由于工作需求,需要将多个(总量10G+)geojson文件写入到sq3库,众所周知,sqlite 不支持多线程写入,那该怎么办呢,在网上也查了很多策略,都没有达到立竿见影的效果。于是还是回到写文件的本质:多线程写多文件,就绕开加锁的机制。
2.单线程读取的效果
单线程读写原始26个geojson文件,共294M,耗时:547S
写完的sq3文件大小:73.3M
3.多进程并发
多进程并发读写geojson,生成多个sq3文件,再合并到一个sq3文件耗时:16.5S
4.工具代码:
4.1 rw_data_geojson.py: 读写geojson文件
import os
import json
GEOMETRY = 'geometry'
def read_all_layer(src_path):
"""
读取geojson 文件,传入读取文件路径,返回dict
dict 是以layername为key,获取每个layer的dict,子字典的key为要素ID
:param src_path: 读取geojson文件路径
:return: 封装dict {layerName:[id:{要素dict}]}
"""
filenames = os.listdir(src_path)
# 过滤出你想要处理的文件,例如只读取.txt文件
txt_filenames = [f for f in filenames if f.endswith('.geojson')]
geo_properties_map = {}
# 循环读取每个文件
for filename in txt_filenames:
file_path = os.path.join(src_path, filename)
with open(file_path, 'r') as file:
content = file.read()
geojson_data = json.loads(content)
features = geojson_data.get('features', [])
dict = {}
for feature in features:
properties = feature.get('properties')
if GEOMETRY in feature:
properties[GEOMETRY] = feature.get('geometry')
dict[properties["id"]] = properties
layername = filename.replace(".geojson", "")
geo_properties_map[layername] = dict
return geo_properties_map
def read_single_layer(geojson_path):
"""
读取指定geojson 文件,返回dict dict 是以layername为key,获取每个layer的dict,子字典的key为要素ID
:param geojson_path: 读取geojson文件
:return: 封装dict {layerName:[id:{要素dict}]}
"""
geo_properties_map = {}
if not geojson_path.endswith('.geojson'):
return geo_properties_map
with open(geojson_path, 'r') as file:
content = file.read()
geojson_data = json.loads(content)
features = geojson_data.get('features', [])
dict = {}
for feature in features:
properties = feature.get('properties')
if GEOMETRY in feature:
properties[GEOMETRY] = feature.get('geometry')
dict[properties["id"]] = properties
key = os.path.basename(geojson_path).replace(".geojson", "")
geo_properties_map[key] = dict
return geo_properties_map
def build_geojson(src_feats, layer_name='', epsg_crs=None):
"""按照图层,格式化成geojson规格"""
attrs = []
for attr in [attr for key, attr in src_feats.items()]:
geos_obj = attr.get(GEOMETRY)
gjson_dict = {"properties": attr, "type": "Feature"}
if geos_obj is not None:
gjson_dict[GEOMETRY] = geos_obj
del attr[GEOMETRY]
attrs.append(gjson_dict)
layer = {"type": "FeatureCollection", "features": attrs}
if layer_name:
layer['name'] = layer_name
if epsg_crs and src_feats and any(GEOMETRY in a for a in attrs):
if isinstance(epsg_crs, int) or (isinstance(epsg_crs, str) and epsg_crs.isdigit()):
crs_str = "urn:ogc:def:crs:EPSG::%s" % epsg_crs
else:
crs_str = epsg_crs
layer['crs'] = {"type": "name", "properties": {"name": crs_str}}
return layer
def write_layer(target_path, layer_name, node_data):
'''
按图层写geojson数据到磁盘
:param target_path: 目标文件目录
:param layer_name: 目标文件名
:param node_data: 写入的dict嵌套类型数据{dict:{[id:value]}}
:return:
'''
if not os.path.exists(target_path):
os.makedirs(target_path)
with open(target_path + "/" + layer_name + ".geojson", 'w') as f:
json.dump(node_data, f)
print(target_path + "/" + layer_name + ".geojson 写入完毕")
4.2 db_sq3_tool.py :处理sq3数据库
import sqlite3
import os
from shapely.geometry import shape
from read_file import rw_data_geojson
import random
import time
import multiprocessing
import datetime
def create_connection(db_file):
""" 创建与SQLite数据库的连接 """
conn = None
try:
conn = sqlite3.connect(db_file)
return conn
except sqlite3.Error as e:
print(e)
return conn
def create_table(conn, create_table_sql):
""" 使用给定的SQL语句创建表 """
try:
cursor = conn.cursor()
cursor.execute(create_table_sql)
conn.commit()
except sqlite3.Error as e:
print(e)
def insert_data(conn, insert_sql, data):
""" 向数据库插入数据 """
try:
cursor = conn.cursor()
cursor.execute(insert_sql, data)
conn.commit()
except sqlite3.Error as e:
print(e)
def batch_insert_data(conn, data_list, table_name, columns):
'''
批量插入数据
:param conn: 数据库连接
:param data_list: 插入数据list
:param table_name: 表名
:param columns: 表的列名list
:return:
'''
cursor = conn.cursor()
# 构建插入语句的占位符
placeholders = ', '.join(['?'] * len(columns))
insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
try:
cursor.executemany(insert_sql, data_list)
conn.commit()
except sqlite3.Error as e:
print(f"An error occurred: {e}")
conn.rollback()
def select_data(conn, select_sql):
""" 从数据库查询数据 """
try:
cursor = conn.cursor()
cursor.execute(select_sql)
rows = cursor.fetchall()
return rows
except sqlite3.Error as e:
print(e)
def update_data(conn, update_sql, data):
""" 更新数据库中的数据 """
try:
cursor = conn.cursor()
cursor.execute(update_sql, data)
conn.commit()
except sqlite3.Error as e:
print(e)
def delete_data(conn, delete_sql, data):
""" 从数据库中删除数据 """
try:
cursor = conn.cursor()
cursor.execute(delete_sql, data)
conn.commit()
except sqlite3.Error as e:
print(e)
def dict_data_write_sqlite(node_data, table_name, conn, batch=1):
'''
将读完的dict 结构图层内容写入到sq3
:param node_data: 要写入的数据dict
:param table_name: 表名称
:param conn: 数据库连接
:param batch: 是否批量插入
:return:
'''
try:
if len(node_data) == 0:
return
# 获取第一行的key转保存表的列名
random_key, random_value = random.choice(list(node_data.items()))
row0 = random_value
flag2type = {'str': 'TEXT', 'int': 'BIGINT', 'float': 'REAL', 'dict': 'TEXT'}
fld_types = []
columns = []
for key, value in row0.items():
value_type = type(value).__name__
# print(f'{table_name} key:{key} 类型: {value_type}')
fld_types.append((key, flag2type[value_type]))
columns.append(key)
fld_sql = ','.join(f'{fld} {typ}' for fld, typ in fld_types if fld != 'id')
pk_sql = 'id BIGINT PRIMARY KEY'
create_tab_sql = f'CREATE TABLE IF NOT EXISTS {table_name} ({pk_sql}, {fld_sql});'
if conn is not None:
# 1.创建表结构
create_table(conn, create_tab_sql)
if batch:
# 方式1:批量插入,一次提交,效率高
data_list = []
for id, data in node_data.items():
feature_list = []
for key, value in data.items():
if 'geometry' == key:
geometry = shape(value)
feature_list.append(str(geometry.wkt))
else:
feature_list.append(str(value))
data_list.append(feature_list)
batch_insert_data(conn, data_list, table_name, columns)
else:
# 方式2:一条一条插入,适合小数据,效率低下
for id, data in node_data.items():
# 插入数据的SQL语句和数据
cur_values = []
for key, value in data.items():
if 'geometry' == key:
geometry = shape(value)
cur_values.append("'" + str(geometry.wkt) + "'")
else:
cur_values.append(str(value))
flds_str = ','.join(columns)
vals_str = ','.join(cur_values)
insert_sql = f"insert into {table_name} ({flds_str}) values ({vals_str})"
insert_data(conn, insert_sql, data)
print(f"{table_name} sq3写入成功")
except Exception as e:
print(" 写入sq3异常: " + e)
def read_single_geojson_write_sq3(args):
'''
单文件读写sq3
:param args:
:return:
'''
file_name, target_path = args
layer_name = file_name.replace(".geojson", "")
db_file = target_path + '/' + layer_name + '.sq3'
# 创建数据库连接
conn = create_connection(db_file)
# 读取数据,写数据库
geojson_file = os.path.join(folder_path, file_name)
node_data = rw_data_geojson.read_single_layer(geojson_file)
if len(node_data[layer_name]) > 0:
for layer_name, layer_value in node_data.items():
dict_data_write_sqlite(layer_value, layer_name, conn, batch=1)
else:
# 删除空sq3
os.remove(db_file)
conn.close()
return file_name
def merge_sq3(target_path):
# 连接到目标数据库(要拷贝到的数据库)
localtime = time.localtime()
merge_folder = target_path + "/" + "merge_sq3_finish"
if not os.path.exists(merge_folder):
os.makedirs(merge_folder)
target_db = merge_folder + "/" + str(time.strftime('%Y%m%d', localtime)) + ".sq3"
if os.path.exists(target_db):
os.remove(target_db)
print(f"{target_db} 已被删除。")
target_conn = create_connection(target_db)
target_cursor = target_conn.cursor()
# 连接到源数据库(要拷贝的数据库)
for item in os.listdir(target_path):
table_name = item.replace(".sq3", "")
source_db_file = os.path.join(target_path, item)
if os.path.isfile(source_db_file) and item != '.DS_Store':
# 附加源数据库到目标数据库连接
target_cursor.execute(f"ATTACH DATABASE '{source_db_file}' AS source_db;")
# 将源sq3中的 table_name 表 复制到 目标.sq3
target_cursor.execute(f"CREATE TABLE {table_name} AS SELECT * FROM source_db.{table_name}")
# 分离附加的数据库
target_cursor.execute("DETACH DATABASE source_db;")
target_conn.commit()
# 提交更改并关闭连接
target_conn.close()
5.单线程读写代码
folder_path = '/Users/admin/Desktop/123/sq3效率/geojson'
target_path = "/Users/admin/Desktop/123/sq3效率/merge_sq3"
# 1.单线程全量读写
start_time = time.time()
node_data = rw_data_geojson.read_all_layer(folder_path)
# 创建数据库连接
db_file = target_path + '/' + '20240928.sq3'
if os.path.exists(db_file):
os.remove(db_file)
print(f"{db_file} 已被删除。")
conn = create_connection(db_file)
for layer_name, layer_value in node_data.items():
if len(node_data[layer_name]) > 0:
dict_data_write_sqlite(layer_value, layer_name, conn, batch=0)
end_time = time.time()
execution_time = end_time - start_time
print(f"写入sq3 函数执行时间:{execution_time} 秒")
exit()
6.多线程读写,合并到一个sq3数据库
# 2.多文件多线程读写
start_time = time.time()
for root, dirs, files in os.walk(target_path):
for file in files:
db_file = os.path.join(root, file)
os.remove(db_file)
print(f"{db_file} 已被删除。")
with multiprocessing.Pool(processes=5) as pool:
for file_name in os.listdir(folder_path):
if file_name == '.DS_Store':
continue
params = [(file_name, target_path)]
pool.map(read_single_geojson_write_sq3, params)
# 合并多个sq3文件
merge_sq3(target_path)
end_time = time.time()
execution_time = end_time - start_time
print(f"写入sq3 函数执行时间:{execution_time} 秒")
exit()
6.在上述基础上,再继续提效
若单个geojson文件太大时,可多线程分批读取,将读取的块内容,写到一个分块的.sq3,再并发合并到单个图层的sq3,最后将多个图层合并到一个sq3中。