When trait in Rust

When trait in Rust
Photo by Ruhan Shete / Unsplash

I will start this article by demonstrating the problem that I was facing.

I am developing a game in Ratatui, a really nice library for creating UIs in the terminal (TUI).

Ratatui offers a builder pattern for widgets.

Step 1: Building a simple widget

// area: Rect, buf: &mut Buffer
Paragraph::new("Hello There!")
    .block(Block::default().borders(Borders::ALL)
        .title("Game Menu").title_alignment(Center)
        .white())
    .light_green()
    .alignment(Center)
    .render(area, buf);

The resulting render:

As you can see, the text is drawn in green, centered, with a white box around it and a title that is also centered.

Now let's do something more complicated: I want to display this widget differently based on whether it is currently "active". Let's replace all the colors with something darker, if the menu is not active.

Step 2: Dynamic Colors

let title_color = if menu_is_active {
    Color::White
} else {
    Color::DarkGray
};
let text_color = if menu_is_active {
    Color::LightCyan
} else {
    Color::DarkGray
};

Paragraph::new("Hello There!")
    .block(Block::default().borders(Borders::ALL)
        .title("Game Menu").title_alignment(Center)
        .style(Style::default().fg(title_color)))
    .style(Style::default().fg(text_color))
    .alignment(Center)
    .render(area, buf);

Active vs Inactive:

It works, but the code blew up quite a bit.

Step 3: Using the dim() method

As it turns out, there is already a dim() method for widgets, that automatically takes care of dimming all the colors inside of it.

With this, I can get rid of the color logic. I can store the widget in an intermediate variable, and then apply a conditional action on it.

let widget = Paragraph::new("Hello There!")
  .block(Block::default().borders(Borders::ALL)
      .title("Game Menu").title_alignment(Center)
      .white())
  .light_green()
  .alignment(Center);
    
let widget = if menu_is_active { widget.dim() } else { widget };

widget.render(area, buf);

But this code ruins the fluid, elegant code style we had originally, because I need to perform a conditional transformation on the widget.

Step 4: When Trait

Can we somehow turn an if expression into a chainable method? Yes we can!

I added this trait to my utils file:

pub trait When {
    fn when(self, condition: bool, action: impl FnOnce(Self) -> Self) -> Self where Self: Sized;
}

impl<T> When for T {
    fn when(self, condition: bool, action: impl FnOnce(T) -> T) -> Self {
        if condition {
            action(self)
        } else {
            self
        }
    }
}

And now I can write my widget like this:

use crate::utils::When; // make the trait available

Paragraph::new("Hello There!")
    .block(Block::default().borders(Borders::ALL)
        .title("Game Menu").title_alignment(Center)
        .white())
    .light_green()
    .alignment(Center)
    .when(! menu_is_active, |w| w.dim()) // <---- the good stuff
    .render(area, buf);

Isn't that glorious?

Let's add another when() and give our widget a Double border when it is active:

Paragraph::new("Hello There!")
    .block(Block::default().borders(Borders::ALL)
        .when(menu_is_active, |b| b.border_type(Double)) // <---- the good stuff
        .title("Game Menu").title_alignment(Center)
        .white())
    .light_green()
    .alignment(Center)
    .when(! menu_is_active, |w| w.dim()) // <---- the good stuff
    .render(area, buf);

The result:

Step 5: Realize what we just built

Our When trait is implemented generically for any T: Sized.

This means I can use it on a lot of types, even numbers and strings.

use crate::utils::When; // make the trait available for any T: Sized

let x = rng().random_range(0..10);
let y = rng().random_range(0..10);

let result = 1.when(x > 5, |a| a + 1)
              .when(y > 5, |a| a * 2);

let message = "There is".to_string()
    .when(x > 5, |m| format!("{m} a number greater than 5"))
    .when(y > 5, |m| format!("{m} and another number greater than 5"));

println!("Result: {result}, Message: {message}");

Is the above code sane? Probably not.
But is it extremely awesome that this works? I think so!

For Builders where I need to call methods conditionally, this is a gift. Especially since it works on types that were not originally defined by me. Now, I can add the when() method to any third party builders on the fly, without touching their source code, and it just works.

This language feature is often called Extension Methods. In Rust it is enabled by Trait implementations. As I recently learned, many other languages have this feature as well, including Kotlin, Scala, Swift and C#.

Addendum: The difference between Rust traits and PHP traits

If you are a PHP dev, you may still be a bit confused about the difference between traits in Rust and PHP.

In PHP we could do this:

trait When {
    function when(bool $condition, callable $action): static {
        return $condition ? $action($this) : $this;
    }
}

class MyClass {
    use When;
}

// Now, instances of MyClass have the when() method

First you define a trait, then you add it to a class definition.
The problem here is, you cannot add traits to classes that do not belong to you. Because in order to add a trait, you have to add it the to actual source code of the class.

In Rust:

pub trait When {
    // the good stuff
}

impl<T> When for T {
    // the good stuff
}

// Since <T> can be anything, any type now has the when() method after we import the trait into a module.

First, we define the trait (which just provides a signature, similar to an interface). And then, we implement the trait for a specific type. But this type can also be generic.

So, structs and traits are decoupled from each other. Developers can afterwards come in and create impls to connect structs with traits.

Here is another cool example that I use in my code

pub trait ToDuration {
    fn milliseconds(&self) -> std::time::Duration;
}

// We are NOT defining the type u64 here (which is an inbuilt number type).
// We establish a relationship between u64 and ToDuration.
impl ToDuration for u64 {
    fn milliseconds(&self) -> std::time::Duration {
        std::time::Duration::from_millis(*self)
    }
}

// usage: Creating a 16 millisecond Duration
if event::poll(16.milliseconds())? {
    self.handle_crossterm_events()?;
}