5

I recently implemented Karatsuba Multiplication as a personal exercise. I wrote my implementation in Python following the pseudocode provided on wikipedia:

procedure karatsuba(num1, num2)
if (num1 < 10) or (num2 < 10)
    return num1*num2
  /* calculates the size of the numbers */
  m = max(size_base10(num1), size_base10(num2))
  m2 = m/2
  /* split the digit sequences about the middle */
  high1, low1 = split_at(num1, m2)
  high2, low2 = split_at(num2, m2)
  /* 3 calls made to numbers approximately half the size */
  z0 = karatsuba(low1, low2)
  z1 = karatsuba((low1+high1), (low2+high2))
  z2 = karatsuba(high1, high2)
  return (z2*10^(2*m2)) + ((z1-z2-z0)*10^(m2)) + (z0)

Here is my python implementation:

def karat(x,y):
    if len(str(x)) == 1 or len(str(y)) == 1:
        return x*y
    else:
        m = max(len(str(x)),len(str(y)))
        m2 = m / 2

        a = x / 10**(m2)
        b = x % 10**(m2)
        c = y / 10**(m2)
        d = y % 10**(m2)

        z0 = karat(b,d)
        z1 = karat((a+b),(c+d))
        z2 = karat(a,c)

        return (z2 * 10**(2*m2)) + ((z1 - z2 - z0) * 10**(m2)) + (z0)

My question is about final merge of z0, z1, and z2.
z2 is shifted m digits over (where m is the length of the largest of two multiplied numbers).
Instead of simply multiplying by 10^(m), the algorithm uses *10^(2*m2)* where m2 is m/2.

I tried replacing 2*m2 with m and got incorrect results. I think this has to do with how the numbers are split but I'm not really sure what's going on.

1
  • why are you performing the same 10**(m2) 6 times at each recursion level, plus calling len(str(x)) and len(str(y)) twice whenever it's still going down recursion levels ? also, since you've already spent the processing cost of making x and y into numeric strings, why not simply use string slicing to extract out the digits for a, b, c, d instead performing any sort of math ? Commented Apr 6 at 2:41

9 Answers 9

13

Depending on your Python version you must or should replace / with the explicit floor division operator // which is the appropriate here; it rounds down ensuring that your exponents remain entire numbers.

This is essential for example when splitting your operands in high digits (by floor dividing by 10^m2) and low digits (by taking the residual modulo 10^m2) this would not work with a fractional m2.

It also explains why 2 * (x // 2) does not necessarily equal x but rather x-1 if x is odd. In the last line of the algorithm 2 m2 is correct because what you are doing is giving a and c their zeros back.

If you are on an older Python version your code may still work because / used to be interpreted as floor division when applied to integers.

def karat(x, y):
    if len(str(x)) == 1 or len(str(y)) == 1:
        return x * y
    m = max(len(str(x)), len(str(y)))
    m2 = m // 2

    a = x // 10 ** (m2)
    b = x % 10 ** (m2)
    c = y // 10 ** (m2)
    d = y % 10 ** (m2)

    z0 = karat(b, d)
    z1 = karat((a + b), (c + d))
    z2 = karat(a, c)

    return (z2 * 10 ** (2 * m2)) + ((z1 - z2 - z0) * 10 ** (m2)) + (z0)
4
  • 1
    A little off the topic, but why do you need to put "else" after if? If condition is true it will return so there is no need for "else"
    – Pedram
    Commented Mar 14, 2018 at 17:12
  • 1
    @Pedram Matter of taste I'd say. It visually emphasises the logical symmetry and makes some patterns of modification - say you were to decide at some point you need to postprocess the results of all branches easier. Of course, one might argue that here it's not symmetric and the first branch is just a corner case to get over with. As I said, a matter of taste. Commented Mar 14, 2018 at 17:57
  • Is Karatsuba multiplication supposed to work with two negative numbers as input? If it is, your solution reaches max recursion depth.
    – user9652688
    Commented Oct 9, 2019 at 19:59
  • 4
    @Saurabh Good catch! This probably doesn't work because Python's floor division rounds down (I think the C equivalent rounds towards zero). I suppose the simplest solution would be to strip the signs off in the beginning and put whatever is appropriate back in the end. Commented Oct 9, 2019 at 20:43
2

i have implemented the same idea but i have restricted to the 2 digit multiplication as the base case because i can reduce float multiplication in function

import math

def multiply(x,y):
    sx= str(x)
    sy= str(y)
    nx= len(sx)
    ny= len(sy)
    if ny<=2 or nx<=2:
        r = int(x)*int(y)
        return r
    n = nx
    if nx>ny:
        sy = sy.rjust(nx,"0")
        n=nx
    elif ny>nx:
        sx = sx.rjust(ny,"0")
        n=ny
    m = n%2
    offset = 0
    if m != 0:
        n+=1
        offset = 1
    floor = int(math.floor(n/2)) - offset
    a = sx[0:floor]
    b = sx[floor:n]
    c = sy[0:floor]
    d = sy[floor:n]
    print(a,b,c,d)

    ac = multiply(a,c)
    bd = multiply(b,d)

    ad_bc = multiply((int(a)+int(b)),(int(c)+int(d)))-ac-bd
    r = ((10**n)*ac)+((10**(n/2))*ad_bc)+bd

    return r

print(multiply(4,5))
print(multiply(4,58779))
print(int(multiply(4872139874092183,5977098709879)))
print(int(4872139874092183*5977098709879))
print(int(multiply(4872349085723098457,597340985723098475)))
print(int(4872349085723098457*597340985723098475))
print(int(multiply(4908347590823749,97098709870985)))
print(int(4908347590823749*97098709870985))
1
  • 2
    I don't understand because I can reduce float multiplication in function: what goal are you trying to achieve in the first place using floating point arithmetic?
    – greybeard
    Commented Jan 2, 2018 at 17:46
1

I tried replacing 2*m2 with m and got incorrect results. I think this has to do with how the numbers are split but I'm not really sure what's going on.

This goes to the heart of how you split your numbers for the recursive calls. If you choose to use an odd n then n//2 will be rounded down to the nearest whole number, meaning your second number will have a length of floor(n/2) and you would have to pad the first with the floor(n/2) zeros. Since we use the same n for both numbers this applies to both. This means if you stick to the original odd n for the final step, you would be padding the first term with the original n zeros instead of the number of zeros that would result from the combination of the first padding plus the second padding (floor(n/2)*2)

1

You have used m2 as a float. It needs to be an integer.

def karat(x,y):
    if len(str(x)) == 1 or len(str(y)) == 1:
        return x*y
    else:
        m = max(len(str(x)),len(str(y)))
        m2 = m // 2

        a = x // 10**(m2)
        b = x % 10**(m2)
        c = y // 10**(m2)
        d = y % 10**(m2)

        z0 = karat(b,d)
        z1 = karat((a+b),(c+d))
        z2 = karat(a,c)

        return (z2 * 10**(2*m2)) + ((z1 - z2 - z0) * 10**(m2)) + (z0)

1
  • 1
    Please don't post only code as answer, but also provide an explanation what your code does and how it solves the problem of the question. Answers with an explanation are usually more helpful and of better quality, and are more likely to attract upvotes. Commented Jul 30, 2020 at 13:35
0

Your code and logic is correct, there is just issue with your base case. Since according to the algo a,b,c,d are 2 digit numbers you should modify your base case and keep the length of x and y equal to 2 in the base case.

3
  • Can you elaborate? This answer isn't entirely clear. Commented Jul 28, 2018 at 21:44
  • In your code x and y are 4 digit number and you know you have divided x as (10^(n/2))a+b that is if x=5678, a=56 and b=78 that is x=((10^2)a)+b this makes the length of string in our base case= 2. And also change your return statement as follows .. ((10**(2*m2))*z2)+((10**m2)*(z1-z2-z0))+z0 Nothing wrong with your return expression just some missing brackets which I think can cause a change is calculation. I hope you have understood what is wrong with the base case................. Commented Jul 30, 2018 at 19:28
  • Don't post essential information in a comment, put it in your answer instead. Comments are transient; answers are not. Commented Jul 30, 2018 at 19:49
0

I think it is better if you used math.log10 function to calculate the number of digits instead of converting to string, something like this :

def number_of_digits(number):
  """
  Used log10 to find no. of digits
  """
  if number > 0:
    return int(math.log10(number)) + 1
  elif number == 0:
    return 1
  else:
    return int(math.log10(-number)) + 1 # Don't count the '-'
1
  • I agree, although the base case makes it so that we'll never see n==0. From there, we might as well just use abs to take the absolute value anyway. So the body could just be: return int(math.log10(abs(number)))+1 Commented Mar 5, 2023 at 2:35
0

I wanted a version that could fit nicely on a single slide, without sacrificing clarity. Here's what I came up with:

def number_of_digits(number: int) -> int:
    return int(math.log10(abs(number))) + 1

def multiply(x: int, y: int) -> int:
    # We CAN multiply small numbers
    if abs(x) < 10 or abs(y) < 10:
        return x * y

    # Calculate the size of the numbers
    digits = max(number_of_digits(x), number_of_digits(y))
    midpoint = 10 ** (digits // 2)

    # Split digit sequences in the middle
    high_x = x // midpoint
    low_x = x % midpoint
    high_y = y // midpoint
    low_y = y % midpoint

    # 3 recursive calls to numbers approximately half the size
    z0 = multiply(low_x, low_y)
    z1 = multiply(low_x + high_x, low_y + high_y)
    z2 = multiply(high_x, high_y)

    return (z2 * midpoint**2) + ((z1 - z2 - z0) * midpoint) + (z0

print(multiply(2**100, 3**100))

I'd argue that:

  1. the variable names are clearer
  2. the number_of_digits helper function should be using log10 to find the number of digits, instead of str+len
  3. The math is a little clearer by extracting out the 10**digits//2 term.
0

A hyper-optimized approach would be by using bitwise operators for faster computation, which are faster than arithmetic operations. It also avoids converting numbers to strings, which is a slow operation.

For example, the base case checks if either x or y is less than 10. This is a simpler and faster check than converting the numbers to strings and checking their lengths:

- if len(str(x)) == 1 or len(str(y)) == 1:
+ if x < 10 or y < 10:

You could use x.bit_length() and y.bit_length() to find the number of bits in x and y. This is faster than converting the numbers to strings and finding their string lengths.

- m = max(len(str(x)),len(str(y)))
+ m = max(x.bit_length(), y.bit_length())

Instead of floating point division, we could use a faster bitwise right shift to divide m by 2.

- m2 = m / 2
+ m2 = m >> 1

We can use bitwise operations to split x and y into high and low parts. This is faster than division and modulus by powers of 10.

- a = x / 10**(m2)
- b = x % 10**(m2)
- c = y / 10**(m2)
- d = y % 10**(m2)
+ high_x, low_x = x >> m2, x & ((1 << m2) - 1)
+ high_y, low_y = y >> m2, y & ((1 << m2) - 1)

The recursive calls remain the same, but the arguments are calculated more efficiently.

- z0 = karat(b,d)
- z1 = karat((a+b),(c+d))
- z2 = karat(a,c)
+ z0 = karat(low_x, low_y)
+ z1 = karat(low_x + high_x, low_y + high_y)
+ z2 = karat(high_x, high_y)

Finally for the return statement, we can use bitwise shift operations to multiply by powers of 2. This is faster than multiplication by powers of 10.

- return (z2 * 10**(2*m2)) + ((z1 - z2 - z0) * 10**(m2)) + (z0)
+ return (z2 << (m2 << 1)) + ((z1 - z2 - z0) << m2) + z0

Full Implementation:

Try it online!

def karat(x, y):
    if x < 10 or y < 10:
        return x * y

    m = max(x.bit_length(), y.bit_length())
    m2 = m >> 1

    high_x, low_x = x >> m2, x & ((1 << m2) - 1)
    high_y, low_y = y >> m2, y & ((1 << m2) - 1)

    z0 = karat(low_x, low_y)
    z1 = karat(low_x + high_x, low_y + high_y)
    z2 = karat(high_x, high_y)

    return (z2 << (m2 << 1)) + ((z1 - z2 - z0) << m2) + z0

print(karat(4, 5))
print(karat(4, 58779))
print(int(karat(4872139874092183, 5977098709879)))
print(int(4872139874092183 * 5977098709879))
print(int(karat(4872349085723098457, 597340985723098475)))
print(int(4872349085723098457 * 597340985723098475))
print(int(karat(4908347590823749, 97098709870985)))
print(int(4908347590823749 * 97098709870985))

While this approach is a bit challenging for the average beginner, this is a great way to get better points for runtime & optimization!

1
  • 1
    bitwise operators […] are faster than arithmetic operations that is one bold statement. You may want to motivate "the - old/+ new syntax".
    – greybeard
    Commented Jun 3 at 19:13
-1

The base case if len(str(x)) == 1 or len(str(y)) == 1: return x*y is incorrect. If you run either of the python code given in answers against large integers, the karat() function will not produce the correct answer.

To make the code correct, you need to change the base case to if len(str(x) < 3 or len(str(y)) < 3: return x*y.

Below is a modified implementation of Paul Panzer's answer that correctly multiplies large integers.

def karat(x,y):
    if len(str(x)) < 3 or len(str(y)) < 3:
        return x*y

    n = max(len(str(x)),len(str(y))) // 2

    a = x // 10**(n)
    b = x % 10**(n)
    c = y // 10**(n)
    d = y % 10**(n)

    z0 = karat(b,d)
    z1 = karat((a+b), (c+d))
    z2 = karat(a,c)

    return ((10**(2*n))*z2)+((10**n)*(z1-z2-z0))+z0
2
  • 1
    I can't really see an improvement in your code. You just replaced the ==1-conditions in the if-clause and as I can see, nothing else changed. Have I overseen something really important? Or is the change from ==1 to <3 the major part. Then, some explanations about this would be fine.
    – colidyre
    Commented Aug 19, 2018 at 13:03
  • @colidyre ==1 is not the correct recursion base case. Therefore, the previous answer's code examples does not actually give the correct answer for multiplying it's input numbers. I made some edits to my answer to clarify.
    – micarlise
    Commented Aug 19, 2018 at 23:07

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.