谷歌今天又开源了,这次是Sketch-RNN
liuian 2025-04-11 01:00 75 浏览
前不久,谷歌公布了一项最新技术,可以教机器画画。今天,谷歌开源了代码。在我们研究其代码之前,首先先按要求设置Magenta环境。(
https://github.com/tensorflow/magenta/blob/master/README.md)
本文详细解释了Sketch-RNN的TensorFlow代码,即之前发布的两篇文章《Teaching Machines to Draw》和《A Neural Representation of Sketch Drawings》中描述的循环神经网络模型(RNN)。
模型概览
sketch-rnn是序列到序列的变体自动编码器。编码器RNN是双向RNN,解码器是自回归混合密度RNN。你可以使用enc_model,dec_model,enc_size,dec_size设置指定要使用的RNN单元格的类型和RNN的大小。
编码器将采用一个潜在代码z,一个维度为z_size的浮点矢量。像VAE一样,我们可以对z强制执行高斯IID分布,并使用kl_weight来控制KL发散损失项的强度。KL散度损失与重建损失之间将会有一个权衡。我们还允许潜在的代码存储信息的一些空间,而不是纯高斯IID。一旦KL损失期限低于kl_tolerance,我们将停止对该期限的优化。
对于中小型数据集,丢失(dropout)和数据扩充是避免过度拟合的非常有用的技术。我们提供了输入丢失、输出丢失、不存在内存丢失的循环丢失三个选项。实际上,我们只使用循环丢失,通常根据数据集将其设置在65%到90%之间。层次归一化和反复丢失可以一起使用,形成了一个强大的组合,用于在小型数据集上训练循环神经网络。
谷歌提供了两种数据增强技术。第一个是随机缩放训练图像大小的random_scale_factor。第二种增加技术(sketch-rnn论文中未使用)剔除线笔划中的随机点。给定一个具有超过2点的线段,我们可以随机放置线段内的点,并且仍然保持类似的矢量图像。这种类型的数据增强在小数据集上使用时非常强大,并且对矢量图是唯一的,因为难以在文本或MIDI数据中删除随机字符或音符,并且也不可能在像素图像数据中丢弃随机像素而不引起大的视觉差异。我们通常将数据增加参数设置为10%至20%。如果在与普通示例相比较的情况下,人类观众几乎没有差异,那么我们应用数据增强技术,而不考虑训练数据集的大小。
有效地使用丢弃和数据扩充,可以避免过度拟合到一个小的训练集。
训练模型
要训练模型,首先需要一个包含训练/验证/测试例子的数据集。我们提供了指向aaron_sheep数据集的链接,默认情况下,该模型将使用此轻量级数据集。
使用示例:
sketch_rnn_train --log_root=checkpoint_path --data_dir=dataset_path --hparams={"data_set"="dataset_filename.npz"}我们建议你在模型和数据集内部创建子目录,以保存自己的数据和检查点。 TensorBoard日志将存储在checkpoint_path内,用于查看训练/验证/测试数据集中各种损失的训练曲线。
以下是模型的完整选项列表以及默认设置:
data_set='aaron_sheep.npz', # Our dataset.
num_steps=10000000, # Total number of training set. Keeplarge.
save_every=500, # Number of batches percheckpoint creation.
dec_rnn_size=512, # Size of decoder.
dec_model='lstm', # Decoder: lstm, layer_norm orhyper.
enc_rnn_size=256, # Size of encoder.
enc_model='lstm', # Encoder: lstm, layer_norm orhyper.
z_size=128, # Size of latent vector z.Recommend 32, 64 or 128.
kl_weight=0.5, # KL weight of loss equation.Recommend 0.5 or 1.0.
kl_weight_start=0.01, # KL start weight when annealing.
kl_tolerance=0.2, # Level of KL loss at which to stopoptimizing for KL.
batch_size=100, # Minibatch size. Recommendleaving at 100.
grad_clip=1.0, # Gradient clipping. Recommendleaving at 1.0.
num_mixture=20, # Number of mixtures in Gaussianmixture model.
learning_rate=0.001, # Learning rate.
decay_rate=0.9999, # Learning rate decay per minibatch.
kl_decay_rate=0.99995, # KL annealing decay rate per minibatch.
min_learning_rate=0.00001, # Minimum learning rate.
use_recurrent_dropout=True, # Recurrent Dropout without Memory Loss.Recomended.
recurrent_dropout_prob=0.90, # Probabilityof recurrent dropout keep.
use_input_dropout=False, # Input dropout. Recommend leaving False.
input_dropout_prob=0.90, # Probability of input dropout keep.
use_output_dropout=False, # Output droput. Recommend leaving False.
output_dropout_prob=0.90, # Probability of output dropout keep.
random_scale_factor=0.15, # Random scaling data augmentionproportion.
augment_stroke_prob=0.10, # Point dropping augmentation proportion.
conditional=True, # If False, use decoder-only model.
以下是一些可能需要用于在非常大的数据集上训练模型的选项,并使用HyperLSTM作为RNN单元。对于小于10K的训练样本的小数据集,具有层规范化(包括enc_model和dec_model的layer_norm)的LSTM效果最佳。
sketch_rnn_train --log_root=models/big_model --data_dir=datasets/big_dataset --hparams={"data_set"="big_dataset_filename.npz","dec_model":"hyper","dec_rnn_size":2048,"enc_model":"layer_norm","enc_rnn_size":512,"save_every":5000,"grad_clip":1.0,"use_recurrent_dropout":0}对于Python 2.7,我们已经在TensorFlow 1.0和1.1上测试了这个模型。
数据集
由于大小限制,此报告不包含任何数据集。
我们已经准备好了许多使用Sketch-RNN开箱即用的数据集。Google QuickDraw数据集(
https://quickdraw.withgoogle.com/data)是涵盖345个类别的50M矢量草图的集合。在quickdraw数据集中,有一个名为Sketch-RNNQuickDraw Dataset的部分描述了可用于此项目的预处理数据文件。每个类别类都存储在其自己的文件中,如cat.npz,并包含70000/2500/2500示例的训练/验证/测试集大小。
从Google云(
https://console.cloud.google.com/storage/quickdraw_dataset/sketchrnn)
下载.npz数据集,以供本地使用。我们建议你创建一个名为datasets / quickdraw的子目录,并将这些.npz文件保存在此子目录中。
除了QuickDraw数据集之外,我们还在较小的数据集上测试了该模型。在sketch-rnn-datasets(
https://github.com/hardmaru/sketch-rnn-datasets)报告中,还有3个数据集:AaronKoblin Sheep Market、Kanji和Omniglot。如果你希望在本地使用它们,我们建议你为每个数据集创建一个子目录,如datasets/ aaron_sheep。如前所述,在小型数据集上训练模型以避免过度拟合时,应使用循环退出和数据增加。
创建自己的数据集
请创建你自己有趣的数据集并训练这些算法!创建新的数据集是乐趣的一部分。你很可能发现有趣的矢量线图数据集,为什么要用现有的预先打包好的数据集呢?在我们的实验中,由几千个例子组成的数据集大小足以产生一些有意义的结果。在这里,我们描述模型期望看到的数据集文件的格式。
数据集中的每个示例都存储为坐标偏移的列表:Δx,Δy用来二进制值表示笔是否从纸张提起。这种格式,我们称之为stroke-3,在论文中有描述(
https://arxiv.org/abs/1308.0850)。 请注意,论文中描述的数据格式有5个元素(stroke-5格式),此转换在DataLoader内自动完成。以下是使用以下格式的乌龟示例草图:
图:作为(Δx,Δy,二进制笔状态)序列的示例草图点和渲染形式。在渲染草图中,线条颜色对应于顺序笔画排列。
在我们的数据集中,示例列表中的每个示例都用np.int16数据类型表示为np.array。你可以将它们存储为np.int8,你可以将其存储起来以节省存储空间。如果你的数据必须是浮点格式,也可以使用np.float16。np.float32可能会浪费存储空间。在我们的数据中,Δx和Δy偏移通常用像素位置表示,它们大于神经网络模型喜欢看到的数字范围,所以在模型中内置了归一化缩放过程。当我们加载训练数据时,模型将自动转换为np.float并在训练前相应规范化。
如果要创建自己的数据集,则必须为训练/验证/测试集创建三个示例列表,以避免过度拟合到训练集。该模型将使用验证集来处理早期停止。对于aaron_sheep数据集,我们使用了7400/300/300的示例,并将每个内容放在python列表中,名为train_data,validation_data和test_data。之后,我们创建了一个名为datasets / aaron_sheep的子目录,我们使用内置的savez_compressed方法将数据集的压缩版本保存在aaron_sheep.npz文件中。在我们的所有实验中,每个数据集的大小是100的确切倍数。
filename = os.path.join('datasets/your_dataset_directory', 'your_dataset_name.npz')我们还通过执行简单的笔画简化来预处理数据,称为Ramer-Douglas-Peucker。 在这里应用这个算法有一些易于使用的开源代码(
https://github.com/fhirschmann/rdp)。 实际上,我们可以将epsilon参数设置为0.2到3.0之间的值,具体取决于我们想要简单的线条。 在本文中,我们使用了一个2.0的epsilon参数。 我们建议你建立最大序列长度小于250的数据集。
如果你有大量简单的SVG图像,则可以使用一些可用的库(
https://pypi.python.org/pypi/svg.path)来将SVG的子集转换为线段,然后可以在将数据转换为stroke-3格式之前对线段应用RDP。
预训练模型
我们为aaron_sheep数据集提供了预先训练的模型,用于条件和无条件训练模式,使用vanilla LSTM单元以及带有层规范化的LSTM单元。这些型号将通过运行Jupyter Notebook下载。它们存储在:
/tmp/sketch_rnn/models/aaron_sheep/lstm
/tmp/sketch_rnn/models/aaron_sheep/lstm_uncond
/tmp/sketch_rnn/models/aaron_sheep/layer_norm
/tmp/sketch_rnn/models/aaron_sheep/layer_norm_uncond
此外,我们为选定的QuickDraw数据集提供了预先训练的模型:
/tmp/sketch_rnn/models/owl/lstm
/tmp/sketch_rnn/models/flamingo/lstm_uncond
/tmp/sketch_rnn/models/catbus/lstm
/tmp/sketch_rnn/models/elephantpig/lstm
使用Jupyter notebook的模型
让我们来模拟猫和公车之间的插值!
我们涵盖了一个简单的Jupyter notebook(
http://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn.ipynb),向你展示如何加载预先训练的模型并生成矢量草图。你能够在两个矢量图像之间进行编码,解码和变形,并生成新的随机图像。采样图像时,可以调整temperature参数来控制不确定度。
来源:
https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/README.md
相关推荐
-
- 驱动网卡(怎么从新驱动网卡)
-
网卡一般是指为电脑主机提供有线无线网络功能的适配器。而网卡驱动指的就是电脑连接识别这些网卡型号的桥梁。网卡只有打上了网卡驱动才能正常使用。并不是说所有的网卡一插到电脑上面就能进行数据传输了,他都需要里面芯片组的驱动文件才能支持他进行数据传输...
-
2026-01-30 00:37 liuian
- win10更新助手装系统(微软win10更新助手)
-
1、点击首页“系统升级”的按钮,给出弹框,告诉用户需要上传IMEI码才能使用升级服务。同时给出同意和取消按钮。华为手机助手2、点击同意,则进入到“系统升级”功能华为手机助手华为手机助手3、在检测界面,...
- windows11专业版密钥最新(windows11专业版激活码永久)
-
Windows11专业版的正版密钥,我们是对windows的激活所必备的工具。该密钥我们可以通过微软商城或者通过计算机的硬件供应商去购买获得。获得了windows11专业版的正版密钥后,我...
-
- 手机删过的软件恢复(手机删除过的软件怎么恢复)
-
操作步骤:1、首先,我们需要先打开手机。然后在许多图标中找到带有[文件管理]文本的图标,然后单击“文件管理”进入页面。2、进入页面后,我们将在顶部看到一行文本:手机,最新信息,文档,视频,图片,音乐,收藏,最后是我们正在寻找的[更多],单击...
-
2026-01-29 23:55 liuian
- 一键ghost手动备份系统步骤(一键ghost 备份)
-
步骤1、首先把装有一键GHOST装系统的U盘插在电脑上,然后打开电脑马上按F2或DEL键入BIOS界面,然后就选择BOOT打USDHDD模式选择好,然后按F10键保存,电脑就会马上重启。 步骤...
- 怎么创建局域网(怎么创建局域网打游戏)
-
1、购买路由器一台。进入路由器把dhcp功能打开 2、购买一台交换机。从路由器lan端口拉出一条网线查到交换机的任意一个端口上。 3、两台以上电脑。从交换机任意端口拉出网线插到电脑上(电脑设置...
- 精灵驱动器官方下载(精灵驱动手机版下载)
-
是的。驱动精灵是一款集驱动管理和硬件检测于一体的、专业级的驱动管理和维护工具。驱动精灵为用户提供驱动备份、恢复、安装、删除、在线更新等实用功能。1、全新驱动精灵2012引擎,大幅提升硬件和驱动辨识能力...
- 一键还原系统步骤(一键还原系统有哪些)
-
1、首先需要下载安装一下Windows一键还原程序,在安装程序窗口中,点击“下一步”,弹出“用户许可协议”窗口,选择“我同意该许可协议的条款”,并点击“下一步”。 2、在弹出的“准备安装”窗口中,可...
- 电脑加速器哪个好(电脑加速器哪款好)
-
我认为pp加速器最好用,飞速土豆太懒,急速酷六根本不工作。pp加速器什么网页都加速,太任劳任怨了!以上是个人观点,具体性能请自己试。ps:我家电脑性能很好。迅游加速盒子是可以加速电脑的。因为有过之...
- 任何u盘都可以做启动盘吗(u盘必须做成启动盘才能装系统吗)
-
是的,需要注意,U盘的大小要在4G以上,最好是8G以上,因为启动盘里面需要装系统,内存小的话,不能用来安装系统。内存卡或者U盘或者移动硬盘都可以用来做启动盘安装系统。普通的U盘就可以,不过最好U盘...
- u盘怎么恢复文件(u盘文件恢复的方法)
-
开360安全卫士,点击上面的“功能大全”。点击文件恢复然后点击“数据”下的“文件恢复”功能。选择驱动接着选择需要恢复的驱动,选择接入的U盘。点击开始扫描选好就点击中间的“开始扫描”,开始扫描U盘数据。...
- 系统虚拟内存太低怎么办(系统虚拟内存占用过高什么原因)
-
1.检查系统虚拟内存使用情况,如果发现有大量的空闲内存,可以尝试释放一些不必要的进程,以释放内存空间。2.如果系统虚拟内存使用率较高,可以尝试增加系统虚拟内存的大小,以便更多的应用程序可以使用更多...
-
- 剪贴板权限设置方法(剪贴板访问权限)
-
1、首先打开iphone手机,触碰并按住单词或图像直到显示选择选项。2、其次,然后选取“拷贝”或“剪贴板”。3、勾选需要的“权限”,最后选择开启,即可完成苹果剪贴板权限设置。仅参考1.打开苹果手机设置按钮,点击【通用】。2.点击【键盘】,再...
-
2026-01-29 21:37 liuian
- 平板系统重装大师(平板重装win系统)
-
如果你的平板开不了机,但可以连接上电脑,那就能好办,楼主下载安装个平板刷机王到你的个人电脑上,然后连接你的平板,平板刷机王会自动识别你的平板,平板刷机王上有你平板的我刷机包,楼主点击下载一个,下载完成...
- 联想官网售后服务网点(联想官网售后服务热线)
-
联想3c服务中心是联想旗下的官方售后,是基于互联网O2O模式开发的全新服务平台。可以为终端用户提供多品牌手机、电脑以及其他3C类产品的维修、保养和保险服务。根据客户需求层次,联想服务针对个人及家庭客户...
- 一周热门
- 最近发表
- 标签列表
-
- python判断字典是否为空 (50)
- crontab每周一执行 (48)
- aes和des区别 (43)
- bash脚本和shell脚本的区别 (35)
- canvas库 (33)
- dataframe筛选满足条件的行 (35)
- gitlab日志 (33)
- lua xpcall (36)
- blob转json (33)
- python判断是否在列表中 (34)
- python html转pdf (36)
- 安装指定版本npm (37)
- idea搜索jar包内容 (33)
- css鼠标悬停出现隐藏的文字 (34)
- linux nacos启动命令 (33)
- gitlab 日志 (36)
- adb pull (37)
- python判断元素在不在列表里 (34)
- python 字典删除元素 (34)
- vscode切换git分支 (35)
- python bytes转16进制 (35)
- grep前后几行 (34)
- hashmap转list (35)
- c++ 字符串查找 (35)
- mysql刷新权限 (34)
