Apache MXNet - KVStore 和可视化


本章讨论 python 包 KVStore 和可视化。

KVStore包

KV 存储代表键值存储。它是用于多设备培训的关键组件。这很重要,因为单台以及多台机器上的设备之间的参数通信是通过一台或多台带有参数 KVStore 的服务器进行传输的。

让我们通过以下几点来了解 KVStore 的工作原理:

  • KVStore 中的每个值都由一个键和一个表示。

  • 网络中的每个参数数组都分配有一个键,并且该参数数组的权重由引用。

  • 之后,工作节点在处理完一批后推送梯度。他们还在处理新批次之前提取更新的权重。

简单来说,我们可以说KVStore是一个数据共享的地方,每个设备都可以将数据推入和拉出。

数据推入和拉出

KVStore 可以被认为是在不同设备(例如 GPU 和计算机)之间共享的单个对象,其中每个设备都能够将数据推入和拉出。

以下是设备推送数据和拉出数据需要遵循的实现步骤:

实施步骤

初始化- 第一步是初始化值。对于我们的示例,我们将在 KVStrore 中初始化一对 (int, NDArray) 对,然后将值拉出 -

import mxnet as mx
kv = mx.kv.create('local') # create a local KVStore.
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())

输出

这会产生以下输出 -

[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]

推送、聚合和更新- 初始化后,我们可以将新值推送到 KVStore 中,其形状与键相同 -

kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a)
print(a.asnumpy())

输出

输出如下 -

[[8. 8. 8.]
 [8. 8. 8.]
 [8. 8. 8.]]

用于推送的数据可以存储在任何设备上,例如 GPU 或计算机。我们还可以将多个值推送到同一个键中。在这种情况下,KVStore 首先将所有这些值相加,然后按如下方式推送聚合值 -

contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())

输出

您将看到以下输出 -

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

对于您应用的每次推送,KVStore 都会将推送的值与已存储的值结合起来。这将在更新程序的帮助下完成。此处,默认更新程序是 ASSIGN。

def update(key, input, stored):
   print("update on key: %d" % key)
   
   stored += input * 2
kv.set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())

输出

当您执行上述代码时,您应该看到以下输出 -

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

例子

kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())

输出

下面给出的是代码的输出 -

update on key: 3
[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

Pull - 与 Push 一样,我们也可以通过一次调用将值拉到多个设备上,如下所示 -

b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

输出

输出如下:

[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

完整的实施示例

下面给出的是完整的实现示例 -

import mxnet as mx
kv = mx.kv.create('local')
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a) # pull out the value
print(a.asnumpy())
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
def update(key, input, stored):
   print("update on key: %d" % key)
   stored += input * 2
kv._set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

处理键值对

我们上面实现的所有操作都涉及单个键,但 KVStore 还提供了键值对列表的接口-

对于单个设备

以下示例显示了单个设备的键值对列表的 KVStore 接口 -

keys = [5, 7, 9]
kv.init(keys, [mx.nd.ones(shape)]*len(keys))
kv.push(keys, [mx.nd.ones(shape)]*len(keys))
b = [mx.nd.zeros(shape)]*len(keys)
kv.pull(keys, out = b)
print(b[1].asnumpy())

输出

您将收到以下输出 -

update on key: 5
update on key: 7
update on key: 9
[[3. 3. 3.]
 [3. 3. 3.]
 [3. 3. 3.]]

对于多设备

以下示例显示了多个设备的键值对列表的 KVStore 接口 -

b = [[mx.nd.ones(shape, ctx) for ctx in contexts]] * len(keys)
kv.push(keys, b)
kv.pull(keys, out = b)
print(b[1][1].asnumpy())

输出

您将看到以下输出 -

update on key: 5
update on key: 7
update on key: 9
[[11. 11. 11.]
 [11. 11. 11.]
 [11. 11. 11.]]

可视化包

可视化包是 Apache MXNet 包,用于将神经网络 (NN) 表示为由节点和边组成的计算图。

可视化神经网络

在下面的示例中,我们将使用mx.viz.plot_network来可视化神经网络。以下是这样做的先决条件 -

先决条件

  • Jupyter笔记本

  • Graphviz 库

实施例

在下面的示例中,我们将可视化用于线性矩阵分解的样本 NN -

import mxnet as mx
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')

# Set the dummy dimensions
k = 64
max_user = 100
max_item = 50

# The user feature lookup
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)

# The item feature lookup
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)

# predict by the inner product and then do sum
N_net = user * item
N_net = mx.symbol.sum_axis(data = N_net, axis = 1)
N_net = mx.symbol.Flatten(data = N_net)

# Defining the loss layer
N_net = mx.symbol.LinearRegressionOutput(data = N_net, label = score)

# Visualize the network
mx.viz.plot_network(N_net)