百度360必应搜狗淘宝本站头条
当前位置:网站首页 > IT知识 > 正文

Gorgonia为Go开发者打开了机器学习的大门

liuian 2025-05-08 19:41 51 浏览

在Python主导的机器学习领域,Go语言凭借其卓越的并发性能和编译型语言的效率优势逐渐崭露头角。Gorgonia作为Go语言生态中领先的机器学习库,提供了类似Theano/TensorFlow的计算图抽象。

官方文档:https://gorgonia.org/

开源地址:
https://github.com/gorgonia/gorgonia

  • 可以进行自动微分
  • 可以进行符号微分
  • 可以执行梯度下降优化
  • 提供多种便捷功能来帮助创建神经网络
  • 相当快(与 Theano 和 TensorFlow 速度相当)
  • 支持 CUDA/GPU 计算(OpenCL 尚不支持)
  • 将支持分布式计算

获取库及其依赖项

$ go get gorgonia.org/gorgonia

下面是一个使用 Gorgonia 构建的三层卷积网络的示例。使用的是 MNIST 数据集。

package main

import (
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"math/rand"
	"os"
	"os/signal"
	"runtime/pprof"
	"syscall"

	"net/http"
	_ "net/http/pprof"

	"github.com/pkg/errors"
	"gorgonia.org/gorgonia"
	"gorgonia.org/gorgonia/examples/mnist"
	"gorgonia.org/tensor"

	"time"

	"gopkg.in/cheggaaa/pb.v1"
)

var (
	epochs     = flag.Int("epochs", 100, "Number of epochs to train for")
	dataset    = flag.String("dataset", "train", "Which dataset to train on? Valid options are \"train\" or \"test\"")
	dtype      = flag.String("dtype", "float64", "Which dtype to use")
	batchsize  = flag.Int("batchsize", 100, "Batch size")
	cpuprofile = flag.String("cpuprofile", "", "CPU profiling")
)

const loc = "../testdata/mnist/"

var dt tensor.Dtype

func parseDtype() {
	switch *dtype {
	case "float64":
		dt = tensor.Float64
	case "float32":
		dt = tensor.Float32
	default:
		log.Fatalf("Unknown dtype: %v", *dtype)
	}
}

type sli struct {
	start, end int
}

func (s sli) Start() int { return s.start }
func (s sli) End() int   { return s.end }
func (s sli) Step() int  { return 1 }

type convnet struct {
	g                  *gorgonia.ExprGraph
	w0, w1, w2, w3, w4 *gorgonia.Node // weights. the number at the back indicates which layer it's used for
	d0, d1, d2, d3     float64        // dropout probabilities

	out *gorgonia.Node
}

func newConvNet(g *gorgonia.ExprGraph) *convnet {
	w0 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(32, 1, 3, 3), gorgonia.WithName("w0"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
	w1 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(64, 32, 3, 3), gorgonia.WithName("w1"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
	w2 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(128, 64, 3, 3), gorgonia.WithName("w2"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
	w3 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128*3*3, 625), gorgonia.WithName("w3"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
	w4 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(625, 10), gorgonia.WithName("w4"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
	return &convnet{
		g:  g,
		w0: w0,
		w1: w1,
		w2: w2,
		w3: w3,
		w4: w4,

		d0: 0.2,
		d1: 0.2,
		d2: 0.2,
		d3: 0.55,
	}
}

func (m *convnet) learnables() gorgonia.Nodes {
	return gorgonia.Nodes{m.w0, m.w1, m.w2, m.w3, m.w4}
}

// This function is particularly verbose for educational reasons. In reality, you'd wrap up the layers within a layer struct type and perform per-layer activations
func (m *convnet) fwd(x *gorgonia.Node) (err error) {
	var c0, c1, c2, fc *gorgonia.Node
	var a0, a1, a2, a3 *gorgonia.Node
	var p0, p1, p2 *gorgonia.Node
	var l0, l1, l2, l3 *gorgonia.Node

	// LAYER 0
	// here we convolve with stride = (1, 1) and padding = (1, 1),
	// which is your bog standard convolution for convnet
	if c0, err = gorgonia.Conv2d(x, m.w0, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
		return errors.Wrap(err, "Layer 0 Convolution failed")
	}
	if a0, err = gorgonia.Rectify(c0); err != nil {
		return errors.Wrap(err, "Layer 0 activation failed")
	}
	if p0, err = gorgonia.MaxPool2D(a0, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
		return errors.Wrap(err, "Layer 0 Maxpooling failed")
	}
	log.Printf("p0 shape %v", p0.Shape())
	if l0, err = gorgonia.Dropout(p0, m.d0); err != nil {
		return errors.Wrap(err, "Unable to apply a dropout")
	}

	// Layer 1
	if c1, err = gorgonia.Conv2d(l0, m.w1, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
		return errors.Wrap(err, "Layer 1 Convolution failed")
	}
	if a1, err = gorgonia.Rectify(c1); err != nil {
		return errors.Wrap(err, "Layer 1 activation failed")
	}
	if p1, err = gorgonia.MaxPool2D(a1, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
		return errors.Wrap(err, "Layer 1 Maxpooling failed")
	}
	if l1, err = gorgonia.Dropout(p1, m.d1); err != nil {
		return errors.Wrap(err, "Unable to apply a dropout to layer 1")
	}

	// Layer 2
	if c2, err = gorgonia.Conv2d(l1, m.w2, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
		return errors.Wrap(err, "Layer 2 Convolution failed")
	}
	if a2, err = gorgonia.Rectify(c2); err != nil {
		return errors.Wrap(err, "Layer 2 activation failed")
	}
	if p2, err = gorgonia.MaxPool2D(a2, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
		return errors.Wrap(err, "Layer 2 Maxpooling failed")
	}
	log.Printf("p2 shape %v", p2.Shape())

	var r2 *gorgonia.Node
	b, c, h, w := p2.Shape()[0], p2.Shape()[1], p2.Shape()[2], p2.Shape()[3]
	if r2, err = gorgonia.Reshape(p2, tensor.Shape{b, c * h * w}); err != nil {
		return errors.Wrap(err, "Unable to reshape layer 2")
	}
	log.Printf("r2 shape %v", r2.Shape())
	if l2, err = gorgonia.Dropout(r2, m.d2); err != nil {
		return errors.Wrap(err, "Unable to apply a dropout on layer 2")
	}

	ioutil.WriteFile("tmp.dot", []byte(m.g.ToDot()), 0644)

	// Layer 3
	if fc, err = gorgonia.Mul(l2, m.w3); err != nil {
		return errors.Wrapf(err, "Unable to multiply l2 and w3")
	}
	if a3, err = gorgonia.Rectify(fc); err != nil {
		return errors.Wrapf(err, "Unable to activate fc")
	}
	if l3, err = gorgonia.Dropout(a3, m.d3); err != nil {
		return errors.Wrapf(err, "Unable to apply a dropout on layer 3")
	}

	// output decode
	var out *gorgonia.Node
	if out, err = gorgonia.Mul(l3, m.w4); err != nil {
		return errors.Wrapf(err, "Unable to multiply l3 and w4")
	}
	m.out, err = gorgonia.SoftMax(out)
	return
}

func main() {
	flag.Parse()
	parseDtype()
	rand.Seed(1337)

	// intercept Ctrl+C
	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
	doneChan := make(chan bool, 1)

	var inputs, targets tensor.Tensor
	var err error

	go func() {
		log.Println(http.ListenAndServe("localhost:6060", nil))
	}()

	trainOn := *dataset
	if inputs, targets, err = mnist.Load(trainOn, loc, dt); err != nil {
		log.Fatal(err)
	}

	// the data is in (numExamples, 784).
	// In order to use a convnet, we need to massage the data
	// into this format (batchsize, numberOfChannels, height, width).
	//
	// This translates into (numExamples, 1, 28, 28).
	//
	// This is because the convolution operators actually understand height and width.
	//
	// The 1 indicates that there is only one channel (MNIST data is black and white).
	numExamples := inputs.Shape()[0]
	bs := *batchsize
	// todo - check bs not 0

	if err := inputs.Reshape(numExamples, 1, 28, 28); err != nil {
		log.Fatal(err)
	}
	g := gorgonia.NewGraph()
	x := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(bs, 1, 28, 28), gorgonia.WithName("x"))
	y := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(bs, 10), gorgonia.WithName("y"))
	m := newConvNet(g)
	if err = m.fwd(x); err != nil {
		log.Fatalf("%+v", err)
	}
	losses := gorgonia.Must(gorgonia.HadamardProd(m.out, y))
	cost := gorgonia.Must(gorgonia.Mean(losses))
	cost = gorgonia.Must(gorgonia.Neg(cost))

	// we wanna track costs
	var costVal gorgonia.Value
	gorgonia.Read(cost, &costVal)

	// if _, err = gorgonia.Grad(cost, m.learnables()...); err != nil {
	// 	log.Fatal(err)
	// }

	// debug
	// ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644)
	// log.Printf("%v", prog)
	// logger := log.New(os.Stderr, "", 0)
	// vm := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(m.learnables()...), gorgonia.WithLogger(logger), gorgonia.WithWatchlist())

	prog, locMap, _ := gorgonia.Compile(g)
	log.Printf("%v", prog)

	vm := gorgonia.NewTapeMachine(g, gorgonia.WithPrecompiled(prog, locMap), gorgonia.BindDualValues(m.learnables()...))
	solver := gorgonia.NewRMSPropSolver(gorgonia.WithBatchSize(float64(bs)))
	defer vm.Close()
	// pprof
	// handlePprof(sigChan, doneChan)

	var profiling bool
	if *cpuprofile != "" {
		f, err := os.Create(*cpuprofile)
		if err != nil {
			log.Fatal(err)
		}
		profiling = true
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}
	go cleanup(sigChan, doneChan, profiling)

	batches := numExamples / bs
	log.Printf("Batches %d", batches)
	bar := pb.New(batches)
	bar.SetRefreshRate(time.Second)
	bar.SetMaxWidth(80)

	for i := 0; i < *epochs; i++ {
		bar.Prefix(fmt.Sprintf("Epoch %d", i))
		bar.Set(0)
		bar.Start()
		for b := 0; b < batches; b++ {
			start := b * bs
			end := start + bs
			if start >= numExamples {
				break
			}
			if end > numExamples {
				end = numExamples
			}

			var xVal, yVal tensor.Tensor
			if xVal, err = inputs.Slice(sli{start, end}); err != nil {
				log.Fatal("Unable to slice x")
			}

			if yVal, err = targets.Slice(sli{start, end}); err != nil {
				log.Fatal("Unable to slice y")
			}
			if err = xVal.(*tensor.Dense).Reshape(bs, 1, 28, 28); err != nil {
				log.Fatalf("Unable to reshape %v", err)
			}

			gorgonia.Let(x, xVal)
			gorgonia.Let(y, yVal)
			if err = vm.RunAll(); err != nil {
				log.Fatalf("Failed at epoch  %d: %v", i, err)
			}
			solver.Step(gorgonia.NodesToValueGrads(m.learnables()))
			vm.Reset()
			bar.Increment()
		}
		log.Printf("Epoch %d | cost %v", i, costVal)

	}
}

func cleanup(sigChan chan os.Signal, doneChan chan bool, profiling bool) {
	select {
	case <-sigChan:
		log.Println("EMERGENCY EXIT!")
		if profiling {
			log.Println("Stop profiling")
			pprof.StopCPUProfile()
		}
		os.Exit(1)

	case <-doneChan:
		return
	}
}

func handlePprof(sigChan chan os.Signal, doneChan chan bool) {
	var profiling bool
	if *cpuprofile != "" {
		f, err := os.Create(*cpuprofile)
		if err != nil {
			log.Fatal(err)
		}
		profiling = true
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}
	go cleanup(sigChan, doneChan, profiling)
}

使用 Gorgonia 的主要原因是开发人员的舒适度。如果您广泛使用 Go 技术栈,现在您可以在熟悉且舒适的环境中创建可用于生产的机器学习系统。

Gorgonia 目前性能相当出色——其速度与 PyTorch 和 Tensorflow 的 CPU 实现相当。由于 CGO 负担过重,GPU 实现的比较略显困难,但请放心,这是一个正在积极改进的领域。

相关推荐

Python 中 必须掌握的 20 个核心函数——items()函数

items()是Python字典对象的方法,用于返回字典中所有键值对的视图对象。它提供了对字典完整内容的高效访问和操作。一、items()的基本用法1.1方法签名dict.items()返回:字典键...

Python字典:键值对的艺术_python字典的用法

字典(dict)是Python的核心数据结构之一,与列表同属可变序列,但采用完全不同的存储方式:定义方式:使用花括号{}(列表使用方括号[])存储结构:以键值对(key-valuepair)...

python字典中如何添加键值对_python怎么往字典里添加键

添加键值对首先定义一个空字典1>>>dic={}直接对字典中不存在的key进行赋值来添加123>>>dic['name']='zhangsan'>>...

Spring Boot @ConfigurationProperties 详解与 Nacos 配置中心集成

本文将深入探讨SpringBoot中@ConfigurationProperties的详细用法,包括其语法细节、类型转换、复合类型处理、数据校验,以及与Nacos配置中心的集成方式。通过...

Dubbo概述_dubbo工作原理和机制

什么是RPCRPC是RemoteProcedureCall的缩写翻译为:远程过程调用目标是为了实现两台(多台)计算机\服务器,互相调用方法\通信的解决方案RPC的概念主要定义了两部分内容序列化协...

再见 Feign!推荐一款微服务间调用神器,跟 SpringCloud 绝配

在微服务项目中,如果我们想实现服务间调用,一般会选择Feign。之前介绍过一款HTTP客户端工具Retrofit,配合SpringBoot非常好用!其实Retrofit不仅支持普通的HTTP调用,还能...

SpringGateway 网关_spring 网关的作用

奈非框架简介早期(2020年前)奈非提供的微服务组件和框架受到了很多开发者的欢迎这些框架和SpringCloudAlibaba的对应关系我们要知道Nacos对应Eureka都是注册中心Dubbo...

Sentinel 限流详解-Sentinel与OpenFeign服务熔断那些事

SentinelResource我们使用到过这个注解,我们需要了解的是其中两个属性:value:资源名称,必填且唯一。@SentinelResource(value="test/get&#...

超详细MPLS学习指南 手把手带你实现IP与二层网络的无缝融合

大家晚上好,我是小老虎,今天的文章有点长,但是都是干货,耐心看下去,不会让你失望的哦!随着ASIC技术的发展,路由查找速度已经不是阻碍网络发展的瓶颈。这使得MPLS在提高转发速度方面不再具备明显的优势...

Cisco 尝试配置MPLS-V.P.N从开始到放弃

本人第一次接触这个协议,所以打算分两篇进行学习和记录,本文枯燥预警,配置命令在下一篇全为定义,其也是算我毕业设计的一个小挑战。新概念重点备注为什么选择该协议IPSecVPN都属于传统VPN传统VP...

MFC -- 网络通信编程_mfc编程教程

要买东西的时候,店家常常说,你要是真心买的,还能给你便宜,你看真心就是不怎么值钱。。。----网易云热评一、创建服务端1、新建一个控制台应用程序,添加源文件server2、添加代码框架#includ...

35W快充?2TB存储?iPhone14爆料汇总,不要再漫天吹15了

iPhone14都还没发布,关于iPhone15的消息却已经漫天飞,故加紧整理了关于iPhone14目前已爆出的消息。本文将从机型、刘海、屏幕、存储、芯片、拍照、信号、机身材质、充电口、快充、配色、价...

SpringCloud Alibaba(四) - Nacos 配置中心

1、环境搭建1.1依赖<!--nacos注册中心注解@EnableDiscoveryClient--><dependency><groupI...

Nacos注册中心最全详解(图文全面总结)

Nacos注册中心是微服务的核心组件,也是大厂经常考察的内容,下面我就重点来详解Nacos注册中心@mikechen本篇已收于mikechen原创超30万字《阿里架构师进阶专题合集》里面。微服务注册中...

网络技术领域端口号备忘录,受益匪浅 !

你好,这里是网络技术联盟站,我是瑞哥。网络端口是计算机网络中用于区分不同应用程序和服务的标识符。每个端口号都是一个16位的数字,范围从0到65535。网络端口的主要功能是帮助网络设备(如计算机和服务器...