[committed] coroutines: Revise await expansions [PR94528]

Message ID B9E0A7E7-92B6-491F-9C77-401B57276433@sandoe.co.uk
State New
Headers show
Series
  • [committed] coroutines: Revise await expansions [PR94528]
Related show

Commit Message

Iain Sandoe April 10, 2020, 11:57 p.m.
The expansions for await expressions were specific to particular
cases, this revises it to be more generic and thus handles the case
that triggered the PR.  Most of the change is code-factoring.

Tested on x86_64-linux/darwin (and powerpc64-linux-gnu)
approved by Nathan on the PR thread (I fixed the comments that
refered to a following patch),

applied to master
thanks
Iain

a: Revise co_await statement walkers.

We want to process the co_awaits one statement at a time.
We also want to be able to determine the insertion points for
new bind scopes needed to cater for temporaries that are
captured by reference and have lifetimes that need extension
to the end of the full expression.  Likewise, the handling of
captured references in the evaluation of conditions might
result in the need to make a frame copy.

This reorganises the statement walking code to make it easier to
extend for these purposes.

b: Factor reference-captured temp code.

We want to be able to use the code that writes a new bind expr
with vars (and their initializers) from several places, so split
that out of the maybe_promote_captured_temps() function into a
new replace_statement_captures ().  Update some comments.

c: Generalize await statement expansion.

This revises the expansion to avoid the need to expand conditionally
on the tree type.  It resolves PR 94528.

gcc/cp/ChangeLog:

2020-04-10  Iain Sandoe  <iain@sandoe.co.uk>

	PR c++/94528
	* coroutines.cc (co_await_expander): Remove.
	(expand_one_await_expression): New.
	(process_one_statement): New.
	(await_statement_expander): New.
	(build_actor_fn): Revise to use per-statement expander.
	(struct susp_frame_data): Reorder and comment.
	(register_awaits): Factor code.
	(replace_statement_captures): New, factored from...
	(maybe_promote_captured_temps):.. here.
	(await_statement_walker): Revise to process per statement.
	(morph_fn_to_coro): Use revised susp_frame_data layout.

gcc/testsuite/ChangeLog:

2020-04-10  Iain Sandoe  <iain@sandoe.co.uk>

	PR c++/94528
	* g++.dg/coroutines/pr94528.C: New test.

Patch

diff --git a/gcc/cp/coroutines.cc b/gcc/cp/coroutines.cc
index ab06c0aef54..57172853639 100644
--- a/gcc/cp/coroutines.cc
+++ b/gcc/cp/coroutines.cc
@@ -1359,6 +1359,13 @@  struct coro_aw_data
   unsigned index;  /* This is our current resume index.  */
 };
 
+/* Lighweight search for the first await expression in tree-walk order.
+   returns:
+     The first await expression found in STMT.
+     NULL_TREE if there are none.
+   So can be used to determine if the statement needs to be processed for
+   awaits.  */
+
 static tree
 co_await_find_in_subtree (tree *stmt, int *do_subtree ATTRIBUTE_UNUSED, void *d)
 {
@@ -1371,57 +1378,33 @@  co_await_find_in_subtree (tree *stmt, int *do_subtree ATTRIBUTE_UNUSED, void *d)
   return NULL_TREE;
 }
 
-/* When we come here:
-    the first operand is the [currently unused] handle for suspend.
-    the second operand is the var to be copy-initialized
-    the third operand is 'o' (the initializer for the second)
-			      as defined in [await.expr] (3.3)
-    the fourth operand is the mode as per the comment on build_co_await ().
+/* Starting with a statment:
 
-   When we leave:
-   the IFN_CO_YIELD carries the labels of the resume and destroy
-   branch targets for this await.  */
+   stmt => some tree containing one or more await expressions.
 
-static tree
-co_await_expander (tree *stmt, int * /*do_subtree*/, void *d)
-{
-  if (STATEMENT_CLASS_P (*stmt) || !EXPR_P (*stmt))
-    return NULL_TREE;
+   We replace the statement with:
+   <STATEMENT_LIST> {
+      initialise awaitable
+      if (!ready)
+	{
+	  suspension context.
+	}
+      resume:
+	revised statement with one await expression rewritten to its
+	await_resume() return value.
+   }
+
+   We then recurse into the initializer and the revised statement
+   repeating this replacement until there are no more await expressions
+   in either.  */
 
+static tree *
+expand_one_await_expression (tree *stmt, tree *await_expr, void *d)
+{
   coro_aw_data *data = (coro_aw_data *) d;
-  enum tree_code stmt_code = TREE_CODE (*stmt);
-  tree stripped_stmt = *stmt;
-  tree *buried_stmt = NULL;
-  tree saved_co_await = NULL_TREE;
-  enum tree_code sub_code = NOP_EXPR;
-
-  if (stmt_code == MODIFY_EXPR || stmt_code == INIT_EXPR)
-    {
-      sub_code = TREE_CODE (TREE_OPERAND (stripped_stmt, 1));
-      if (sub_code == CO_AWAIT_EXPR)
-	saved_co_await = TREE_OPERAND (stripped_stmt, 1); /* Get the RHS.  */
-      else if (tree r
-	       = cp_walk_tree (&TREE_OPERAND (stripped_stmt, 1),
-			       co_await_find_in_subtree, &buried_stmt, NULL))
-	saved_co_await = r;
-    }
-  else if (stmt_code == CALL_EXPR)
-    {
-      if (tree r = cp_walk_tree (&stripped_stmt, co_await_find_in_subtree,
-				 &buried_stmt, NULL))
-	saved_co_await = r;
-    }
-  else if ((stmt_code == CONVERT_EXPR || stmt_code == NOP_EXPR)
-	   && TREE_CODE (TREE_OPERAND (stripped_stmt, 0)) == CO_AWAIT_EXPR)
-    saved_co_await = TREE_OPERAND (stripped_stmt, 0);
-  else if (stmt_code == CO_AWAIT_EXPR)
-    saved_co_await = stripped_stmt;
-
-  if (!saved_co_await)
-    return NULL_TREE;
 
-  /* We want to splice in the await_resume() value in some cases.  */
   tree saved_statement = *stmt;
+  tree saved_co_await = *await_expr;
 
   tree actor = data->actor_fn;
   location_t loc = EXPR_LOCATION (*stmt);
@@ -1454,6 +1437,7 @@  co_await_expander (tree *stmt, int * /*do_subtree*/, void *d)
   tree stmt_list = NULL;
   tree t_expr = STRIP_NOPS (expr);
   tree r;
+  tree *await_init = NULL;
   if (t_expr == var)
     dtor = NULL_TREE;
   else
@@ -1461,7 +1445,9 @@  co_await_expander (tree *stmt, int * /*do_subtree*/, void *d)
       /* Initialize the var from the provided 'o' expression.  */
       r = build2 (INIT_EXPR, await_type, var, expr);
       r = coro_build_cvt_void_expr_stmt (r, loc);
-      append_to_statement_list (r, &stmt_list);
+      append_to_statement_list_force (r, &stmt_list);
+      /* We have an initializer, which might itself contain await exprs.  */
+      await_init = tsi_stmt_ptr (tsi_last (stmt_list));
     }
 
   /* Use the await_ready() call to test if we need to suspend.  */
@@ -1597,46 +1583,77 @@  co_await_expander (tree *stmt, int * /*do_subtree*/, void *d)
   if (REFERENCE_REF_P (resume_call))
     /* Sink to await_resume call_expr.  */
     resume_call = TREE_OPERAND (resume_call, 0);
-  switch (stmt_code)
-    {
-    default: /* not likely to work .. but... */
-      append_to_statement_list (resume_call, &stmt_list);
-      break;
-    case CONVERT_EXPR:
-    case NOP_EXPR:
-      TREE_OPERAND (stripped_stmt, 0) = resume_call;
-      append_to_statement_list (saved_statement, &stmt_list);
-      break;
-    case INIT_EXPR:
-    case MODIFY_EXPR:
-    case CALL_EXPR:
-      /* Replace the use of co_await by the resume expr.  */
-      if (sub_code == CO_AWAIT_EXPR)
-	{
-	  /* We're updating the interior of a possibly <(void) expr>cleanup.  */
-	  TREE_OPERAND (stripped_stmt, 1) = resume_call;
-	  append_to_statement_list (saved_statement, &stmt_list);
-	}
-      else if (buried_stmt != NULL)
-	{
-	  *buried_stmt = resume_call;
-	  append_to_statement_list (saved_statement, &stmt_list);
-	}
-      else
-	{
-	  error_at (loc, "failed to substitute the resume method in %qE",
-		    saved_statement);
-	  append_to_statement_list (saved_statement, &stmt_list);
-	}
-      break;
-    }
+
+  *await_expr = resume_call; /* Replace the co_await expr with its result.  */
+  append_to_statement_list_force (saved_statement, &stmt_list);
+  /* Get a pointer to the revised statment.  */
+  tree *revised = tsi_stmt_ptr (tsi_last (stmt_list));
   if (needs_dtor)
     append_to_statement_list (dtor, &stmt_list);
   data->index += 2;
+
+  /* Replace the original statement with the expansion.  */
   *stmt = stmt_list;
+
+  /* Now, if the awaitable had an initializer, expand any awaits that might
+     be embedded in it.  */
+  tree *aw_expr_ptr;
+  if (await_init &&
+      cp_walk_tree (await_init, co_await_find_in_subtree, &aw_expr_ptr, NULL))
+    expand_one_await_expression (await_init, aw_expr_ptr, d);
+
+  /* Expand any more await expressions in the the original statement.  */
+  if (cp_walk_tree (revised, co_await_find_in_subtree, &aw_expr_ptr, NULL))
+    expand_one_await_expression (revised, aw_expr_ptr, d);
+
+  return NULL;
+}
+
+/* Check to see if a statement contains at least one await expression, if
+   so, then process that.  */
+
+static tree
+process_one_statement (tree *stmt, void *d)
+{
+  tree *aw_expr_ptr;
+  if (cp_walk_tree (stmt, co_await_find_in_subtree, &aw_expr_ptr, NULL))
+    expand_one_await_expression (stmt, aw_expr_ptr, d);
   return NULL_TREE;
 }
 
+static tree
+await_statement_expander (tree *stmt, int *do_subtree, void *d)
+{
+  tree res = NULL_TREE;
+
+  /* Process a statement at a time.  */
+  if (TREE_CODE (*stmt) == BIND_EXPR)
+    res = cp_walk_tree (&BIND_EXPR_BODY (*stmt), await_statement_expander,
+			d, NULL);
+  else if (TREE_CODE (*stmt) == STATEMENT_LIST)
+    {
+      tree_stmt_iterator i;
+      for (i = tsi_start (*stmt); !tsi_end_p (i); tsi_next (&i))
+	{
+	  res = cp_walk_tree (tsi_stmt_ptr (i), await_statement_expander,
+			      d, NULL);
+	  if (res)
+	    return res;
+	}
+      *do_subtree = 0; /* Done subtrees.  */
+    }
+  else if (STATEMENT_CLASS_P (*stmt))
+    return NULL_TREE; /* Process the sub-trees.  */
+  else if (EXPR_P (*stmt))
+    {
+      process_one_statement (stmt, d);
+      *do_subtree = 0; /* Done subtrees.  */
+    }
+
+  /* Continue statement walk, where required.  */
+  return res;
+}
+
 /* Suspend point hash_map.  */
 
 struct suspend_point_info
@@ -2398,7 +2415,7 @@  build_actor_fn (location_t loc, tree coro_frame_type, tree actor, tree fnbody,
   coro_aw_data data = {actor, actor_fp, resume_pt_number, i_a_r_c,
 		       ash, del_promise_label, ret_label,
 		       continue_label, continuation, 2};
-  cp_walk_tree (&actor_body, co_await_expander, &data, NULL);
+  cp_walk_tree (&actor_body, await_statement_expander, &data, NULL);
 
   actor_body = pop_stmt_list (actor_body);
   BIND_EXPR_BODY (actor_bind) = actor_body;
@@ -2564,16 +2581,21 @@  coro_make_frame_entry (tree *field_list, const char *name, tree fld_type,
   return id;
 }
 
+/* This data set is used when analyzing statements for await expressions.  */
 struct susp_frame_data
 {
-  tree *field_list;
-  tree handle_type;
-  hash_set<tree> captured_temps;
-  vec<tree, va_gc> *to_replace;
-  vec<tree, va_gc> *block_stack;
-  unsigned count;
-  unsigned saw_awaits;
-  bool captures_temporary;
+  /* Function-wide.  */
+  tree *field_list; /* The current coroutine frame field list.  */
+  tree handle_type; /* The self-handle type for this coroutine.  */
+  vec<tree, va_gc> *block_stack; /* Track block scopes.  */
+  vec<tree, va_gc> *bind_stack;  /* Track current bind expr.  */
+  unsigned await_number;	 /* Which await in the function.  */
+  unsigned condition_number;	 /* Which replaced condition in the fn.  */
+  /* Temporary values for one statement or expression being analyzed.  */
+  hash_set<tree> captured_temps; /* The suspend captured these temps.  */
+  vec<tree, va_gc> *to_replace;  /* The VAR decls to replace.  */
+  unsigned saw_awaits;		 /* Count of awaits in this statement  */
+  bool captures_temporary;	 /* This expr captures temps by ref.  */
 };
 
 /* Walk the sub-tree looking for call expressions that both capture
@@ -2704,19 +2726,15 @@  register_awaits (tree *stmt, int *do_subtree ATTRIBUTE_UNUSED, void *d)
   if (TREE_CODE (*stmt) != CO_AWAIT_EXPR && TREE_CODE (*stmt) != CO_YIELD_EXPR)
     return NULL_TREE;
 
-  /* co_yield is syntactic sugar, re-write it to co_await.  */
   tree aw_expr = *stmt;
   location_t aw_loc = EXPR_LOCATION (aw_expr); /* location of the co_xxxx.  */
+  /* co_yield is syntactic sugar, re-write it to co_await.  */
   if (TREE_CODE (aw_expr) == CO_YIELD_EXPR)
     {
       aw_expr = TREE_OPERAND (aw_expr, 1);
       *stmt = aw_expr;
     }
 
-  /* Count how many awaits full expression contains.  This is not the same
-     as the counter used for the function-wide await point number.  */
-  data->saw_awaits++;
-
   /* If the awaitable is a parm or a local variable, then we already have
      a frame copy, so don't make a new one.  */
   tree aw = TREE_OPERAND (aw_expr, 1);
@@ -2731,7 +2749,7 @@  register_awaits (tree *stmt, int *do_subtree ATTRIBUTE_UNUSED, void *d)
     {
       /* The required field has the same type as the proxy stored in the
 	 await expr.  */
-      char *nam = xasprintf ("__aw_s.%d", data->count);
+      char *nam = xasprintf ("__aw_s.%d", data->await_number);
       aw_field_nam = coro_make_frame_entry (data->field_list, nam,
 					    aw_field_type, aw_loc);
       free (nam);
@@ -2739,7 +2757,10 @@  register_awaits (tree *stmt, int *do_subtree ATTRIBUTE_UNUSED, void *d)
 
   register_await_info (aw_expr, aw_field_type, aw_field_nam);
 
-  data->count++; /* Each await suspend context is unique.  */
+  /* Count how many awaits the current expression contains.  */
+  data->saw_awaits++;
+  /* Each await suspend context is unique, this is a function-wide value.  */
+  data->await_number++;
 
   /* We now need to know if to take special action on lifetime extension
      of temporaries captured by reference.  This can only happen if such
@@ -2760,6 +2781,100 @@  register_awaits (tree *stmt, int *do_subtree ATTRIBUTE_UNUSED, void *d)
    We don't want to incur the effort of checking for this unless we have
    an await expression in the current full expression.  */
 
+/* This takes the statement which contains one or more temporaries that have
+   been 'captured' by reference in the initializer(s) of co_await(s).
+   The statement is replaced by a bind expression that has actual variables
+   to replace the temporaries.  These variables will be added to the coro-
+   frame in the same manner as user-authored ones.  */
+
+static void
+replace_statement_captures (tree *stmt, void *d)
+{
+  susp_frame_data *awpts = (susp_frame_data *) d;
+  location_t sloc = EXPR_LOCATION (*stmt);
+  tree aw_bind
+    = build3_loc (sloc, BIND_EXPR, void_type_node, NULL, NULL, NULL);
+
+  /* Any cleanup point expression might no longer be necessary, since we
+     are removing one or more temporaries.  */
+  tree aw_statement_current = *stmt;
+  if (TREE_CODE (aw_statement_current) == CLEANUP_POINT_EXPR)
+    aw_statement_current = TREE_OPERAND (aw_statement_current, 0);
+
+  /* Collected the scope vars we need move the temps to regular. */
+  tree aw_bind_body = push_stmt_list ();
+  tree varlist = NULL_TREE;
+  int vnum = -1;
+  while (!awpts->to_replace->is_empty ())
+    {
+      tree to_replace = awpts->to_replace->pop ();
+      tree orig_temp;
+      if (TREE_CODE (to_replace) == CO_AWAIT_EXPR)
+	{
+	  orig_temp = TREE_OPERAND (to_replace, 3);
+	  orig_temp = TREE_VEC_ELT (orig_temp, 2);
+	  orig_temp = TREE_OPERAND (orig_temp, 0);
+	}
+      else
+	orig_temp = TREE_OPERAND (to_replace, 0);
+
+      tree var_type = TREE_TYPE (orig_temp);
+      gcc_checking_assert (same_type_p (TREE_TYPE (to_replace), var_type));
+      /* Build a variable to hold the captured value, this will be included
+	 in the frame along with any user-authored locals.  */
+      char *nam = xasprintf ("aw_%d.tmp.%d", awpts->await_number, ++vnum);
+      tree newvar = build_lang_decl (VAR_DECL, get_identifier (nam), var_type);
+      free (nam);
+      /* If we have better location than the whole expression use that, else
+	 fall back to the expression loc.  */
+      DECL_CONTEXT (newvar) = DECL_CONTEXT (orig_temp);
+      if (DECL_SOURCE_LOCATION (orig_temp))
+	sloc = DECL_SOURCE_LOCATION (orig_temp);
+     else
+	sloc = EXPR_LOCATION (*stmt);
+      DECL_SOURCE_LOCATION (newvar) = sloc;
+      DECL_CHAIN (newvar) = varlist;
+      varlist = newvar; /* Chain it onto the list for the bind expr.  */
+      /* Declare and initialize it in the new bind scope.  */
+      add_decl_expr (newvar);
+      tree new_s = build2_loc (sloc, INIT_EXPR, var_type, newvar, to_replace);
+      new_s = coro_build_cvt_void_expr_stmt (new_s, sloc);
+      add_stmt (new_s);
+
+     /* Replace all instances of that temp in the original expr.  */
+      proxy_replace pr = {to_replace, newvar};
+       cp_walk_tree (&aw_statement_current, replace_proxy, &pr, NULL);
+    }
+
+  /* What's left should be the original statement with any co_await captured
+     temporaries broken out.  Other temporaries might remain so see if we
+     need to wrap the revised statement in a cleanup.  */
+  aw_statement_current = maybe_cleanup_point_expr_void (aw_statement_current);
+  add_stmt (aw_statement_current);
+
+  BIND_EXPR_BODY (aw_bind) = pop_stmt_list (aw_bind_body);
+  awpts->captured_temps.empty ();
+
+  BIND_EXPR_VARS (aw_bind) = nreverse (varlist);
+  tree b_block = make_node (BLOCK);
+  if (!awpts->block_stack->is_empty ())
+    {
+      tree s_block = awpts->block_stack->last ();
+      if (s_block)
+	{
+	BLOCK_SUPERCONTEXT (b_block) = s_block;
+	BLOCK_CHAIN (b_block) = BLOCK_SUBBLOCKS (s_block);
+	BLOCK_SUBBLOCKS (s_block) = b_block;
+	}
+    }
+  BIND_EXPR_BLOCK (aw_bind) = b_block;
+  *stmt = aw_bind;
+}
+
+/* This is called for single statements from the co-await statement walker.
+   It checks to see if the statement contains any co-awaits and, if so,
+   whether any of these 'capture' a temporary by reference.  */
+
 static tree
 maybe_promote_captured_temps (tree *stmt, void *d)
 {
@@ -2769,90 +2884,19 @@  maybe_promote_captured_temps (tree *stmt, void *d)
 
   /* When register_awaits sees an await, it walks the initializer for
      that await looking for temporaries captured by reference and notes
-     them in awpts->captured_temps.  We only need to take any action here
-     if the statement contained any awaits, and any of those had temporaries
-     captured by reference in the initializers for their class.  */
-
-  tree res = cp_walk_tree (stmt, register_awaits, d, &visited);
-  if (!res && awpts->saw_awaits > 0 && !awpts->captured_temps.is_empty ())
-    {
-      location_t sloc = EXPR_LOCATION (*stmt);
-      tree aw_bind
-	= build3_loc (sloc, BIND_EXPR, void_type_node, NULL, NULL, NULL);
-
-      /* Any cleanup point expression might no longer be necessary, since we
-	 are removing one or more temporaries.  */
-      tree aw_statement_current = *stmt;
-      if (TREE_CODE (aw_statement_current) == CLEANUP_POINT_EXPR)
-	aw_statement_current = TREE_OPERAND (aw_statement_current, 0);
-
-      /* Collected the scope vars we need move the temps to regular. */
-      tree aw_bind_body = push_stmt_list ();
-      tree varlist = NULL_TREE;
-      int vnum = -1;
-      while (!awpts->to_replace->is_empty ())
-	{
-	  size_t bufsize = sizeof ("__aw_.tmp.") + 20;
-	  char *buf = (char *) alloca (bufsize);
-	  snprintf (buf, bufsize, "__aw_%d.tmp.%d", awpts->count, ++vnum);
-	  tree to_replace = awpts->to_replace->pop ();
-	  tree orig_temp;
-	  if (TREE_CODE (to_replace) == CO_AWAIT_EXPR)
-	    {
-	      orig_temp = TREE_OPERAND (to_replace, 3);
-	      orig_temp = TREE_VEC_ELT (orig_temp, 2);
-	      orig_temp = TREE_OPERAND (orig_temp, 0);
-	    }
-	  else
-	    orig_temp = TREE_OPERAND (to_replace, 0);
-
-	  tree var_type = TREE_TYPE (orig_temp);
-	  gcc_assert (same_type_p (TREE_TYPE (to_replace), var_type));
-	  tree newvar
-	    = build_lang_decl (VAR_DECL, get_identifier (buf), var_type);
-	  DECL_CONTEXT (newvar) = DECL_CONTEXT (orig_temp);
-	  if (DECL_SOURCE_LOCATION (orig_temp))
-	    sloc = DECL_SOURCE_LOCATION (orig_temp);
-	  DECL_SOURCE_LOCATION (newvar) = sloc;
-	  DECL_CHAIN (newvar) = varlist;
-	  varlist = newvar; /* Chain it onto the list for the bind expr.  */
-	  /* Declare and initialze it in the new bind scope.  */
-	  add_decl_expr (newvar);
-	  tree stmt
-	    = build2_loc (sloc, INIT_EXPR, var_type, newvar, to_replace);
-	  stmt = coro_build_cvt_void_expr_stmt (stmt, sloc);
-	  add_stmt (stmt);
-	  proxy_replace pr = {to_replace, newvar};
-	  /* Replace all instances of that temp in the original expr.  */
-	  cp_walk_tree (&aw_statement_current, replace_proxy, &pr, NULL);
-	}
+     them in awpts->captured_temps.  */
 
-      /* What's left should be the original statement with any co_await
-	 captured temporaries broken out.  Other temporaries might remain
-	 so see if we need to wrap the revised statement in a cleanup.  */
-      aw_statement_current =
-	maybe_cleanup_point_expr_void (aw_statement_current);
-      add_stmt (aw_statement_current);
-      BIND_EXPR_BODY (aw_bind) = pop_stmt_list (aw_bind_body);
-      awpts->captured_temps.empty ();
-
-      BIND_EXPR_VARS (aw_bind) = nreverse (varlist);
-      tree b_block = make_node (BLOCK);
-      if (!awpts->block_stack->is_empty ())
-	{
-	  tree s_block = awpts->block_stack->last ();
-	  if (s_block)
-	    {
-	      BLOCK_SUPERCONTEXT (b_block) = s_block;
-	      BLOCK_CHAIN (b_block) = BLOCK_SUBBLOCKS (s_block);
-	      BLOCK_SUBBLOCKS (s_block) = b_block;
-	    }
-	}
-      BIND_EXPR_BLOCK (aw_bind) = b_block;
+  if (tree res = cp_walk_tree (stmt, register_awaits, d, &visited))
+    return res; /* We saw some reason to abort the tree walk.  */
 
-      *stmt = aw_bind;
-    }
-  return res;
+  /* We only need to take any action here if the statement contained any
+     awaits and any of those had temporaries captured by reference in their
+     initializers. */
+
+  if (awpts->saw_awaits > 0 && !awpts->captured_temps.is_empty ())
+    replace_statement_captures (stmt, d);
+
+  return NULL_TREE;
 }
 
 static tree
@@ -2861,45 +2905,39 @@  await_statement_walker (tree *stmt, int *do_subtree, void *d)
   tree res = NULL_TREE;
   susp_frame_data *awpts = (susp_frame_data *) d;
 
-  /* We might need to insert a new bind expression, and want to link it
-     into the correct scope, so keep a note of the current block scope.  */
+  /* Process a statement at a time.  */
   if (TREE_CODE (*stmt) == BIND_EXPR)
     {
-      tree *body = &BIND_EXPR_BODY (*stmt);
+      /* We might need to insert a new bind expression, and want to link it
+	 into the correct scope, so keep a note of the current block scope.  */
       tree blk = BIND_EXPR_BLOCK (*stmt);
       vec_safe_push (awpts->block_stack, blk);
-
-      if (TREE_CODE (*body) == STATEMENT_LIST)
-	{
-	  tree_stmt_iterator i;
-	  for (i = tsi_start (*body); !tsi_end_p (i); tsi_next (&i))
-	    {
-	      tree *new_stmt = tsi_stmt_ptr (i);
-	      if (STATEMENT_CLASS_P (*new_stmt) || !EXPR_P (*new_stmt)
-		  || TREE_CODE (*new_stmt) == BIND_EXPR)
-		res = cp_walk_tree (new_stmt, await_statement_walker, d, NULL);
-	      else
-		res = maybe_promote_captured_temps (new_stmt, d);
-	      if (res)
-		return res;
-	    }
-	  *do_subtree = 0; /* Done subtrees.  */
-	}
-      else if (!STATEMENT_CLASS_P (*body) && EXPR_P (*body)
-	       && TREE_CODE (*body) != BIND_EXPR)
+      res = cp_walk_tree (&BIND_EXPR_BODY (*stmt), await_statement_walker,
+			  d, NULL);
+      awpts->block_stack->pop ();
+      *do_subtree = 0; /* Done subtrees.  */
+    }
+  else if (TREE_CODE (*stmt) == STATEMENT_LIST)
+    {
+      tree_stmt_iterator i;
+      for (i = tsi_start (*stmt); !tsi_end_p (i); tsi_next (&i))
 	{
-	  res = maybe_promote_captured_temps (body, d);
-	  *do_subtree = 0; /* Done subtrees.  */
+	  res = cp_walk_tree (tsi_stmt_ptr (i), await_statement_walker,
+			      d, NULL);
+	  if (res)
+	    return res;
 	}
-      awpts->block_stack->pop ();
+      *do_subtree = 0; /* Done subtrees.  */
     }
-  else if (!STATEMENT_CLASS_P (*stmt) && EXPR_P (*stmt)
-	   && TREE_CODE (*stmt) != BIND_EXPR)
+  else if (STATEMENT_CLASS_P (*stmt))
+    return NULL_TREE; /* Process the subtrees.  */
+  else if (EXPR_P (*stmt))
     {
       res = maybe_promote_captured_temps (stmt, d);
       *do_subtree = 0; /* Done subtrees.  */
     }
-  /* If it wasn't a statement list, or a single statement, continue.  */
+ 
+  /* Continue recursion, if needed.  */
   return res;
 }
 
@@ -3273,9 +3311,11 @@  morph_fn_to_coro (tree orig, tree *resumer, tree *destroyer)
      to promote any temporaries that are captured by reference (to regular
      vars) they will get added to the coro frame along with other locals.  */
   susp_frame_data body_aw_points
-    = {&field_list, handle_type, hash_set<tree> (), NULL, NULL, 0, 0, false};
-  body_aw_points.to_replace = make_tree_vector ();
+    = {&field_list, handle_type, NULL, NULL, 0, 0,
+       hash_set<tree> (), NULL, 0, false};
   body_aw_points.block_stack = make_tree_vector ();
+  body_aw_points.bind_stack = make_tree_vector ();
+  body_aw_points.to_replace = make_tree_vector ();
   cp_walk_tree (&fnbody, await_statement_walker, &body_aw_points, NULL);
 
   /* Final suspend is mandated.  */
@@ -3914,7 +3954,7 @@  morph_fn_to_coro (tree orig, tree *resumer, tree *destroyer)
   /* Actor ...  */
   build_actor_fn (fn_start, coro_frame_type, actor, fnbody, orig, param_uses,
 		  &local_var_uses, param_dtor_list, initial_await, final_await,
-		  body_aw_points.count, frame_size);
+		  body_aw_points.await_number, frame_size);
 
   /* Destroyer ... */
   build_destroy_fn (fn_start, coro_frame_type, destroy, actor);

diff --git a/gcc/testsuite/g++.dg/coroutines/pr94528.C b/gcc/testsuite/g++.dg/coroutines/pr94528.C
new file mode 100644
index 00000000000..80e7273f178
--- /dev/null
+++ b/gcc/testsuite/g++.dg/coroutines/pr94528.C
@@ -0,0 +1,64 @@ 
+//  { dg-additional-options "-std=c++20 -fpreprocessed -w" }
+namespace std {
+inline namespace {
+template <typename _Result, typename> struct coroutine_traits {
+  using promise_type = _Result::promise_type;
+};
+template <typename = void> struct coroutine_handle;
+template <> struct coroutine_handle<> { public: };
+template <typename> struct coroutine_handle : coroutine_handle<> {};
+struct suspend_always {
+  bool await_ready();
+  void await_suspend(coroutine_handle<>);
+  void await_resume();
+};
+} // namespace
+} // namespace std
+namespace coro = std;
+namespace cppcoro {
+class task {
+private:
+  struct awaitable_base {
+    coro::coroutine_handle<> m_coroutine;
+    bool await_ready() const noexcept;
+    void await_suspend(coro::coroutine_handle<> awaitingCoroutine) noexcept;
+  };
+
+public:
+  auto operator co_await() const &noexcept {
+    struct awaitable : awaitable_base {
+      decltype(auto) await_resume() {}
+    };
+    return awaitable{m_coroutine};
+  }
+
+private:
+  coro::coroutine_handle<> m_coroutine;
+};
+class shared_task;
+class shared_task_promise_base {
+  struct final_awaiter {
+    bool await_ready() const noexcept;
+    template <typename PROMISE>
+    void await_suspend(coro::coroutine_handle<PROMISE> h) noexcept;
+    void await_resume() noexcept;
+  };
+
+public:
+  coro::suspend_always initial_suspend() noexcept;
+  final_awaiter final_suspend() noexcept;
+  void unhandled_exception() noexcept;
+};
+class shared_task_promise : public shared_task_promise_base {
+public:
+  shared_task get_return_object() noexcept;
+  void return_void() noexcept;
+};
+class shared_task {
+public:
+  using promise_type = shared_task_promise;
+};
+auto make_shared_task(cppcoro::task awaitable) -> shared_task {
+  co_return co_await static_cast<cppcoro::task &&>(awaitable);
+}
+} // namespace cppcoro