CINXE.COM
'mesh' Dialect - MLIR
<!doctype html><html lang=en-us><head><meta charset=utf-8><meta http-equiv=x-ua-compatible content="IE=edge"><meta name=viewport content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=no"><title>'mesh' Dialect - MLIR</title><meta name=description content="Multi-Level IR Compiler Framework"><meta name=generator content="Hugo 0.119.0"><link href=https://mlir.llvm.org/index.xml rel=alternate type=application/rss+xml><link rel=canonical href=https://mlir.llvm.org/docs/Dialects/Mesh/><link rel=stylesheet href=https://mlir.llvm.org/css/theme.css><script src=https://use.fontawesome.com/releases/v5.0.6/js/all.js></script> <link rel=stylesheet href=https://mlir.llvm.org/css/chroma.min.css><script src=https://cdn.jsdelivr.net/npm/jquery@3.3.1/dist/jquery.min.js></script> <script src=https://cdn.jsdelivr.net/npm/jquery.easing@1.4.1/jquery.easing.min.js></script> <script src=https://mlir.llvm.org/js/bundle.js></script> <script type=text/javascript src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> <script type=text/x-mathjax-config> MathJax.Hub.Config({ tex2jax: { inlineMath: [['$', '$'] ], displayMath: [ ['$$','$$'], ["\\[","\\]"] ] } }); </script><link rel=apple-touch-icon sizes=180x180 href="/apple-touch-icon.png?v=1"><link rel=icon type=image/png sizes=32x32 href="/favicon-32x32.png?v=1"><link rel=icon type=image/png sizes=16x16 href="/favicon-16x16.png?v=1"><link rel=manifest href="/site.webmanifest?v=1"><link rel=mask-icon href="/safari-pinned-tab.svg?v=1" color=#3775e0><link rel="shortcut icon" href="/favicon.ico?v=1"><meta name=msapplication-TileColor content="#2d89ef"><meta name=theme-color content="#ffffff"><link rel=icon href=/favicon.svg type=image/svg+xml sizes=any><style>:root{}</style></head><body><div class=container><header><h1><div><img src=https://mlir.llvm.org//mlir-logo.png width=40px align=absmiddle> MLIR</div></h1><p class=description>Multi-Level IR Compiler Framework</p></header><div class=global-menu><nav><ul><li class=parent><a href>Community<i class="fas fa-angle-right"></i></a><ul class=sub-menu><li class=child><a href=https://llvm.discourse.group/c/mlir/31>Forums</a></li><li class=child><a href=https://discord.gg/xS7Z362>Chat</a></li></ul></li><li><a href=/getting_started/Debugging/>Debugging Tips</a></li><li><a href=/getting_started/Faq/>FAQ</a></li><li class=parent><a href=https://github.com/llvm/llvm-project/tree/main/mlir>Source<i class="fas fa-angle-right"></i></a><ul class=sub-menu><li class=child><a href=/doxygen/>Doxygen</a></li><li class=child><a href=https://github.com/llvm/llvm-project/tree/main/mlir>GitHub</a></li></ul></li><li><a href="https://bugs.llvm.org/buglist.cgi?bug_status=__open__&list_id=177877&order=changeddate%20DESC%2Cpriority%2Cbug_severity&product=MLIR&query_format=specific">Bugs</a></li><li><a href=https://github.com/llvm/mlir-www/tree/main/website/static/LogoAssets>Logo Assets</a></li><li><a href=https://www.youtube.com/MLIRCompiler>Youtube Channel</a></li></ul></nav></div><div class=content-container><main><h1>'mesh' Dialect</h1><p>The <code>mesh</code> dialect contains a set of attributes, operations and interfaces that are useful for representing sharding and communication on a device mesh cluster.</p><p><nav id=TableOfContents><ul><li><a href=#collective-communication-operations>Collective Communication Operations</a><ul><li><a href=#device-groups>Device groups</a></li><li><a href=#in-group-device>In-group Device</a></li><li><a href=#purity>Purity</a></li></ul></li><li><a href=#operations>Operations</a><ul><li><a href=#meshall_gather-meshallgatherop><code>mesh.all_gather</code> (mesh::AllGatherOp)</a></li><li><a href=#meshall_reduce-meshallreduceop><code>mesh.all_reduce</code> (mesh::AllReduceOp)</a></li><li><a href=#meshall_slice-meshallsliceop><code>mesh.all_slice</code> (mesh::AllSliceOp)</a></li><li><a href=#meshall_to_all-meshalltoallop><code>mesh.all_to_all</code> (mesh::AllToAllOp)</a></li><li><a href=#meshbroadcast-meshbroadcastop><code>mesh.broadcast</code> (mesh::BroadcastOp)</a></li><li><a href=#meshgather-meshgatherop><code>mesh.gather</code> (mesh::GatherOp)</a></li><li><a href=#meshmesh-meshmeshop><code>mesh.mesh</code> (mesh::MeshOp)</a></li><li><a href=#meshmesh_shape-meshmeshshapeop><code>mesh.mesh_shape</code> (mesh::MeshShapeOp)</a></li><li><a href=#meshprocess_linear_index-meshprocesslinearindexop><code>mesh.process_linear_index</code> (mesh::ProcessLinearIndexOp)</a></li><li><a href=#meshprocess_multi_index-meshprocessmultiindexop><code>mesh.process_multi_index</code> (mesh::ProcessMultiIndexOp)</a></li><li><a href=#meshrecv-meshrecvop><code>mesh.recv</code> (mesh::RecvOp)</a></li><li><a href=#meshreduce-meshreduceop><code>mesh.reduce</code> (mesh::ReduceOp)</a></li><li><a href=#meshreduce_scatter-meshreducescatterop><code>mesh.reduce_scatter</code> (mesh::ReduceScatterOp)</a></li><li><a href=#meshscatter-meshscatterop><code>mesh.scatter</code> (mesh::ScatterOp)</a></li><li><a href=#meshsend-meshsendop><code>mesh.send</code> (mesh::SendOp)</a></li><li><a href=#meshshard-meshshardop><code>mesh.shard</code> (mesh::ShardOp)</a></li><li><a href=#meshshard_shape-meshshardshapeop><code>mesh.shard_shape</code> (mesh::ShardShapeOp)</a></li><li><a href=#meshsharding-meshshardingop><code>mesh.sharding</code> (mesh::ShardingOp)</a></li><li><a href=#meshshift-meshshiftop><code>mesh.shift</code> (mesh::ShiftOp)</a></li><li><a href=#meshupdate_halo-meshupdatehaloop><code>mesh.update_halo</code> (mesh::UpdateHaloOp)</a></li></ul></li><li><a href=#attributes-20>Attributes</a><ul><li><a href=#meshaxesarrayattr>MeshAxesArrayAttr</a></li><li><a href=#reductionkindattr>ReductionKindAttr</a></li></ul></li></ul></nav><h2 id=collective-communication-operations>Collective Communication Operations <a class=headline-hash href=#collective-communication-operations>¶</a></h2><p>There are a number of operations in the Mesh dialect to facilitate communication between devices in a mesh. It is assumed that the user is familiar with collective operations. <a href=https://en.wikipedia.org/wiki/Collective_operation>Wikipedia</a> has a good explanation. The main addition is that the collectives in this dialect have mesh semantics.</p><h3 id=device-groups>Device groups <a class=headline-hash href=#device-groups>¶</a></h3><p>The operation attributes <code>mesh</code> and <code>mesh_axes</code> specifies a list of device mesh axes that partition the devices into disjoint groups. The collective operation is performed between devices in the same group. Devices that have the same coordinates outside of axes <code>mesh_axes</code> are in the same group. A group is described by its multi-index along the axes outside of <code>mesh_axes</code>. For example if we have a device mesh of size <code>2x3x4x5</code> and the partition mesh axes list is <code>[0, 1]</code> then devices are partitioned into the groups <code>{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }</code>. The device groups would be <code>{ (k, m) | 0<=k<4, 0<=m<5 }</code>. Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group. Device (1, 0, 2, 4) will be in another group. Some collective operations like all-to-all and all-gather care about the order of devices. The order of device in a device group is induced by the order of axes in <code>mesh_axes</code>. The axes are ordered from outer to inner. If we have an axis list <code>[3, 1]</code> then device <code>(i, 1, k, 0)</code> will precede both devices <code>(i, 0, k, 1)</code> and <code>(i, 2, k, 0)</code>.</p><h3 id=in-group-device>In-group Device <a class=headline-hash href=#in-group-device>¶</a></h3><p>Some operations like <code>broadcast</code>, <code>scatter</code> and <code>send</code> specify devices in each device-group. These devices are represented with their multi-index over the mesh axes that are not constant within a device group. These are the axes specified by <code>mesh_axes</code> attribute.</p><p>For Example on a 3D mesh an operation with <code>mesh_axes = [0, 2]</code> would specify an in-group device with <code>(i, j)</code>. Then for each group with index <code>g</code> on the second axis, the in-group device would be <code>(i, g, j)</code>.</p><h3 id=purity>Purity <a class=headline-hash href=#purity>¶</a></h3><p>Collectives that involve the whole device group to perform a single operation are pure. The exceptions are <code>send</code> and <code>recv</code>.</p><p>There is an assumption that the execution is SPMD. Not only that each process runs the same program, but that at the point of execution of a collective operation, all processes are in a coherent state. All compiler transformations must be consistent. Collective operations in the IR that may correspond to the same runtime collective operation must be transformed in a consistent manner. For example if a collective operation is optimized out, than it must also not appear in any path of execution on any process.</p><p>Having the operations as <code>Pure</code> implies that if an interpreter is to execute the IR containing the <code>mesh</code> collectives, all processes would execute the same line when they reach a pure collective operation. This requirement stems from the need to be compatible with general optimization passes like dead code and common sub-expression elimination.</p><h2 id=operations>Operations <a class=headline-hash href=#operations>¶</a></h2><p><a href=https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td>source</a></p><h3 id=meshall_gather-meshallgatherop><code>mesh.all_gather</code> (mesh::AllGatherOp) <a class=headline-hash href=#meshall_gather-meshallgatherop>¶</a></h3><p><em>All-gather over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.all_gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis attr-dict `:` type($input) `->` type($result) </code></pre><p>Gathers along the <code>gather_axis</code> tensor axis.</p><p>Example:</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-mlir data-lang=mlir><span class=line><span class=cl>mesh<span class=p>.</span>mesh <span class=nf>@mesh0</span><span class=p>(</span><span class=nl>shape =</span> <span class=m>2x2</span><span class=p>)</span> </span></span><span class=line><span class=cl><span class=p>...</span> </span></span><span class=line><span class=cl><span class=nv>%1</span> <span class=p>=</span> mesh<span class=p>.</span>all_gather <span class=nv>%0</span> on <span class=nf>@mesh0</span> <span class=nl>mesh_axes =</span> <span class=p>[</span><span class=m>1</span><span class=p>]</span> <span class=nl>gather_axis =</span> <span class=m>1</span> </span></span><span class=line><span class=cl> <span class=p>:</span> <span class=kt>tensor</span><span class=p><</span><span class=m>2x2x</span><span class=k>i8</span><span class=p>></span> <span class=p>-></span> <span class=kt>tensor</span><span class=p><</span><span class=m>2x4x</span><span class=k>i8</span><span class=p>></span> </span></span></code></pre></div><p>Input:</p><pre tabindex=0><code> +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+ </code></pre><p>Result:</p><pre tabindex=0><code>gather tensor axis 1 ------------> +-------------+ | 1 2 5 6 | <- devices (0, 0) and (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 0) and (1, 1) | 11 12 15 16 | +-------------+ </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code>, <code>SameOperandsAndResultElementType</code>, <code>SameOperandsAndResultRank</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes>Attributes: <a class=headline-hash href=#attributes>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>gather_axis</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr></table><h4 id=operands>Operands: <a class=headline-hash href=#operands>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h4 id=results>Results: <a class=headline-hash href=#results>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h3 id=meshall_reduce-meshallreduceop><code>mesh.all_reduce</code> (mesh::AllReduceOp) <a class=headline-hash href=#meshall_reduce-meshallreduceop>¶</a></h3><p><em>All-reduce over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.all_reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)? attr-dict `:` type($input) `->` type($result) </code></pre><p>The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.</p><p>Attributes: <code>reduction</code>: Indicates the reduction method.</p><p>Example:</p><pre tabindex=0><code>%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max> : tensor<3x4xf32> -> tensor<3x4xf64> </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code>, <code>SameOperandsAndResultShape</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-1>Attributes: <a class=headline-hash href=#attributes-1>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>reduction</code></td><td>::mlir::mesh::ReductionKindAttr</td><td><details><summary>Reduction of an iterator/mesh dimension.</summary><p>Enum cases:</p><ul><li>sum (<code>Sum</code>)</li><li>max (<code>Max</code>)</li><li>min (<code>Min</code>)</li><li>product (<code>Product</code>)</li><li>average (<code>Average</code>)</li><li>bitwise_and (<code>BitwiseAnd</code>)</li><li>bitwise_or (<code>BitwiseOr</code>)</li><li>bitwise_xor (<code>BitwiseXor</code>)</li><li>generic (<code>Generic</code>)</li></ul></details></td></tr></table><h4 id=operands-1>Operands: <a class=headline-hash href=#operands-1>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h4 id=results-1>Results: <a class=headline-hash href=#results-1>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshall_slice-meshallsliceop><code>mesh.all_slice</code> (mesh::AllSliceOp) <a class=headline-hash href=#meshall_slice-meshallsliceop>¶</a></h3><p><em>All-slice over a device mesh. This is the inverse of all-gather.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.all_slice` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis attr-dict `:` type($input) `->` type($result) </code></pre><p>Slice along the <code>slice_axis</code> tensor axis. This operation can be thought of as the inverse of all-gather. Technically, it is not required that all processes have the same input tensor. Each process will slice a piece of its local tensor based on its in-group device index. The operation does not communicate data between devices.</p><p>Example:</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-mlir data-lang=mlir><span class=line><span class=cl>mesh<span class=p>.</span>mesh <span class=nf>@mesh0</span><span class=p>(</span><span class=nl>shape =</span> <span class=m>2x2</span><span class=p>)</span> </span></span><span class=line><span class=cl><span class=p>...</span> </span></span><span class=line><span class=cl><span class=nv>%1</span> <span class=p>=</span> mesh<span class=p>.</span>all_slice <span class=nv>%0</span> on <span class=nf>@mesh0</span> <span class=nl>mesh_axes =</span> <span class=p>[</span><span class=m>1</span><span class=p>]</span> <span class=nl>slice_axis =</span> <span class=m>1</span> </span></span><span class=line><span class=cl> <span class=p>:</span> <span class=kt>tensor</span><span class=p><</span><span class=m>2x4x</span><span class=k>i8</span><span class=p>></span> <span class=p>-></span> <span class=kt>tensor</span><span class=p><</span><span class=m>2x2x</span><span class=k>i8</span><span class=p>></span> </span></span></code></pre></div><p>Input:</p><pre tabindex=0><code>+-------------+ | 1 2 5 6 | <- devices (0, 0) and (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 0) and (1, 1) | 11 12 15 16 | +-------------+ </code></pre><p>Result:</p><pre tabindex=0><code>gather tensor axis 1 ------------> +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+ </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code>, <code>SameOperandsAndResultElementType</code>, <code>SameOperandsAndResultRank</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-2>Attributes: <a class=headline-hash href=#attributes-2>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>slice_axis</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr></table><h4 id=operands-2>Operands: <a class=headline-hash href=#operands-2>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h4 id=results-2>Results: <a class=headline-hash href=#results-2>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h3 id=meshall_to_all-meshalltoallop><code>mesh.all_to_all</code> (mesh::AllToAllOp) <a class=headline-hash href=#meshall_to_all-meshalltoallop>¶</a></h3><p><em>All-to-all over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.all_to_all` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `split_axis` `=` $split_axis `concat_axis` `=` $concat_axis attr-dict `:` type($input) `->` type($result) </code></pre><p>Performs an all-to-all on tensor pieces split along <code>split_axis</code>. The resulting pieces are concatenated along <code>concat_axis</code> on ech device.</p><p>Example:</p><pre tabindex=0><code>mesh.mesh @mesh0(shape = 3) ... %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0] split_axis = 0 concat_axis = 0 : tensor<3x2xi8> -> tensor<3x2xi8> </code></pre><p>Input:</p><pre tabindex=0><code> device device device (0) (1) (2) +-------+-------+-------+ | split and concat along | 11 12 | 21 22 | 31 32 | | tensor axis 0 | 13 14 | 23 24 | 33 34 | ↓ | 15 16 | 25 26 | 35 36 | +-------+-------+-------+ </code></pre><p>Result:</p><pre tabindex=0><code> device device device (0) (1) (2) +-------+-------+-------+ | 11 12 | 13 14 | 15 16 | | 21 22 | 23 24 | 25 26 | | 31 32 | 33 34 | 35 36 | +-------+-------+-------+ </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code>, <code>SameOperandsAndResultElementType</code>, <code>SameOperandsAndResultRank</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-3>Attributes: <a class=headline-hash href=#attributes-3>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>split_axis</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr><tr><td><code>concat_axis</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr></table><h4 id=operands-3>Operands: <a class=headline-hash href=#operands-3>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h4 id=results-3>Results: <a class=headline-hash href=#results-3>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h3 id=meshbroadcast-meshbroadcastop><code>mesh.broadcast</code> (mesh::BroadcastOp) <a class=headline-hash href=#meshbroadcast-meshbroadcastop>¶</a></h3><p><em>Broadcast over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.broadcast` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `root` `=` custom<DynamicIndexList>($root_dynamic, $root) attr-dict `:` functional-type(operands, results) </code></pre><p>Broadcast the tensor on <code>root</code> to all devices in each respective group. The operation broadcasts along mesh axes <code>mesh_axes</code>. The <code>root</code> device specifies the in-group multi-index that is broadcast to all other devices in the group.</p><p>Example:</p><pre tabindex=0><code>mesh.mesh @mesh0(shape = 2x2) %1 = mesh.broadcast %0 on @mesh0 mesh_axes = [0] root = [0] : (tensor<2xi8>) -> tensor<2xi8> </code></pre><p>Input:</p><pre tabindex=0><code> +-------+-------+ | broadcast device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 +-------+-------+ ↓ device (1, 0) -> | | | <- device (1, 1) +-------+-------+ </code></pre><p>Output:</p><pre tabindex=0><code> +-------+-------+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) +-------+-------+ device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1) +-------+-------+ </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-4>Attributes: <a class=headline-hash href=#attributes-4>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>root</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-4>Operands: <a class=headline-hash href=#operands-4>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>ranked tensor of any type values</td></tr><tr><td style=text-align:center><code>root_dynamic</code></td><td>variadic of index</td></tr></tbody></table><h4 id=results-4>Results: <a class=headline-hash href=#results-4>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshgather-meshgatherop><code>mesh.gather</code> (mesh::GatherOp) <a class=headline-hash href=#meshgather-meshgatherop>¶</a></h3><p><em>Gather over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis `root` `=` custom<DynamicIndexList>($root_dynamic, $root) attr-dict `:` functional-type(operands, results) </code></pre><p>Gathers on device <code>root</code> along the <code>gather_axis</code> tensor axis. <code>root</code> specifies the coordinates of a device along <code>mesh_axes</code>. It uniquely identifies the root device for each device group. The result tensor on non-root devices is undefined. Using it will result in undefined behavior.</p><p>Example:</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-mlir data-lang=mlir><span class=line><span class=cl>mesh<span class=p>.</span>mesh <span class=nf>@mesh0</span><span class=p>(</span><span class=nl>shape =</span> <span class=m>2x2</span><span class=p>)</span> </span></span><span class=line><span class=cl><span class=p>...</span> </span></span><span class=line><span class=cl><span class=nv>%1</span> <span class=p>=</span> mesh<span class=p>.</span>gather <span class=nv>%0</span> on <span class=nf>@mesh0</span> <span class=nl>mesh_axes =</span> <span class=p>[</span><span class=m>1</span><span class=p>]</span> </span></span><span class=line><span class=cl> <span class=nl>gather_axis =</span> <span class=m>1</span> <span class=nl>root =</span> <span class=p>[</span><span class=m>1</span><span class=p>]</span> </span></span><span class=line><span class=cl> <span class=p>:</span> <span class=p>(</span><span class=kt>tensor</span><span class=p><</span><span class=m>2x2x</span><span class=k>i8</span><span class=p>>)</span> <span class=p>-></span> <span class=kt>tensor</span><span class=p><</span><span class=m>2x4x</span><span class=k>i8</span><span class=p>></span> </span></span></code></pre></div><p>Input:</p><pre tabindex=0><code> gather tensor axis 1 ------------> +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+ </code></pre><p>Result:</p><pre tabindex=0><code>+-------------+ | 1 2 5 6 | <- devices (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 1) | 11 12 15 16 | +-------------+ </code></pre><p>Devices <code>(0, 0)</code> and <code>(1, 0)</code> have undefined result.</p><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-5>Attributes: <a class=headline-hash href=#attributes-5>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>gather_axis</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr><tr><td><code>root</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-5>Operands: <a class=headline-hash href=#operands-5>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr><tr><td style=text-align:center><code>root_dynamic</code></td><td>variadic of index</td></tr></tbody></table><h4 id=results-5>Results: <a class=headline-hash href=#results-5>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h3 id=meshmesh-meshmeshop><code>mesh.mesh</code> (mesh::MeshOp) <a class=headline-hash href=#meshmesh-meshmeshop>¶</a></h3><p><em>Description of a device/process mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.mesh` $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)` attr-dict </code></pre><p>The mesh.mesh operation is a symbol operation that identifies a specific mesh. The operation has three attributes:</p><ol><li><p><code>sym_name</code>: This attribute uniquely identifies the name of the mesh. This name serves as a symbolic reference to the mesh throughout the MLIR module, allowing for consistent referencing and easier debugging.</p></li><li><p><code>shape</code>: This attribute represents the shape of the device mesh. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example <code>2x?x4</code>.</p></li></ol><p>Example:</p><pre tabindex=0><code>// A device mesh with 3 axes, the total device number is 4 * 8 * 12 // The dimension sizes are 4, 8, 12 mesh.mesh @mesh0(shape = 4x8x12) // A device mesh with 2 axes, the total device number is unknown // The first dimension size is 4 and the second is unknown mesh.mesh @mesh1(shape = 4x?) // A device mesh with 2 axes, the total device number is unknown // The first dimension size is unknown and the second is 4 mesh.mesh @mesh2(shape = ?x4) // A device mesh with 2 axes, the number of devices along both axes // is unknown mesh.mesh @mesh3(shape = ?x?) </code></pre><p>Interfaces: <code>Symbol</code></p><h4 id=attributes-6>Attributes: <a class=headline-hash href=#attributes-6>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>sym_name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr><tr><td><code>shape</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h3 id=meshmesh_shape-meshmeshshapeop><code>mesh.mesh_shape</code> (mesh::MeshShapeOp) <a class=headline-hash href=#meshmesh_shape-meshmeshshapeop>¶</a></h3><p><em>Get the shape of the mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.mesh_shape` $mesh (`axes` `=` $axes^)? attr-dict `:` type($result) </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-7>Attributes: <a class=headline-hash href=#attributes-7>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr></table><h4 id=results-6>Results: <a class=headline-hash href=#results-6>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>variadic of index</td></tr></tbody></table><h3 id=meshprocess_linear_index-meshprocesslinearindexop><code>mesh.process_linear_index</code> (mesh::ProcessLinearIndexOp) <a class=headline-hash href=#meshprocess_linear_index-meshprocesslinearindexop>¶</a></h3><p><em>Get the linear index of the current device.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.process_linear_index` `on` $mesh attr-dict `:` type($result) </code></pre><p>Example:</p><pre tabindex=0><code>%idx = mesh.process_linear_index on @mesh : index </code></pre><p>if <code>@mesh</code> has shape <code>(10, 20, 30)</code>, a device with multi index <code>(1, 2, 3)</code> will have linear index <code>3 + 30*2 + 20*30*1</code>.</p><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>InferTypeOpInterface</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-8>Attributes: <a class=headline-hash href=#attributes-8>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr></table><h4 id=results-7>Results: <a class=headline-hash href=#results-7>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>index</td></tr></tbody></table><h3 id=meshprocess_multi_index-meshprocessmultiindexop><code>mesh.process_multi_index</code> (mesh::ProcessMultiIndexOp) <a class=headline-hash href=#meshprocess_multi_index-meshprocessmultiindexop>¶</a></h3><p><em>Get the multi index of current device along specified mesh axes.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.process_multi_index` `on` $mesh (`axes` `=` $axes^)? attr-dict `:` type($result) </code></pre><p>It is used in the SPMD format of IR. The <code>axes</code> mush be non-negative and less than the total number of mesh axes. If the axes are empty then get the index along all axes.</p><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-9>Attributes: <a class=headline-hash href=#attributes-9>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr></table><h4 id=results-8>Results: <a class=headline-hash href=#results-8>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>variadic of index</td></tr></tbody></table><h3 id=meshrecv-meshrecvop><code>mesh.recv</code> (mesh::RecvOp) <a class=headline-hash href=#meshrecv-meshrecvop>¶</a></h3><p><em>Send over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.recv` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)? attr-dict `:` functional-type(operands, results) </code></pre><p>Receive from a device within a device group.</p><p>Interfaces: <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><h4 id=attributes-10>Attributes: <a class=headline-hash href=#attributes-10>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>source</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-6>Operands: <a class=headline-hash href=#operands-6>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr><tr><td style=text-align:center><code>source_dynamic</code></td><td>variadic of index</td></tr></tbody></table><h4 id=results-9>Results: <a class=headline-hash href=#results-9>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshreduce-meshreduceop><code>mesh.reduce</code> (mesh::ReduceOp) <a class=headline-hash href=#meshreduce-meshreduceop>¶</a></h3><p><em>Reduce over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)? `root` `=` custom<DynamicIndexList>($root_dynamic, $root) attr-dict `:` functional-type(operands, results) </code></pre><p>Reduces on device <code>root</code> within each device group. <code>root</code> specifies the coordinates of a device along <code>mesh_axes</code>. It uniquely identifies the root device within its device group. The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.</p><p>Attributes: <code>reduction</code>: Indicates the reduction method.</p><p>Example:</p><pre tabindex=0><code>%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max> root = [2, 3] : (tensor<3x4xf32>) -> tensor<3x4xf64> </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-11>Attributes: <a class=headline-hash href=#attributes-11>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>reduction</code></td><td>::mlir::mesh::ReductionKindAttr</td><td><details><summary>Reduction of an iterator/mesh dimension.</summary><p>Enum cases:</p><ul><li>sum (<code>Sum</code>)</li><li>max (<code>Max</code>)</li><li>min (<code>Min</code>)</li><li>product (<code>Product</code>)</li><li>average (<code>Average</code>)</li><li>bitwise_and (<code>BitwiseAnd</code>)</li><li>bitwise_or (<code>BitwiseOr</code>)</li><li>bitwise_xor (<code>BitwiseXor</code>)</li><li>generic (<code>Generic</code>)</li></ul></details></td></tr><tr><td><code>root</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-7>Operands: <a class=headline-hash href=#operands-7>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>ranked tensor of any type values</td></tr><tr><td style=text-align:center><code>root_dynamic</code></td><td>variadic of index</td></tr></tbody></table><h4 id=results-10>Results: <a class=headline-hash href=#results-10>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshreduce_scatter-meshreducescatterop><code>mesh.reduce_scatter</code> (mesh::ReduceScatterOp) <a class=headline-hash href=#meshreduce_scatter-meshreducescatterop>¶</a></h3><p><em>Reduce-scatter over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.reduce_scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)? `scatter_axis` `=` $scatter_axis attr-dict `:` type($input) `->` type($result) </code></pre><p>After the reduction, the result is scattered within each device group. The tensor is split along <code>scatter_axis</code> and the pieces distributed across the device group. Example:</p><pre tabindex=0><code>mesh.mesh @mesh0(shape = 2x2) ... %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1] reduction = <max> scatter_axis = 0 : tensor<3x4xf32> -> tensor<1x4xf64> </code></pre><p>Input:</p><pre tabindex=0><code> device (0, 1) ↓ +-------+-------+ | scatter tensor device (0, 0) -> | 1 2 | 5 6 | | axis 0 | 3 4 | 7 8 | ↓ +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | | 11 12 | 15 16 | +-------+-------+ ↑ device (1, 1) </code></pre><p>Result:</p><pre tabindex=0><code>+-------+ | 6 8 | <- devices (0, 0) +-------+ | 10 12 | <- devices (0, 1) +-------+ | 22 24 | <- devices (1, 0) +-------+ | 26 28 | <- devices (1, 1) +-------+ </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code>, <code>SameOperandsAndResultRank</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-12>Attributes: <a class=headline-hash href=#attributes-12>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>reduction</code></td><td>::mlir::mesh::ReductionKindAttr</td><td><details><summary>Reduction of an iterator/mesh dimension.</summary><p>Enum cases:</p><ul><li>sum (<code>Sum</code>)</li><li>max (<code>Max</code>)</li><li>min (<code>Min</code>)</li><li>product (<code>Product</code>)</li><li>average (<code>Average</code>)</li><li>bitwise_and (<code>BitwiseAnd</code>)</li><li>bitwise_or (<code>BitwiseOr</code>)</li><li>bitwise_xor (<code>BitwiseXor</code>)</li><li>generic (<code>Generic</code>)</li></ul></details></td></tr><tr><td><code>scatter_axis</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr></table><h4 id=operands-8>Operands: <a class=headline-hash href=#operands-8>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h4 id=results-11>Results: <a class=headline-hash href=#results-11>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshscatter-meshscatterop><code>mesh.scatter</code> (mesh::ScatterOp) <a class=headline-hash href=#meshscatter-meshscatterop>¶</a></h3><p><em>Scatter over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis `root` `=` custom<DynamicIndexList>($root_dynamic, $root) attr-dict `:` functional-type(operands, results) </code></pre><p>For each device group split the input tensor on the <code>root</code> device along axis <code>scatter_axis</code> and scatter the parts across the group devices.</p><p>Example:</p><pre tabindex=0><code>mesh.mesh @mesh0(shape = 2x2) %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0] scatter_axis = 0 root = [1] : (tensor<2x2xi8>) -> tensor<1x2xi8> </code></pre><p>Input:</p><pre tabindex=0><code> device (0, 1) ↓ +-------+-------+ | scatter tensor device (0, 0) -> | | | | axis 0 | | | ↓ +-------+-------+ device (1, 0) -> | 1 2 | 5 6 | | 3 4 | 7 8 | +-------+-------+ ↑ device (1, 1) </code></pre><p>Result:</p><pre tabindex=0><code> device (0, 1) ↓ +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | +-------+-------+ device (1, 0) -> | 3 4 | 7 8 | +-------+-------+ ↑ device (1, 1) </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-13>Attributes: <a class=headline-hash href=#attributes-13>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>scatter_axis</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr><tr><td><code>root</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-9>Operands: <a class=headline-hash href=#operands-9>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr><tr><td style=text-align:center><code>root_dynamic</code></td><td>variadic of index</td></tr></tbody></table><h4 id=results-12>Results: <a class=headline-hash href=#results-12>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshsend-meshsendop><code>mesh.send</code> (mesh::SendOp) <a class=headline-hash href=#meshsend-meshsendop>¶</a></h3><p><em>Send over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.send` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination) attr-dict `:` functional-type(operands, results) </code></pre><p>Send from one device to another within a device group.</p><p>Interfaces: <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><h4 id=attributes-14>Attributes: <a class=headline-hash href=#attributes-14>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>destination</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-10>Operands: <a class=headline-hash href=#operands-10>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr><tr><td style=text-align:center><code>destination_dynamic</code></td><td>variadic of index</td></tr></tbody></table><h4 id=results-13>Results: <a class=headline-hash href=#results-13>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshshard-meshshardop><code>mesh.shard</code> (mesh::ShardOp) <a class=headline-hash href=#meshshard-meshshardop>¶</a></h3><p><em>Annotate on how a tensor is sharded across a mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.shard` $src `to` $sharding (`annotate_for_users` $annotate_for_users^)? attr-dict `:` type($result) </code></pre><p>The mesh.shard operation is designed to specify and guide the sharding behavior of a tensor value across a mesh topology. This operation has two operands and two optional attributes:</p><ol><li><p><code>input</code>: This operand represents the tensor value that needs to be annotated for sharding.</p></li><li><p><code>sharding</code>: This attribute is type of <code>MeshShardingType</code>, which is the core data structure to represent distribution of a tensor on a mesh. it is typically defiend by an <code>mesh.sharding</code> operation.</p></li><li><p><code>annotate_for_users</code>: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.</p></li></ol><p>Example:</p><pre tabindex=0><code>func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () { %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> ... } func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () { %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> ... } func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () { %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ... } // The first mesh.shard op applies to %arg0, the second mesh.shard op // applies for the operand of op0, the third mesh.shard op applies for the // operand of op2 func.func @both_result_and_multi_operands_annotated( %arg0 : tensor<4x8xf32>) -> () { %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> %sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32> %sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> "op0"(%1) : ... "op1"(%2) : ... ... } </code></pre><p>The following usages are undefined:</p><pre tabindex=0><code>func.func @annotate_on_same_result_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32> %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32> ... } func.func @annotate_on_same_result_same_value_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32> %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32> ... } func.func @annotate_on_same_operand_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> ... } func.func @result_annotated_after_operand( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32> ... } </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>InferTypeOpInterface</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-15>Attributes: <a class=headline-hash href=#attributes-15>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>annotate_for_users</code></td><td>::mlir::UnitAttr</td><td>unit attribute</td></tr></table><h4 id=operands-11>Operands: <a class=headline-hash href=#operands-11>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>src</code></td><td>ranked tensor of any type values</td></tr><tr><td style=text-align:center><code>sharding</code></td><td>sharding definition</td></tr></tbody></table><h4 id=results-14>Results: <a class=headline-hash href=#results-14>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshshard_shape-meshshardshapeop><code>mesh.shard_shape</code> (mesh::ShardShapeOp) <a class=headline-hash href=#meshshard_shape-meshshardshapeop>¶</a></h3><p><em>Get the shard shape of a given process/device.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.shard_shape` custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result) </code></pre><p>The device/process id is a linearized id of the device/process in the mesh. This operation might be used during spmdization when the shard shape depends on (non-constant) values used in <code>mesh.sharding</code>.</p><p>Traits: <code>AlwaysSpeculatableImplTrait</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-16>Attributes: <a class=headline-hash href=#attributes-16>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>shape</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-12>Operands: <a class=headline-hash href=#operands-12>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>sharding</code></td><td>sharding definition</td></tr><tr><td style=text-align:center><code>device</code></td><td>index</td></tr></tbody></table><h4 id=results-15>Results: <a class=headline-hash href=#results-15>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>variadic of index</td></tr></tbody></table><h3 id=meshsharding-meshshardingop><code>mesh.sharding</code> (mesh::ShardingOp) <a class=headline-hash href=#meshsharding-meshshardingop>¶</a></h3><p><em>Define a sharding of a tensor.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.sharding` $mesh `split_axes` `=` $split_axes (`partial` `=` $partial_type $partial_axes^)? (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)? (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)? attr-dict `:` type($result) </code></pre><p>The MeshSharding specifies how a tensor is sharded and distributed across the process mesh. It is typically used in a <code>mesh.shard</code> operation. The operation has the follwing attributes and operands:</p><ol><li><p><code>mesh</code>: this attribute is a FlatSymbolRefAttr that refers to the device mesh where the distributed tensor is placed. The symbol must resolve to a <code>mesh.mesh</code> operation.</p></li><li><p><code>split_axes</code>: is an array composed of int64_t sub-arrays. The outer array’s maximum size is the <code>rank</code> of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.</p></li><li><p>[Optional] <code>partial_axes</code>: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified by <code>partial_type</code>.</p></li><li><p>[Optional] <code>partial_type</code>: indicates the reduction type of the possible all-reduce op. It has 4 possible values: <code>generic</code>: is not an allowed value inside a shard attribute.</p></li><li><p>[Optional] Sizes of halos to be added for each sharded tensor dimension. <code>halo_sizes</code> is provided as a flattened 1d array of i64s, 2 values for each sharded dimension. <code>halo_sizes = [1, 2]</code> means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end. <code>halo_sizes = [1, 2, 2, 3]</code> defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets <code>[1,2]</code> halos and the seconds gets <code>[2,3]</code> halos. <code>?</code> indicates dynamic halo sizes.</p></li><li><p>[Optional] Offsets for each shard and sharded tensor dimension. <code>sharded_dims_offsets</code> is provided as a flattened 1d array of i64s. For each sharded tensor dimension the offsets (starting index) of all shards in that dimension and an additional value for the end of the last shard are provided. For a 1d sharding this means that position <code>i</code> has the exclusive prefix sum for shard <code>i</code>, and since only contiguous sharding is supported, its inclusive prefix sum is at position ‘i+1’.</p></li></ol><p>Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded, <code>sharded_dims_offsets</code> = [0, 24, 32, 0, 20, 32] means that the first device of the device-mesh will get a shard of shape 24x20x32 and the second device will get a shard of shape 8x12x32. <code>?</code> indicates dynamic shard dimensions.</p><p><code>halo_sizes</code> and <code>sharded_dims_offsets</code> are mutually exclusive.</p><p>Examples:</p><pre tabindex=0><code>mesh.mesh @mesh0(shape = 2x2x4) mesh.mesh @mesh1d_4(shape = 4) // The tensor is fully replicated on @mesh0. // Currently, there must be at least one sub-array present in axes, even // if it's empty. Otherwise, a parsing error will occur. %sharding0 = mesh.sharding @mesh0 split_axes = [[]] // The tensor is sharded on the first dimension along axis 0 of @mesh0 %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] // The tensor is sharded on its first dimension along axis 0 of @mesh0 and // it is also a partial_sum along mesh axis 1. %sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1] // The tensor is sharded on its first dimension along axis 0 of @mesh0 and // it is also a partial_max along mesh axis 1. %sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1] // Could be used for a mesh.shard op %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32> // The tensor is sharded on its first dimension along axis 0 of @mesh0 and // and it has halo-sizes of 1 and 2 on the sharded dim. %halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32> // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4 // and it has pre-defined shard sizes. The shards of the devices will have // the following shapes: [4x2, 4x3, 4x4, 4x5] %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14] %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32> </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code>, <code>AttrSizedOperandSegments</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>InferTypeOpInterface</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-17>Attributes: <a class=headline-hash href=#attributes-17>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>split_axes</code></td><td>::mlir::mesh::MeshAxesArrayAttr</td><td></td></tr><tr><td><code>partial_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>partial_type</code></td><td>::mlir::mesh::ReductionKindAttr</td><td><details><summary>Reduction of an iterator/mesh dimension.</summary><p>Enum cases:</p><ul><li>sum (<code>Sum</code>)</li><li>max (<code>Max</code>)</li><li>min (<code>Min</code>)</li><li>product (<code>Product</code>)</li><li>average (<code>Average</code>)</li><li>bitwise_and (<code>BitwiseAnd</code>)</li><li>bitwise_or (<code>BitwiseOr</code>)</li><li>bitwise_xor (<code>BitwiseXor</code>)</li><li>generic (<code>Generic</code>)</li></ul></details></td></tr><tr><td><code>static_sharded_dims_offsets</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr><tr><td><code>static_halo_sizes</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-13>Operands: <a class=headline-hash href=#operands-13>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>dynamic_sharded_dims_offsets</code></td><td>variadic of 64-bit signless integer</td></tr><tr><td style=text-align:center><code>dynamic_halo_sizes</code></td><td>variadic of 64-bit signless integer</td></tr></tbody></table><h4 id=results-16>Results: <a class=headline-hash href=#results-16>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>sharding definition</td></tr></tbody></table><h3 id=meshshift-meshshiftop><code>mesh.shift</code> (mesh::ShiftOp) <a class=headline-hash href=#meshshift-meshshiftop>¶</a></h3><p><em>Shift over a device mesh.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.shift` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `shift_axis` `=` $shift_axis `offset` `=` $offset (`rotate` $rotate^)? attr-dict `:` type($input) `->` type($result) </code></pre><p>Within each device group shift along mesh axis <code>shift_axis</code> by an offset <code>offset</code>. The result on devices that do not have a corresponding source is undefined. <code>shift_axis</code> must be one of <code>mesh_axes</code>. If the <code>rotate</code> attribute is present, instead of a shift a rotation is done.</p><p>Example:</p><pre tabindex=0><code>mesh.mesh @mesh0(shape = 2x4) %1 = mesh.shift on @mesh0 mesh_axes = [1] shift_axis = 1 offset = 2 rotate : tensor<2xi8> -> tensor<2xi8> </code></pre><p>Input:</p><pre tabindex=0><code>mesh axis 1 -----------> +----+----+----+----+ | 1 | 2 | 3 | 4 | +----+----+----+----+ | 5 | 6 | 7 | 8 | +----+----+----+----+ </code></pre><p>Result:</p><pre tabindex=0><code>+----+----+----+----+ | 3 | 4 | 1 | 2 | +----+----+----+----+ | 7 | 8 | 5 | 6 | +----+----+----+----+ </code></pre><p>Traits: <code>AlwaysSpeculatableImplTrait</code>, <code>SameOperandsAndResultElementType</code>, <code>SameOperandsAndResultShape</code></p><p>Interfaces: <code>ConditionallySpeculatable</code>, <code>NoMemoryEffect (MemoryEffectOpInterface)</code>, <code>OpAsmOpInterface</code>, <code>SymbolUserOpInterface</code></p><p>Effects: <code>MemoryEffects::Effect{}</code></p><h4 id=attributes-18>Attributes: <a class=headline-hash href=#attributes-18>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>mesh_axes</code></td><td>::mlir::DenseI16ArrayAttr</td><td>i16 dense array attribute</td></tr><tr><td><code>shift_axis</code></td><td>::mlir::IntegerAttr</td><td>index attribute</td></tr><tr><td><code>offset</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr><tr><td><code>rotate</code></td><td>::mlir::UnitAttr</td><td>unit attribute</td></tr></table><h4 id=operands-14>Operands: <a class=headline-hash href=#operands-14>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>input</code></td><td>non-0-ranked.tensor of any type values</td></tr></tbody></table><h4 id=results-17>Results: <a class=headline-hash href=#results-17>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>ranked tensor of any type values</td></tr></tbody></table><h3 id=meshupdate_halo-meshupdatehaloop><code>mesh.update_halo</code> (mesh::UpdateHaloOp) <a class=headline-hash href=#meshupdate_halo-meshupdatehaloop>¶</a></h3><p><em>Update halo data.</em></p><p>Syntax:</p><pre tabindex=0><code>operation ::= `mesh.update_halo` $source `into` $destination `on` $mesh `split_axes` `=` $split_axes (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)? (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)? attr-dict `:` type($source) `->` type($result) </code></pre><p>This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor/memref data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.</p><p>Source and destination might have different halo sizes.</p><p>Assumes all devices hold tensors with same-sized halo data as specified by <code>source_halo_sizes/static_source_halo_sizes</code> and <code>destination_halo_sizes/static_destination_halo_sizes</code> in source shard and destination/result shard.</p><p><code>split_axes</code> specifies for each tensor axis along which mesh axes its halo data is updated.</p><p>Traits: <code>AttrSizedOperandSegments</code></p><p>Interfaces: <code>DestinationStyleOpInterface</code>, <code>SymbolUserOpInterface</code></p><h4 id=attributes-19>Attributes: <a class=headline-hash href=#attributes-19>¶</a></h4><table><tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr><tr><td><code>mesh</code></td><td>::mlir::FlatSymbolRefAttr</td><td>flat symbol reference attribute</td></tr><tr><td><code>split_axes</code></td><td>::mlir::mesh::MeshAxesArrayAttr</td><td></td></tr><tr><td><code>static_source_halo_sizes</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr><tr><td><code>static_destination_halo_sizes</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr></table><h4 id=operands-15>Operands: <a class=headline-hash href=#operands-15>¶</a></h4><table><thead><tr><th style=text-align:center>Operand</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>source</code></td><td>non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values</td></tr><tr><td style=text-align:center><code>destination</code></td><td>non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values</td></tr><tr><td style=text-align:center><code>source_halo_sizes</code></td><td>variadic of 64-bit signless integer</td></tr><tr><td style=text-align:center><code>destination_halo_sizes</code></td><td>variadic of 64-bit signless integer</td></tr></tbody></table><h4 id=results-18>Results: <a class=headline-hash href=#results-18>¶</a></h4><table><thead><tr><th style=text-align:center>Result</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center><code>result</code></td><td>non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values</td></tr></tbody></table><h2 id=attributes-20>Attributes <a class=headline-hash href=#attributes-20>¶</a></h2><h3 id=meshaxesarrayattr>MeshAxesArrayAttr <a class=headline-hash href=#meshaxesarrayattr>¶</a></h3><p>Syntax:</p><pre tabindex=0><code>#mesh.axisarray< ::llvm::ArrayRef<MeshAxesAttr> # axes > </code></pre><h4 id=parameters>Parameters: <a class=headline-hash href=#parameters>¶</a></h4><table><thead><tr><th style=text-align:center>Parameter</th><th style=text-align:center>C++ type</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center>axes</td><td style=text-align:center><code>::llvm::ArrayRef<MeshAxesAttr></code></td><td></td></tr></tbody></table><h3 id=reductionkindattr>ReductionKindAttr <a class=headline-hash href=#reductionkindattr>¶</a></h3><p>Reduction of an iterator/mesh dimension.</p><p>Syntax:</p><pre tabindex=0><code>#mesh.partial< ::mlir::mesh::ReductionKind # value > </code></pre><p>Enum cases:</p><ul><li>sum (<code>Sum</code>)</li><li>max (<code>Max</code>)</li><li>min (<code>Min</code>)</li><li>product (<code>Product</code>)</li><li>average (<code>Average</code>)</li><li>bitwise_and (<code>BitwiseAnd</code>)</li><li>bitwise_or (<code>BitwiseOr</code>)</li><li>bitwise_xor (<code>BitwiseXor</code>)</li><li>generic (<code>Generic</code>)</li></ul><h4 id=parameters-1>Parameters: <a class=headline-hash href=#parameters-1>¶</a></h4><table><thead><tr><th style=text-align:center>Parameter</th><th style=text-align:center>C++ type</th><th>Description</th></tr></thead><tbody><tr><td style=text-align:center>value</td><td style=text-align:center><code>::mlir::mesh::ReductionKind</code></td><td>an enum of type ReductionKind</td></tr></tbody></table><div class=edit-meta><br></div><nav class=pagination><a class="nav nav-prev" href=https://mlir.llvm.org/docs/Dialects/MemRef/ title="'memref' Dialect"><i class="fas fa-arrow-left" aria-hidden=true></i> Prev - 'memref' Dialect</a> <a class="nav nav-next" href=https://mlir.llvm.org/docs/Dialects/MLProgramOps/ title="'ml_program' Dialect">Next - 'ml_program' Dialect <i class="fas fa-arrow-right" aria-hidden=true></i></a></nav><footer><p class=powered>Powered by <a href=https://gohugo.io>Hugo</a>. Theme by <a href=https://themes.gohugo.io/hugo-theme-techdoc/>TechDoc</a>. Designed by <a href=https://github.com/thingsym/hugo-theme-techdoc>Thingsym</a>.</p></footer></main><div class=sidebar><nav class=slide-menu><ul><li><a href=https://mlir.llvm.org/>Home</a></li><li><a href=https://mlir.llvm.org/users/>Users of MLIR</a></li><li><a href=https://mlir.llvm.org/pubs/>MLIR Related Publications</a></li><li><a href=https://mlir.llvm.org/talks/>Talks</a></li><li><a href=https://mlir.llvm.org/deprecation/>Deprecations & Current Refactoring</a></li><li class=has-sub-menu><a href=https://mlir.llvm.org/getting_started/>Getting Started<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/getting_started/ReportingIssues/>Reporting Issues</a></li><li><a href=https://mlir.llvm.org/getting_started/Debugging/>Debugging Tips</a></li><li><a href=https://mlir.llvm.org/getting_started/Faq/>FAQ</a></li><li><a href=https://mlir.llvm.org/getting_started/Contributing/>How to Contribute</a></li><li><a href=https://mlir.llvm.org/getting_started/DeveloperGuide/>Developer Guide</a></li><li><a href=https://mlir.llvm.org/getting_started/openprojects/>Open Projects</a></li><li><a href=https://mlir.llvm.org/getting_started/Glossary/>Glossary</a></li><li><a href=https://mlir.llvm.org/getting_started/TestingGuide/>Testing Guide</a></li></ul></li><li class="parent has-sub-menu"><a href=https://mlir.llvm.org/docs/>Code Documentation<span class="mark opened">-</span></a><ul class=sub-menu><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Bindings/>Bindings<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Bindings/Python/>MLIR Python Bindings</a></li></ul></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Tools/>Tools<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Tools/MLIRLSP/>MLIR : Language Server Protocol</a></li><li><a href=https://mlir.llvm.org/docs/Tools/mlir-reduce/>MLIR Reduce</a></li><li><a href=https://mlir.llvm.org/docs/Tools/mlir-rewrite/>mlir-rewrite</a></li></ul></li><li><a href=https://mlir.llvm.org/docs/QuantPasses/></a></li><li><a href=https://mlir.llvm.org/docs/ActionTracing/>Action: Tracing and Debugging MLIR-based Compilers</a></li><li><a href=https://mlir.llvm.org/docs/BufferDeallocationInternals/>Buffer Deallocation - Internals</a></li><li><a href=https://mlir.llvm.org/docs/Bufferization/>Bufferization</a></li><li><a href=https://mlir.llvm.org/docs/DataLayout/>Data Layout Modeling</a></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/DefiningDialects/>Defining Dialects<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/DefiningDialects/Constraints/>Constraints</a></li><li><a href=https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/>Defining Dialect Attributes and Types</a></li><li><a href=https://mlir.llvm.org/docs/DefiningDialects/Operations/>Operation Definition Specification (ODS)</a></li></ul></li><li><a href=https://mlir.llvm.org/docs/Diagnostics/>Diagnostic Infrastructure</a></li><li><a href=https://mlir.llvm.org/docs/DialectConversion/>Dialect Conversion</a></li><li class="parent has-sub-menu"><a href=https://mlir.llvm.org/docs/Dialects/>Dialects<span class="mark opened">-</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Dialects/DLTITransformOps/></a></li><li><a href=https://mlir.llvm.org/docs/Dialects/OpenACCDialect/>'acc' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/Affine/>'affine' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/AMDGPU/>'amdgpu' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/AMX/>'amx' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/ArithOps/>'arith' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/ArmNeon/>'arm_neon' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/ArmSVE/>'arm_sve' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/ArmSME/>'ArmSME' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/AsyncDialect/>'async' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/BufferizationOps/>'bufferization' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/ControlFlowDialect/>'cf' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/ComplexOps/>'complex' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/DLTIDialect/>'dlti' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/EmitC/>'emitc' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/Func/>'func' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/GPU/>'gpu' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/IndexOps/>'index' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/IRDL/>'irdl' Dialect</a></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Dialects/Linalg/>'linalg' Dialect<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Dialects/Linalg/OpDSL/>Linalg OpDSL</a></li></ul></li><li><a href=https://mlir.llvm.org/docs/Dialects/LLVM/>'llvm' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/MathOps/>'math' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/MemRef/>'memref' Dialect</a></li><li class=active><a href=https://mlir.llvm.org/docs/Dialects/Mesh/>'mesh' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/MLProgramOps/>'ml_program' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/MPI/>'mpi' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/NVGPU/>'nvgpu' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/NVVMDialect/>'nvvm' Dialect</a></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Dialects/OpenMPDialect/>'omp' Dialect<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Dialects/OpenMPDialect/ODS/>ODS Documentation</a></li></ul></li><li><a href=https://mlir.llvm.org/docs/Dialects/PDLInterpOps/>'pdl_interp' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/PDLOps/>'pdl' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/PolynomialDialect/>'polynomial' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/PtrOps/>'ptr' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/QuantDialect/>'quant' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/ROCDLDialect/>'rocdl' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/SCFDialect/>'scf' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/ShapeDialect/>'shape' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/SparseTensorOps/>'sparse_tensor' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/TensorOps/>'tensor' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/UBOps/>'ub' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/VCIXDialect/>'vcix' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/Vector/>'vector' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/X86Vector/>'x86vector' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/XeGPU/>'xegpu' Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/Builtin/>Builtin Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/MatchOpInterfaces/>OpInterface definitions</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/SPIR-V/>SPIR-V Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/TOSA/>Tensor Operator Set Architecture (TOSA) Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Dialects/Transform/>Transform Dialect</a></li></ul></li><li><a href=https://mlir.llvm.org/docs/Interfaces/>Interfaces</a></li><li><a href=https://mlir.llvm.org/docs/TargetLLVMIR/>LLVM IR Target</a></li><li><a href=https://mlir.llvm.org/docs/BytecodeFormat/>MLIR Bytecode Format</a></li><li><a href=https://mlir.llvm.org/docs/CAPI/>MLIR C API</a></li><li><a href=https://mlir.llvm.org/docs/LangRef/>MLIR Language Reference</a></li><li><a href=https://mlir.llvm.org/docs/ReleaseNotes/>MLIR Release Notes</a></li><li><a href=https://mlir.llvm.org/docs/Canonicalization/>Operation Canonicalization</a></li><li><a href=https://mlir.llvm.org/docs/OwnershipBasedBufferDeallocation/>Ownership-based Buffer Deallocation</a></li><li><a href=https://mlir.llvm.org/docs/PassManagement/>Pass Infrastructure</a></li><li><a href=https://mlir.llvm.org/docs/Passes/>Passes</a></li><li><a href=https://mlir.llvm.org/docs/PatternRewriter/>Pattern Rewriting : Generic DAG-to-DAG Rewriting</a></li><li><a href=https://mlir.llvm.org/docs/PDLL/>PDLL - PDL Language</a></li><li><a href=https://mlir.llvm.org/docs/Quantization/>Quantization</a></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Rationale/>Rationale<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Rationale/RationaleGenericDAGRewriter/>Generic DAG Rewriter Infrastructure Rationale</a></li><li><a href=https://mlir.llvm.org/docs/Rationale/RationaleLinalgDialect/>Linalg Dialect Rationale: The Case For Compiler-Friendly Custom Operations</a></li><li><a href=https://mlir.llvm.org/docs/Rationale/Rationale/>MLIR Rationale</a></li><li><a href=https://mlir.llvm.org/docs/Rationale/MLIRForGraphAlgorithms/>MLIR: Incremental Application to Graph Algorithms in ML Frameworks</a></li><li><a href=https://mlir.llvm.org/docs/Rationale/RationaleSimplifiedPolyhedralForm/>MLIR: The case for a simplified polyhedral form</a></li><li><a href=https://mlir.llvm.org/docs/Rationale/SideEffectsAndSpeculation/>Side Effects & Speculation</a></li><li><a href=https://mlir.llvm.org/docs/Rationale/UsageOfConst/>Usage of 'const' in MLIR, for core IR types</a></li></ul></li><li><a href=https://mlir.llvm.org/docs/ShapeInference/>Shape Inference</a></li><li><a href=https://mlir.llvm.org/docs/SPIRVToLLVMDialectConversion/>SPIR-V Dialect to LLVM Dialect conversion manual</a></li><li><a href=https://mlir.llvm.org/docs/SymbolsAndSymbolTables/>Symbols and Symbol Tables</a></li><li><a href=https://mlir.llvm.org/docs/DeclarativeRewrites/>Table-driven Declarative Rewrite Rule (DRR)</a></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Traits/>Traits<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Traits/Broadcastable/>The `Broadcastable` Trait</a></li></ul></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Tutorials/>Tutorials<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Tutorials/CreatingADialect/>Creating a Dialect</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/QuickstartRewrites/>Quickstart tutorial to adding MLIR graph rewrite</a></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Tutorials/Toy/>Toy Tutorial<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Tutorials/Toy/Ch-1/>Chapter 1: Toy Language and AST</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/Toy/Ch-2/>Chapter 2: Emitting Basic MLIR</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/Toy/Ch-3/>Chapter 3: High-level Language-Specific Analysis and Transformation</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/Toy/Ch-4/>Chapter 4: Enabling Generic Transformation with Interfaces</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/Toy/Ch-5/>Chapter 5: Partial Lowering to Lower-Level Dialects for Optimization</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/Toy/Ch-6/>Chapter 6: Lowering to LLVM and CodeGeneration</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/Toy/Ch-7/>Chapter 7: Adding a Composite Type to Toy</a></li></ul></li><li class=has-sub-menu><a href=https://mlir.llvm.org/docs/Tutorials/transform/>Transform Dialect Tutorial<span class="mark closed">+</span></a><ul class=sub-menu><li><a href=https://mlir.llvm.org/docs/Tutorials/transform/Ch0/>Chapter 0: A Primer on “Structured” Linalg Operations</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/transform/Ch1/>Chapter 1: Combining Existing Transformations</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/transform/Ch2/>Chapter 2: Adding a Simple New Transformation Operation</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/transform/Ch3/>Chapter 3: More than Simple Transform Operations</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/transform/Ch4/>Chapter 4: Matching Payload with Transform Operations</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/transform/ChH/>Chapter H: Reproducing Halide Schedule</a></li></ul></li><li><a href=https://mlir.llvm.org/docs/Tutorials/UnderstandingTheIRStructure/>Understanding the IR Structure</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/MlirOpt/>Using `mlir-opt`</a></li><li><a href=https://mlir.llvm.org/docs/Tutorials/DataFlowAnalysis/>Writing DataFlow Analyses in MLIR</a></li></ul></li></ul></li></ul></nav><div class=sidebar-footer></div></div></div><a href=# id=backtothetop-fixed class=backtothetop data-backtothetop-duration=600 data-backtothetop-easing=easeOutQuart data-backtothetop-fixed-fadein=1000 data-backtothetop-fixed-fadeout=1000 data-backtothetop-fixed-bottom=10 data-backtothetop-fixed-right=20><span class="fa-layers fa-fw"><i class="fas fa-circle"></i> <i class="fas fa-arrow-circle-up"></i></span></a></div></body></html>