28th November, 2025
Having implemented 32-bit integer multiplication for Tenstorrent’s AI accelerators, we turn our attention to 32-bit integer division.
We have an optimised reciprocal implementation for FP32 (the subject of an upcoming article), giving us ~23.5 bits of precision. This means we need to do additional work to obtain the remaining bits of precision for 32-bit signed integers.
a and b to FP32,
compute an approximate FP32 reciprocal,
inv_b_f = reciprocal(b_f), and then multiply with
a_f to get an initial approximation of the quotient
q_f in FP32.q. This lets us compute remainder = a - q*b
using the 32-bit integer multiplication techniques we learned in the
other article.correction_f = remainder_f * inv_b_f to be added or
subtracted from the approximate quotient.b=0, or if a=-2147483648 and
b=-1, due to overflow.floor(a/b) = trunc(a/b) - 1; otherwise they’re the
same.a = load(in0)
b = load(in1)
# Preserve sign of result.
sign = a ^ b
# Note: abs(-2147483648) = -2147483648 (0x8000_0000).
# We take the absolute value so that we can use sign-magnitude-to-FP32
# instructions.
a = abs(a)
b = abs(b)
# Convert sign-magnitude integers to FP32. Note that -2147483648 (0x8000_0000)
# means -0 in sign-magnitude form, so we check for -0.0; all other values will
# be >= 0.0.
a_f = sm32_to_fp32(a)
if a_f < 0.0:
a_f = 2147483648.0
b_f = sm32_to_fp32(b)
if b_f < 0.0:
b_f = 2147483648.0
# Compute our initial approximate quotient q = a/b (FP32)
inv_b_f = reciprocal(b_f)
# Note that we don't need the low bits of q, since we'll be doing a
# correction step anyway. This means we can zero the low bits of q and
# simplify the conversion to integer, and reduce the number of multiplications.
# If we add a mantissa alignment factor 2.0**(23+9), this gives us (q>>9)
# with a single mantissa extraction instruction.
q_f = a_f * inv_b_f + 2.0**32
q = exman(q_f)
# Compute qb = q * b. This tells us how close our approximation `q` is to
# the target `a`.
#
# Set q = (q1<<9) + 0, where q1 is 23b.
# Set b = (b1<<23) + b0, where b1 is 9b and b0 is 23b.
#
# Now: q*b = (q1<<9) * (b1<<23) + (q1<<9) * b0
# = (q1<<9) * b0
# = (q1*b0) << 9
#
# The first term overflows 32 bits and is discarded, leaving us with just one
# multiplication.
# q is already right-shifted by 9
q1 = q
b0 = b
# shift q to proper alignment
q <<= 9
# mul24 multiplies two 23-bit integers, giving the low or high 23 bits of
# the product. Inputs do not need to be masked as this is done internally.
qb = mul24_lo(q1, b0) << 9
# Compute remainder: a - q*b
r = a - qb
r_f = sm32_to_fp32(abs(r))
# Correction calculation. We know that the correction will only require ~10
# bits, so fp32_to_u16 is safe to use (round-to-nearest with ties away from
# zero).
correction_f = r_f * inv_b_f
correction = fp32_to_u16(correction_f)
# Compute tmp = correction * b
# = (correction * (b1<<23)) + (correction * b0)
#
# = (mul24_lo(correction, b1) << 23)
# + (mul24_hi(correction, b0) << 23)
# + mul24_lo(correction, b0)
#
# = (mul24_lo(correction, b1) + mul24_hi(correction, b0)) << 23
# + mul24_lo(correction, b0)
b1 = b >> 23
tmp = ((mul24_lo(correction, b1) + mul24_hi(correction, b0)) << 23) \
+ mul24_lo(correction, b0)
# Apply the correction and adjust remainder.
if r < 0:
q -= correction
r += tmp
else:
q += correction
r -= tmp
# Since the correction might have been rounded, we may need to correct one
# additional bit.
# The (r-1)<0 check excludes the case where r=-2147483648.
if r < 0 and (r-1) < 0:
q -= 1
r += b
elif r >= b:
q += 1
r -= b
result = q
# Restore sign and optionally convert "trunc" to "floor" rounding.
if sign < 0:
result = -result
if floor and r != 0:
result -= 1
store(result)The implementation is quite similar, except that we no longer have
access to SFPMUL24,
and so we resort to using SFPMAD
on 11-bit chunks.
a = load(in0)
b = load(in1)
# Preserve sign of result.
sign = a ^ b
# Note: abs(-2147483648) = -2147483648 (0x8000_0000).
# We take the absolute value so that we can use sign-magnitude-to-FP32
# instructions.
a = abs(a)
b = abs(b)
# Convert sign-magnitude integers to FP32. Note that -2147483648 (0x8000_0000)
# means -0 in sign-magnitude form, so we check for -0.0; all other values will
# be >= 0.0.
a_f = sm32_to_fp32(a)
if a_f < 0.0:
a_f = 2147483648.0
b_f = sm32_to_fp32(b)
if b_f < 0.0:
b_f = 2147483648.0
# Compute our initial approximate quotient q = a/b (FP32)
inv_b_f = reciprocal(b_f)
# Note that we don't need the low bits of q, since we'll be doing a
# correction step anyway. This means we can zero the low bits of q and
# simplify the conversion to integer, and reduce the number of multiplications.
# If we add a mantissa alignment factor 2.0**(23+10), this gives us (q>>10)
# with a single mantissa extraction instruction.
q_f = a_f * inv_b_f + 2.0**33
q = exman(q_f)
# Compute remainder: a - q*b
#
# Split q into q = (q2<<21) + (q1<<10) + 0, where q2 and q1 are 11b.
# Split b into b = (b2<<22) + (b1<<11) + b0, where b2 is 10b, and b1 and b0 are 11b.
#
# Then q*b = (q2<<21) * (b2<<22)
# + (q2<<21) * (b1<<11)
# + (q2<<21) * b0
# + (q1<<10) * (b2<<22)
# + (q1<<10) * (b1<<11)
# + (q1<<10) * b0
#
# = (q2<<21) * b0
# + (q1<<10) * (b1<<11)
# + (q1<<10) * b0
#
# = ((q2 * b0) + (q1 * b1)) << 21
# + (q1 * b0) << 10
mask = 0x7ff
mantissa_alignment_offset = 8388608.0 # fp32
# q is already right-shifted by 10
q1 = sm32_to_fp32(q & mask)
q2 = sm32_to_fp32(q >> 11)
b1 = sm32_to_fp32((b >> 11) & mask)
b0 = sm32_to_fp32(b & mask)
# shift q to proper alignment
q <<= 10
# Adding 2**23 in FP32 aligns the mantissa bits for easy extraction as integers.
lo = q1 * b0 + mantissa_alignment_offset
hi = q2 * b0 + mantissa_alignment_offset
hi = q1 * b1 + hi
# Extract the integers and compute qb.
qb = (exman(hi) << 21) + (exman(lo) << 10)
# Compute the remainder.
r = a - qb
r_f = sm32_to_fp32(abs(r))
# Correction calculation. We know that the correction will only require ~11
# bits, so fp32_to_u16 is safe to use (round-to-nearest with ties away from
# zero).
correction_f = r_f * inv_b_f
correction = fp32_to_u16(correction_f)
# Convert the integer correction back to FP32.
correction_f = sm32_to_fp32(correction)
# Compute correction * ((b2<<22) + (b1<<11) + b0)
b2 = sm32_to_fp32(b >> 22)
low = correction_f * b0 + mantissa_alignment_offset
mid = correction_f * b1 + mantissa_alignment_offset
top = correction_f * b2 + mantissa_alignment_offset
# This is used to adjust the remainder.
tmp = (exman(top) << 22) + (exman(mid) << 11) + exman(low)
# Apply the correction and adjust remainder.
if r < 0:
q -= correction
r += tmp
else:
q += correction
r -= tmp
# Since the correction might have been rounded, we may need to correct one
# additional bit.
# The (r-1)<0 check excludes the case where r=-2147483648.
if r < 0 and (r-1) < 0:
q -= 1
r += b
elif r >= b:
q += 1
r -= b
result = q
# Restore sign and optionally convert "trunc" to "floor" rounding.
if sign < 0:
result = -result
if floor and r != 0:
result -= 1
store(result)