深入理解NumPy数据类型:避免数组赋值时的数据溢出问题

admin 百科 18

深入理解NumPy数据类型:避免数组赋值时的数据溢出问题

在使用numpy数组时,若为数组指定了不合适的`dtype`(数据类型),尤其是范围过小的整数类型如`np.uint8`,可能导致超出该类型表示范围的数值发生溢出,从而在数组赋值后出现意外的“数据改变”。本文将深入探讨numpy数据类型溢出的原理,并通过实例代码展示如何避免这一常见陷阱,确保数据完整性。

在Python的数据科学生态中,NumPy库以其高效的数值计算能力而广受欢迎。然而,初学者在使用NumPy数组时,有时会遇到一个令人困惑的问题:当将一个数组的值赋给另一个新创建的数组时,新数组中的数据却与预期不符,甚至出现完全不同的数值。这并非NumPy的“bug”,而是对数据类型(dtype)理解不足导致的数据溢出问题。

深入理解NumPy数据类型:避免数组赋值时的数据溢出问题-第2张图片-佛山资讯网

问题的根源:NumPy数据类型溢出

NumPy数组在创建时,会为其中存储的元素指定一个数据类型,例如np.int8、np.uint8、np.int32、np.float64等。每种数据类型都有其特定的取值范围。当试图将一个超出该类型表示范围的数值存储进去时,就会发生数据溢出(overflow),导致数值被截断或“环绕”(wrap around),从而得到一个看似错误的结果。

以np.uint8为例,它代表无符号8位整数(unsigned 8-bit integer)。这意味着它只能存储0到255之间的整数值。任何小于0或大于255的数值在被强制转换为np.uint8时,都会发生溢出。例如:

  • 573 转换为 np.uint8:573 % 256 = 61
  • 1023 转换为 np.uint8:1023 % 256 = 255 (因为1023 = 3*256 + 255)

我们可以使用np.iinfo函数来查看整数数据类型的详细信息,包括其最小值和最大值:

import numpy as np

# 查看 np.uint8 的信息
print(np.iinfo(np.uint8))
# 输出: iinfo(min=0, max=255, dtype=uint8)

# 查看 np.int32 的信息
print(np.iinfo(np.int32))
# 输出: iinfo(min=-2147483648, max=2147483647, dtype=int32)

登录后复制

案例分析:不正确的dtype导致的数据改变

考虑以下场景,一个函数旨在对三维NumPy数组进行重新排序:

import numpy as np

def reorder_problematic(points):
    # 将输入数组重塑为 (4, 2)
    points = points.reshape((4, 2))

    # 创建空的输出数组,指定 dtype 为 np.uint8
    # 这里的 dtype 是问题的关键
    points_new = np.zeros((4, 1, 2), np.uint8) 

    # 根据和与差进行排序逻辑
    add = points.sum(1)
    diff = np.diff(points, axis=1)

    points_new[0] = points[np.argmin(add)]
    points_new[3] = points[np.argmax(add)]
    points_new[1] = points[np.argmin(diff)]
    points_new[2] = points[np.argmax(diff)]

    return points_new

# 示例输入数据
input_data = np.array([[[ 573,  148]], [[  25,  223]], [[ 153, 1023]], [[ 730,  863]]])
output_data = reorder_problematic(input_data)
print("原始输入数据:\n", input_data)
print("使用 problematic 函数的输出数据:\n", output_data)

登录后复制

运行上述代码,会得到如下输出:

原始输入数据:
 [[[ 573  148]]
 [[  25  223]]
 [[ 153 1023]]
 [[ 730  863]]]
使用 problematic 函数的输出数据:
 [[[ 25 223]]
 [[ 61 148]]
 [[153 255]]
 [[218  95]]]

登录后复制

可以看到,output_data中的数值与input_data完全不同。例如,573变成了61,1023变成了255。这正是由于points_new数组被声明为np.uint8类型,而input_data中的许多数值(如573, 1023, 730, 863)都超出了np.uint8的取值范围(0-255),导致在赋值时发生了溢出。

为了进一步验证,我们可以尝试将原始输入数据直接转换为np.uint8类型:

# 显式将输入数据转换为 np.uint8
test_cast = np.array([[[573, 148]],
                      [[25, 223]],
                      [[153, 1023]],
                      [[730, 863]]],
                     dtype=np.uint8)
print("输入数据强制转换为 np.uint8 后的结果:\n", test_cast)

登录后复制

输出结果与reorder_problematic函数的输出高度吻合:

输入数据强制转换为 np.uint8 后的结果:
 [[[ 61 148]]
 [[ 25 223]]
 [[153 255]]
 [[218  95]]]

登录后复制

为什么基于列表的方法“有效”?

在原问题中,作者还提供了一个基于Python列表的重写版本,它似乎“修复”了这个问题:

def reorder_by_lst(points):
    points = points.reshape((4, 2))
    add = points.sum(1)
    diff = np.diff(points, axis=1)

    a = points[np.argmin(add)]
    d = points[np.argmax(add)]
    b = points[np.argmin(diff)]
    c = points[np.argmax(diff)]

    lst = [a, b, c, d]
    # 这里将列表转换为 NumPy 数组,但没有指定 dtype
    return np.array(lst) 

output_data_lst = reorder_by_lst(input_data)
print("使用 list 函数的输出数据:\n", output_data_lst)

登录后复制

输出结果:

标签: python overflow 为什么

发布评论 0条评论)

还木有评论哦,快来抢沙发吧~