Python字符串比较不会短路?

通常的说法是,在检查密码或哈希等内容时,必须在恒定时间内进行字符串比较,因此建议避免使用a == b. 但是,我运行了以下脚本,结果不支持a==b第一个不相同字符短路的假设。


from time import perf_counter_ns

import random


def timed_cmp(a, b):

    start = perf_counter_ns()

    a == b

    end = perf_counter_ns()

    return end - start


def n_timed_cmp(n, a, b):

    "average time for a==b done n times"

    ts = [timed_cmp(a, b) for _ in range(n)]

    return sum(ts) / len(ts)


def check_cmp_time():

    random.seed(123)

    # generate a random string of n characters

    n = 2 ** 8

    s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])


    # generate a list of strings, which all differs from the original string

    # by one character, at a different position

    # only do that for the first 50 char, it's enough to get data

    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]


    timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]

    sorted_timed = sorted(timed, key=lambda t: t[1])


    # print the 10 fastest

    for x in sorted_timed[:10]:

        i, t = x

        print("{}\t{:3f}".format(i, t))


    print("---")

    i, t = timed[0]

    print("{}\t{:3f}".format(i, t))


    i, t = timed[1]

    print("{}\t{:3f}".format(i, t))


if __name__ == "__main__":

    check_cmp_time()


这是运行的结果,重新运行脚本给出的结果略有不同,但都不令人满意。


# ran with cpython 3.8.3


6   78.051700

1   78.203200

15  78.222700

14  78.384800

11  78.396300

12  78.441800

9   78.476900

13  78.519000

8   78.586200

3   78.631500

---

0   80.691100

1   78.203200

我原以为最快的比较是第一个不同字符位于字符串开头的位置,但这不是我得到的。知道发生了什么事吗???


qq_笑_17
浏览 76回答 2
2回答

30秒到达战场

有区别,您只是在这么小的弦上看不到它。这是一个适用于您的代码的小补丁,所以我使用更长的字符串,我通过将 A 放在一个位置来进行 10 次检查,从头到尾在原始字符串中均匀分布,我的意思是,像这样:A_____________________________________________________________________A_____________________________________________________________________A_____________________________________________________________________A_____________________________________________________________________A_____________________________________________________________________A_____________________________________________________________________A_____________________________________________________________________A_____________________________________________________________________A_____________________________________________________________________A_____________________________________________________________________A___@@ -15,13 +15,13 @@ def n_timed_cmp(n, a, b): def check_cmp_time():     random.seed(123)     # generate a random string of n characters-    n = 2 ** 8+    n = 2 ** 16     s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])     # generate a list of strings, which all differs from the original string     # by one character, at a different position     # only do that for the first 50 char, it's enough to get data-    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]+    diffs = [s[:i] + "A" + s[i+1:] for i in range(0, n, n // 10)]     timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]     sorted_timed = sorted(timed, key=lambda t: t[1])你会得到:0   122.6210001   213.4657002   380.2141003   460.4220005   694.2787004   722.0100007   894.6303006   1020.7221009   1149.4730008   1341.754500---0   122.6210001   213.465700请注意,在您的示例中,只有2**8字符,它已经很明显,请应用此补丁:@@ -21,7 +21,7 @@ def check_cmp_time():     # generate a list of strings, which all differs from the original string     # by one character, at a different position     # only do that for the first 50 char, it's enough to get data-    diffs = [s[:i] + "A" + s[i+1:] for i in range(min(50, n))]+    diffs = [s[:i] + "A" + s[i+1:] for i in [0, n - 1]]      timed = [(i, n_timed_cmp(10000, s, d)) for (i, d) in enumerate(diffs)]     sorted_timed = sorted(timed, key=lambda t: t[1])只保留两种极端情况(第一个字母变化与最后一个字母变化),你会得到:$ python3 cmp.py0   124.1318001   135.566000数字可能会有所不同,但大多数时候 test0比 test 快一点1。为了更精确地隔离修改了哪个字符,只要 memcmp 一个字符一个字符地执行它就可以,只要它不使用整数比较,通常是在最后一个字符未对齐时,或者在非常短的字符串上,比如8 个字符的字符串,正如我在这里演示的那样:from time import perf_counter_nsfrom statistics import medianimport randomdef check_cmp_time():    random.seed(123)    # generate a random string of n characters    n = 8    s = "".join([chr(random.randint(ord("a"), ord("z"))) for _ in range(n)])    # generate a list of strings, which all differs from the original string    # by one character, at a different position    # only do that for the first 50 char, it's enough to get data    diffs = [s[:i] + "A" + s[i + 1 :] for i in range(n)]    values = {x: [] for x in range(n)}    for _ in range(10_000_000):        for i, diff in enumerate(diffs):            start = perf_counter_ns()            s == diff            values[i].append(perf_counter_ns() - start)    timed = [[k, median(v)] for k, v in values.items()]    sorted_timed = sorted(timed, key=lambda t: t[1])    # print the 10 fastest    for x in sorted_timed[:10]:        i, t = x        print("{}\t{:3f}".format(i, t))    print("---")    i, t = timed[0]    print("{}\t{:3f}".format(i, t))    i, t = timed[1]    print("{}\t{:3f}".format(i, t))if __name__ == "__main__":    check_cmp_time()这给了我:1   221.0000002   222.0000003   223.0000004   223.0000005   223.0000006   223.0000007   223.0000000   241.000000差异是如此之小,Python 和 perf_counter_ns 可能不再是这里的正确工具。

扬帆大鱼

看,要知道它为什么不短路,您必须进行一些挖掘。简单的答案当然是它不会短路,因为标准没有这样规定。但是您可能会想,“为什么实现不选择短路?当然,它必须更快!”。不完全的。出于显而易见的原因,让我们来看看cpython。查看中定义的函数的代码unicode_compare_equnicodeobject.cstatic intunicode_compare_eq(PyObject *str1, PyObject *str2){&nbsp; &nbsp; int kind;&nbsp; &nbsp; void *data1, *data2;&nbsp; &nbsp; Py_ssize_t len;&nbsp; &nbsp; int cmp;&nbsp; &nbsp; len = PyUnicode_GET_LENGTH(str1);&nbsp; &nbsp; if (PyUnicode_GET_LENGTH(str2) != len)&nbsp; &nbsp; &nbsp; &nbsp; return 0;&nbsp; &nbsp; kind = PyUnicode_KIND(str1);&nbsp; &nbsp; if (PyUnicode_KIND(str2) != kind)&nbsp; &nbsp; &nbsp; &nbsp; return 0;&nbsp; &nbsp; data1 = PyUnicode_DATA(str1);&nbsp; &nbsp; data2 = PyUnicode_DATA(str2);&nbsp; &nbsp; cmp = memcmp(data1, data2, len * kind);&nbsp; &nbsp; return (cmp == 0);}(注意:这个函数实际上是在推导之后调用的,str1并且str2不是同一个对象 - 如果它们是 - 那么这只是一个简单的True立即)特别关注这一行-cmp = memcmp(data1, data2, len * kind);啊,我们又回到了另一个十字路口。是否memcmp短路?C标准没有规定这样的要求。如opengroup 文档和C 标准草案的第 7.24.4.1 节所示7.24.4.1 memcmp 函数概要#include <string.h>int memcmp(const void *s1, const void *s2, size_t n);描述memcmp 函数将 s1 指向的对象的前 n 个字符与 s2 指向的对象的前 n 个字符进行比较。退货memcmp 函数返回一个大于、等于或小于零的整数,对应于 s1 指向的对象大于、等于或小于 s2 指向的对象。大多数C 实现(包括glibc)选择不短路。但为什么?我们是不是漏掉了什么,你为什么不短路?因为他们使用的比较可能不像逐字节检查那样天真。该标准不要求逐字节比较对象。这就是优化的机会。它的作用glibc是比较类型的元素,unsigned long int而不仅仅是unsigned char. 检查实施幕后还有很多事情要做——讨论远远超出了这个问题的范围,毕竟这甚至没有被标记为问题C;)。虽然我发现这个答案可能值得一看。但要知道,优化就在那里,只是形式与乍一看可能想到的方法大不相同。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python