Mimic the C# yield instruction in VC++

Recently mickey0 posted a question on the CodeGuru forum, asking whether VC++ had any syntax similar to the C# yield instruction. Although I knew the answer was ‘no’, I was still intrigued to figure out if it was possible to implement something similar.

I started with a web search and found a couple of existing implementations, all of which were using Win32 Fibers, so I realized that fibers was the way to go. Even though the examples I found was good tutorials on fibers, they did not meet my requirements on a proper yield implementation. So I had to get dirty.

Don’t bother to read the post? Here is the Yield Iterator Sample Code.

My requirements

  • It should be able to yield return something.
  • It should be able to prematurely terminate the enumeration loop (which I’ve decided to call iteration) without any strange side effects.
  • It should be able to handle exceptions in a proper manner.

In addition there was one thing that in my honest opinion was missing from the C# implementation:

  • It should allow nested functions to call yield_returnand yield_break.

My final expectations

So, with my requirements specified, here is how I would like a final result to be used like this:

// Sample name iterator
class name_iterator : public yield_iterator<std::string>
{
public:
    // implement iterator that returns a set of names
    void iterate()
    {
	yield_return("Egil");
	yield_return("Berit");
	yield_return("Nina");
	yield_break;
	yield_return("Petter");
	yield_return("Karianne");
    }
};  

int _tmain(int argc, _TCHAR* argv[])
{
    // iterate names
    for (name_iterator ni; ni != ni.end(); ++ni)
        std::cout << *ni << std::endl;  

    return 0;
}

Lets start out with the class definition and its default constructor.

A light peek

I’ve decided to base my implementation on a standard forward iterator and event though iterators are typically used to iterate some other collection, my yield_iterator is self-contained. That is, there is no other collection (or enumerable) that produces the elements, the yield_iterator produces the elements itself through the user implemented pure virtual iterate() function:

template <typename T>  

class yield_iterator : public std::iterator<std::forward_iterator_tag, T>
{
private:
    T  m_value;   // Iterators current value.
    bool m_initialized;  // Iteration initialized?
    bool m_done;    // Iteration done?
    void* m_fiber;   // Main fiber
    void* m_iteratorFiber; // Fiber used to maintain state of iteration
    bool m_throws;   // Specify whether main fiber throws exception
    std::exception m_ex;  // Exception to throwpublic:  

    // Default contructor. Does nothing more than to initialize member fields.
    yield_iterator() : m_fiber(0), m_iteratorFiber(0), m_initialized(false),
        m_done(false), m_throws(false)
    {
    }

Some quick explanation of the class’ member fields.

The m_value holds the iterators current value, the one you typically would access through iterator::operator *.

The following m_initialized flag simply tells whether the iterator is initialized or not. When initialized the iterators fibers are set up, and the iteration loop has been executed atleast one time to produce a valid m_value (or the iterator should be equal to yield_iterator::end()).

The m_done flag signals that the iterator loop has completed. There are tree ways to complete the loop. Either the loop completed normally, just as any function would do when there are no more instructions to execute. Another way is after a call to yield_break (or simply return). The third way is due to iterator destruction. When messing around with fibers you need to ensure proper stack unwinding. If the iterator is destructed before the loop has been fully completed you need to tell it to perform a last iteration and unwind any previously constructed objects that lives on the stack.

The m_fiber and m_iteratorFiber void pointers are our, well - fibers. To reach our goal we need to be able manipulate our thread context and this is done using multiple (in this case two) fibers. More on that as we get deep into the mud.

The last two fields; m_ex and m_throws is used for exception handling, as exceptions thrown in the iterator fiber need to be copied into the main fiber.

The first steps

What any forward iterator needs is a way to get the current value, and a way to move on to the next value. Those two operations are implemented through the operator*() function and the operator++() function:

// Return this iterators current value
T operator*()
{
    // We need to ensure that everything is up and running when the user
    // tries to dereference the current value.
    if (!m_initialized)
        iterateImpl();  

    return m_value;
}  

// Advances the iterator by one.
yield_iterator& operator++()
{
    // Make an additional call the iterate function.
    iterateImpl();  

    // Return this iterator
    return *this;
}

There is not much to say about these functions; except that they both calls into the iterateImpl() function… the fun starts within iterateImpl().

Digging deeper

Now, lets take a look at iterateImpl() function. This function is the start of every call to operator++():

void iterateImpl()
{
    // Already done? the user should compare this iterator with
    // yield_iterator::end() to avoid this.
    _ASSERT(!m_done);  

    // Initial call? Setup fibers.
    if (!m_initialized)
    {
        m_initialized = true;  

        // If we're not already running in a fiber, then convert this
        // thread to one.
        m_fiber = GetCurrentFiber();  

        if (IsBadReadPtr(m_fiber, sizeof(m_fiber)))
            m_fiber = ConvertThreadToFiber(this);
        _ASSERT(m_fiber != NULL);  

        // Create additional fiber for running the iterator procedure
        m_iteratorFiber = CreateFiber(4096, iteratorFiberProc, this);
        _ASSERT(m_iteratorFiber != NULL);
    }  

    // Switch to the iterator fiber. If this is the first time, then the
    // iterator fiber starts running from the beginning of
    // staticIteratorFiberProc. If not then the fiber continues running
    // from the SwitchToFiber(fiber) call made by yieldImpl.
    SwitchToFiber(m_iteratorFiber);  

    // We're here either due to the SwitchToFiber(fiber) call made in
    // yieldImpl or because the staticIteratorFiberProc has completed. If
    // it's the latter reason then the done flag is set, and it's time to
    // clean up.
    if (m_done)
        DeleteFiber(m_iteratorFiber);  

    // Check if exception was thrown.
    if (m_throws)
        throw m_ex;
}

The function begins by checking if our yield_iterator is initialized or not. If not, the fibers are initialized.

To use fibers you need to convert the thread to one. This is done using the ConvertThreadToFiber function. As this code is written on and for Windows XP there is no good way to determine whether a thread is (already) a fiber or not. There are better support for this on Vista, though. Anyway, first we try to get the current running fiber using GetCurrentFiber(), and if that returns a readable memory region, we assume that we’re already in fiber mode. If not, we place a call to ConvertThreadToFiber to get a handle to our main fiber (the one which is currently running).

The second step is to create an additional fiber for our iterator function. This is done by CreateFiber. This fiber is initialized to start at staticIteratorFiberProc and we provide the this instance as fiber data. Remember that there is a difference between threads and fibers. There is only one fiber running (within one thread) at any given time. So, currently our main fiber is still running, while our iterator fiber is silently waiting at the beginning of staticIteratorFiberProc.

The third step is to switch to our iterator fiber. The SwitchToFiber call will save the state of the current, main fiber, and the transfer control to our iterator fiber. If this is our first run -and let us just assume that for now - this will transfer the execution point to the beginning of staticIteratorFiberProc.

Getting dirty

The static iteratorFiberProc is our entrypoint for the iterator fiber. The CreateFiber function expects a static function, but we pass it a pointer to yield_iterator and read it back from the void *pv argument:

// Static dispatch function that simply forward the call to its equivalent
// non-static member function.
static void WINAPI iteratorFiberProc(void *pv)
{
    _ASSERT(pv != NULL);
    yield_iterator<T> *yit = static_cast<yield_iterator<T>*>(pv);  

    // Invoke the user implemented iterate function. Any exceptions thrown
    // from the user implemented iterate function is preserved and rethrown
    // by the main fiber.
    try
    {
        yit->iterate();
    } catch(std::exception& e)
    {
        yit->setException(e);
    } catch(const char *str)
    {
        yit->setException(std::exception(str));
    } catch(...)
    {
        yit->setException(std::exception(
            "Unknown exception thrown in yield_iterator::iterate."));
    }  

    // Signal that we're done.
    yit->m_done = true;

    // Switch back to main fiber which is either waiting in iterateImpl or
    // in the destructor.
    SwitchToFiber(yit->m_fiber);
}

As you can see, we start out by getting a pointer to the yield_iterator, then, wrapped in a try-catch statement we call the user implemented iterate() function. The setException calls are made to preserve any exceptions thrown in the iterate() function. As you should know by now, we’re running in the iterator fiber, and at some point we should switch back to the main fiber and return to the caller. Since we cannot throw exceptions across fibers, we need to temporarily store the exception, and then when we finally switch back we rethrow the exception - in the last statement of iterateImpl().

The last thing to do is to flag that we’re done, and switch to our main fiber. Although just looking at this function, it seems that this function is no more than just a call to iterate() and SwitchToFiber(), and ofcourse that is just what it is. But, if the iterate() function makes one or more calls the yieldImpl() function, things start to make sense.

Lets revisit the user implemented iterate function:

void iterate()
{
    yield_return("Egil");
    yield_return("Berit");
    yield_return("Nina");
    yield_break;
    yield_return("Petter");
    yield_return("Karianne");
}

What this function does, which should be considered just an example, is to call yield_return. The yield_return macro is implemented using the following yieldImpl() function:

void yieldImpl(T value)
{
    // Update iterators current value.
    m_value = value;  

    // Switch back to main fiber 'waiting' in iteratorImpl.
    SwitchToFiber(m_fiber);
}

This makes everything sum up. It stores the provided value - the now current value of the iterator and the one to be returned to the caller - and then switch back to the main fiber. If you managed to follow along, the main fiber should be ‘waiting’ at the SwitchToFiber(m_iteratorFiber) statement in the iteratorImpl() function. Now, the main fiber wakes up, leaving the iterator fiber stale, and then executes the last couple of statement before returning back to the user:

    // We're here either due to the SwitchToFiber(fiber) call made in
    // yieldImpl or because the staticIteratorFiberProc has completed. If
    // it's the latter reason then the done flag is set, and it's time to
    // clean up.
    if (m_done)
        DeleteFiber(m_iteratorFiber);  

    // Check if exception was thrown.
    if (m_throws)
        throw m_ex;
}

After this the execution point is back in the user code, and the use may dereference the iterator to get its current value. All this, and what we’ve actually achieve it to run the first yield_return statement in our iterate() function ;)

Now, if the user decides to call the iterator::operator++ once more we simply repeat the whole operation. The only difference now is that the iterator fiber (m_iteratorFiber) isn’t located at the beginning of the iteratorFiberProc(). This time it’s located at the end of the yieldImpl() function, just about to return and execute the second statement in iterate().

The second line of the iterate() function is yet another call to yield_return(), which yet again stores the new ‘yielded’ value, before it switches back to the main fiber and returns to user code. This can go on for as long the users want and as long as there are values to yield.

So far, so good

At this point we got the yield functionality halfway up and running. As I stated in my requirements I wanted to implement yield_return, yield_break and a new feature which I’ve decided to call yield_nested. The latter one should enable nested functions to in turn call yield_return, yield_break or even yield_nested…

I’ll start by explaining the yield_return(x) macro:

 #define yield_return(x) { yieldImpl(x); if (yield_iterator::done()) return; }

What this function does is to call yieldImpl providing the argument. The next step is to check the done flag and if we’re done it makes a clean return. This is a very important check. You might wonder how the the iterator could become ‘done’ at this step? The reason follows.

When, at any point, returning back to user code you won’t know what the user will do next. If the iterate() function is halfway through and the user decides not to perform any more iterations, then any stack finalization in the iterate() function won’t ever occur. To make sure that this the stack is finalized in a proper manner we add some logics to the yield_iterators destructor. When the destructor is called, we check the iterators state and if necessary we raise the done flag and transfers the control to the iterator fiber for one last run:

// Destructor. Clean up midway or when done.
virtual ~yield_iterator()
{
    if (m_initialized && !m_done)
    {
        // The user implemented iterate functions might be 'midway'. We
        // need to unwind its stack to make sure that everything is
        // properly cleaned up and to avoid strange effects. Set done flag
        // and switch back to the iterator fiber.
        m_done = true;
        SwitchToFiber(m_iteratorFiber);  

        // Delete iterator fiber
        DeleteFiber(m_iteratorFiber);
    }
}

So, when returning from yieldImpl the done flag might have been set by the destructor, and the the macro makes the iterate() function unwind and finalizes its stack and then return right away.

The next macro is yield_break:

 #define yield_break { return; }

This is simply a wrapped return statement. No matter if the iteration has completed or not, the iterate() function unwinds/finalizes its stack and the returns.

The last macro is the yield_nested :

 #define yield_nested(x) { x; if (yield_iterator::done()) return; }

The macro itself does not call yieldImpl. But, by checking the done flag, it makes sure that if any nested function made a call to yieldImpl which left the iterator in a ‘done’ state, then the stack is finalized and the iterate() returns immediately.

Final result…

…should look like this:

yield_iterator_result.png

Here is the Yield Iterator Sample Code.

Resources
Using fibers to simplify enumerators, part 1: When life is easier for the enumerator
Yield Return Iterator for Native C++ Using Fibers

- petter

3 Responses to “Mimic the C# yield instruction in VC++”

  1. Raymond Chen Says:

    There is no promise that a valid fiber will pass IsBadReadPtr. In the absence of IsThreadAFiber you just need to keep track of this by some external convention. (Who knows, maybe in the next version of Windows, fibers will be indices into an array instead of pointers.)

  2. Petter Labråten Says:

    Thanks for pointing that out, Raymond.

    As you say it’s probably better to remove ConvertThreadToFiber from the iterator, and leave it up to the user/developer to make sure that the thread is converted to a fiber. At least on a pre-Vista box.

  3. blog Says:

    hi…

    wonderful…

Leave a comment...

Powered by WordPress. Entries (RSS) and Comments (RSS).