Rounding on Tenstorrent

12th December, 2025

Introduction

Below are optimised implementations of various rounding operations for Tenstorrent’s AI accelerators.

Assembly syntax: the destination register is the last register specified, e.g. sfpmad A,B,C,D means D=A*B+C.

trunc(x)

This operation discards the fractional part of a floating point number.

One way to achieve this is by masking the bit representation:

if x.Exp < 0:
  # for all |x| < 1, trunc(x) = sgn(x) * 0, i.e. preserve the sign only
  mask = 0x8000_0000
elif x.Exp < 23:
  # for all |x| < 2**23, we zero the fractional part.
  mask = 0xffff_ffff << (23 - x.Exp)
else:
  # for all |x| ≥ 2**23, x is already an integer.
  mask = 0xffff_ffff
result = mask & x

This translates into the following assembly code, achieving 7 cycles:

; assume constant register L12 = 23
; L0 is input, L1 is result
sfploadi L1, 0x8000, MOD0_FLOATB ; mask = 0x8000_0000
sfpexexp L0, L2, 8|2             ; exp = exexp(x); disable lanes where exp<0
sfploadi L1, 0xffff, MOD0_SHORT  ; mask = 0xffff_ffff
sfpiadd 0, L12, L2, 8|2          ; exp = 23 - exp; disable lanes where exp≥23
sfpshft 0, L2, L1, 0             ; mask <<= exp
sfpencc 0, 0, 0, 0               ; reset lanes
sfpand L0, L1                    ; mask &= x

Note that SFPIADD does not support an immediate operand when negating the destination. We want to calculate exp = 23 - exp, so instead we use a constant register containing the value 23.

We store the result in L1 to avoid overwriting the original value in L0, which is useful for the other operations below.

frac(x)

\text{frac}(x) = x - \text{trunc}(x)

We can reuse the trunc primitive and trivially compute this with an additional two cycles, for a total of 9 cycles:

# frac(x) = x - trunc(x)
v = trunc(u)
v = u - v # Wormhole doesn't support negation, so instead we use v = u + v * -1.0 here
; first compute L1=trunc(L0) in 7 cycles
sfpmad L1, -1.0, L0, L1, 0 ; Wormhole doesn't support negation, so we multiply v by the constant -1.0
sfpnop

floor(x)

\text{floor}(x) = \lfloor x \rfloor

We can reuse the trunc primitive here. We only have to adjust the value if x is negative and it has a fractional part. Another way of saying this is that \text{floor}(x) \le x should always hold.

On Blackhole, we can check this with a single instruction, SFPGT:

v = trunc(u)
if v > u:
  v -= 1.0
end

This costs an additional 3 cycles, making the total 10 cycles.

; compute L1=trunc(L0) in 7 cycles.
sfpgt 0, L0, L1, 1          ; enable lanes where trunc(x) > x, otherwise disable
sfpmad L1, 1.0, -1.0, L1, 0 ; L1 = L1 - 1.0
sfpencc 0, 0, 0, 0          ; reset lanes

On Wormhole, SFPGT is not available. As mentioned above, we only need to adjust the value if x is negative, so we check for this first. Then, we check whether x had a fractional part. Treating u and v as two’s complement integers, we use SFPIADD to compute u = v - u (overwriting the original value, but we won’t need it after this), and then enable/disable lanes based on the result. This works because if we treat values as two’s complement integers, bit_cast<i32>(trunc(x)) ≤ bit_cast<i32>(x), meaning that (again, two’s complement) v - u ≤ 0 always holds. So we only need to distinguish v - u == 0 and v - u < 0.

v = trunc(u)
if u < 0 and v - u < 0:
  v -= 1.0

This costs an additional 4 cycles, making the total 11 cycles.

; compute L1=trunc(L0) in 7 cycles.
sfpsetcc 0, L0, 0, LT0      ; disable lanes where x ≥ 0
sfpiadd 0, L1, L0, 8|2      ; disable lanes where v - u ≥ 0
sfpmad L1, 1.0, -1.0, L1, 0 ; L1 = L1 - 1.0
sfpencc 0, 0, 0, 0          ; reset lanes

ceil(x)

\text{ceil}(x) = \lceil x \rceil

Similarly to floor(x), we can reuse the trunc primitive. This time, we only have to adjust the value if x is positive and it has a fractional part. Another way of saying this is that \text{ceil}(x) \ge x should always hold.

On Blackhole, we can check this with a single instruction, SFPGT:

v = trunc(u)
if u > v:
  v += 1.0
end

This costs an additional 3 cycles, making the total 10 cycles.

; compute L1=trunc(L0) in 7 cycles.
sfpgt 0, L1, L0, 1         ; enable lanes where x > trunc(x), otherwise disable
sfpmad L1, 1.0, 1.0, L1, 0 ; L1 = L1 + 1.0
sfpencc 0, 0, 0, 0         ; reset lanes

Due to the lack of SFPGT on Wormhole, we do a very similar check to the one for floor(x), except that we check for positive x instead of negative x. The check for a fractional part remains identical.

v = trunc(u)
if u > 0 and v - u < 0:
  v += 1.0

This costs an additional 4 cycles, making the total 11 cycles.

; compute L1=trunc(L0) in 7 cycles.
sfpsetcc 0, L0, 0, GTE0    ; disable lanes where x < 0
sfpiadd 0, L1, L0, 8|2     ; disable lanes where v - u ≥ 0
sfpmad L1, 1.0, 1.0, L1, 0 ; L1 = L1 + 1.0
sfpencc 0, 0, 0, 0         ; reset lanes

round(x)

Round x to nearest integer (ties to even).

Notice that:

t = abs(v)
if exexp(t) < 23:
  t += 2.0**23
  t -= 2.0**23
  v.{Exp,Man} = {t.Exp,t.Man}

Optimised code (7 cycles) is as follows:

sfpsetsgn 0, L0, L1, 1    ; t.{Sign,Exp,Man} = {0,t.Exp,t.Man}
sfpaddi 8388608.0, L1, 0  ; t += 2**23
sfpexexp 0, L0, L2, 0     ; hide nop; extract exponent of v
sfpaddi -8388608.0, L1, 0 ; t += -2**23
sfpiadd -23, 0, L2, 8     ; hide nop; disable lanes where exp≥23
sfpsetsgn 0, L1, L0, 0    ; v.{Exp,Man} = {t.Exp,t.Man}
sfpencc 0, 0, 0, 0        ; reset lanes; v will now be rounded

SFPLOADMACRO

Just for fun, here’s a version optimised with SFPLOADMACRO, achieving a throughput of 4 cycles per row.

magic = 2**23

m = load(0)

# simultaneously load v, and compute m = ±2**23 with correct sign
v = load(0) ; m.{Exp,Man} = magic.{Exp,Man}

# simultaneously issue:
# 1. load the original value to r (L0)
# 2. SFPIADD (writes result end of this cycle)
# 3. SFPMAD (writes result end of next cycle)
r = load(0) ; v = m - v ; v = mad(1.0, m, v)

# L7 can safely read the result of m - v here, right-shifting by 31 gives 0 or 1
L7 = v >> 31

# indirect SFPMAD writes to L0 (r) or L1 depending on L7, at the end of the next cycle
L[L7] = mad(-1.0, m, v)
# nop

store(r, 0)

The trick is to use the indirect addressing mode of SFPMAD, which writes its result to VD=L[L7]. We store the sign bit of magic - v in L7, so that if v ≥ magic, the result is written to L1 instead and ignored. Otherwise, we overwrite L0 with v - magic, giving us the rounded result.

Acknowledgements

Thanks to Tenstorrent for sponsoring this work.