前段时间帮同事处理了一个把 CSV 数据导入到 MySQL 的需求。两个很大的 CSV 文件, 分别有 3GB、2100 万条记录和 7GB、3500 万条记录。对于这个量级的数据,用简单的单进程/单线程导入 会耗时很久,最终用了多进程的方式来实现。具体过程不赘述,记录一下几个要点:
批量插入而不是逐条插入
为了加快插入速度,先不要建索引
生产者和消费者模型,主进程读文件,多个 worker 进程执行插入
注意控制 worker 的数量,避免对 MySQL 造成太大的压力
注意处理脏数据导致的异常
原始数据是 GBK 编码,所以还要注意转换成 UTF-8
用 click 封装命令行工具
具体的代码实现如下:
代码如下 | 复制代码 |
#!/usr/bin/env python # -*- coding: utf-8 -*-
importcodecs importcsv importlogging importmultiprocessing importos importwarnings
importclick importMySQLdb importsqlalchemy
warnings.filterwarnings('ignore', category=MySQLdb.Warning)
# 批量插入的记录数量 BATCH=5000
DB_URI='mysql://root@localhost:3306/example?charset=utf8'
engine=sqlalchemy.create_engine(DB_URI)
defget_table_cols(table): sql='SELECT * FROM `{table}` LIMIT 0'.format(table=table) res=engine.execute(sql) returnres.keys()
definsert_many(table, cols, rows, cursor): sql='INSERT INTO `{table}` ({cols}) VALUES ({marks})'.format( table=table, cols=', '.join(cols), marks=', '.join(['%s']*len(cols))) cursor.execute(sql,*rows) logging.info('process %s inserted %s rows into table %s', os.getpid(),len(rows), table)
definsert_worker(table, cols, queue): rows=[] # 每个子进程创建自己的 engine 对象 cursor=sqlalchemy.create_engine(DB_URI) whileTrue: row=queue.get() ifrowisNone: ifrows: insert_many(table, cols, rows, cursor) break
rows.append(row) iflen(rows)==BATCH: insert_many(table, cols, rows, cursor) rows=[]
definsert_parallel(table, reader, w=10): cols=get_table_cols(table)
# 数据队列,主进程读文件并往里写数据,worker 进程从队列读数据 # 注意一下控制队列的大小,避免消费太慢导致堆积太多数据,占用过多内存 queue=multiprocessing.Queue(maxsize=w*BATCH*2) workers=[] foriinrange(w): p=multiprocessing.Process(target=insert_worker, args=(table, cols, queue)) p.start() workers.append(p) logging.info('starting # %s worker process, pid: %s...', i+1, p.pid)
dirty_data_file='./{}_dirty_rows.csv'.format(table) xf=open(dirty_data_file,'w') writer=csv.writer(xf, delimiter=reader.dialect.delimiter)
forlineinreader: # 记录并跳过脏数据: 键值数量不一致 iflen(line) !=len(cols): writer.writerow(line) continue
# 把 None 值替换为 'NULL' clean_line=[Noneifx=='NULL'elsexforxinline]
# 往队列里写数据 queue.put(tuple(clean_line)) ifreader.line_num%500000==0: logging.info('put %s tasks into queue.', reader.line_num)
xf.close()
# 给每个 worker 发送任务结束的信号 logging.info('send close signal to worker processes') foriinrange(w): queue.put(None)
forpinworkers: p.join()
defconvert_file_to_utf8(f, rv_file=None): ifnotrv_file: name, ext=os.path.splitext(f) ifisinstance(name,unicode): name=name.encode('utf8') rv_file='{}_utf8{}'.format(name, ext) logging.info('start to process file %s', f) withopen(f) as infd: withopen(rv_file,'w') as outfd: lines=[] loop=0 chunck=200000 first_line=infd.readline().strip(codecs.BOM_UTF8).strip()+'n' lines.append(first_line) forlineininfd: clean_line=line.decode('gb18030').encode('utf8') clean_line=clean_line.rstrip()+'n' lines.append(clean_line) iflen(lines)==chunck: outfd.writelines(lines) lines=[] loop+=1 logging.info('processed %s lines.', loop*chunck)
outfd.writelines(lines) logging.info('processed %s lines.', loop*chunck+len(lines))
@click.group() defcli(): logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
@cli.command('gbk_to_utf8') @click.argument('f') defconvert_gbk_to_utf8(f): convert_file_to_utf8(f)
@cli.command('load') @click.option('-t','--table', required=True,help='表名') @click.option('-i','--filename', required=True,help='输入文件') @click.option('-w','--workers', default=10,help='worker 数量,默认 10') defload_fac_day_pro_nos_sal_table(table, filename, workers): withopen(filename) as fd: fd.readline() # skip header reader=csv.reader(fd) insert_parallel(table, reader, w=workers)
if__name__=='__main__': cli() |