diff options
author | Carl Friedrich Bolz-Tereick <cfbolz@gmx.de> | 2021-02-26 11:58:51 +0100 |
---|---|---|
committer | Carl Friedrich Bolz-Tereick <cfbolz@gmx.de> | 2021-02-26 11:58:51 +0100 |
commit | 28fab560253e17f0e6f5b5f2c2f5d443cf50155f (patch) | |
tree | 5b2bec6be7c0f483a354b0c09150d167b444e345 /rpython | |
parent | add a random test for finding (diff) | |
download | pypy-28fab560253e17f0e6f5b5f2c2f5d443cf50155f.tar.gz pypy-28fab560253e17f0e6f5b5f2c2f5d443cf50155f.tar.bz2 pypy-28fab560253e17f0e6f5b5f2c2f5d443cf50155f.zip |
follow what cpython is doing more systematically:
add similar cases, stop using StringBuilder, make a correctly sized llstr
directly. needs a refactoring
Diffstat (limited to 'rpython')
-rw-r--r-- | rpython/rlib/rstring.py | 130 | ||||
-rw-r--r-- | rpython/rlib/test/test_rstring.py | 16 |
2 files changed, 117 insertions, 29 deletions
diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py index ed7a61d734..c77a364069 100644 --- a/rpython/rlib/rstring.py +++ b/rpython/rlib/rstring.py @@ -251,34 +251,10 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): builder.append_slice(input, upper, len(input)) replacements = upper + 1 - elif isinstance(sub, str) and len(sub) == 1: - # a copy of the code that is specialized for single (ascii) characters - sub = sub[0] - cnt = count(input, sub, 0, len(input)) - if cnt == 0: - return input, 0 - if maxsplit > 0 and cnt > maxsplit: - cnt = maxsplit - diff_len = len(by) - 1 - try: - result_size = ovfcheck(diff_len * cnt) - result_size = ovfcheck(result_size + len(input)) - except OverflowError: - raise - replacements = cnt - - builder = Builder(result_size) - start = 0 - while maxsplit != 0: - next = find(input, sub, start, len(input)) - if next < 0: - break - builder.append_slice(input, start, next) - builder.append(by) - start = next + 1 - maxsplit -= 1 # NB. if it's already < 0, it stays < 0 - - builder.append_slice(input, start, len(input)) + elif isinstance(input, str) and len(sub) == 1: + if len(by) == 1: + return replace_count_str_chr_chr(input, sub[0], by[0], maxsplit) + return replace_count_str_chr_str(input, sub[0], by, maxsplit) else: # First compute the exact result size @@ -286,6 +262,8 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): cnt = count(input, sub, 0, len(input)) if isinstance(input, str) and cnt == 0: return input, 0 + if isinstance(input, str): + return replace_count_str_str_str(input, sub, by, cnt, maxsplit) else: assert isutf8 from rpython.rlib import rutf8 @@ -330,6 +308,102 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): return builder.build(), replacements +def replace_count_str_chr_chr(input, c1, c2, maxsplit): + from rpython.rtyper.annlowlevel import llstr, hlstr + s = llstr(input) + length = len(s.chars) + start = find(input, c1, 0, len(input)) + if start < 0: + return input, 0 + newstr = s.malloc(length) + src = s.chars + dst = newstr.chars + s.copy_contents(s, newstr, 0, 0, len(input)) + dst[start] = c2 + count = 1 + start += 1 + maxsplit -= 1 + while maxsplit != 0: + next = find(input, c1, start, len(input)) + if next < 0: + break + dst[next] = c2 + start = next + 1 + maxsplit -= 1 + count += 1 + + return hlstr(newstr), count + +def replace_count_str_chr_str(input, sub, by, maxsplit): + from rpython.rtyper.annlowlevel import llstr, hlstr + cnt = count(input, sub, 0, len(input)) + if cnt == 0: + return input, 0 + if maxsplit > 0 and cnt > maxsplit: + cnt = maxsplit + diff_len = len(by) - 1 + try: + result_size = ovfcheck(diff_len * cnt) + result_size = ovfcheck(result_size + len(input)) + except OverflowError: + raise + + s = llstr(input) + by_ll = llstr(by) + + newstr = s.malloc(result_size) + dst = 0 + start = 0 + while maxsplit != 0: + next = find(input, sub, start, len(input)) + if next < 0: + break + s.copy_contents(s, newstr, start, dst, next - start) + dst += next - start + s.copy_contents(by_ll, newstr, 0, dst, len(by)) + dst += len(by) + + start = next + 1 + maxsplit -= 1 # NB. if it's already < 0, it stays < 0 + + s.copy_contents(s, newstr, start, dst, len(input) - start) + assert dst - start + len(input) == result_size + return hlstr(newstr), cnt + +def replace_count_str_str_str(input, sub, by, cnt, maxsplit): + from rpython.rtyper.annlowlevel import llstr, hlstr + if cnt > maxsplit and maxsplit > 0: + cnt = maxsplit + diff_len = len(by) - len(sub) + try: + result_size = ovfcheck(diff_len * cnt) + result_size = ovfcheck(result_size + len(input)) + except OverflowError: + raise + + s = llstr(input) + by_ll = llstr(by) + newstr = s.malloc(result_size) + sublen = len(sub) + bylen = len(by) + inputlen = len(input) + dst = 0 + start = 0 + while maxsplit != 0: + next = find(input, sub, start, inputlen) + if next < 0: + break + s.copy_contents(s, newstr, start, dst, next - start) + dst += next - start + s.copy_contents(by_ll, newstr, 0, dst, bylen) + dst += bylen + start = next + sublen + maxsplit -= 1 # NB. if it's already < 0, it stays < 0 + s.copy_contents(s, newstr, start, dst, len(input) - start) + assert dst - start + len(input) == result_size + return hlstr(newstr), cnt + + def _normalize_start_end(length, start, end): if start < 0: start += length diff --git a/rpython/rlib/test/test_rstring.py b/rpython/rlib/test/test_rstring.py index 66c23dbd52..a2ce14e0a8 100644 --- a/rpython/rlib/test/test_rstring.py +++ b/rpython/rlib/test/test_rstring.py @@ -6,7 +6,7 @@ from rpython.rlib.rstring import find, rfind, count, _search, SEARCH_COUNT, SEAR from rpython.rlib.buffer import StringBuffer from rpython.rtyper.test.tool import BaseRtypingTest -from hypothesis import given, strategies as st +from hypothesis import given, strategies as st, assume def test_split(): def check_split(value, sub, *args, **kwargs): @@ -326,3 +326,17 @@ def test_hypothesis_search(u, prefix, suffix): count = _search(s, u, 0, len(s), SEARCH_COUNT) assert count == s.count(u) assert 1 <= count + + +@given(st.text(), st.lists(st.text(), min_size=2), st.text(), st.integers(min_value=0, max_value=1000000)) +def test_hypothesis_search(needle, pieces, by, maxcount): + needle = needle.encode("utf-8") + pieces = [piece.encode("utf-8") for piece in pieces] + by = by.encode("utf-8") + input = needle.join(pieces) + assume(len(input) > 0) + + if needle == '' and pieces == [] and by == '0' and maxcount == 1: + import pdb; pdb.set_trace() + res = replace(input, needle, by, maxcount) + assert res == input.replace(needle, by, maxcount) |