谷歌今天又开源了,这次是Sketch-RNN
liuian 2025-04-11 01:00 54 浏览
前不久,谷歌公布了一项最新技术,可以教机器画画。今天,谷歌开源了代码。在我们研究其代码之前,首先先按要求设置Magenta环境。(
https://github.com/tensorflow/magenta/blob/master/README.md)
本文详细解释了Sketch-RNN的TensorFlow代码,即之前发布的两篇文章《Teaching Machines to Draw》和《A Neural Representation of Sketch Drawings》中描述的循环神经网络模型(RNN)。
模型概览
sketch-rnn是序列到序列的变体自动编码器。编码器RNN是双向RNN,解码器是自回归混合密度RNN。你可以使用enc_model,dec_model,enc_size,dec_size设置指定要使用的RNN单元格的类型和RNN的大小。
编码器将采用一个潜在代码z,一个维度为z_size的浮点矢量。像VAE一样,我们可以对z强制执行高斯IID分布,并使用kl_weight来控制KL发散损失项的强度。KL散度损失与重建损失之间将会有一个权衡。我们还允许潜在的代码存储信息的一些空间,而不是纯高斯IID。一旦KL损失期限低于kl_tolerance,我们将停止对该期限的优化。
对于中小型数据集,丢失(dropout)和数据扩充是避免过度拟合的非常有用的技术。我们提供了输入丢失、输出丢失、不存在内存丢失的循环丢失三个选项。实际上,我们只使用循环丢失,通常根据数据集将其设置在65%到90%之间。层次归一化和反复丢失可以一起使用,形成了一个强大的组合,用于在小型数据集上训练循环神经网络。
谷歌提供了两种数据增强技术。第一个是随机缩放训练图像大小的random_scale_factor。第二种增加技术(sketch-rnn论文中未使用)剔除线笔划中的随机点。给定一个具有超过2点的线段,我们可以随机放置线段内的点,并且仍然保持类似的矢量图像。这种类型的数据增强在小数据集上使用时非常强大,并且对矢量图是唯一的,因为难以在文本或MIDI数据中删除随机字符或音符,并且也不可能在像素图像数据中丢弃随机像素而不引起大的视觉差异。我们通常将数据增加参数设置为10%至20%。如果在与普通示例相比较的情况下,人类观众几乎没有差异,那么我们应用数据增强技术,而不考虑训练数据集的大小。
有效地使用丢弃和数据扩充,可以避免过度拟合到一个小的训练集。
训练模型
要训练模型,首先需要一个包含训练/验证/测试例子的数据集。我们提供了指向aaron_sheep数据集的链接,默认情况下,该模型将使用此轻量级数据集。
使用示例:
sketch_rnn_train --log_root=checkpoint_path --data_dir=dataset_path --hparams={"data_set"="dataset_filename.npz"}我们建议你在模型和数据集内部创建子目录,以保存自己的数据和检查点。 TensorBoard日志将存储在checkpoint_path内,用于查看训练/验证/测试数据集中各种损失的训练曲线。
以下是模型的完整选项列表以及默认设置:
data_set='aaron_sheep.npz', # Our dataset.
num_steps=10000000, # Total number of training set. Keeplarge.
save_every=500, # Number of batches percheckpoint creation.
dec_rnn_size=512, # Size of decoder.
dec_model='lstm', # Decoder: lstm, layer_norm orhyper.
enc_rnn_size=256, # Size of encoder.
enc_model='lstm', # Encoder: lstm, layer_norm orhyper.
z_size=128, # Size of latent vector z.Recommend 32, 64 or 128.
kl_weight=0.5, # KL weight of loss equation.Recommend 0.5 or 1.0.
kl_weight_start=0.01, # KL start weight when annealing.
kl_tolerance=0.2, # Level of KL loss at which to stopoptimizing for KL.
batch_size=100, # Minibatch size. Recommendleaving at 100.
grad_clip=1.0, # Gradient clipping. Recommendleaving at 1.0.
num_mixture=20, # Number of mixtures in Gaussianmixture model.
learning_rate=0.001, # Learning rate.
decay_rate=0.9999, # Learning rate decay per minibatch.
kl_decay_rate=0.99995, # KL annealing decay rate per minibatch.
min_learning_rate=0.00001, # Minimum learning rate.
use_recurrent_dropout=True, # Recurrent Dropout without Memory Loss.Recomended.
recurrent_dropout_prob=0.90, # Probabilityof recurrent dropout keep.
use_input_dropout=False, # Input dropout. Recommend leaving False.
input_dropout_prob=0.90, # Probability of input dropout keep.
use_output_dropout=False, # Output droput. Recommend leaving False.
output_dropout_prob=0.90, # Probability of output dropout keep.
random_scale_factor=0.15, # Random scaling data augmentionproportion.
augment_stroke_prob=0.10, # Point dropping augmentation proportion.
conditional=True, # If False, use decoder-only model.
以下是一些可能需要用于在非常大的数据集上训练模型的选项,并使用HyperLSTM作为RNN单元。对于小于10K的训练样本的小数据集,具有层规范化(包括enc_model和dec_model的layer_norm)的LSTM效果最佳。
sketch_rnn_train --log_root=models/big_model --data_dir=datasets/big_dataset --hparams={"data_set"="big_dataset_filename.npz","dec_model":"hyper","dec_rnn_size":2048,"enc_model":"layer_norm","enc_rnn_size":512,"save_every":5000,"grad_clip":1.0,"use_recurrent_dropout":0}对于Python 2.7,我们已经在TensorFlow 1.0和1.1上测试了这个模型。
数据集
由于大小限制,此报告不包含任何数据集。
我们已经准备好了许多使用Sketch-RNN开箱即用的数据集。Google QuickDraw数据集(
https://quickdraw.withgoogle.com/data)是涵盖345个类别的50M矢量草图的集合。在quickdraw数据集中,有一个名为Sketch-RNNQuickDraw Dataset的部分描述了可用于此项目的预处理数据文件。每个类别类都存储在其自己的文件中,如cat.npz,并包含70000/2500/2500示例的训练/验证/测试集大小。
从Google云(
https://console.cloud.google.com/storage/quickdraw_dataset/sketchrnn)
下载.npz数据集,以供本地使用。我们建议你创建一个名为datasets / quickdraw的子目录,并将这些.npz文件保存在此子目录中。
除了QuickDraw数据集之外,我们还在较小的数据集上测试了该模型。在sketch-rnn-datasets(
https://github.com/hardmaru/sketch-rnn-datasets)报告中,还有3个数据集:AaronKoblin Sheep Market、Kanji和Omniglot。如果你希望在本地使用它们,我们建议你为每个数据集创建一个子目录,如datasets/ aaron_sheep。如前所述,在小型数据集上训练模型以避免过度拟合时,应使用循环退出和数据增加。
创建自己的数据集
请创建你自己有趣的数据集并训练这些算法!创建新的数据集是乐趣的一部分。你很可能发现有趣的矢量线图数据集,为什么要用现有的预先打包好的数据集呢?在我们的实验中,由几千个例子组成的数据集大小足以产生一些有意义的结果。在这里,我们描述模型期望看到的数据集文件的格式。
数据集中的每个示例都存储为坐标偏移的列表:Δx,Δy用来二进制值表示笔是否从纸张提起。这种格式,我们称之为stroke-3,在论文中有描述(
https://arxiv.org/abs/1308.0850)。 请注意,论文中描述的数据格式有5个元素(stroke-5格式),此转换在DataLoader内自动完成。以下是使用以下格式的乌龟示例草图:
图:作为(Δx,Δy,二进制笔状态)序列的示例草图点和渲染形式。在渲染草图中,线条颜色对应于顺序笔画排列。
在我们的数据集中,示例列表中的每个示例都用np.int16数据类型表示为np.array。你可以将它们存储为np.int8,你可以将其存储起来以节省存储空间。如果你的数据必须是浮点格式,也可以使用np.float16。np.float32可能会浪费存储空间。在我们的数据中,Δx和Δy偏移通常用像素位置表示,它们大于神经网络模型喜欢看到的数字范围,所以在模型中内置了归一化缩放过程。当我们加载训练数据时,模型将自动转换为np.float并在训练前相应规范化。
如果要创建自己的数据集,则必须为训练/验证/测试集创建三个示例列表,以避免过度拟合到训练集。该模型将使用验证集来处理早期停止。对于aaron_sheep数据集,我们使用了7400/300/300的示例,并将每个内容放在python列表中,名为train_data,validation_data和test_data。之后,我们创建了一个名为datasets / aaron_sheep的子目录,我们使用内置的savez_compressed方法将数据集的压缩版本保存在aaron_sheep.npz文件中。在我们的所有实验中,每个数据集的大小是100的确切倍数。
filename = os.path.join('datasets/your_dataset_directory', 'your_dataset_name.npz')我们还通过执行简单的笔画简化来预处理数据,称为Ramer-Douglas-Peucker。 在这里应用这个算法有一些易于使用的开源代码(
https://github.com/fhirschmann/rdp)。 实际上,我们可以将epsilon参数设置为0.2到3.0之间的值,具体取决于我们想要简单的线条。 在本文中,我们使用了一个2.0的epsilon参数。 我们建议你建立最大序列长度小于250的数据集。
如果你有大量简单的SVG图像,则可以使用一些可用的库(
https://pypi.python.org/pypi/svg.path)来将SVG的子集转换为线段,然后可以在将数据转换为stroke-3格式之前对线段应用RDP。
预训练模型
我们为aaron_sheep数据集提供了预先训练的模型,用于条件和无条件训练模式,使用vanilla LSTM单元以及带有层规范化的LSTM单元。这些型号将通过运行Jupyter Notebook下载。它们存储在:
/tmp/sketch_rnn/models/aaron_sheep/lstm
/tmp/sketch_rnn/models/aaron_sheep/lstm_uncond
/tmp/sketch_rnn/models/aaron_sheep/layer_norm
/tmp/sketch_rnn/models/aaron_sheep/layer_norm_uncond
此外,我们为选定的QuickDraw数据集提供了预先训练的模型:
/tmp/sketch_rnn/models/owl/lstm
/tmp/sketch_rnn/models/flamingo/lstm_uncond
/tmp/sketch_rnn/models/catbus/lstm
/tmp/sketch_rnn/models/elephantpig/lstm
使用Jupyter notebook的模型
让我们来模拟猫和公车之间的插值!
我们涵盖了一个简单的Jupyter notebook(
http://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn.ipynb),向你展示如何加载预先训练的模型并生成矢量草图。你能够在两个矢量图像之间进行编码,解码和变形,并生成新的随机图像。采样图像时,可以调整temperature参数来控制不确定度。
来源:
https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/README.md
相关推荐
- psp模拟器ios(psp模拟器ios推荐)
-
psp手机模拟器推荐PPSSPP,作为最流行的开源PSP模拟器,因为其强大的功能和兼容性广受玩家们喜爱。虽然提供了PC和安卓双平台的支持,但是有碍于安卓设备的硬件,移动端PPSSPP的功能并不完整。不...
- 台式机重装系统按f几(重装电脑系统按f几)
-
F8、F9、F10、F11、F12、F2、del。一般用到这几个。下面以联想电脑装WIN10系统为例:1、将制作好的U盘插入要重装系统的电脑,开机画面出现电脑品牌logo时,不停地按“f2键”进入“B...
- win10激活错误代码0x8007007b
-
Win10激活出现0x8007007b解决方法如下1、找到计算机,右键点击属性,确认你的电脑系统是否是windows10。2、鼠标右击桌面,依次点击个性化-主题-桌面图标设置,勾选计算机后依次点击应用...
-
- 4000台式电脑最好的组装配置
-
四千元价格组装电脑主机与五千元组装电脑主机的价格类似,因为电脑主机就几个大部件,电脑主机主板是多少代的产品?主板内存的插槽数?电脑处理器等如果是自己组装,都可以配置到十二代产品,电脑硬盘可以分为256G固态硬盘做系统盘,1T机械硬盘作为工作...
-
2025-11-06 20:05 liuian
- linux是一种什么系统(linux属于什么系统)
-
Linux,全称GNU/Linux,是一种免费使用和自由传播的类UNIX操作系统,是一个基于POSIX的多用户、多任务、支持多线程和多CPU的操作系统。其内核由林纳斯·本纳第克特·托瓦兹于1991年1...
- 手机管理大师免费版(手机管理大师极速版)
-
使用手机“文件管理”打开文件夹时提示访问受限,需要前往“文件”应用查看1.进入手机设置——安全——应用权限——权限/应用2.在手机桌面找到手机管家——权限隐私——应用权限——权限/应用?当然,相对于被...
- 电脑能开机但是进不去桌面怎么办
-
打开任务管理器按Ctrl+Shift+Esc打开任务管理器。文件中运行新任务点击文件,运行新任务。输入指令重启桌面输入explorer.exe,点击确定,等待桌面重启完成就可以了。电脑已经是我们生活中...
- 怎样解除自动关机模式(怎样解除自动开关机)
-
1、打开手机主界面,找到系统自带的“时钟”应用,点击打开它。2、点击进入时钟后,点击右下角的“计时器”。3、进入到计时器后,点击“在计时结束启用雷达”这个选项。4、然后在这里,下拉到最下面,勾选“停...
- 电脑最高配置是什么配置2025
-
一,2023最新主流电脑装机配置如下。二,处理器可以使用十二代的i512400或者i512490f,内存16gb双通道,显卡rtx3060,主板可以使用b660m或者h610m。三,如果十三代酷睿...
- MySQL慢查询优化:从explain到索引,DBA手把手教你提升10倍性能
-
数据库性能是应用系统的生命线,而慢查询就像隐藏在系统中的定时炸弹。某电商平台曾因一条未优化的SQL导致订单系统响应时间从200ms飙升至8秒,最终引发用户投诉和订单流失。今天我们就来系统学习MySQL...
- 一文读懂SQL五大操作类别(DDL/DML/DQL/DCL/TCL)的基础语法
-
在SQL中,DDL、DML、DQL、DCL、TCL是按操作类型划分的五大核心语言类别,缩写及简介如下:DDL(DataDefinitionLanguage,数据定义语言):用于定义和管理数据库结构...
- 闲来无事,学学Mysql增、删,改,查
-
Mysql增、删,改,查1“增”——添加数据1.1为表中所有字段添加数据1.1.1INSERT语句中指定所有字段名语法:INSERTINTO表名(字段名1,字段名2,…)VALUES(值1...
- 数据库:MySQL 高性能优化规范建议
-
数据库命令规范所有数据库对象名称必须使用小写字母并用下划线分割所有数据库对象名称禁止使用MySQL保留关键字(如果表名中包含关键字查询时,需要将其用单引号括起来)数据库对象的命名要能做到见名识意,...
- 下载工具合集_下载工具手机版
-
迅雷,在国内的下载地位还是很难撼动的,所需要用到的地方还挺多。缺点就是不开会员,软件会限速。EagleGet,全能下载管理器,支持HTTP(S)FTPMMSRTSP协议,也可以使用浏览器扩展检测...
- mediamtx v1.15.2 更新详解:功能优化与问题修复
-
mediamtxv1.15.2已于2025年10月14日发布,本次更新在功能、性能优化以及问题修复方面带来了多项改进,同时也更新了部分依赖库并提升了安全性。以下为本次更新的详细内容:...
- 一周热门
- 最近发表
- 标签列表
-
- python判断字典是否为空 (50)
- crontab每周一执行 (48)
- aes和des区别 (43)
- bash脚本和shell脚本的区别 (35)
- canvas库 (33)
- dataframe筛选满足条件的行 (35)
- gitlab日志 (33)
- lua xpcall (36)
- blob转json (33)
- python判断是否在列表中 (34)
- python html转pdf (36)
- 安装指定版本npm (37)
- idea搜索jar包内容 (33)
- css鼠标悬停出现隐藏的文字 (34)
- linux nacos启动命令 (33)
- gitlab 日志 (36)
- adb pull (37)
- python判断元素在不在列表里 (34)
- python 字典删除元素 (34)
- vscode切换git分支 (35)
- python bytes转16进制 (35)
- grep前后几行 (34)
- hashmap转list (35)
- c++ 字符串查找 (35)
- mysql刷新权限 (34)
