Skip to content

Files

Latest commit

c3011f5 · Jan 9, 2022

History

History
771 lines (527 loc) · 26.4 KB

File metadata and controls

771 lines (527 loc) · 26.4 KB

二、数组

当我的许多同事第一次开始使用 Python 时,他们惊讶地发现该语言没有内置的数组数据结构。Python 列表数据结构足够通用,可以处理许多编程场景,但是对于科学和数值编程来说,需要数组和矩阵。SciPy 和 NumPy 库中最基本的对象是数组数据结构。下面的截图显示了这一章的方向。

图 21:数字阵列演示

在 2.1 节中,您将学习创建和初始化 NumPy 数组的最常见方法,并了解 NumPy 支持的各种数字数据类型。

在 2.2 节中,您将学习如何使用where()函数、使用in关键字以及使用程序定义的函数在数组中搜索目标值。

在 2.3 节中,您将学习如何使用三种不同的内置排序算法(快速排序、合并排序和堆排序)对 NumPy 数组进行排序。您还将了解 NumPy 数组引用机制。

在 2.4 节中,您将学习如何使用 NumPy shuffle()函数随机化数组,以及如何使用程序定义的函数和 Fisher-Yates 算法随机化数组。

NumPy 库中最基本的对象是数组数据结构。NumPy 数组类似于其他编程语言中的数组。数组有固定的大小,每个单元格必须包含相同的类型。NumPy 库有几个函数可以让你创建一个数组。

看看代码清单 4 中的演示程序。在两个初步的print语句之后,程序执行开始于使用硬编码的数值创建一个数组:

arr = np.array([1., 3., 5., 7., 9.]) dt = np.dtype(arr[0]) print "Cell element type is " + str(dt.name) # displays 'float64'

NumPy array()函数接受一个 Python 列表(如方括号所示),并返回一个包含列表值的数组。注意小数点。这些告诉解释器将单元格值转换为float64,这是数组的默认浮点数据类型。如果没有小数,解释器会将值转换为int32,数组的默认整数类型。

代码清单 4:初始化数值数组

  # arrays.py
  # Python 2.7

  import numpy as np

  def >my_print(arr, cols, dec):
    n = len(arr) # print array using indexing
    fmt = "%." + str(dec) + "f" # like %.4f
    for i in xrange(n):  # alt: for x in arr
      if i > 0 and i % cols == 0:
        print ""
      print fmt % arr[i],
    print "\n"

  # =====

  print "\nBegin array demo \n"

  print "Creating array arr using np.array() and list with
  hard-coded values "
  arr = np.array([1., 3., 5., 7., 9.]) # float64
  dt = np.dtype(arr[0])
  print "Cell element type is " + str(dt.name)
  print ""

  print "Printing array arr using built-in print() "
  print arr
  print ""

  print "Creating int array arr using np.arange(9) "
  arr = np.arange(9) # [0, 1, . . 8] # int32
  print "Printing array arr using built-in print() "
  print arr
  print ""

  cols = 4; dec = 0
  print "Printing array arr using my_print() with cols=" + str(cols),
  print "and dec=" + str(dec)
  my_print(arr, cols, dec)

  print "Creating array arr using np.zeros(5) "
  arr = np.zeros(5)
  print "Printing array arr using built-in print() "
  print arr
  print ""

  print "Creating array arr using  np.linspace(2., 5., 6)"
  arr = np.linspace(2., 5., 6) # 6 values from [2.0 to 5.0] inc.
  print "Printing array arr using built-in print() "
  print arr
  print ""

  print "\nEnd demo \n"
  C:\SciPy\Ch2>
  python arrays.py

  Begin
  array demo 

  Creating
  array arr using np.array() and list with hard-coded values 
  Cell
  element type is float64

  Printing
  array arr using built-in print() 
  [
  1.  3.  5.  7.  9.]

  Creating
  int array arr using np.arange(9) 
  Printing
  array arr using built-in print() 
  [0
  1 2 3 4 5 6 7 8]

  Printing
  array arr using my_print() with cols=4 and dec=0
  0
  1 2 3 
  4
  5 6 7 
  8

  Creating
  array arr using np.zeros(5) 
  Printing
  array arr using built-in print() 
  [
  0.  0.  0.  0.  0.]

  Creating
  array arr using  np.linspace(2., 5., 6)
  Printing
  array arr using built-in print() 
  [
  2.   2.6  3.2  3.8  4.4  5\. ]

  End demo

如果您正在创建一个数组,并且float64int32都不合适,您可以将数据类型显式化。例如:

arr = np.array([2.0, 4.0, 6.0], dtype=np.float16)

NumPy 有四种浮点数据类型:float_float16float32float64。默认的浮点类型是float64:11 位指数和 52 位尾数的有符号值。NumPy 还支持复数。

NumPy 有 11 种整数数据类型,包括int32int64uint64。默认整数数据类型为int32(即可能值介于-2,147,483,648 和+2,147,483,647 之间的带符号 32 位整数)。

您也可以使用字符串值和布尔值创建数组。例如:

arr_b = np.array([True, False, True]) arr_s = np.array(["ant", "bat", "cow"])

创建数组后,演示使用内置的print语句显示数组值:

print "Printing array arr using built-in print() " print arr

Python 2.7 print语句在大多数情况下显示 NumPy 数组都是简单有效的。如果需要自定义输出格式,可以使用 NumPy set_printoptions()功能或者编写程序自定义显示功能。

接下来,演示程序使用 NumPy arange()函数创建并初始化一个数组:

arr = np.arange(9)``print "Printing array arr using built-in print() " print arr    # displays [0 1 2 3 4 5 6 7 8]

arange(n)的调用返回一个int32数组,其顺序值为 0,1,2,… (n-1)。请注意,NumPy arange()函数(名称代表数组-范围,不是单词array的拼写错误)与 Python range()函数(返回整数值列表)和 Python xrange()函数(返回可用于遍历列表或数组的迭代器对象)有很大不同。

接下来,演示程序使用名为my_print()的程序定义函数显示由arange()函数生成的数组:

cols = 4; dec = 0 print "Printing array arr using my_print() with cols=" + str(cols), print "and dec=" + str(dec) my_print(arr, cols, dec)

自定义函数使用指定的小数位数(此处为 0,因为值是整数)在指定的列数(此处为 4)中显示一个数组。

如果您是 Python 新手,您可能会对第一个print语句后的尾随逗号字符感到困惑。该语法用于打印时不带换行符,类似于 C# Console.Write()方法(与WriteLine()方法相反)或 Java System.out.print()方法(与println()方法相反)。

程序定义功能my_print()实现如下:

def my_print(arr, cols, dec):   n = len(arr)   fmt = "%." + str(dec) + "f"  # like %.4f   for i in xrange(n):     if i > 0 and i % cols == 0:       print ""     print fmt % arr[i],   print "\n"

该函数首先使用 Python len()函数查找数组中的单元格数量。另一种方法是使用更高效的 NumPy size属性:

n = arr.size

注意size后面没有括号,因为是属性,不是函数。my_print()函数使用传统的数组索引遍历数组:

for i in xrange(n):

使用这种技术,数组arr中的一个单元值作为arr[i]被访问。另一种方法是迭代数组,如下所示:

for x in arr

这里,x是单元格值。这种技术类似于在其他语言(如 C#)中使用“for-each”循环。在大多数情况下,我更喜欢使用数组索引而不是“for-eaking”,但是我的大多数同事更喜欢“for x in arr”语法。

接下来,演示程序使用 NumPy zeros()函数创建一个数组:

arr = np.zeros(5) print "Printing array arr using built-in print() " print arr

根据我的经验,使用zeros()函数可能是创建 NumPy 数组最常见的方法。顾名思义,调用zeros(n)会创建一个包含n单元格的数组,并将每个单元格初始化为 0.0 值。默认的元素值是float64,所以如果你想要一个初始化为 0 值的整数数组,你必须给zeros()提供dtype参数值,如下所示:

arr = np.zeros(5, dtype=np.int32)

一个密切相关的 NumPy 函数是ones(),它将一个数组初始化为所有 1.0(或整数 1)的值。

演示最后使用 NumPy linspace()函数创建并初始化了一个数组:

arr = np.linspace(2., 5., 6) print "Printing array arr using built-in print() " print arr

linspace(start, stop, num)的调用返回一个数组,该数组包含num个单元格,其值在startstop之间均匀分布,包括两者。演示调用np.linspace(2., 5., 6)返回六个float64值的数组,从 2.0 开始,到 5.0 结束(2.0、2.6、3.2、3.8、4.4 和 5.0)。

请注意,几乎所有接受启动和停止参数的 Python 和 NumPy 函数都在[start,stop]中返回值,即在启动包含和停止排除之间。NumPy linspace()功能是一个例外。

还有很多其他可以创建数组的 NumPy 函数,但是array()zeros() 函数可以处理大多数编程场景。您可以使用程序定义的函数创建专门的数组。例如,假设您需要创建第一个 n 个奇数的数组。您可以定义:

def my_odds(n):   result = np.zeros(n, dtype=np.int32)   v = 1   for i in xrange(n):     result[i] = v     v += 2   return result

然后您可以通过调用创建一个包含前四个奇数的数组:

arr = my_odds(4)

与创建 NumPy 数组密切相关的一项任务是复制现有数组。NumPy copy()功能可以做到这一点,本电子书后面会详细介绍。

资源

有关 NumPy 数字数据类型的更多详细信息,请参见

有关 NumPy 数值数组初始化函数的更多信息,请参见

有关 NumPy 数组迭代的更多详细信息,请参见

在数值数组中搜索某个目标值是一项常见的任务。有三种基本方法可以搜索 NumPy 数组。可以使用 NumPy where()函数,可以使用 Python in关键字,也可以使用程序定义的搜索函数。有趣的是,没有哪种 NumPy index()函数像 C#和 Java 等几种编程语言中的函数一样。

代码清单 5:数组搜索演示

  # searching.py
  # Python 2.7

  import numpy as np

  def >my_search(a, x, eps):
    for i in xrange(len(a)):
      if (np.isclose(a[i], x, eps)):
         return i
    return -1

  # =====

  print "\nBegin array search demo \n"

  arr = np.array([7.0, 9.0, 5.0, 1.0, 5.0, 8.0])

  print "Array arr is "
  print arr
  print ""

  target = 5.0
  print "Target value is "
  print target
  print ""

  found = target in arr
  print "Search result using 'target in arr' syntax = " + str(found)
  print ""

  result = np.where(arr == target)
  print "Search result using np.where(arr == target) is "
  print result
  print ""

  idx = my_search(arr, target, 1.0e-6)
  print "Search result using my_search() = "
  print idx
  print ""

  print "\nEnd demo \n"
  C:\SciPy\Ch2>
  python searching.py

  Begin
  array search demo 

  Array
  arr is 
  [
  7.  9.  5.  1.  5.  8.]

  Target
  value is 
  5.0

  Search
  result using 'target in arr' syntax = True

  Search
  result using np.where(arr == target) is 
  (array([2,
  4], dtype=int64),)

  Search
  result using my_search() = 
  2

  End
  demo 

演示程序首先创建一个数组和一个要搜索的目标值:

arr = np.array([7.0, 9.0, 5.0, 1.0, 5.0, 8.0]) print "Array arr is " print arr

target = 5.0 print "Target value is " print target

接下来,演示程序使用 Python in关键字在数组中搜索目标值:

found = target in arr print "Search result using 'target in arr' syntax = " + str(found) # 'True'

调用target in arr的返回结果是布尔型的,要么是True要么是False。又好又简单。然而,使用这种语法来搜索浮点值数组并不是一个好主意。问题是比较两个浮点值是否完全相等非常棘手。例如:

`>>> x = 0.15 + 0.15

y = 0.20 + 0.10 'yes' if x == y else 'no' 'no'

what the heck?!`

当比较两个浮点值是否相等时,通常不应该比较精确相等;相反,您应该检查这两个值是否非常非常接近。

存储在内存中的浮点值有时只是接近真实值,因此比较两个浮点值是否完全相等可能会产生意想不到的结果。target in arr语法没有给你任何直接的方法来控制目标值与数组中的值的比较。请注意,整数数组(或字符串数组或布尔数组)不存在检查精确相等的问题,因此target in arr语法适用于这些数组。

target in arr语法在演示程序中确实工作正常,返回了正确的True结果。接下来,演示程序使用数字 T2 功能进行搜索:

result = np.where(arr == target) print "Search result using np.where(arr == target) is " print result

有点棘手的返回结果是:

(array([2, 4], dtype=int64,)

where()函数返回包含数组的元组(如括号所示)。该数组保存搜索到的数组中出现目标值的索引,在这种情况下是单元格 2 和 4。如果搜索不在数组中的目标值,返回的结果是一个数组长度为 0:

(array([], dtype=int64),)

因此,如果您只想知道目标值是否在数组中,您可以沿着以下行检查返回值:

if len(result[0]) == 0:   print "target not found in array" else:   print "target is in array"

与使用in关键字搜索的情况一样,不建议使用where()函数搜索浮点值数组,因为您无法控制单元格值与目标值的比较方式。但是对整数、字符串和布尔数组使用where()函数是安全有效的。

接下来,演示使用程序定义的函数搜索数组:

idx = my_search(arr, target, 1.0e-6) print "Search result using my_search() = " print idx

如果在数组中找不到目标值,程序定义的my_search()函数返回-1,如果目标在数组中,返回目标第一次出现的单元格索引。在这种情况下,返回值为 2,因为目标值 5.0 位于数组的单元格[2]和[4]中。第三个参数 1.0e-6 是定义两个浮点值必须有多接近才能被认为相等的容差。

功能my_search()定义为:

def my_search(a, x, eps):   for i in xrange(len(a)):     if np.isclose(a[i], x, eps):        return i   return -1

NumPy isclose()函数比较两个值,如果值在彼此的 eps 范围内(这代表ε,数学中常用的希腊字母,代表一个小值),则返回True

不使用isclose()函数,您可以直接使用 Python 内置的abs()函数或 NumPy fabs()函数进行比较,如下所示:

if abs(a[i] - x) < eps:``  return i

if np.fabs(a[i] - x) < eps:   return i

总而言之,要搜索浮点值的数组,我建议使用程序定义的函数,它允许您控制两个值的相等比较。对于整数、字符串和布尔数组,可以使用in关键字、where()函数或程序定义的函数。

在某些情况下,您可能希望找到目标值在数组中最后一次出现的位置。使用带有整数、字符串或布尔数组的where()函数,您可以编写如下代码:

result = np.where(arr == target) if len(result[0]) == 0:   print "-1" # not found else:   print result[0][len(result[0])-1] # last idx

要在程序定义的函数中找到目标值的最后一次出现,可以使用for i in xrange(len(a)-1, -1, -1):循环从后向前遍历数组。

资源

有关 NumPy where()功能的更多信息,请参见

关于 NumPy 如何在内存中存储数组的技术细节,请参见

有关 Python 内置函数的列表,如绝对值,请参见 【https://docs.python.org/2/library/functions.html】

将数组中的值按顺序排列是一项常见且基本的编程任务。NumPy 库有一个sort()函数,用于实现三种不同的排序算法:快速排序算法、合并排序算法和堆排序算法。

有趣的是,与大多数其他语言中的排序函数不同,对sort(arr)的调用会返回一个已排序的数组,而原始数组arr保持不变。许多编程语言中的排序函数会对数组参数进行就地排序,并且不会返回新的排序数组。但是,如果您愿意,您可以使用调用arr.sort()将 NumPy 数组arr排序到位。

代码清单 6:数组排序演示

  # sorting.py
  # Python 2.7

  import numpy as np
  import time

  def >my_qsort(a):
    quick_sorter(a, 0, len(a)-1)

  def >quick_sorter(a, lo, hi):
    if lo < hi:
      p = partition(a, lo, hi)
      quick_sorter(a, lo, p-1)
      quick_sorter(a, p+1, hi)

  def >partition(a, lo, hi):
    piv = a[hi]
    i = lo
    for j in xrange(lo, hi):
      if a[j] <= piv:
        a[i], a[j] = a[j], a[i]
        i += 1
    a[i], a[hi] = a[hi], a[i]
    return i

  # =====

  print "\nBegin array sorting demo \n"

  arr = np.array([4.0, 3.0, 0.0, 2.0, 1.0, 9.0, 7.0, 6.0, 5.0])
  print "Original array is "
  print arr
  print ""

  s_arr = np.sort(arr, kind='quicksort')
  print "Return result of sorting using np.sort(arr, 'quicksort')
  is "
  print s_arr
  print ""
  print "Original array after calling np.sort() is "
  print arr
  print ""

  print "Calling my_qsort(arr) "
  start_time = time.clock() # record starting time
  my_qsort(arr)
  end_time = time.clock()
  elapsed_time = end_time - start_time

  print "Elapsed time = "
  print str(elapsed_time) + " seconds"
  print ""

  print "Original array after calling my_qsort(arr) is "
  print arr
  print ""

  print "\nEnd demo \n"
  C:\SciPy\Ch2>
  python sorting.py

  Begin
  array sorting demo 

  Original
  array is 
  [
  4.  3.  0.  2\.  1.  9.  7.  6.  5.]

  Return
  result of sorting using np.sort(arr, 'quicksort') is 
  [
  0.  1.  2.  3.  4.  5.  6.  7.  9.]

  Original
  array after calling np.sort() is 
  [
  4.  3.  0.  2.  1.  9.  7.  6.  5.]

  Calling
  my_qsort(arr) 
  Elapsed
  time = 
  3.6342481559e-05
  seconds

  Original
  array after calling my_qsort(arr) is 
  [
  0.  1.  2.  3.  4.  5.  6.  7.  9.]

  End demo

演示程序首先创建一个数组:

arr = np.array([4.0, 3.0, 0.0, 2.0, 1.0, 9.0, 7.0, 6.0, 5.0]) print "Original array is " print arr

接下来,演示程序使用 NumPy sort()函数对数组进行排序:

s_arr = np.sort(arr, kind='quicksort')

sort()函数返回一个新的数组,数组中的值按顺序排列,保持原始数组不变:

print "Return result of sorting using np.sort(arr, 'quicksort') is " print s_arr   # in order print "Original array after calling np.sort() is " print arr     # unchanged

可以使用稍微不同的语法或调用模式对数组进行就地排序:

arr.sort(kind='quicksort') print arr     # arr is sorted

arr = np.sort(arr, kind='quicksort') print arr     # arr is sorted

默认为快速排序算法,因此对sort()的调用可能是这样写的:

s_arr = np.sort(arr)

sort()提供一个明确的kind参数作为一种文档形式是有用的,尤其是在其他人将使用你的代码的情况下。

另外两种排序算法可以这样称呼:

s_arr = np.sort(arr, kind='mergesort') s_arr = np.sort(arr, kind='heapsort')

arr.sort(kind='mergesort') arr.sort(kind='heapsort')

默认情况下,sort()函数从低值到高值对数组元素进行排序。如果要对数组进行从高值到低值的排序,不能直接这样做,但排序后可以使用 Python 切片语法进行反向(数组没有明确的reverse()函数):

arr = np.array([4.0, 8.0, 6.0, 5.0]) s_arr = np.sort(arr, kind='quicksort')  # s_arr = [4.0 5.0 6.0 8.0] r_arr = s_arr[::-1]                     # r_arr = [8.0 6.0 5.0 4.0]

arr = np.array([4.0, 8.0, 6.0, 5.0]) orig = np.copy(arr)                  # make a copy of original arr[::-1].sort(kind='quicksort')     # reverse sort arr in-place r_arr = np.copy(arr)                 # copy reversed to r_arr if you wish arr = np.copy(orig)                  # restore arr to original if you wish

注意sort()功能有一个可选的order参数。但是,当数组中的单元格包含具有多个字段的对象时,此参数控制字段的比较顺序。因此order不控制升序或降序排序行为。

演示程序最后使用程序定义的my_qsort()函数对数组进行排序:

`start_time = time.clock() my_qsort(arr) end_time = time.clock()

elapsed_time = end_time - start_time print "Elapsed time = " print str(elapsed_time) + " seconds"

print "Original array after callng my_qsort(arr) is " print arr    # arr is sorted`

程序定义的my_qsort()函数对其数组参数进行适当排序。该演示通过用time.clock()函数调用包装其调用来测量my_qsort()使用的大致时间量。请注意,演示程序在源代码顶部有一个import time语句,用于将clock()函数带入范围。

使用像 NumPy 这样的库的全部意义在于,您可以使用像sort()这样的内置函数,因此您不必编写程序定义的函数。然而,在某些情况下,编写一个自定义版本的 NumPy 函数是有用的。特别是,您可以自定义程序定义函数的行为,通常会牺牲额外的时间(编写函数)和性能。

快速排序算法的核心是partition()函数。快速排序和分区如何工作的详细解释不在本电子书的范围内,但是任何快速排序实现的行为都取决于如何选择所谓的枢轴值。自定义partition()功能中的关键代码行是:

piv = a[hi]

透视值被选为当前正在处理的子阵列中的最后一个单元值。备选方案是选择第一个单元格值(piv = a[lo])、中间单元格值或在a[lo]a[hi]之间随机选择的单元格值。

资源

有关 NumPy sort()功能的更多信息,请参见

本节中程序定义的快速排序功能基于维基百科中 【https://en.wikipedia.org/wiki/Quicksort】的文章。

有关使用 Python time模块的更多信息,请参见 【https://docs.python.org/2/library/time.html】

有关 NumPy 数组切片的更多信息,请参见

数据科学编程中一个令人惊讶的常见任务是洗牌。洗牌是将数组中的单元格值重新排列成随机的顺序。你可以把洗牌看作排序的对立面。您可以使用内置的 NumPy random.shuffle()函数或通过编写程序定义的洗牌函数来洗牌。

代码清单 7:数组洗牌演示

  # shuffling.py
  # Python 2.7

  import numpy as np

  def >my_shuffle(a, seed):
    # shuffle array a in place using Fisher-Yates algorithm
    np.random.seed(seed)
    n = len(a)
    for i in xrange(n):
      r = np.random.randint(i, n)
      a[r], a[i] = a[i], a[r]
    return

  # =====

  arr = np.arange(10, dtype=np.int64) # [0, 1, 2, .. 9]
  orig = np.copy(arr)
  print "Array arr is "
  print arr
  print ""

  np.random.shuffle(arr)
  print "Array arr after a call to np.random.shuffle(arr) is
  "
  print arr
  print ""

  print "Resetting array arr"
  arr = np.copy(orig)
  print "Array arr is "
  print arr
  print ""

  my_shuffle(arr, seed=0)
  print "Array arr after call to my_shuffle(arr, seed=0) is
  "
  print arr
  print ""

  print "\nEnd demo \n" 
  C:\SciPy\Ch2>
  python shuffling.py

  Array
  arr is 
  [0
  1 2 3 4 5 6 7 8 9]

  Array
  arr after a call to np.random.shuffle(arr) is 
  [1
  9 8 3 0 2 7 5 6 4]

  Resetting
  array arr
  Array
  arr is 
  [0
  1 2 3 4 5 6 7 8 9]

  Array
  arr after call to my_shuffle(arr, seed=0) is 
  [5
  1 0 6 7 3 9 8 4 2]

  End
  demo

演示程序首先使用arange()函数创建一个包含 10 个值(0 到 9)的有序整数数组,然后使用copy()函数复制该数组:

arr = np.arange(10, dtype=np.int64) orig = np.copy(arr) print "Array arr is " print arr

接下来,演示使用 NumPy random.shuffle()函数对数组的内容进行洗牌:

np.random.shuffle(arr) print "Array arr after a call to np.random.shuffle(arr) is " print arr

random.shuffle()函数将数组参数的内容按随机顺序重新排序。在这个例子中,底层随机数生成器的种子值没有被设置,所以如果你再次运行这个程序,你几乎肯定会得到一个不同的数组排序。如果您希望程序结果具有可重复性(通常是这种情况),您可以像这样显式设置底层种子值:

np.random.seed(0) np.random.shuffle(arr)

这里种子被任意设置为 0。接下来,演示程序使用副本将数组重置为其原始值:

print "Resetting array arr" arr = np.copy(orig) print "Array arr is " print arr

使用赋值运算符代替copy()函数来复制原始数组是一个错误。例如,假设您编写了以下代码:

arr = np.arange(10, dtype=np.int64)orig = arr   # assignment is probably not what you intended print "Array arr is " print arr

因为数组赋值是通过引用而不是通过值工作的,所以origarr本质上都是指向内存中同一个数组的指针。对arr所做的任何改变,例如对random.shuffle(arr)的呼叫,也隐含着对orig的影响。因此,在呼叫random.shuffle()后尝试重置arr将没有效果。

NumPy 数组作为引用对象的另一个重要后果是,带有数组参数的函数可以就地修改数组。也可以使用view()函数创建对数组的引用,例如arr_v = arr.view()创建arr的引用副本。

演示程序最后使用程序定义的函数my_shuffle()对数组进行洗牌:

my_shuffle(arr, seed=0) print "Array arr after call to my_shuffle(arr, seed=0) is " print arr

功能my_shuffle()定义为:

def my_shuffle(a, seed):  np.random.seed(seed)   n = len(a)   for i in xrange(n):     r = np.random.randint(i, n)     a[r], a[i] = a[i], a[r]   return

将一个数组打乱成一个随机的顺序是非常棘手的,而且很容易写出错误的代码。函数my_shuffle()使用所谓的费希尔-耶茨算法,这是大多数情况下最好的方法。请注意,该函数使用非常方便的a,b = b,a Python 习惯用法来交换两个值。另一种方法是使用标准的tmp=a; a=b; b=tmp习语,这是其他编程语言所需要的。

资源

关于 NumPy random.shuffle()功能的更多详细信息,请参见

关于设置随机种子的更多信息,请参见

关于费希尔-耶茨洗牌算法的有趣解释,见 https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle