CINXE.COM
jax/jax/experimental/jax2tf at main · jax-ml/jax · GitHub
<!DOCTYPE html> <html lang="en" data-color-mode="auto" data-light-theme="light" data-dark-theme="dark" data-a11y-animated-images="system" data-a11y-link-underlines="true" > <head> <meta charset="utf-8"> <link rel="dns-prefetch" href="https://github.githubassets.com"> <link rel="dns-prefetch" href="https://avatars.githubusercontent.com"> <link rel="dns-prefetch" href="https://github-cloud.s3.amazonaws.com"> <link rel="dns-prefetch" href="https://user-images.githubusercontent.com/"> <link rel="preconnect" href="https://github.githubassets.com" crossorigin> <link rel="preconnect" href="https://avatars.githubusercontent.com"> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/light-7aa84bb7e11e.css" /><link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/dark-f65db3e8d171.css" /><link data-color-theme="dark_dimmed" crossorigin="anonymous" media="all" rel="stylesheet" data-href="https://github.githubassets.com/assets/dark_dimmed-a8258e3c6dda.css" /><link data-color-theme="dark_high_contrast" crossorigin="anonymous" media="all" rel="stylesheet" data-href="https://github.githubassets.com/assets/dark_high_contrast-7e97d834719c.css" /><link data-color-theme="dark_colorblind" crossorigin="anonymous" media="all" rel="stylesheet" data-href="https://github.githubassets.com/assets/dark_colorblind-01d869f460be.css" /><link data-color-theme="light_colorblind" crossorigin="anonymous" media="all" rel="stylesheet" data-href="https://github.githubassets.com/assets/light_colorblind-534f3e971240.css" /><link data-color-theme="light_high_contrast" crossorigin="anonymous" media="all" rel="stylesheet" data-href="https://github.githubassets.com/assets/light_high_contrast-a8cc7d138001.css" /><link data-color-theme="light_tritanopia" crossorigin="anonymous" media="all" rel="stylesheet" data-href="https://github.githubassets.com/assets/light_tritanopia-35e9dfdc4f9f.css" /><link data-color-theme="dark_tritanopia" crossorigin="anonymous" media="all" rel="stylesheet" data-href="https://github.githubassets.com/assets/dark_tritanopia-cf4cc5f62dfe.css" /> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/primer-primitives-d9abecd14f1e.css" /> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/primer-93aded0ee8a1.css" /> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/global-8bed0685a4b5.css" /> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/github-a954a02d9269.css" /> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/repository-4fce88777fa8.css" /> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/code-0210be90f4d3.css" /> <script type="application/json" id="client-env">{"locale":"en","featureFlags":["bypass_copilot_indexing_quota","copilot_immersive_file_preview","copilot_new_references_ui","copilot_bing_skill_ga","copilot_attach_folder_reference","copilot_personal_instructions","copilot_personal_instructions_templates","copilot_chat_repo_custom_instructions_preview","copilot_chat_retry_on_error","copilot_chat_persist_submitted_input","copilot_conversational_ux_history_refs","copilot_chat_shared_chat_input","copilot_chat_shared_topic_indicator","copilot_chat_shared_repo_sso_banner","copilot_editor_upsells","copilot_dotcom_chat_reduce_telemetry","copilot_implicit_context","copilot_no_floating_button","copilot_smell_icebreaker_ux","copilot_read_shared_conversation","dotcom_chat_client_side_skills","copilot_new_markdown_renderer","experimentation_azure_variant_endpoint","failbot_handle_non_errors","geojson_azure_maps","ghost_pilot_confidence_truncation_25","ghost_pilot_confidence_truncation_40","github_models_o3_mini_streaming","hovercard_accessibility","issues_react_remove_placeholders","issues_react_blur_item_picker_on_close","issues_react_include_bots_in_pickers","marketing_pages_search_explore_provider","remove_child_patch","sample_network_conn_type","swp_enterprise_contact_form","site_copilot_vscode_link_update","site_proxima_australia_update","issues_react_create_milestone","issues_react_cache_fix_workaround","lifecycle_label_name_updates"]}</script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/wp-runtime-8f98335c0ff3.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_oddbird_popover-polyfill_dist_popover_js-9da652f58479.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_arianotify-polyfill_ariaNotify-polyfill_js-node_modules_github_mi-3abb8f-d7e6bc799724.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_failbot_failbot_ts-4600dbf2d60a.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/environment-f04cb2a9fc8c.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_primer_behaviors_dist_esm_index_mjs-0dbb79f97f8f.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_selector-observer_dist_index_esm_js-f690fd9ae3d5.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_relative-time-element_dist_index_js-f6da4b3fa34c.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_auto-complete-element_dist_index_js-node_modules_github_catalyst_-8e9f78-a74b4e0a8a6b.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_text-expander-element_dist_index_js-78748950cb0c.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_filter-input-element_dist_index_js-node_modules_github_remote-inp-b5f1d7-a1760ffda83d.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_markdown-toolbar-element_dist_index_js-ceef33f593fa.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_file-attachment-element_dist_index_js-node_modules_primer_view-co-c44a69-f0c8a795d1fd.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/github-elements-44d18ad044b3.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/element-registry-b7096f9a808d.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_braintree_browser-detection_dist_browser-detection_js-node_modules_githu-2906d7-2a07a295af40.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_lit-html_lit-html_js-be8cb88f481b.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_mini-throttle_dist_index_js-node_modules_morphdom_dist_morphdom-e-7c534c-a4a1922eb55f.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_turbo_dist_turbo_es2017-esm_js-e3cbe28f1638.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_remote-form_dist_index_js-node_modules_delegated-events_dist_inde-893f9f-6cf3320416b8.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_color-convert_index_js-e3180fe3bcb3.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_quote-selection_dist_index_js-node_modules_github_session-resume_-947061-205cd97df772.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_updatable-content_updatable-content_ts-a1563f62660e.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/app_assets_modules_github_behaviors_task-list_ts-app_assets_modules_github_sso_ts-ui_packages-900dde-f48a418a99d4.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/app_assets_modules_github_sticky-scroll-into-view_ts-8fa27fd7fbb6.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/app_assets_modules_github_behaviors_ajax-error_ts-app_assets_modules_github_behaviors_include-87a4ae-e2caa5390f5a.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/app_assets_modules_github_behaviors_commenting_edit_ts-app_assets_modules_github_behaviors_ht-83c235-783fc7e142e5.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/behaviors-854fa1987fb5.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_delegated-events_dist_index_js-node_modules_github_catalyst_lib_index_js-f6223d90c7ba.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/notifications-global-e12489347ccf.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_virtualized-list_es_index_js-node_modules_github_template-parts_lib_index_js-96453a51f920.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_remote-form_dist_index_js-node_modules_delegated-events_dist_inde-70450e-eecf0d50276f.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/app_assets_modules_github_ref-selector_ts-0a7bffd2f129.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/codespaces-fe2c516230f3.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_filter-input-element_dist_index_js-node_modules_github_remote-inp-3eebbd-0763620ad7bf.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_mini-throttle_dist_decorators_js-node_modules_delegated-events_di-e161aa-9d41fb1b6c9e.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_file-attachment-element_dist_index_js-node_modules_github_remote--3c9c82-7238cfcdaa51.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/repositories-a4509a8583cd.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_mini-throttle_dist_index_js-node_modules_github_catalyst_lib_inde-dbbea9-26cce2010167.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/code-menu-6a5f60eab447.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/primer-react-8e38c0ecf8b7.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/react-core-218ef2d63cca.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/react-lib-f1bca44e0926.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/octicons-react-611691cca2f6.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_emotion_is-prop-valid_dist_emotion-is-prop-valid_esm_js-node_modules_emo-62da9f-2df2f32ec596.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_mini-throttle_dist_index_js-node_modules_stacktrace-parser_dist_s-e7dcdd-f7cc96ebae76.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_oddbird_popover-polyfill_dist_popover-fn_js-55fea94174bf.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_dompurify_dist_purify_js-b89b98661809.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_lodash-es__Stack_js-node_modules_lodash-es__Uint8Array_js-node_modules_l-4faaa6-4a736fde5c2f.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_lodash-es__baseIsEqual_js-8929eb9718d5.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_hydro-analytics-client_dist_analytics-client_js-node_modules_gith-853b24-f2006d2a5b98.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_focus-visible_dist_focus-visible_js-node_modules_fzy_js_index_js-node_mo-35e85b-b2842e98946f.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_aria-live_aria-live_ts-ui_packages_promise-with-resolvers-polyfill_promise-with-r-17c672-d6b5ea82572a.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_paths_index_ts-73c512ff3577.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_ref-selector_RefSelector_tsx-2cce17df147b.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_code-view-shared_utilities_web-worker_ts-ui_packages_code-view-shared_worker-jobs-a69584-34e5b559ff45.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_code-view-shared_hooks_use-canonical-object_ts-ui_packages_code-view-shared_hooks-a6859a-09c7f754ea79.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_repos-file-tree-view_repos-file-tree-view_ts-ui_packages_feature-request_FeatureR-648c3b-e2701dc83e0b.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/app_assets_modules_github_blob-anchor_ts-ui_packages_code-nav_code-nav_ts-ui_packages_filter--8253c1-91468a3354f9.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/react-code-view-684bc66bc078.js"></script> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/primer-react.f5cae0c610a2c4c470c2.module.css" /> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/react-code-view.ab7d8fac328c00e5e0cc.module.css" /> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/notifications-subscriptions-menu-eff84ecbf2b6.js"></script> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/primer-react.f5cae0c610a2c4c470c2.module.css" /> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/notifications-subscriptions-menu.1bcff9205c241e99cff2.module.css" /> <title>jax/jax/experimental/jax2tf at main · jax-ml/jax · GitHub</title> <meta name="route-pattern" content="/:user_id/:repository/tree/*name(/*path)" data-turbo-transient> <meta name="route-controller" content="files" data-turbo-transient> <meta name="route-action" content="disambiguate" data-turbo-transient> <meta name="current-catalog-service-hash" content="f3abb0cc802f3d7b95fc8762b94bdcb13bf39634c40c357301c4aa1d67a256fb"> <meta name="request-id" content="C236:93150:8E51A:A4BD7:67B3180E" data-pjax-transient="true"/><meta name="html-safe-nonce" content="6dfc28d358b8e7583b486acfcff80b26e8c155d8fc54943a6ea6372f71b344a9" data-pjax-transient="true"/><meta name="visitor-payload" content="eyJyZWZlcnJlciI6IiIsInJlcXVlc3RfaWQiOiJDMjM2OjkzMTUwOjhFNTFBOkE0QkQ3OjY3QjMxODBFIiwidmlzaXRvcl9pZCI6IjQ4NzI0ODMyMjMzMTIzOTgiLCJyZWdpb25fZWRnZSI6InNvdXRoZWFzdGFzaWEiLCJyZWdpb25fcmVuZGVyIjoic291dGhlYXN0YXNpYSJ9" data-pjax-transient="true"/><meta name="visitor-hmac" content="9a8c895c2156380f3a709fdf971ac5b512aad83fc71a2a7b6a0c1c1740556ddb" data-pjax-transient="true"/> <meta name="hovercard-subject-tag" content="repository:154739597" data-turbo-transient> <meta name="github-keyboard-shortcuts" content="repository,source-code,file-tree,copilot" data-turbo-transient="true" /> <meta name="selected-link" value="repo_source" data-turbo-transient> <link rel="assets" href="https://github.githubassets.com/"> <meta name="google-site-verification" content="Apib7-x98H0j5cPqHWwSMm6dNU4GmODRoqxLiDzdx9I"> <meta name="octolytics-url" content="https://collector.github.com/github/collect" /> <meta name="analytics-location" content="/<user-name>/<repo-name>/files/disambiguate" data-turbo-transient="true" /> <meta name="user-login" content=""> <meta name="viewport" content="width=device-width"> <meta name="description" content="Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - jax/jax/experimental/jax2tf at main · jax-ml/jax"> <link rel="search" type="application/opensearchdescription+xml" href="/opensearch.xml" title="GitHub"> <link rel="fluid-icon" href="https://github.com/fluidicon.png" title="GitHub"> <meta property="fb:app_id" content="1401488693436528"> <meta name="apple-itunes-app" content="app-id=1477376905, app-argument=https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf" /> <meta name="twitter:image" content="https://repository-images.githubusercontent.com/154739597/90607180-e100-11e9-8642-c65819bec604" /><meta name="twitter:site" content="@github" /><meta name="twitter:card" content="summary_large_image" /><meta name="twitter:title" content="jax/jax/experimental/jax2tf at main · jax-ml/jax" /><meta name="twitter:description" content="Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - jax-ml/jax" /> <meta property="og:image" content="https://repository-images.githubusercontent.com/154739597/90607180-e100-11e9-8642-c65819bec604" /><meta property="og:image:alt" content="Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - jax-ml/jax" /><meta property="og:site_name" content="GitHub" /><meta property="og:type" content="object" /><meta property="og:title" content="jax/jax/experimental/jax2tf at main · jax-ml/jax" /><meta property="og:url" content="https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf" /><meta property="og:description" content="Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - jax-ml/jax" /> <meta name="hostname" content="github.com"> <meta name="expected-hostname" content="github.com"> <meta http-equiv="x-pjax-version" content="7285c1d48b26dbf4a947aa96973b0576ccc264b2eed60ad331fecccb5cc407f9" data-turbo-track="reload"> <meta http-equiv="x-pjax-csp-version" content="ace39c3b6632770952207593607e6e0be0db363435a8b877b1f96abe6430f345" data-turbo-track="reload"> <meta http-equiv="x-pjax-css-version" content="1c71206221e00a0a8e77d94d48d954f34ddbd711c4a0ced954fd49cd786cfa61" data-turbo-track="reload"> <meta http-equiv="x-pjax-js-version" content="e957203196628458a3ddb1546cfef8b3b07cafa937f5a0542f0fb75f6273f835" data-turbo-track="reload"> <meta name="turbo-cache-control" content="no-preview" data-turbo-transient=""> <meta name="turbo-cache-control" content="no-cache" data-turbo-transient> <meta data-hydrostats="publish"> <meta name="go-import" content="github.com/jax-ml/jax git https://github.com/jax-ml/jax.git"> <meta name="octolytics-dimension-user_id" content="58486408" /><meta name="octolytics-dimension-user_login" content="jax-ml" /><meta name="octolytics-dimension-repository_id" content="154739597" /><meta name="octolytics-dimension-repository_nwo" content="jax-ml/jax" /><meta name="octolytics-dimension-repository_public" content="true" /><meta name="octolytics-dimension-repository_is_fork" content="false" /><meta name="octolytics-dimension-repository_network_root_id" content="154739597" /><meta name="octolytics-dimension-repository_network_root_nwo" content="jax-ml/jax" /> <meta name="turbo-body-classes" content="logged-out env-production page-responsive"> <meta name="browser-stats-url" content="https://api.github.com/_private/browser/stats"> <meta name="browser-errors-url" content="https://api.github.com/_private/browser/errors"> <link rel="mask-icon" href="https://github.githubassets.com/assets/pinned-octocat-093da3e6fa40.svg" color="#000000"> <link rel="alternate icon" class="js-site-favicon" type="image/png" href="https://github.githubassets.com/favicons/favicon.png"> <link rel="icon" class="js-site-favicon" type="image/svg+xml" href="https://github.githubassets.com/favicons/favicon.svg" data-base-href="https://github.githubassets.com/favicons/favicon"> <meta name="theme-color" content="#1e2327"> <meta name="color-scheme" content="light dark" /> <link rel="manifest" href="/manifest.json" crossOrigin="use-credentials"> </head> <body class="logged-out env-production page-responsive" style="word-wrap: break-word;"> <div data-turbo-body class="logged-out env-production page-responsive" style="word-wrap: break-word;"> <div class="position-relative header-wrapper js-header-wrapper "> <a href="#start-of-content" data-skip-target-assigned="false" class="px-2 py-4 color-bg-accent-emphasis color-fg-on-emphasis show-on-focus js-skip-to-content">Skip to content</a> <span data-view-component="true" class="progress-pjax-loader Progress position-fixed width-full"> <span style="width: 0%;" data-view-component="true" class="Progress-item progress-pjax-loader-bar left-0 top-0 color-bg-accent-emphasis"></span> </span> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/ui_packages_ui-commands_ui-commands_ts-e571874765ef.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/keyboard-shortcuts-dialog-765cf28766da.js"></script> <link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/primer-react.f5cae0c610a2c4c470c2.module.css" /> <react-partial partial-name="keyboard-shortcuts-dialog" data-ssr="false" data-attempted-ssr="false" > <script type="application/json" data-target="react-partial.embeddedData">{"props":{"docsUrl":"https://docs.github.com/get-started/accessibility/keyboard-shortcuts"}}</script> <div data-target="react-partial.reactRoot"></div> </react-partial> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/vendors-node_modules_github_remote-form_dist_index_js-node_modules_delegated-events_dist_inde-94fd67-73b675cf164a.js"></script> <script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/sessions-2d195d11c56b.js"></script> <header class="HeaderMktg header-logged-out js-details-container js-header Details f4 py-3" role="banner" data-is-top="true" data-color-mode=light data-light-theme=light data-dark-theme=dark> <h2 class="sr-only">Navigation Menu</h2> <button type="button" class="HeaderMktg-backdrop d-lg-none border-0 position-fixed top-0 left-0 width-full height-full js-details-target" aria-label="Toggle navigation"> <span class="d-none">Toggle navigation</span> </button> <div class="d-flex flex-column flex-lg-row flex-items-center px-3 px-md-4 px-lg-5 height-full position-relative z-1"> <div class="d-flex flex-justify-between flex-items-center width-full width-lg-auto"> <div class="flex-1"> <button aria-label="Toggle navigation" aria-expanded="false" type="button" data-view-component="true" class="js-details-target js-nav-padding-recalculate js-header-menu-toggle Button--link Button--medium Button d-lg-none color-fg-inherit p-1"> <span class="Button-content"> <span class="Button-label"><div class="HeaderMenu-toggle-bar rounded my-1"></div> <div class="HeaderMenu-toggle-bar rounded my-1"></div> <div class="HeaderMenu-toggle-bar rounded my-1"></div></span> </span> </button> </div> <a class="mr-lg-3 color-fg-inherit flex-order-2 js-prevent-focus-on-mobile-nav" href="/" aria-label="Homepage" data-analytics-event="{"category":"Marketing nav","action":"click to go to homepage","label":"ref_page:Marketing;ref_cta:Logomark;ref_loc:Header"}"> <svg height="32" aria-hidden="true" viewBox="0 0 24 24" version="1.1" width="32" data-view-component="true" class="octicon octicon-mark-github"> <path d="M12.5.75C6.146.75 1 5.896 1 12.25c0 5.089 3.292 9.387 7.863 10.91.575.101.79-.244.79-.546 0-.273-.014-1.178-.014-2.142-2.889.532-3.636-.704-3.866-1.35-.13-.331-.69-1.352-1.18-1.625-.402-.216-.977-.748-.014-.762.906-.014 1.553.834 1.769 1.179 1.035 1.74 2.688 1.25 3.349.948.1-.747.402-1.25.733-1.538-2.559-.287-5.232-1.279-5.232-5.678 0-1.25.445-2.285 1.178-3.09-.115-.288-.517-1.467.115-3.048 0 0 .963-.302 3.163 1.179.92-.259 1.897-.388 2.875-.388.977 0 1.955.13 2.875.388 2.2-1.495 3.162-1.179 3.162-1.179.633 1.581.23 2.76.115 3.048.733.805 1.179 1.825 1.179 3.09 0 4.413-2.688 5.39-5.247 5.678.417.36.776 1.05.776 2.128 0 1.538-.014 2.774-.014 3.162 0 .302.216.662.79.547C20.709 21.637 24 17.324 24 12.25 24 5.896 18.854.75 12.5.75Z"></path> </svg> </a> <div class="flex-1 flex-order-2 text-right"> <a href="/login?return_to=https%3A%2F%2Fgithub.com%2Fjax-ml%2Fjax%2Ftree%2Fmain%2Fjax%2Fexperimental%2Fjax2tf" class="HeaderMenu-link HeaderMenu-button d-inline-flex d-lg-none flex-order-1 f5 no-underline border color-border-default rounded-2 px-2 py-1 color-fg-inherit js-prevent-focus-on-mobile-nav" data-hydro-click="{"event_type":"authentication.click","payload":{"location_in_page":"site header menu","repository_id":null,"auth_type":"SIGN_UP","originating_url":"https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf","user_id":null}}" data-hydro-click-hmac="97202bb3b4afeb6197315d11acbc25c8851862d74986185bef54c08fafb65461" data-analytics-event="{"category":"Marketing nav","action":"click to Sign in","label":"ref_page:Marketing;ref_cta:Sign in;ref_loc:Header"}" > Sign in </a> </div> </div> <div class="HeaderMenu js-header-menu height-fit position-lg-relative d-lg-flex flex-column flex-auto top-0"> <div class="HeaderMenu-wrapper d-flex flex-column flex-self-start flex-lg-row flex-auto rounded rounded-lg-0"> <nav class="HeaderMenu-nav" aria-label="Global"> <ul class="d-lg-flex list-style-none"> <li class="HeaderMenu-item position-relative flex-wrap flex-justify-between flex-items-center d-block d-lg-flex flex-lg-nowrap flex-lg-items-center js-details-container js-header-menu-item"> <button type="button" class="HeaderMenu-link border-0 width-full width-lg-auto px-0 px-lg-2 py-lg-2 no-wrap d-flex flex-items-center flex-justify-between js-details-target" aria-expanded="false"> Product <svg opacity="0.5" aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-chevron-down HeaderMenu-icon ml-1"> <path d="M12.78 5.22a.749.749 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.06 0L3.22 6.28a.749.749 0 1 1 1.06-1.06L8 8.939l3.72-3.719a.749.749 0 0 1 1.06 0Z"></path> </svg> </button> <div class="HeaderMenu-dropdown dropdown-menu rounded m-0 p-0 pt-2 pt-lg-4 position-relative position-lg-absolute left-0 left-lg-n3 pb-2 pb-lg-4 d-lg-flex flex-wrap dropdown-menu-wide"> <div class="HeaderMenu-column px-lg-4 border-lg-right mb-4 mb-lg-0 pr-lg-7"> <div class="border-bottom pb-3 pb-lg-0 border-lg-bottom-0"> <ul class="list-style-none f5" > <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"github_copilot","context":"product","tag":"link","label":"github_copilot_link_product_navbar"}" href="https://github.com/features/copilot"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-copilot color-fg-subtle mr-3"> <path d="M23.922 16.992c-.861 1.495-5.859 5.023-11.922 5.023-6.063 0-11.061-3.528-11.922-5.023A.641.641 0 0 1 0 16.736v-2.869a.841.841 0 0 1 .053-.22c.372-.935 1.347-2.292 2.605-2.656.167-.429.414-1.055.644-1.517a10.195 10.195 0 0 1-.052-1.086c0-1.331.282-2.499 1.132-3.368.397-.406.89-.717 1.474-.952 1.399-1.136 3.392-2.093 6.122-2.093 2.731 0 4.767.957 6.166 2.093.584.235 1.077.546 1.474.952.85.869 1.132 2.037 1.132 3.368 0 .368-.014.733-.052 1.086.23.462.477 1.088.644 1.517 1.258.364 2.233 1.721 2.605 2.656a.832.832 0 0 1 .053.22v2.869a.641.641 0 0 1-.078.256ZM12.172 11h-.344a4.323 4.323 0 0 1-.355.508C10.703 12.455 9.555 13 7.965 13c-1.725 0-2.989-.359-3.782-1.259a2.005 2.005 0 0 1-.085-.104L4 11.741v6.585c1.435.779 4.514 2.179 8 2.179 3.486 0 6.565-1.4 8-2.179v-6.585l-.098-.104s-.033.045-.085.104c-.793.9-2.057 1.259-3.782 1.259-1.59 0-2.738-.545-3.508-1.492a4.323 4.323 0 0 1-.355-.508h-.016.016Zm.641-2.935c.136 1.057.403 1.913.878 2.497.442.544 1.134.938 2.344.938 1.573 0 2.292-.337 2.657-.751.384-.435.558-1.15.558-2.361 0-1.14-.243-1.847-.705-2.319-.477-.488-1.319-.862-2.824-1.025-1.487-.161-2.192.138-2.533.529-.269.307-.437.808-.438 1.578v.021c0 .265.021.562.063.893Zm-1.626 0c.042-.331.063-.628.063-.894v-.02c-.001-.77-.169-1.271-.438-1.578-.341-.391-1.046-.69-2.533-.529-1.505.163-2.347.537-2.824 1.025-.462.472-.705 1.179-.705 2.319 0 1.211.175 1.926.558 2.361.365.414 1.084.751 2.657.751 1.21 0 1.902-.394 2.344-.938.475-.584.742-1.44.878-2.497Z"></path><path d="M14.5 14.25a1 1 0 0 1 1 1v2a1 1 0 0 1-2 0v-2a1 1 0 0 1 1-1Zm-5 0a1 1 0 0 1 1 1v2a1 1 0 0 1-2 0v-2a1 1 0 0 1 1-1Z"></path> </svg> <div> <div class="color-fg-default h4">GitHub Copilot</div> Write better code with AI </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"security","context":"product","tag":"link","label":"security_link_product_navbar"}" href="https://github.com/features/security"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-shield-check color-fg-subtle mr-3"> <path d="M16.53 9.78a.75.75 0 0 0-1.06-1.06L11 13.19l-1.97-1.97a.75.75 0 0 0-1.06 1.06l2.5 2.5a.75.75 0 0 0 1.06 0l5-5Z"></path><path d="m12.54.637 8.25 2.675A1.75 1.75 0 0 1 22 4.976V10c0 6.19-3.771 10.704-9.401 12.83a1.704 1.704 0 0 1-1.198 0C5.77 20.705 2 16.19 2 10V4.976c0-.758.489-1.43 1.21-1.664L11.46.637a1.748 1.748 0 0 1 1.08 0Zm-.617 1.426-8.25 2.676a.249.249 0 0 0-.173.237V10c0 5.46 3.28 9.483 8.43 11.426a.199.199 0 0 0 .14 0C17.22 19.483 20.5 15.461 20.5 10V4.976a.25.25 0 0 0-.173-.237l-8.25-2.676a.253.253 0 0 0-.154 0Z"></path> </svg> <div> <div class="color-fg-default h4">Security</div> Find and fix vulnerabilities </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"actions","context":"product","tag":"link","label":"actions_link_product_navbar"}" href="https://github.com/features/actions"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-workflow color-fg-subtle mr-3"> <path d="M1 3a2 2 0 0 1 2-2h6.5a2 2 0 0 1 2 2v6.5a2 2 0 0 1-2 2H7v4.063C7 16.355 7.644 17 8.438 17H12.5v-2.5a2 2 0 0 1 2-2H21a2 2 0 0 1 2 2V21a2 2 0 0 1-2 2h-6.5a2 2 0 0 1-2-2v-2.5H8.437A2.939 2.939 0 0 1 5.5 15.562V11.5H3a2 2 0 0 1-2-2Zm2-.5a.5.5 0 0 0-.5.5v6.5a.5.5 0 0 0 .5.5h6.5a.5.5 0 0 0 .5-.5V3a.5.5 0 0 0-.5-.5ZM14.5 14a.5.5 0 0 0-.5.5V21a.5.5 0 0 0 .5.5H21a.5.5 0 0 0 .5-.5v-6.5a.5.5 0 0 0-.5-.5Z"></path> </svg> <div> <div class="color-fg-default h4">Actions</div> Automate any workflow </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"codespaces","context":"product","tag":"link","label":"codespaces_link_product_navbar"}" href="https://github.com/features/codespaces"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-codespaces color-fg-subtle mr-3"> <path d="M3.5 3.75C3.5 2.784 4.284 2 5.25 2h13.5c.966 0 1.75.784 1.75 1.75v7.5A1.75 1.75 0 0 1 18.75 13H5.25a1.75 1.75 0 0 1-1.75-1.75Zm-2 12c0-.966.784-1.75 1.75-1.75h17.5c.966 0 1.75.784 1.75 1.75v4a1.75 1.75 0 0 1-1.75 1.75H3.25a1.75 1.75 0 0 1-1.75-1.75ZM5.25 3.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h13.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Zm-2 12a.25.25 0 0 0-.25.25v4c0 .138.112.25.25.25h17.5a.25.25 0 0 0 .25-.25v-4a.25.25 0 0 0-.25-.25Z"></path><path d="M10 17.75a.75.75 0 0 1 .75-.75h6.5a.75.75 0 0 1 0 1.5h-6.5a.75.75 0 0 1-.75-.75Zm-4 0a.75.75 0 0 1 .75-.75h.5a.75.75 0 0 1 0 1.5h-.5a.75.75 0 0 1-.75-.75Z"></path> </svg> <div> <div class="color-fg-default h4">Codespaces</div> Instant dev environments </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"issues","context":"product","tag":"link","label":"issues_link_product_navbar"}" href="https://github.com/features/issues"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-issue-opened color-fg-subtle mr-3"> <path d="M12 1c6.075 0 11 4.925 11 11s-4.925 11-11 11S1 18.075 1 12 5.925 1 12 1ZM2.5 12a9.5 9.5 0 0 0 9.5 9.5 9.5 9.5 0 0 0 9.5-9.5A9.5 9.5 0 0 0 12 2.5 9.5 9.5 0 0 0 2.5 12Zm9.5 2a2 2 0 1 1-.001-3.999A2 2 0 0 1 12 14Z"></path> </svg> <div> <div class="color-fg-default h4">Issues</div> Plan and track work </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"code_review","context":"product","tag":"link","label":"code_review_link_product_navbar"}" href="https://github.com/features/code-review"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-code-review color-fg-subtle mr-3"> <path d="M10.3 6.74a.75.75 0 0 1-.04 1.06l-2.908 2.7 2.908 2.7a.75.75 0 1 1-1.02 1.1l-3.5-3.25a.75.75 0 0 1 0-1.1l3.5-3.25a.75.75 0 0 1 1.06.04Zm3.44 1.06a.75.75 0 1 1 1.02-1.1l3.5 3.25a.75.75 0 0 1 0 1.1l-3.5 3.25a.75.75 0 1 1-1.02-1.1l2.908-2.7-2.908-2.7Z"></path><path d="M1.5 4.25c0-.966.784-1.75 1.75-1.75h17.5c.966 0 1.75.784 1.75 1.75v12.5a1.75 1.75 0 0 1-1.75 1.75h-9.69l-3.573 3.573A1.458 1.458 0 0 1 5 21.043V18.5H3.25a1.75 1.75 0 0 1-1.75-1.75ZM3.25 4a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h2.5a.75.75 0 0 1 .75.75v3.19l3.72-3.72a.749.749 0 0 1 .53-.22h10a.25.25 0 0 0 .25-.25V4.25a.25.25 0 0 0-.25-.25Z"></path> </svg> <div> <div class="color-fg-default h4">Code Review</div> Manage code changes </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"discussions","context":"product","tag":"link","label":"discussions_link_product_navbar"}" href="https://github.com/features/discussions"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-comment-discussion color-fg-subtle mr-3"> <path d="M1.75 1h12.5c.966 0 1.75.784 1.75 1.75v9.5A1.75 1.75 0 0 1 14.25 14H8.061l-2.574 2.573A1.458 1.458 0 0 1 3 15.543V14H1.75A1.75 1.75 0 0 1 0 12.25v-9.5C0 1.784.784 1 1.75 1ZM1.5 2.75v9.5c0 .138.112.25.25.25h2a.75.75 0 0 1 .75.75v2.19l2.72-2.72a.749.749 0 0 1 .53-.22h6.5a.25.25 0 0 0 .25-.25v-9.5a.25.25 0 0 0-.25-.25H1.75a.25.25 0 0 0-.25.25Z"></path><path d="M22.5 8.75a.25.25 0 0 0-.25-.25h-3.5a.75.75 0 0 1 0-1.5h3.5c.966 0 1.75.784 1.75 1.75v9.5A1.75 1.75 0 0 1 22.25 20H21v1.543a1.457 1.457 0 0 1-2.487 1.03L15.939 20H10.75A1.75 1.75 0 0 1 9 18.25v-1.465a.75.75 0 0 1 1.5 0v1.465c0 .138.112.25.25.25h5.5a.75.75 0 0 1 .53.22l2.72 2.72v-2.19a.75.75 0 0 1 .75-.75h2a.25.25 0 0 0 .25-.25v-9.5Z"></path> </svg> <div> <div class="color-fg-default h4">Discussions</div> Collaborate outside of code </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description" data-analytics-event="{"location":"navbar","action":"code_search","context":"product","tag":"link","label":"code_search_link_product_navbar"}" href="https://github.com/features/code-search"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-code-square color-fg-subtle mr-3"> <path d="M10.3 8.24a.75.75 0 0 1-.04 1.06L7.352 12l2.908 2.7a.75.75 0 1 1-1.02 1.1l-3.5-3.25a.75.75 0 0 1 0-1.1l3.5-3.25a.75.75 0 0 1 1.06.04Zm3.44 1.06a.75.75 0 1 1 1.02-1.1l3.5 3.25a.75.75 0 0 1 0 1.1l-3.5 3.25a.75.75 0 1 1-1.02-1.1l2.908-2.7-2.908-2.7Z"></path><path d="M2 3.75C2 2.784 2.784 2 3.75 2h16.5c.966 0 1.75.784 1.75 1.75v16.5A1.75 1.75 0 0 1 20.25 22H3.75A1.75 1.75 0 0 1 2 20.25Zm1.75-.25a.25.25 0 0 0-.25.25v16.5c0 .138.112.25.25.25h16.5a.25.25 0 0 0 .25-.25V3.75a.25.25 0 0 0-.25-.25Z"></path> </svg> <div> <div class="color-fg-default h4">Code Search</div> Find more, search less </div> </a></li> </ul> </div> </div> <div class="HeaderMenu-column px-lg-4"> <div class="border-bottom pb-3 pb-lg-0 border-lg-bottom-0 border-bottom-0"> <span class="d-block h4 color-fg-default my-1" id="product-explore-heading">Explore</span> <ul class="list-style-none f5" aria-labelledby="product-explore-heading"> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"all_features","context":"product","tag":"link","label":"all_features_link_product_navbar"}" href="https://github.com/features"> All features </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary Link--external" target="_blank" data-analytics-event="{"location":"navbar","action":"documentation","context":"product","tag":"link","label":"documentation_link_product_navbar"}" href="https://docs.github.com"> Documentation <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-link-external HeaderMenu-external-icon color-fg-subtle"> <path d="M3.75 2h3.5a.75.75 0 0 1 0 1.5h-3.5a.25.25 0 0 0-.25.25v8.5c0 .138.112.25.25.25h8.5a.25.25 0 0 0 .25-.25v-3.5a.75.75 0 0 1 1.5 0v3.5A1.75 1.75 0 0 1 12.25 14h-8.5A1.75 1.75 0 0 1 2 12.25v-8.5C2 2.784 2.784 2 3.75 2Zm6.854-1h4.146a.25.25 0 0 1 .25.25v4.146a.25.25 0 0 1-.427.177L13.03 4.03 9.28 7.78a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042l3.75-3.75-1.543-1.543A.25.25 0 0 1 10.604 1Z"></path> </svg> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary Link--external" target="_blank" data-analytics-event="{"location":"navbar","action":"github_skills","context":"product","tag":"link","label":"github_skills_link_product_navbar"}" href="https://skills.github.com"> GitHub Skills <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-link-external HeaderMenu-external-icon color-fg-subtle"> <path d="M3.75 2h3.5a.75.75 0 0 1 0 1.5h-3.5a.25.25 0 0 0-.25.25v8.5c0 .138.112.25.25.25h8.5a.25.25 0 0 0 .25-.25v-3.5a.75.75 0 0 1 1.5 0v3.5A1.75 1.75 0 0 1 12.25 14h-8.5A1.75 1.75 0 0 1 2 12.25v-8.5C2 2.784 2.784 2 3.75 2Zm6.854-1h4.146a.25.25 0 0 1 .25.25v4.146a.25.25 0 0 1-.427.177L13.03 4.03 9.28 7.78a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042l3.75-3.75-1.543-1.543A.25.25 0 0 1 10.604 1Z"></path> </svg> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary Link--external" target="_blank" data-analytics-event="{"location":"navbar","action":"blog","context":"product","tag":"link","label":"blog_link_product_navbar"}" href="https://github.blog"> Blog <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-link-external HeaderMenu-external-icon color-fg-subtle"> <path d="M3.75 2h3.5a.75.75 0 0 1 0 1.5h-3.5a.25.25 0 0 0-.25.25v8.5c0 .138.112.25.25.25h8.5a.25.25 0 0 0 .25-.25v-3.5a.75.75 0 0 1 1.5 0v3.5A1.75 1.75 0 0 1 12.25 14h-8.5A1.75 1.75 0 0 1 2 12.25v-8.5C2 2.784 2.784 2 3.75 2Zm6.854-1h4.146a.25.25 0 0 1 .25.25v4.146a.25.25 0 0 1-.427.177L13.03 4.03 9.28 7.78a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042l3.75-3.75-1.543-1.543A.25.25 0 0 1 10.604 1Z"></path> </svg> </a></li> </ul> </div> </div> </div> </li> <li class="HeaderMenu-item position-relative flex-wrap flex-justify-between flex-items-center d-block d-lg-flex flex-lg-nowrap flex-lg-items-center js-details-container js-header-menu-item"> <button type="button" class="HeaderMenu-link border-0 width-full width-lg-auto px-0 px-lg-2 py-lg-2 no-wrap d-flex flex-items-center flex-justify-between js-details-target" aria-expanded="false"> Solutions <svg opacity="0.5" aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-chevron-down HeaderMenu-icon ml-1"> <path d="M12.78 5.22a.749.749 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.06 0L3.22 6.28a.749.749 0 1 1 1.06-1.06L8 8.939l3.72-3.719a.749.749 0 0 1 1.06 0Z"></path> </svg> </button> <div class="HeaderMenu-dropdown dropdown-menu rounded m-0 p-0 pt-2 pt-lg-4 position-relative position-lg-absolute left-0 left-lg-n3 d-lg-flex flex-wrap dropdown-menu-wide"> <div class="HeaderMenu-column px-lg-4 border-lg-right mb-4 mb-lg-0 pr-lg-7"> <div class="border-bottom pb-3 pb-lg-0 border-lg-bottom-0 pb-lg-3 mb-3 mb-lg-0"> <span class="d-block h4 color-fg-default my-1" id="solutions-by-company-size-heading">By company size</span> <ul class="list-style-none f5" aria-labelledby="solutions-by-company-size-heading"> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"enterprises","context":"solutions","tag":"link","label":"enterprises_link_solutions_navbar"}" href="https://github.com/enterprise"> Enterprises </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"small_and_medium_teams","context":"solutions","tag":"link","label":"small_and_medium_teams_link_solutions_navbar"}" href="https://github.com/team"> Small and medium teams </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"startups","context":"solutions","tag":"link","label":"startups_link_solutions_navbar"}" href="https://github.com/enterprise/startups"> Startups </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"nonprofits","context":"solutions","tag":"link","label":"nonprofits_link_solutions_navbar"}" href="/solutions/industry/nonprofits"> Nonprofits </a></li> </ul> </div> <div class="border-bottom pb-3 pb-lg-0 border-lg-bottom-0"> <span class="d-block h4 color-fg-default my-1" id="solutions-by-use-case-heading">By use case</span> <ul class="list-style-none f5" aria-labelledby="solutions-by-use-case-heading"> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"devsecops","context":"solutions","tag":"link","label":"devsecops_link_solutions_navbar"}" href="/solutions/use-case/devsecops"> DevSecOps </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"devops","context":"solutions","tag":"link","label":"devops_link_solutions_navbar"}" href="/solutions/use-case/devops"> DevOps </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"ci_cd","context":"solutions","tag":"link","label":"ci_cd_link_solutions_navbar"}" href="/solutions/use-case/ci-cd"> CI/CD </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"view_all_use_cases","context":"solutions","tag":"link","label":"view_all_use_cases_link_solutions_navbar"}" href="/solutions/use-case"> View all use cases </a></li> </ul> </div> </div> <div class="HeaderMenu-column px-lg-4"> <div class="border-bottom pb-3 pb-lg-0 border-lg-bottom-0"> <span class="d-block h4 color-fg-default my-1" id="solutions-by-industry-heading">By industry</span> <ul class="list-style-none f5" aria-labelledby="solutions-by-industry-heading"> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"healthcare","context":"solutions","tag":"link","label":"healthcare_link_solutions_navbar"}" href="/solutions/industry/healthcare"> Healthcare </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"financial_services","context":"solutions","tag":"link","label":"financial_services_link_solutions_navbar"}" href="/solutions/industry/financial-services"> Financial services </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"manufacturing","context":"solutions","tag":"link","label":"manufacturing_link_solutions_navbar"}" href="/solutions/industry/manufacturing"> Manufacturing </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"government","context":"solutions","tag":"link","label":"government_link_solutions_navbar"}" href="/solutions/industry/government"> Government </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"view_all_industries","context":"solutions","tag":"link","label":"view_all_industries_link_solutions_navbar"}" href="/solutions/industry"> View all industries </a></li> </ul> </div> </div> <div class="HeaderMenu-trailing-link rounded-bottom-2 flex-shrink-0 mt-lg-4 px-lg-4 py-4 py-lg-3 f5 text-semibold"> <a href="/solutions"> View all solutions <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-chevron-right HeaderMenu-trailing-link-icon"> <path d="M6.22 3.22a.75.75 0 0 1 1.06 0l4.25 4.25a.75.75 0 0 1 0 1.06l-4.25 4.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042L9.94 8 6.22 4.28a.75.75 0 0 1 0-1.06Z"></path> </svg> </a> </div> </div> </li> <li class="HeaderMenu-item position-relative flex-wrap flex-justify-between flex-items-center d-block d-lg-flex flex-lg-nowrap flex-lg-items-center js-details-container js-header-menu-item"> <button type="button" class="HeaderMenu-link border-0 width-full width-lg-auto px-0 px-lg-2 py-lg-2 no-wrap d-flex flex-items-center flex-justify-between js-details-target" aria-expanded="false"> Resources <svg opacity="0.5" aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-chevron-down HeaderMenu-icon ml-1"> <path d="M12.78 5.22a.749.749 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.06 0L3.22 6.28a.749.749 0 1 1 1.06-1.06L8 8.939l3.72-3.719a.749.749 0 0 1 1.06 0Z"></path> </svg> </button> <div class="HeaderMenu-dropdown dropdown-menu rounded m-0 p-0 pt-2 pt-lg-4 position-relative position-lg-absolute left-0 left-lg-n3 pb-2 pb-lg-4 d-lg-flex flex-wrap dropdown-menu-wide"> <div class="HeaderMenu-column px-lg-4 border-lg-right mb-4 mb-lg-0 pr-lg-7"> <div class="border-bottom pb-3 pb-lg-0 border-lg-bottom-0"> <span class="d-block h4 color-fg-default my-1" id="resources-topics-heading">Topics</span> <ul class="list-style-none f5" aria-labelledby="resources-topics-heading"> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"ai","context":"resources","tag":"link","label":"ai_link_resources_navbar"}" href="/resources/articles/ai"> AI </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"devops","context":"resources","tag":"link","label":"devops_link_resources_navbar"}" href="/resources/articles/devops"> DevOps </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"security","context":"resources","tag":"link","label":"security_link_resources_navbar"}" href="/resources/articles/security"> Security </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"software_development","context":"resources","tag":"link","label":"software_development_link_resources_navbar"}" href="/resources/articles/software-development"> Software Development </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"view_all","context":"resources","tag":"link","label":"view_all_link_resources_navbar"}" href="/resources/articles"> View all </a></li> </ul> </div> </div> <div class="HeaderMenu-column px-lg-4"> <div class="border-bottom pb-3 pb-lg-0 border-lg-bottom-0 border-bottom-0"> <span class="d-block h4 color-fg-default my-1" id="resources-explore-heading">Explore</span> <ul class="list-style-none f5" aria-labelledby="resources-explore-heading"> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary Link--external" target="_blank" data-analytics-event="{"location":"navbar","action":"learning_pathways","context":"resources","tag":"link","label":"learning_pathways_link_resources_navbar"}" href="https://resources.github.com/learn/pathways"> Learning Pathways <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-link-external HeaderMenu-external-icon color-fg-subtle"> <path d="M3.75 2h3.5a.75.75 0 0 1 0 1.5h-3.5a.25.25 0 0 0-.25.25v8.5c0 .138.112.25.25.25h8.5a.25.25 0 0 0 .25-.25v-3.5a.75.75 0 0 1 1.5 0v3.5A1.75 1.75 0 0 1 12.25 14h-8.5A1.75 1.75 0 0 1 2 12.25v-8.5C2 2.784 2.784 2 3.75 2Zm6.854-1h4.146a.25.25 0 0 1 .25.25v4.146a.25.25 0 0 1-.427.177L13.03 4.03 9.28 7.78a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042l3.75-3.75-1.543-1.543A.25.25 0 0 1 10.604 1Z"></path> </svg> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary Link--external" target="_blank" data-analytics-event="{"location":"navbar","action":"white_papers_ebooks_webinars","context":"resources","tag":"link","label":"white_papers_ebooks_webinars_link_resources_navbar"}" href="https://resources.github.com"> White papers, Ebooks, Webinars <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-link-external HeaderMenu-external-icon color-fg-subtle"> <path d="M3.75 2h3.5a.75.75 0 0 1 0 1.5h-3.5a.25.25 0 0 0-.25.25v8.5c0 .138.112.25.25.25h8.5a.25.25 0 0 0 .25-.25v-3.5a.75.75 0 0 1 1.5 0v3.5A1.75 1.75 0 0 1 12.25 14h-8.5A1.75 1.75 0 0 1 2 12.25v-8.5C2 2.784 2.784 2 3.75 2Zm6.854-1h4.146a.25.25 0 0 1 .25.25v4.146a.25.25 0 0 1-.427.177L13.03 4.03 9.28 7.78a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042l3.75-3.75-1.543-1.543A.25.25 0 0 1 10.604 1Z"></path> </svg> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"customer_stories","context":"resources","tag":"link","label":"customer_stories_link_resources_navbar"}" href="https://github.com/customer-stories"> Customer Stories </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary Link--external" target="_blank" data-analytics-event="{"location":"navbar","action":"partners","context":"resources","tag":"link","label":"partners_link_resources_navbar"}" href="https://partner.github.com"> Partners <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-link-external HeaderMenu-external-icon color-fg-subtle"> <path d="M3.75 2h3.5a.75.75 0 0 1 0 1.5h-3.5a.25.25 0 0 0-.25.25v8.5c0 .138.112.25.25.25h8.5a.25.25 0 0 0 .25-.25v-3.5a.75.75 0 0 1 1.5 0v3.5A1.75 1.75 0 0 1 12.25 14h-8.5A1.75 1.75 0 0 1 2 12.25v-8.5C2 2.784 2.784 2 3.75 2Zm6.854-1h4.146a.25.25 0 0 1 .25.25v4.146a.25.25 0 0 1-.427.177L13.03 4.03 9.28 7.78a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042l3.75-3.75-1.543-1.543A.25.25 0 0 1 10.604 1Z"></path> </svg> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"executive_insights","context":"resources","tag":"link","label":"executive_insights_link_resources_navbar"}" href="https://github.com/solutions/executive-insights"> Executive Insights </a></li> </ul> </div> </div> </div> </li> <li class="HeaderMenu-item position-relative flex-wrap flex-justify-between flex-items-center d-block d-lg-flex flex-lg-nowrap flex-lg-items-center js-details-container js-header-menu-item"> <button type="button" class="HeaderMenu-link border-0 width-full width-lg-auto px-0 px-lg-2 py-lg-2 no-wrap d-flex flex-items-center flex-justify-between js-details-target" aria-expanded="false"> Open Source <svg opacity="0.5" aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-chevron-down HeaderMenu-icon ml-1"> <path d="M12.78 5.22a.749.749 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.06 0L3.22 6.28a.749.749 0 1 1 1.06-1.06L8 8.939l3.72-3.719a.749.749 0 0 1 1.06 0Z"></path> </svg> </button> <div class="HeaderMenu-dropdown dropdown-menu rounded m-0 p-0 pt-2 pt-lg-4 position-relative position-lg-absolute left-0 left-lg-n3 pb-2 pb-lg-4 px-lg-4"> <div class="HeaderMenu-column"> <div class="border-bottom pb-3 pb-lg-0 pb-lg-3 mb-3 mb-lg-0 mb-lg-3"> <ul class="list-style-none f5" > <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description" data-analytics-event="{"location":"navbar","action":"github_sponsors","context":"open_source","tag":"link","label":"github_sponsors_link_open_source_navbar"}" href="/sponsors"> <div> <div class="color-fg-default h4">GitHub Sponsors</div> Fund open source developers </div> </a></li> </ul> </div> <div class="border-bottom pb-3 pb-lg-0 pb-lg-3 mb-3 mb-lg-0 mb-lg-3"> <ul class="list-style-none f5" > <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description" data-analytics-event="{"location":"navbar","action":"the_readme_project","context":"open_source","tag":"link","label":"the_readme_project_link_open_source_navbar"}" href="https://github.com/readme"> <div> <div class="color-fg-default h4">The ReadME Project</div> GitHub community articles </div> </a></li> </ul> </div> <div class="border-bottom pb-3 pb-lg-0 border-bottom-0"> <span class="d-block h4 color-fg-default my-1" id="open-source-repositories-heading">Repositories</span> <ul class="list-style-none f5" aria-labelledby="open-source-repositories-heading"> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"topics","context":"open_source","tag":"link","label":"topics_link_open_source_navbar"}" href="https://github.com/topics"> Topics </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"trending","context":"open_source","tag":"link","label":"trending_link_open_source_navbar"}" href="https://github.com/trending"> Trending </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary" data-analytics-event="{"location":"navbar","action":"collections","context":"open_source","tag":"link","label":"collections_link_open_source_navbar"}" href="https://github.com/collections"> Collections </a></li> </ul> </div> </div> </div> </li> <li class="HeaderMenu-item position-relative flex-wrap flex-justify-between flex-items-center d-block d-lg-flex flex-lg-nowrap flex-lg-items-center js-details-container js-header-menu-item"> <button type="button" class="HeaderMenu-link border-0 width-full width-lg-auto px-0 px-lg-2 py-lg-2 no-wrap d-flex flex-items-center flex-justify-between js-details-target" aria-expanded="false"> Enterprise <svg opacity="0.5" aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-chevron-down HeaderMenu-icon ml-1"> <path d="M12.78 5.22a.749.749 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.06 0L3.22 6.28a.749.749 0 1 1 1.06-1.06L8 8.939l3.72-3.719a.749.749 0 0 1 1.06 0Z"></path> </svg> </button> <div class="HeaderMenu-dropdown dropdown-menu rounded m-0 p-0 pt-2 pt-lg-4 position-relative position-lg-absolute left-0 left-lg-n3 pb-2 pb-lg-4 px-lg-4"> <div class="HeaderMenu-column"> <div class="border-bottom pb-3 pb-lg-0 pb-lg-3 mb-3 mb-lg-0 mb-lg-3"> <ul class="list-style-none f5" > <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description" data-analytics-event="{"location":"navbar","action":"enterprise_platform","context":"enterprise","tag":"link","label":"enterprise_platform_link_enterprise_navbar"}" href="/enterprise"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-stack color-fg-subtle mr-3"> <path d="M11.063 1.456a1.749 1.749 0 0 1 1.874 0l8.383 5.316a1.751 1.751 0 0 1 0 2.956l-8.383 5.316a1.749 1.749 0 0 1-1.874 0L2.68 9.728a1.751 1.751 0 0 1 0-2.956Zm1.071 1.267a.25.25 0 0 0-.268 0L3.483 8.039a.25.25 0 0 0 0 .422l8.383 5.316a.25.25 0 0 0 .268 0l8.383-5.316a.25.25 0 0 0 0-.422Z"></path><path d="M1.867 12.324a.75.75 0 0 1 1.035-.232l8.964 5.685a.25.25 0 0 0 .268 0l8.964-5.685a.75.75 0 0 1 .804 1.267l-8.965 5.685a1.749 1.749 0 0 1-1.874 0l-8.965-5.685a.75.75 0 0 1-.231-1.035Z"></path><path d="M1.867 16.324a.75.75 0 0 1 1.035-.232l8.964 5.685a.25.25 0 0 0 .268 0l8.964-5.685a.75.75 0 0 1 .804 1.267l-8.965 5.685a1.749 1.749 0 0 1-1.874 0l-8.965-5.685a.75.75 0 0 1-.231-1.035Z"></path> </svg> <div> <div class="color-fg-default h4">Enterprise platform</div> AI-powered developer platform </div> </a></li> </ul> </div> <div class="border-bottom pb-3 pb-lg-0 border-bottom-0"> <span class="d-block h4 color-fg-default my-1" id="enterprise-available-add-ons-heading">Available add-ons</span> <ul class="list-style-none f5" aria-labelledby="enterprise-available-add-ons-heading"> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"advanced_security","context":"enterprise","tag":"link","label":"advanced_security_link_enterprise_navbar"}" href="https://github.com/enterprise/advanced-security"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-shield-check color-fg-subtle mr-3"> <path d="M16.53 9.78a.75.75 0 0 0-1.06-1.06L11 13.19l-1.97-1.97a.75.75 0 0 0-1.06 1.06l2.5 2.5a.75.75 0 0 0 1.06 0l5-5Z"></path><path d="m12.54.637 8.25 2.675A1.75 1.75 0 0 1 22 4.976V10c0 6.19-3.771 10.704-9.401 12.83a1.704 1.704 0 0 1-1.198 0C5.77 20.705 2 16.19 2 10V4.976c0-.758.489-1.43 1.21-1.664L11.46.637a1.748 1.748 0 0 1 1.08 0Zm-.617 1.426-8.25 2.676a.249.249 0 0 0-.173.237V10c0 5.46 3.28 9.483 8.43 11.426a.199.199 0 0 0 .14 0C17.22 19.483 20.5 15.461 20.5 10V4.976a.25.25 0 0 0-.173-.237l-8.25-2.676a.253.253 0 0 0-.154 0Z"></path> </svg> <div> <div class="color-fg-default h4">Advanced Security</div> Enterprise-grade security features </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description pb-lg-3" data-analytics-event="{"location":"navbar","action":"github_copilot","context":"enterprise","tag":"link","label":"github_copilot_link_enterprise_navbar"}" href="/features/copilot#enterprise"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-copilot color-fg-subtle mr-3"> <path d="M23.922 16.992c-.861 1.495-5.859 5.023-11.922 5.023-6.063 0-11.061-3.528-11.922-5.023A.641.641 0 0 1 0 16.736v-2.869a.841.841 0 0 1 .053-.22c.372-.935 1.347-2.292 2.605-2.656.167-.429.414-1.055.644-1.517a10.195 10.195 0 0 1-.052-1.086c0-1.331.282-2.499 1.132-3.368.397-.406.89-.717 1.474-.952 1.399-1.136 3.392-2.093 6.122-2.093 2.731 0 4.767.957 6.166 2.093.584.235 1.077.546 1.474.952.85.869 1.132 2.037 1.132 3.368 0 .368-.014.733-.052 1.086.23.462.477 1.088.644 1.517 1.258.364 2.233 1.721 2.605 2.656a.832.832 0 0 1 .053.22v2.869a.641.641 0 0 1-.078.256ZM12.172 11h-.344a4.323 4.323 0 0 1-.355.508C10.703 12.455 9.555 13 7.965 13c-1.725 0-2.989-.359-3.782-1.259a2.005 2.005 0 0 1-.085-.104L4 11.741v6.585c1.435.779 4.514 2.179 8 2.179 3.486 0 6.565-1.4 8-2.179v-6.585l-.098-.104s-.033.045-.085.104c-.793.9-2.057 1.259-3.782 1.259-1.59 0-2.738-.545-3.508-1.492a4.323 4.323 0 0 1-.355-.508h-.016.016Zm.641-2.935c.136 1.057.403 1.913.878 2.497.442.544 1.134.938 2.344.938 1.573 0 2.292-.337 2.657-.751.384-.435.558-1.15.558-2.361 0-1.14-.243-1.847-.705-2.319-.477-.488-1.319-.862-2.824-1.025-1.487-.161-2.192.138-2.533.529-.269.307-.437.808-.438 1.578v.021c0 .265.021.562.063.893Zm-1.626 0c.042-.331.063-.628.063-.894v-.02c-.001-.77-.169-1.271-.438-1.578-.341-.391-1.046-.69-2.533-.529-1.505.163-2.347.537-2.824 1.025-.462.472-.705 1.179-.705 2.319 0 1.211.175 1.926.558 2.361.365.414 1.084.751 2.657.751 1.21 0 1.902-.394 2.344-.938.475-.584.742-1.44.878-2.497Z"></path><path d="M14.5 14.25a1 1 0 0 1 1 1v2a1 1 0 0 1-2 0v-2a1 1 0 0 1 1-1Zm-5 0a1 1 0 0 1 1 1v2a1 1 0 0 1-2 0v-2a1 1 0 0 1 1-1Z"></path> </svg> <div> <div class="color-fg-default h4">GitHub Copilot</div> Enterprise-grade AI features </div> </a></li> <li> <a class="HeaderMenu-dropdown-link d-block no-underline position-relative py-2 Link--secondary d-flex flex-items-center Link--has-description" data-analytics-event="{"location":"navbar","action":"premium_support","context":"enterprise","tag":"link","label":"premium_support_link_enterprise_navbar"}" href="/premium-support"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-comment-discussion color-fg-subtle mr-3"> <path d="M1.75 1h12.5c.966 0 1.75.784 1.75 1.75v9.5A1.75 1.75 0 0 1 14.25 14H8.061l-2.574 2.573A1.458 1.458 0 0 1 3 15.543V14H1.75A1.75 1.75 0 0 1 0 12.25v-9.5C0 1.784.784 1 1.75 1ZM1.5 2.75v9.5c0 .138.112.25.25.25h2a.75.75 0 0 1 .75.75v2.19l2.72-2.72a.749.749 0 0 1 .53-.22h6.5a.25.25 0 0 0 .25-.25v-9.5a.25.25 0 0 0-.25-.25H1.75a.25.25 0 0 0-.25.25Z"></path><path d="M22.5 8.75a.25.25 0 0 0-.25-.25h-3.5a.75.75 0 0 1 0-1.5h3.5c.966 0 1.75.784 1.75 1.75v9.5A1.75 1.75 0 0 1 22.25 20H21v1.543a1.457 1.457 0 0 1-2.487 1.03L15.939 20H10.75A1.75 1.75 0 0 1 9 18.25v-1.465a.75.75 0 0 1 1.5 0v1.465c0 .138.112.25.25.25h5.5a.75.75 0 0 1 .53.22l2.72 2.72v-2.19a.75.75 0 0 1 .75-.75h2a.25.25 0 0 0 .25-.25v-9.5Z"></path> </svg> <div> <div class="color-fg-default h4">Premium Support</div> Enterprise-grade 24/7 support </div> </a></li> </ul> </div> </div> </div> </li> <li class="HeaderMenu-item position-relative flex-wrap flex-justify-between flex-items-center d-block d-lg-flex flex-lg-nowrap flex-lg-items-center js-details-container js-header-menu-item"> <a class="HeaderMenu-link no-underline px-0 px-lg-2 py-3 py-lg-2 d-block d-lg-inline-block" data-analytics-event="{"location":"navbar","action":"pricing","context":"global","tag":"link","label":"pricing_link_global_navbar"}" href="https://github.com/pricing">Pricing</a> </li> </ul> </nav> <div class="d-flex flex-column flex-lg-row width-full flex-justify-end flex-lg-items-center text-center mt-3 mt-lg-0 text-lg-left ml-lg-3"> <qbsearch-input class="search-input" data-scope="repo:jax-ml/jax" data-custom-scopes-path="/search/custom_scopes" data-delete-custom-scopes-csrf="vrR5dVN6SoxQshtsbXW8cBKre-70SDUyGFSIEFnxugoXG1QEIkR4cFvK6f43Oc0f1ovHVwKSWC6FrDdJ1RJfhw" data-max-custom-scopes="10" data-header-redesign-enabled="false" data-initial-value="" data-blackbird-suggestions-path="/search/suggestions" data-jump-to-suggestions-path="/_graphql/GetSuggestedNavigationDestinations" data-current-repository="jax-ml/jax" data-current-org="jax-ml" data-current-owner="" data-logged-in="false" data-copilot-chat-enabled="false" data-nl-search-enabled="false" data-retain-scroll-position="true"> <div class="search-input-container search-with-dialog position-relative d-flex flex-row flex-items-center mr-4 rounded" data-action="click:qbsearch-input#searchInputContainerClicked" > <button type="button" class="header-search-button placeholder input-button form-control d-flex flex-1 flex-self-stretch flex-items-center no-wrap width-full py-0 pl-2 pr-0 text-left border-0 box-shadow-none" data-target="qbsearch-input.inputButton" aria-label="Search or jump to…" aria-haspopup="dialog" placeholder="Search or jump to..." data-hotkey=s,/ autocapitalize="off" data-analytics-event="{"location":"navbar","action":"searchbar","context":"global","tag":"input","label":"searchbar_input_global_navbar"}" data-action="click:qbsearch-input#handleExpand" > <div class="mr-2 color-fg-muted"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-search"> <path d="M10.68 11.74a6 6 0 0 1-7.922-8.982 6 6 0 0 1 8.982 7.922l3.04 3.04a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215ZM11.5 7a4.499 4.499 0 1 0-8.997 0A4.499 4.499 0 0 0 11.5 7Z"></path> </svg> </div> <span class="flex-1" data-target="qbsearch-input.inputButtonText">Search or jump to...</span> <div class="d-flex" data-target="qbsearch-input.hotkeyIndicator"> <svg xmlns="http://www.w3.org/2000/svg" width="22" height="20" aria-hidden="true" class="mr-1"><path fill="none" stroke="#979A9C" opacity=".4" d="M3.5.5h12c1.7 0 3 1.3 3 3v13c0 1.7-1.3 3-3 3h-12c-1.7 0-3-1.3-3-3v-13c0-1.7 1.3-3 3-3z"></path><path fill="#979A9C" d="M11.8 6L8 15.1h-.9L10.8 6h1z"></path></svg> </div> </button> <input type="hidden" name="type" class="js-site-search-type-field"> <div class="Overlay--hidden " data-modal-dialog-overlay> <modal-dialog data-action="close:qbsearch-input#handleClose cancel:qbsearch-input#handleClose" data-target="qbsearch-input.searchSuggestionsDialog" role="dialog" id="search-suggestions-dialog" aria-modal="true" aria-labelledby="search-suggestions-dialog-header" data-view-component="true" class="Overlay Overlay--width-large Overlay--height-auto"> <h1 id="search-suggestions-dialog-header" class="sr-only">Search code, repositories, users, issues, pull requests...</h1> <div class="Overlay-body Overlay-body--paddingNone"> <div data-view-component="true"> <div class="search-suggestions position-fixed width-full color-shadow-large border color-fg-default color-bg-default overflow-hidden d-flex flex-column query-builder-container" style="border-radius: 12px;" data-target="qbsearch-input.queryBuilderContainer" hidden > <!-- '"` --><!-- </textarea></xmp> --></option></form><form id="query-builder-test-form" action="" accept-charset="UTF-8" method="get"> <query-builder data-target="qbsearch-input.queryBuilder" id="query-builder-query-builder-test" data-filter-key=":" data-view-component="true" class="QueryBuilder search-query-builder"> <div class="FormControl FormControl--fullWidth"> <label id="query-builder-test-label" for="query-builder-test" class="FormControl-label sr-only"> Search </label> <div class="QueryBuilder-StyledInput width-fit " data-target="query-builder.styledInput" > <span id="query-builder-test-leadingvisual-wrap" class="FormControl-input-leadingVisualWrap QueryBuilder-leadingVisualWrap"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-search FormControl-input-leadingVisual"> <path d="M10.68 11.74a6 6 0 0 1-7.922-8.982 6 6 0 0 1 8.982 7.922l3.04 3.04a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215ZM11.5 7a4.499 4.499 0 1 0-8.997 0A4.499 4.499 0 0 0 11.5 7Z"></path> </svg> </span> <div data-target="query-builder.styledInputContainer" class="QueryBuilder-StyledInputContainer"> <div aria-hidden="true" class="QueryBuilder-StyledInputContent" data-target="query-builder.styledInputContent" ></div> <div class="QueryBuilder-InputWrapper"> <div aria-hidden="true" class="QueryBuilder-Sizer" data-target="query-builder.sizer"></div> <input id="query-builder-test" name="query-builder-test" value="" autocomplete="off" type="text" role="combobox" spellcheck="false" aria-expanded="false" aria-describedby="validation-3757d968-61f4-4c81-a0b0-c92f6e5a4590" data-target="query-builder.input" data-action=" input:query-builder#inputChange blur:query-builder#inputBlur keydown:query-builder#inputKeydown focus:query-builder#inputFocus " data-view-component="true" class="FormControl-input QueryBuilder-Input FormControl-medium" /> </div> </div> <span class="sr-only" id="query-builder-test-clear">Clear</span> <button role="button" id="query-builder-test-clear-button" aria-labelledby="query-builder-test-clear query-builder-test-label" data-target="query-builder.clearButton" data-action=" click:query-builder#clear focus:query-builder#clearButtonFocus blur:query-builder#clearButtonBlur " variant="small" hidden="hidden" type="button" data-view-component="true" class="Button Button--iconOnly Button--invisible Button--medium mr-1 px-2 py-0 d-flex flex-items-center rounded-1 color-fg-muted"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-x-circle-fill Button-visual"> <path d="M2.343 13.657A8 8 0 1 1 13.658 2.343 8 8 0 0 1 2.343 13.657ZM6.03 4.97a.751.751 0 0 0-1.042.018.751.751 0 0 0-.018 1.042L6.94 8 4.97 9.97a.749.749 0 0 0 .326 1.275.749.749 0 0 0 .734-.215L8 9.06l1.97 1.97a.749.749 0 0 0 1.275-.326.749.749 0 0 0-.215-.734L9.06 8l1.97-1.97a.749.749 0 0 0-.326-1.275.749.749 0 0 0-.734.215L8 6.94Z"></path> </svg> </button> </div> <template id="search-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-search"> <path d="M10.68 11.74a6 6 0 0 1-7.922-8.982 6 6 0 0 1 8.982 7.922l3.04 3.04a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215ZM11.5 7a4.499 4.499 0 1 0-8.997 0A4.499 4.499 0 0 0 11.5 7Z"></path> </svg> </template> <template id="code-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-code"> <path d="m11.28 3.22 4.25 4.25a.75.75 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.275-.326.749.749 0 0 1 .215-.734L13.94 8l-3.72-3.72a.749.749 0 0 1 .326-1.275.749.749 0 0 1 .734.215Zm-6.56 0a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042L2.06 8l3.72 3.72a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L.47 8.53a.75.75 0 0 1 0-1.06Z"></path> </svg> </template> <template id="file-code-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-file-code"> <path d="M4 1.75C4 .784 4.784 0 5.75 0h5.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v8.586A1.75 1.75 0 0 1 14.25 15h-9a.75.75 0 0 1 0-1.5h9a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 10 4.25V1.5H5.75a.25.25 0 0 0-.25.25v2.5a.75.75 0 0 1-1.5 0Zm1.72 4.97a.75.75 0 0 1 1.06 0l2 2a.75.75 0 0 1 0 1.06l-2 2a.749.749 0 0 1-1.275-.326.749.749 0 0 1 .215-.734l1.47-1.47-1.47-1.47a.75.75 0 0 1 0-1.06ZM3.28 7.78 1.81 9.25l1.47 1.47a.751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018l-2-2a.75.75 0 0 1 0-1.06l2-2a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042Zm8.22-6.218V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path> </svg> </template> <template id="history-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-history"> <path d="m.427 1.927 1.215 1.215a8.002 8.002 0 1 1-1.6 5.685.75.75 0 1 1 1.493-.154 6.5 6.5 0 1 0 1.18-4.458l1.358 1.358A.25.25 0 0 1 3.896 6H.25A.25.25 0 0 1 0 5.75V2.104a.25.25 0 0 1 .427-.177ZM7.75 4a.75.75 0 0 1 .75.75v2.992l2.028.812a.75.75 0 0 1-.557 1.392l-2.5-1A.751.751 0 0 1 7 8.25v-3.5A.75.75 0 0 1 7.75 4Z"></path> </svg> </template> <template id="repo-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-repo"> <path d="M2 2.5A2.5 2.5 0 0 1 4.5 0h8.75a.75.75 0 0 1 .75.75v12.5a.75.75 0 0 1-.75.75h-2.5a.75.75 0 0 1 0-1.5h1.75v-2h-8a1 1 0 0 0-.714 1.7.75.75 0 1 1-1.072 1.05A2.495 2.495 0 0 1 2 11.5Zm10.5-1h-8a1 1 0 0 0-1 1v6.708A2.486 2.486 0 0 1 4.5 9h8ZM5 12.25a.25.25 0 0 1 .25-.25h3.5a.25.25 0 0 1 .25.25v3.25a.25.25 0 0 1-.4.2l-1.45-1.087a.249.249 0 0 0-.3 0L5.4 15.7a.25.25 0 0 1-.4-.2Z"></path> </svg> </template> <template id="bookmark-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-bookmark"> <path d="M3 2.75C3 1.784 3.784 1 4.75 1h6.5c.966 0 1.75.784 1.75 1.75v11.5a.75.75 0 0 1-1.227.579L8 11.722l-3.773 3.107A.751.751 0 0 1 3 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v9.91l3.023-2.489a.75.75 0 0 1 .954 0l3.023 2.49V2.75a.25.25 0 0 0-.25-.25Z"></path> </svg> </template> <template id="plus-circle-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-plus-circle"> <path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM1.5 8a6.5 6.5 0 1 0 13 0 6.5 6.5 0 0 0-13 0Zm7.25-3.25v2.5h2.5a.75.75 0 0 1 0 1.5h-2.5v2.5a.75.75 0 0 1-1.5 0v-2.5h-2.5a.75.75 0 0 1 0-1.5h2.5v-2.5a.75.75 0 0 1 1.5 0Z"></path> </svg> </template> <template id="circle-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-dot-fill"> <path d="M8 4a4 4 0 1 1 0 8 4 4 0 0 1 0-8Z"></path> </svg> </template> <template id="trash-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-trash"> <path d="M11 1.75V3h2.25a.75.75 0 0 1 0 1.5H2.75a.75.75 0 0 1 0-1.5H5V1.75C5 .784 5.784 0 6.75 0h2.5C10.216 0 11 .784 11 1.75ZM4.496 6.675l.66 6.6a.25.25 0 0 0 .249.225h5.19a.25.25 0 0 0 .249-.225l.66-6.6a.75.75 0 0 1 1.492.149l-.66 6.6A1.748 1.748 0 0 1 10.595 15h-5.19a1.75 1.75 0 0 1-1.741-1.575l-.66-6.6a.75.75 0 1 1 1.492-.15ZM6.5 1.75V3h3V1.75a.25.25 0 0 0-.25-.25h-2.5a.25.25 0 0 0-.25.25Z"></path> </svg> </template> <template id="team-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-people"> <path d="M2 5.5a3.5 3.5 0 1 1 5.898 2.549 5.508 5.508 0 0 1 3.034 4.084.75.75 0 1 1-1.482.235 4 4 0 0 0-7.9 0 .75.75 0 0 1-1.482-.236A5.507 5.507 0 0 1 3.102 8.05 3.493 3.493 0 0 1 2 5.5ZM11 4a3.001 3.001 0 0 1 2.22 5.018 5.01 5.01 0 0 1 2.56 3.012.749.749 0 0 1-.885.954.752.752 0 0 1-.549-.514 3.507 3.507 0 0 0-2.522-2.372.75.75 0 0 1-.574-.73v-.352a.75.75 0 0 1 .416-.672A1.5 1.5 0 0 0 11 5.5.75.75 0 0 1 11 4Zm-5.5-.5a2 2 0 1 0-.001 3.999A2 2 0 0 0 5.5 3.5Z"></path> </svg> </template> <template id="project-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-project"> <path d="M1.75 0h12.5C15.216 0 16 .784 16 1.75v12.5A1.75 1.75 0 0 1 14.25 16H1.75A1.75 1.75 0 0 1 0 14.25V1.75C0 .784.784 0 1.75 0ZM1.5 1.75v12.5c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25V1.75a.25.25 0 0 0-.25-.25H1.75a.25.25 0 0 0-.25.25ZM11.75 3a.75.75 0 0 1 .75.75v7.5a.75.75 0 0 1-1.5 0v-7.5a.75.75 0 0 1 .75-.75Zm-8.25.75a.75.75 0 0 1 1.5 0v5.5a.75.75 0 0 1-1.5 0ZM8 3a.75.75 0 0 1 .75.75v3.5a.75.75 0 0 1-1.5 0v-3.5A.75.75 0 0 1 8 3Z"></path> </svg> </template> <template id="pencil-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-pencil"> <path d="M11.013 1.427a1.75 1.75 0 0 1 2.474 0l1.086 1.086a1.75 1.75 0 0 1 0 2.474l-8.61 8.61c-.21.21-.47.364-.756.445l-3.251.93a.75.75 0 0 1-.927-.928l.929-3.25c.081-.286.235-.547.445-.758l8.61-8.61Zm.176 4.823L9.75 4.81l-6.286 6.287a.253.253 0 0 0-.064.108l-.558 1.953 1.953-.558a.253.253 0 0 0 .108-.064Zm1.238-3.763a.25.25 0 0 0-.354 0L10.811 3.75l1.439 1.44 1.263-1.263a.25.25 0 0 0 0-.354Z"></path> </svg> </template> <template id="copilot-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-copilot"> <path d="M7.998 15.035c-4.562 0-7.873-2.914-7.998-3.749V9.338c.085-.628.677-1.686 1.588-2.065.013-.07.024-.143.036-.218.029-.183.06-.384.126-.612-.201-.508-.254-1.084-.254-1.656 0-.87.128-1.769.693-2.484.579-.733 1.494-1.124 2.724-1.261 1.206-.134 2.262.034 2.944.765.05.053.096.108.139.165.044-.057.094-.112.143-.165.682-.731 1.738-.899 2.944-.765 1.23.137 2.145.528 2.724 1.261.566.715.693 1.614.693 2.484 0 .572-.053 1.148-.254 1.656.066.228.098.429.126.612.012.076.024.148.037.218.924.385 1.522 1.471 1.591 2.095v1.872c0 .766-3.351 3.795-8.002 3.795Zm0-1.485c2.28 0 4.584-1.11 5.002-1.433V7.862l-.023-.116c-.49.21-1.075.291-1.727.291-1.146 0-2.059-.327-2.71-.991A3.222 3.222 0 0 1 8 6.303a3.24 3.24 0 0 1-.544.743c-.65.664-1.563.991-2.71.991-.652 0-1.236-.081-1.727-.291l-.023.116v4.255c.419.323 2.722 1.433 5.002 1.433ZM6.762 2.83c-.193-.206-.637-.413-1.682-.297-1.019.113-1.479.404-1.713.7-.247.312-.369.789-.369 1.554 0 .793.129 1.171.308 1.371.162.181.519.379 1.442.379.853 0 1.339-.235 1.638-.54.315-.322.527-.827.617-1.553.117-.935-.037-1.395-.241-1.614Zm4.155-.297c-1.044-.116-1.488.091-1.681.297-.204.219-.359.679-.242 1.614.091.726.303 1.231.618 1.553.299.305.784.54 1.638.54.922 0 1.28-.198 1.442-.379.179-.2.308-.578.308-1.371 0-.765-.123-1.242-.37-1.554-.233-.296-.693-.587-1.713-.7Z"></path><path d="M6.25 9.037a.75.75 0 0 1 .75.75v1.501a.75.75 0 0 1-1.5 0V9.787a.75.75 0 0 1 .75-.75Zm4.25.75v1.501a.75.75 0 0 1-1.5 0V9.787a.75.75 0 0 1 1.5 0Z"></path> </svg> </template> <template id="copilot-error-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-copilot-error"> <path d="M16 11.24c0 .112-.072.274-.21.467L13 9.688V7.862l-.023-.116c-.49.21-1.075.291-1.727.291-.198 0-.388-.009-.571-.029L6.833 5.226a4.01 4.01 0 0 0 .17-.782c.117-.935-.037-1.395-.241-1.614-.193-.206-.637-.413-1.682-.297-.683.076-1.115.231-1.395.415l-1.257-.91c.579-.564 1.413-.877 2.485-.996 1.206-.134 2.262.034 2.944.765.05.053.096.108.139.165.044-.057.094-.112.143-.165.682-.731 1.738-.899 2.944-.765 1.23.137 2.145.528 2.724 1.261.566.715.693 1.614.693 2.484 0 .572-.053 1.148-.254 1.656.066.228.098.429.126.612.012.076.024.148.037.218.924.385 1.522 1.471 1.591 2.095Zm-5.083-8.707c-1.044-.116-1.488.091-1.681.297-.204.219-.359.679-.242 1.614.091.726.303 1.231.618 1.553.299.305.784.54 1.638.54.922 0 1.28-.198 1.442-.379.179-.2.308-.578.308-1.371 0-.765-.123-1.242-.37-1.554-.233-.296-.693-.587-1.713-.7Zm2.511 11.074c-1.393.776-3.272 1.428-5.43 1.428-4.562 0-7.873-2.914-7.998-3.749V9.338c.085-.628.677-1.686 1.588-2.065.013-.07.024-.143.036-.218.029-.183.06-.384.126-.612-.18-.455-.241-.963-.252-1.475L.31 4.107A.747.747 0 0 1 0 3.509V3.49a.748.748 0 0 1 .625-.73c.156-.026.306.047.435.139l14.667 10.578a.592.592 0 0 1 .227.264.752.752 0 0 1 .046.249v.022a.75.75 0 0 1-1.19.596Zm-1.367-.991L5.635 7.964a5.128 5.128 0 0 1-.889.073c-.652 0-1.236-.081-1.727-.291l-.023.116v4.255c.419.323 2.722 1.433 5.002 1.433 1.539 0 3.089-.505 4.063-.934Z"></path> </svg> </template> <template id="workflow-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-workflow"> <path d="M0 1.75C0 .784.784 0 1.75 0h3.5C6.216 0 7 .784 7 1.75v3.5A1.75 1.75 0 0 1 5.25 7H4v4a1 1 0 0 0 1 1h4v-1.25C9 9.784 9.784 9 10.75 9h3.5c.966 0 1.75.784 1.75 1.75v3.5A1.75 1.75 0 0 1 14.25 16h-3.5A1.75 1.75 0 0 1 9 14.25v-.75H5A2.5 2.5 0 0 1 2.5 11V7h-.75A1.75 1.75 0 0 1 0 5.25Zm1.75-.25a.25.25 0 0 0-.25.25v3.5c0 .138.112.25.25.25h3.5a.25.25 0 0 0 .25-.25v-3.5a.25.25 0 0 0-.25-.25Zm9 9a.25.25 0 0 0-.25.25v3.5c0 .138.112.25.25.25h3.5a.25.25 0 0 0 .25-.25v-3.5a.25.25 0 0 0-.25-.25Z"></path> </svg> </template> <template id="book-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-book"> <path d="M0 1.75A.75.75 0 0 1 .75 1h4.253c1.227 0 2.317.59 3 1.501A3.743 3.743 0 0 1 11.006 1h4.245a.75.75 0 0 1 .75.75v10.5a.75.75 0 0 1-.75.75h-4.507a2.25 2.25 0 0 0-1.591.659l-.622.621a.75.75 0 0 1-1.06 0l-.622-.621A2.25 2.25 0 0 0 5.258 13H.75a.75.75 0 0 1-.75-.75Zm7.251 10.324.004-5.073-.002-2.253A2.25 2.25 0 0 0 5.003 2.5H1.5v9h3.757a3.75 3.75 0 0 1 1.994.574ZM8.755 4.75l-.004 7.322a3.752 3.752 0 0 1 1.992-.572H14.5v-9h-3.495a2.25 2.25 0 0 0-2.25 2.25Z"></path> </svg> </template> <template id="code-review-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-code-review"> <path d="M1.75 1h12.5c.966 0 1.75.784 1.75 1.75v8.5A1.75 1.75 0 0 1 14.25 13H8.061l-2.574 2.573A1.458 1.458 0 0 1 3 14.543V13H1.75A1.75 1.75 0 0 1 0 11.25v-8.5C0 1.784.784 1 1.75 1ZM1.5 2.75v8.5c0 .138.112.25.25.25h2a.75.75 0 0 1 .75.75v2.19l2.72-2.72a.749.749 0 0 1 .53-.22h6.5a.25.25 0 0 0 .25-.25v-8.5a.25.25 0 0 0-.25-.25H1.75a.25.25 0 0 0-.25.25Zm5.28 1.72a.75.75 0 0 1 0 1.06L5.31 7l1.47 1.47a.751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018l-2-2a.75.75 0 0 1 0-1.06l2-2a.75.75 0 0 1 1.06 0Zm2.44 0a.75.75 0 0 1 1.06 0l2 2a.75.75 0 0 1 0 1.06l-2 2a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042L10.69 7 9.22 5.53a.75.75 0 0 1 0-1.06Z"></path> </svg> </template> <template id="codespaces-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-codespaces"> <path d="M0 11.25c0-.966.784-1.75 1.75-1.75h12.5c.966 0 1.75.784 1.75 1.75v3A1.75 1.75 0 0 1 14.25 16H1.75A1.75 1.75 0 0 1 0 14.25Zm2-9.5C2 .784 2.784 0 3.75 0h8.5C13.216 0 14 .784 14 1.75v5a1.75 1.75 0 0 1-1.75 1.75h-8.5A1.75 1.75 0 0 1 2 6.75Zm1.75-.25a.25.25 0 0 0-.25.25v5c0 .138.112.25.25.25h8.5a.25.25 0 0 0 .25-.25v-5a.25.25 0 0 0-.25-.25Zm-2 9.5a.25.25 0 0 0-.25.25v3c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25v-3a.25.25 0 0 0-.25-.25Z"></path><path d="M7 12.75a.75.75 0 0 1 .75-.75h4.5a.75.75 0 0 1 0 1.5h-4.5a.75.75 0 0 1-.75-.75Zm-4 0a.75.75 0 0 1 .75-.75h.5a.75.75 0 0 1 0 1.5h-.5a.75.75 0 0 1-.75-.75Z"></path> </svg> </template> <template id="comment-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-comment"> <path d="M1 2.75C1 1.784 1.784 1 2.75 1h10.5c.966 0 1.75.784 1.75 1.75v7.5A1.75 1.75 0 0 1 13.25 12H9.06l-2.573 2.573A1.458 1.458 0 0 1 4 13.543V12H2.75A1.75 1.75 0 0 1 1 10.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h2a.75.75 0 0 1 .75.75v2.19l2.72-2.72a.749.749 0 0 1 .53-.22h4.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path> </svg> </template> <template id="comment-discussion-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-comment-discussion"> <path d="M1.75 1h8.5c.966 0 1.75.784 1.75 1.75v5.5A1.75 1.75 0 0 1 10.25 10H7.061l-2.574 2.573A1.458 1.458 0 0 1 2 11.543V10h-.25A1.75 1.75 0 0 1 0 8.25v-5.5C0 1.784.784 1 1.75 1ZM1.5 2.75v5.5c0 .138.112.25.25.25h1a.75.75 0 0 1 .75.75v2.19l2.72-2.72a.749.749 0 0 1 .53-.22h3.5a.25.25 0 0 0 .25-.25v-5.5a.25.25 0 0 0-.25-.25h-8.5a.25.25 0 0 0-.25.25Zm13 2a.25.25 0 0 0-.25-.25h-.5a.75.75 0 0 1 0-1.5h.5c.966 0 1.75.784 1.75 1.75v5.5A1.75 1.75 0 0 1 14.25 12H14v1.543a1.458 1.458 0 0 1-2.487 1.03L9.22 12.28a.749.749 0 0 1 .326-1.275.749.749 0 0 1 .734.215l2.22 2.22v-2.19a.75.75 0 0 1 .75-.75h1a.25.25 0 0 0 .25-.25Z"></path> </svg> </template> <template id="organization-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-organization"> <path d="M1.75 16A1.75 1.75 0 0 1 0 14.25V1.75C0 .784.784 0 1.75 0h8.5C11.216 0 12 .784 12 1.75v12.5c0 .085-.006.168-.018.25h2.268a.25.25 0 0 0 .25-.25V8.285a.25.25 0 0 0-.111-.208l-1.055-.703a.749.749 0 1 1 .832-1.248l1.055.703c.487.325.779.871.779 1.456v5.965A1.75 1.75 0 0 1 14.25 16h-3.5a.766.766 0 0 1-.197-.026c-.099.017-.2.026-.303.026h-3a.75.75 0 0 1-.75-.75V14h-1v1.25a.75.75 0 0 1-.75.75Zm-.25-1.75c0 .138.112.25.25.25H4v-1.25a.75.75 0 0 1 .75-.75h2.5a.75.75 0 0 1 .75.75v1.25h2.25a.25.25 0 0 0 .25-.25V1.75a.25.25 0 0 0-.25-.25h-8.5a.25.25 0 0 0-.25.25ZM3.75 6h.5a.75.75 0 0 1 0 1.5h-.5a.75.75 0 0 1 0-1.5ZM3 3.75A.75.75 0 0 1 3.75 3h.5a.75.75 0 0 1 0 1.5h-.5A.75.75 0 0 1 3 3.75Zm4 3A.75.75 0 0 1 7.75 6h.5a.75.75 0 0 1 0 1.5h-.5A.75.75 0 0 1 7 6.75ZM7.75 3h.5a.75.75 0 0 1 0 1.5h-.5a.75.75 0 0 1 0-1.5ZM3 9.75A.75.75 0 0 1 3.75 9h.5a.75.75 0 0 1 0 1.5h-.5A.75.75 0 0 1 3 9.75ZM7.75 9h.5a.75.75 0 0 1 0 1.5h-.5a.75.75 0 0 1 0-1.5Z"></path> </svg> </template> <template id="rocket-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-rocket"> <path d="M14.064 0h.186C15.216 0 16 .784 16 1.75v.186a8.752 8.752 0 0 1-2.564 6.186l-.458.459c-.314.314-.641.616-.979.904v3.207c0 .608-.315 1.172-.833 1.49l-2.774 1.707a.749.749 0 0 1-1.11-.418l-.954-3.102a1.214 1.214 0 0 1-.145-.125L3.754 9.816a1.218 1.218 0 0 1-.124-.145L.528 8.717a.749.749 0 0 1-.418-1.11l1.71-2.774A1.748 1.748 0 0 1 3.31 4h3.204c.288-.338.59-.665.904-.979l.459-.458A8.749 8.749 0 0 1 14.064 0ZM8.938 3.623h-.002l-.458.458c-.76.76-1.437 1.598-2.02 2.5l-1.5 2.317 2.143 2.143 2.317-1.5c.902-.583 1.74-1.26 2.499-2.02l.459-.458a7.25 7.25 0 0 0 2.123-5.127V1.75a.25.25 0 0 0-.25-.25h-.186a7.249 7.249 0 0 0-5.125 2.123ZM3.56 14.56c-.732.732-2.334 1.045-3.005 1.148a.234.234 0 0 1-.201-.064.234.234 0 0 1-.064-.201c.103-.671.416-2.273 1.15-3.003a1.502 1.502 0 1 1 2.12 2.12Zm6.94-3.935c-.088.06-.177.118-.266.175l-2.35 1.521.548 1.783 1.949-1.2a.25.25 0 0 0 .119-.213ZM3.678 8.116 5.2 5.766c.058-.09.117-.178.176-.266H3.309a.25.25 0 0 0-.213.119l-1.2 1.95ZM12 5a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z"></path> </svg> </template> <template id="shield-check-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-shield-check"> <path d="m8.533.133 5.25 1.68A1.75 1.75 0 0 1 15 3.48V7c0 1.566-.32 3.182-1.303 4.682-.983 1.498-2.585 2.813-5.032 3.855a1.697 1.697 0 0 1-1.33 0c-2.447-1.042-4.049-2.357-5.032-3.855C1.32 10.182 1 8.566 1 7V3.48a1.75 1.75 0 0 1 1.217-1.667l5.25-1.68a1.748 1.748 0 0 1 1.066 0Zm-.61 1.429.001.001-5.25 1.68a.251.251 0 0 0-.174.237V7c0 1.36.275 2.666 1.057 3.859.784 1.194 2.121 2.342 4.366 3.298a.196.196 0 0 0 .154 0c2.245-.957 3.582-2.103 4.366-3.297C13.225 9.666 13.5 8.358 13.5 7V3.48a.25.25 0 0 0-.174-.238l-5.25-1.68a.25.25 0 0 0-.153 0ZM11.28 6.28l-3.5 3.5a.75.75 0 0 1-1.06 0l-1.5-1.5a.749.749 0 0 1 .326-1.275.749.749 0 0 1 .734.215l.97.97 2.97-2.97a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042Z"></path> </svg> </template> <template id="heart-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-heart"> <path d="m8 14.25.345.666a.75.75 0 0 1-.69 0l-.008-.004-.018-.01a7.152 7.152 0 0 1-.31-.17 22.055 22.055 0 0 1-3.434-2.414C2.045 10.731 0 8.35 0 5.5 0 2.836 2.086 1 4.25 1 5.797 1 7.153 1.802 8 3.02 8.847 1.802 10.203 1 11.75 1 13.914 1 16 2.836 16 5.5c0 2.85-2.045 5.231-3.885 6.818a22.066 22.066 0 0 1-3.744 2.584l-.018.01-.006.003h-.002ZM4.25 2.5c-1.336 0-2.75 1.164-2.75 3 0 2.15 1.58 4.144 3.365 5.682A20.58 20.58 0 0 0 8 13.393a20.58 20.58 0 0 0 3.135-2.211C12.92 9.644 14.5 7.65 14.5 5.5c0-1.836-1.414-3-2.75-3-1.373 0-2.609.986-3.029 2.456a.749.749 0 0 1-1.442 0C6.859 3.486 5.623 2.5 4.25 2.5Z"></path> </svg> </template> <template id="server-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-server"> <path d="M1.75 1h12.5c.966 0 1.75.784 1.75 1.75v4c0 .372-.116.717-.314 1 .198.283.314.628.314 1v4a1.75 1.75 0 0 1-1.75 1.75H1.75A1.75 1.75 0 0 1 0 12.75v-4c0-.358.109-.707.314-1a1.739 1.739 0 0 1-.314-1v-4C0 1.784.784 1 1.75 1ZM1.5 2.75v4c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25v-4a.25.25 0 0 0-.25-.25H1.75a.25.25 0 0 0-.25.25Zm.25 5.75a.25.25 0 0 0-.25.25v4c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25v-4a.25.25 0 0 0-.25-.25ZM7 4.75A.75.75 0 0 1 7.75 4h4.5a.75.75 0 0 1 0 1.5h-4.5A.75.75 0 0 1 7 4.75ZM7.75 10h4.5a.75.75 0 0 1 0 1.5h-4.5a.75.75 0 0 1 0-1.5ZM3 4.75A.75.75 0 0 1 3.75 4h.5a.75.75 0 0 1 0 1.5h-.5A.75.75 0 0 1 3 4.75ZM3.75 10h.5a.75.75 0 0 1 0 1.5h-.5a.75.75 0 0 1 0-1.5Z"></path> </svg> </template> <template id="globe-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-globe"> <path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM5.78 8.75a9.64 9.64 0 0 0 1.363 4.177c.255.426.542.832.857 1.215.245-.296.551-.705.857-1.215A9.64 9.64 0 0 0 10.22 8.75Zm4.44-1.5a9.64 9.64 0 0 0-1.363-4.177c-.307-.51-.612-.919-.857-1.215a9.927 9.927 0 0 0-.857 1.215A9.64 9.64 0 0 0 5.78 7.25Zm-5.944 1.5H1.543a6.507 6.507 0 0 0 4.666 5.5c-.123-.181-.24-.365-.352-.552-.715-1.192-1.437-2.874-1.581-4.948Zm-2.733-1.5h2.733c.144-2.074.866-3.756 1.58-4.948.12-.197.237-.381.353-.552a6.507 6.507 0 0 0-4.666 5.5Zm10.181 1.5c-.144 2.074-.866 3.756-1.58 4.948-.12.197-.237.381-.353.552a6.507 6.507 0 0 0 4.666-5.5Zm2.733-1.5a6.507 6.507 0 0 0-4.666-5.5c.123.181.24.365.353.552.714 1.192 1.436 2.874 1.58 4.948Z"></path> </svg> </template> <template id="issue-opened-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-issue-opened"> <path d="M8 9.5a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Z"></path><path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM1.5 8a6.5 6.5 0 1 0 13 0 6.5 6.5 0 0 0-13 0Z"></path> </svg> </template> <template id="device-mobile-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-device-mobile"> <path d="M3.75 0h8.5C13.216 0 14 .784 14 1.75v12.5A1.75 1.75 0 0 1 12.25 16h-8.5A1.75 1.75 0 0 1 2 14.25V1.75C2 .784 2.784 0 3.75 0ZM3.5 1.75v12.5c0 .138.112.25.25.25h8.5a.25.25 0 0 0 .25-.25V1.75a.25.25 0 0 0-.25-.25h-8.5a.25.25 0 0 0-.25.25ZM8 13a1 1 0 1 1 0-2 1 1 0 0 1 0 2Z"></path> </svg> </template> <template id="package-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-package"> <path d="m8.878.392 5.25 3.045c.54.314.872.89.872 1.514v6.098a1.75 1.75 0 0 1-.872 1.514l-5.25 3.045a1.75 1.75 0 0 1-1.756 0l-5.25-3.045A1.75 1.75 0 0 1 1 11.049V4.951c0-.624.332-1.201.872-1.514L7.122.392a1.75 1.75 0 0 1 1.756 0ZM7.875 1.69l-4.63 2.685L8 7.133l4.755-2.758-4.63-2.685a.248.248 0 0 0-.25 0ZM2.5 5.677v5.372c0 .09.047.171.125.216l4.625 2.683V8.432Zm6.25 8.271 4.625-2.683a.25.25 0 0 0 .125-.216V5.677L8.75 8.432Z"></path> </svg> </template> <template id="credit-card-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-credit-card"> <path d="M10.75 9a.75.75 0 0 0 0 1.5h1.5a.75.75 0 0 0 0-1.5h-1.5Z"></path><path d="M0 3.75C0 2.784.784 2 1.75 2h12.5c.966 0 1.75.784 1.75 1.75v8.5A1.75 1.75 0 0 1 14.25 14H1.75A1.75 1.75 0 0 1 0 12.25ZM14.5 6.5h-13v5.75c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25Zm0-2.75a.25.25 0 0 0-.25-.25H1.75a.25.25 0 0 0-.25.25V5h13Z"></path> </svg> </template> <template id="play-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-play"> <path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM1.5 8a6.5 6.5 0 1 0 13 0 6.5 6.5 0 0 0-13 0Zm4.879-2.773 4.264 2.559a.25.25 0 0 1 0 .428l-4.264 2.559A.25.25 0 0 1 6 10.559V5.442a.25.25 0 0 1 .379-.215Z"></path> </svg> </template> <template id="gift-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-gift"> <path d="M2 2.75A2.75 2.75 0 0 1 4.75 0c.983 0 1.873.42 2.57 1.232.268.318.497.668.68 1.042.183-.375.411-.725.68-1.044C9.376.42 10.266 0 11.25 0a2.75 2.75 0 0 1 2.45 4h.55c.966 0 1.75.784 1.75 1.75v2c0 .698-.409 1.301-1 1.582v4.918A1.75 1.75 0 0 1 13.25 16H2.75A1.75 1.75 0 0 1 1 14.25V9.332C.409 9.05 0 8.448 0 7.75v-2C0 4.784.784 4 1.75 4h.55c-.192-.375-.3-.8-.3-1.25ZM7.25 9.5H2.5v4.75c0 .138.112.25.25.25h4.5Zm1.5 0v5h4.5a.25.25 0 0 0 .25-.25V9.5Zm0-4V8h5.5a.25.25 0 0 0 .25-.25v-2a.25.25 0 0 0-.25-.25Zm-7 0a.25.25 0 0 0-.25.25v2c0 .138.112.25.25.25h5.5V5.5h-5.5Zm3-4a1.25 1.25 0 0 0 0 2.5h2.309c-.233-.818-.542-1.401-.878-1.793-.43-.502-.915-.707-1.431-.707ZM8.941 4h2.309a1.25 1.25 0 0 0 0-2.5c-.516 0-1 .205-1.43.707-.337.392-.646.975-.879 1.793Z"></path> </svg> </template> <template id="code-square-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-code-square"> <path d="M0 1.75C0 .784.784 0 1.75 0h12.5C15.216 0 16 .784 16 1.75v12.5A1.75 1.75 0 0 1 14.25 16H1.75A1.75 1.75 0 0 1 0 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25V1.75a.25.25 0 0 0-.25-.25Zm7.47 3.97a.75.75 0 0 1 1.06 0l2 2a.75.75 0 0 1 0 1.06l-2 2a.749.749 0 0 1-1.275-.326.749.749 0 0 1 .215-.734L10.69 8 9.22 6.53a.75.75 0 0 1 0-1.06ZM6.78 6.53 5.31 8l1.47 1.47a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215l-2-2a.75.75 0 0 1 0-1.06l2-2a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042Z"></path> </svg> </template> <template id="device-desktop-icon"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-device-desktop"> <path d="M14.25 1c.966 0 1.75.784 1.75 1.75v7.5A1.75 1.75 0 0 1 14.25 12h-3.727c.099 1.041.52 1.872 1.292 2.757A.752.752 0 0 1 11.25 16h-6.5a.75.75 0 0 1-.565-1.243c.772-.885 1.192-1.716 1.292-2.757H1.75A1.75 1.75 0 0 1 0 10.25v-7.5C0 1.784.784 1 1.75 1ZM1.75 2.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25ZM9.018 12H6.982a5.72 5.72 0 0 1-.765 2.5h3.566a5.72 5.72 0 0 1-.765-2.5Z"></path> </svg> </template> <div class="position-relative"> <ul role="listbox" class="ActionListWrap QueryBuilder-ListWrap" aria-label="Suggestions" data-action=" combobox-commit:query-builder#comboboxCommit mousedown:query-builder#resultsMousedown " data-target="query-builder.resultsList" data-persist-list=false id="query-builder-test-results" ></ul> </div> <div class="FormControl-inlineValidation" id="validation-3757d968-61f4-4c81-a0b0-c92f6e5a4590" hidden="hidden"> <span class="FormControl-inlineValidation--visual"> <svg aria-hidden="true" height="12" viewBox="0 0 12 12" version="1.1" width="12" data-view-component="true" class="octicon octicon-alert-fill"> <path d="M4.855.708c.5-.896 1.79-.896 2.29 0l4.675 8.351a1.312 1.312 0 0 1-1.146 1.954H1.33A1.313 1.313 0 0 1 .183 9.058ZM7 7V3H5v4Zm-1 3a1 1 0 1 0 0-2 1 1 0 0 0 0 2Z"></path> </svg> </span> <span></span> </div> </div> <div data-target="query-builder.screenReaderFeedback" aria-live="polite" aria-atomic="true" class="sr-only"></div> </query-builder></form> <div class="d-flex flex-row color-fg-muted px-3 text-small color-bg-default search-feedback-prompt"> <a target="_blank" href="https://docs.github.com/search-github/github-code-search/understanding-github-code-search-syntax" data-view-component="true" class="Link color-fg-accent text-normal ml-2">Search syntax tips</a> <div class="d-flex flex-1"></div> </div> </div> </div> </div> </modal-dialog></div> </div> <div data-action="click:qbsearch-input#retract" class="dark-backdrop position-fixed" hidden data-target="qbsearch-input.darkBackdrop"></div> <div class="color-fg-default"> <dialog-helper> <dialog data-target="qbsearch-input.feedbackDialog" data-action="close:qbsearch-input#handleDialogClose cancel:qbsearch-input#handleDialogClose" id="feedback-dialog" aria-modal="true" aria-labelledby="feedback-dialog-title" aria-describedby="feedback-dialog-description" data-view-component="true" class="Overlay Overlay-whenNarrow Overlay--size-medium Overlay--motion-scaleFade Overlay--disableScroll"> <div data-view-component="true" class="Overlay-header"> <div class="Overlay-headerContentWrap"> <div class="Overlay-titleWrap"> <h1 class="Overlay-title " id="feedback-dialog-title"> Provide feedback </h1> </div> <div class="Overlay-actionWrap"> <button data-close-dialog-id="feedback-dialog" aria-label="Close" type="button" data-view-component="true" class="close-button Overlay-closeButton"><svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-x"> <path d="M3.72 3.72a.75.75 0 0 1 1.06 0L8 6.94l3.22-3.22a.749.749 0 0 1 1.275.326.749.749 0 0 1-.215.734L9.06 8l3.22 3.22a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L8 9.06l-3.22 3.22a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042L6.94 8 3.72 4.78a.75.75 0 0 1 0-1.06Z"></path> </svg></button> </div> </div> </div> <scrollable-region data-labelled-by="feedback-dialog-title"> <div data-view-component="true" class="Overlay-body"> <!-- '"` --><!-- </textarea></xmp> --></option></form><form id="code-search-feedback-form" data-turbo="false" action="/search/feedback" accept-charset="UTF-8" method="post"><input type="hidden" data-csrf="true" name="authenticity_token" value="cIEsBRm5g/kkuMKv1Z2zNig/hxvoPHirmsH5M0g0qCDpeXdDYm24DOOZaRcGqhCY0m4/T69sQ6XqCJcQqlpQUg==" /> <p>We read every piece of feedback, and take your input very seriously.</p> <textarea name="feedback" class="form-control width-full mb-2" style="height: 120px" id="feedback"></textarea> <input name="include_email" id="include_email" aria-label="Include my email address so I can be contacted" class="form-control mr-2" type="checkbox"> <label for="include_email" style="font-weight: normal">Include my email address so I can be contacted</label> </form></div> </scrollable-region> <div data-view-component="true" class="Overlay-footer Overlay-footer--alignEnd"> <button data-close-dialog-id="feedback-dialog" type="button" data-view-component="true" class="btn"> Cancel </button> <button form="code-search-feedback-form" data-action="click:qbsearch-input#submitFeedback" type="submit" data-view-component="true" class="btn-primary btn"> Submit feedback </button> </div> </dialog></dialog-helper> <custom-scopes data-target="qbsearch-input.customScopesManager"> <dialog-helper> <dialog data-target="custom-scopes.customScopesModalDialog" data-action="close:qbsearch-input#handleDialogClose cancel:qbsearch-input#handleDialogClose" id="custom-scopes-dialog" aria-modal="true" aria-labelledby="custom-scopes-dialog-title" aria-describedby="custom-scopes-dialog-description" data-view-component="true" class="Overlay Overlay-whenNarrow Overlay--size-medium Overlay--motion-scaleFade Overlay--disableScroll"> <div data-view-component="true" class="Overlay-header Overlay-header--divided"> <div class="Overlay-headerContentWrap"> <div class="Overlay-titleWrap"> <h1 class="Overlay-title " id="custom-scopes-dialog-title"> Saved searches </h1> <h2 id="custom-scopes-dialog-description" class="Overlay-description">Use saved searches to filter your results more quickly</h2> </div> <div class="Overlay-actionWrap"> <button data-close-dialog-id="custom-scopes-dialog" aria-label="Close" type="button" data-view-component="true" class="close-button Overlay-closeButton"><svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-x"> <path d="M3.72 3.72a.75.75 0 0 1 1.06 0L8 6.94l3.22-3.22a.749.749 0 0 1 1.275.326.749.749 0 0 1-.215.734L9.06 8l3.22 3.22a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L8 9.06l-3.22 3.22a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042L6.94 8 3.72 4.78a.75.75 0 0 1 0-1.06Z"></path> </svg></button> </div> </div> </div> <scrollable-region data-labelled-by="custom-scopes-dialog-title"> <div data-view-component="true" class="Overlay-body"> <div data-target="custom-scopes.customScopesModalDialogFlash"></div> <div hidden class="create-custom-scope-form" data-target="custom-scopes.createCustomScopeForm"> <!-- '"` --><!-- </textarea></xmp> --></option></form><form id="custom-scopes-dialog-form" data-turbo="false" action="/search/custom_scopes" accept-charset="UTF-8" method="post"><input type="hidden" data-csrf="true" name="authenticity_token" value="Hza4eLCniVa3EFZJk4ageCfZTXHjK3hYCsFPOV7AnTyzsjRKCxAM9tc8QRDQ6EB7ddf49D0R6Tn7EX8vFoh4ag==" /> <div data-target="custom-scopes.customScopesModalDialogFlash"></div> <input type="hidden" id="custom_scope_id" name="custom_scope_id" data-target="custom-scopes.customScopesIdField"> <div class="form-group"> <label for="custom_scope_name">Name</label> <auto-check src="/search/custom_scopes/check_name" required only-validate-on-blur="false"> <input type="text" name="custom_scope_name" id="custom_scope_name" data-target="custom-scopes.customScopesNameField" class="form-control" autocomplete="off" placeholder="github-ruby" required maxlength="50"> <input type="hidden" data-csrf="true" value="RjrIBJmOqzzRxVlQkTbOrkCJB+odO9ZJP/nPhWpn//cbxsbw1/pcHGPnCJaJSvKDHIqIPdi5rFUh3QErAt4zCw==" /> </auto-check> </div> <div class="form-group"> <label for="custom_scope_query">Query</label> <input type="text" name="custom_scope_query" id="custom_scope_query" data-target="custom-scopes.customScopesQueryField" class="form-control" autocomplete="off" placeholder="(repo:mona/a OR repo:mona/b) AND lang:python" required maxlength="500"> </div> <p class="text-small color-fg-muted"> To see all available qualifiers, see our <a class="Link--inTextBlock" href="https://docs.github.com/search-github/github-code-search/understanding-github-code-search-syntax">documentation</a>. </p> </form> </div> <div data-target="custom-scopes.manageCustomScopesForm"> <div data-target="custom-scopes.list"></div> </div> </div> </scrollable-region> <div data-view-component="true" class="Overlay-footer Overlay-footer--alignEnd Overlay-footer--divided"> <button data-action="click:custom-scopes#customScopesCancel" type="button" data-view-component="true" class="btn"> Cancel </button> <button form="custom-scopes-dialog-form" data-action="click:custom-scopes#customScopesSubmit" data-target="custom-scopes.customScopesSubmitButton" type="submit" data-view-component="true" class="btn-primary btn"> Create saved search </button> </div> </dialog></dialog-helper> </custom-scopes> </div> </qbsearch-input> <div class="position-relative HeaderMenu-link-wrap d-lg-inline-block"> <a href="/login?return_to=https%3A%2F%2Fgithub.com%2Fjax-ml%2Fjax%2Ftree%2Fmain%2Fjax%2Fexperimental%2Fjax2tf" class="HeaderMenu-link HeaderMenu-link--sign-in HeaderMenu-button flex-shrink-0 no-underline d-none d-lg-inline-flex border border-lg-0 rounded rounded-lg-0 px-2 py-1" style="margin-left: 12px;" data-hydro-click="{"event_type":"authentication.click","payload":{"location_in_page":"site header menu","repository_id":null,"auth_type":"SIGN_UP","originating_url":"https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf","user_id":null}}" data-hydro-click-hmac="97202bb3b4afeb6197315d11acbc25c8851862d74986185bef54c08fafb65461" data-analytics-event="{"category":"Marketing nav","action":"click to go to homepage","label":"ref_page:Marketing;ref_cta:Sign in;ref_loc:Header"}" > Sign in </a> </div> <a href="/signup?ref_cta=Sign+up&ref_loc=header+logged+out&ref_page=%2F%3Cuser-name%3E%2F%3Crepo-name%3E%2Ffiles%2Fdisambiguate&source=header-repo&source_repo=jax-ml%2Fjax" class="HeaderMenu-link HeaderMenu-link--sign-up HeaderMenu-button flex-shrink-0 d-flex d-lg-inline-flex no-underline border color-border-default rounded px-2 py-1" data-hydro-click="{"event_type":"authentication.click","payload":{"location_in_page":"site header menu","repository_id":null,"auth_type":"SIGN_UP","originating_url":"https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf","user_id":null}}" data-hydro-click-hmac="97202bb3b4afeb6197315d11acbc25c8851862d74986185bef54c08fafb65461" data-analytics-event="{"category":"Sign up","action":"click to sign up for account","label":"ref_page:/<user-name>/<repo-name>/files/disambiguate;ref_cta:Sign up;ref_loc:header logged out"}" > Sign up </a> <button type="button" class="sr-only js-header-menu-focus-trap d-block d-lg-none">Reseting focus</button> </div> </div> </div> </div> </header> <div hidden="hidden" data-view-component="true" class="js-stale-session-flash stale-session-flash flash flash-warn flash-full"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-alert"> <path d="M6.457 1.047c.659-1.234 2.427-1.234 3.086 0l6.082 11.378A1.75 1.75 0 0 1 14.082 15H1.918a1.75 1.75 0 0 1-1.543-2.575Zm1.763.707a.25.25 0 0 0-.44 0L1.698 13.132a.25.25 0 0 0 .22.368h12.164a.25.25 0 0 0 .22-.368Zm.53 3.996v2.5a.75.75 0 0 1-1.5 0v-2.5a.75.75 0 0 1 1.5 0ZM9 11a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z"></path> </svg> <span class="js-stale-session-flash-signed-in" hidden>You signed in with another tab or window. <a class="Link--inTextBlock" href="">Reload</a> to refresh your session.</span> <span class="js-stale-session-flash-signed-out" hidden>You signed out in another tab or window. <a class="Link--inTextBlock" href="">Reload</a> to refresh your session.</span> <span class="js-stale-session-flash-switched" hidden>You switched accounts on another tab or window. <a class="Link--inTextBlock" href="">Reload</a> to refresh your session.</span> <button id="icon-button-30622a4a-2844-4a61-920c-259a5ff7e3b8" aria-labelledby="tooltip-9ce4264d-e2ab-4732-b1de-dc18a5afa9a1" type="button" data-view-component="true" class="Button Button--iconOnly Button--invisible Button--medium flash-close js-flash-close"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-x Button-visual"> <path d="M3.72 3.72a.75.75 0 0 1 1.06 0L8 6.94l3.22-3.22a.749.749 0 0 1 1.275.326.749.749 0 0 1-.215.734L9.06 8l3.22 3.22a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L8 9.06l-3.22 3.22a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042L6.94 8 3.72 4.78a.75.75 0 0 1 0-1.06Z"></path> </svg> </button><tool-tip id="tooltip-9ce4264d-e2ab-4732-b1de-dc18a5afa9a1" for="icon-button-30622a4a-2844-4a61-920c-259a5ff7e3b8" popover="manual" data-direction="s" data-type="label" data-view-component="true" class="sr-only position-absolute">Dismiss alert</tool-tip> </div> </div> <div id="start-of-content" class="show-on-focus"></div> <div id="js-flash-container" class="flash-container" data-turbo-replace> <template class="js-flash-template"> <div class="flash flash-full {{ className }}"> <div > <button autofocus class="flash-close js-flash-close" type="button" aria-label="Dismiss this message"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-x"> <path d="M3.72 3.72a.75.75 0 0 1 1.06 0L8 6.94l3.22-3.22a.749.749 0 0 1 1.275.326.749.749 0 0 1-.215.734L9.06 8l3.22 3.22a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L8 9.06l-3.22 3.22a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042L6.94 8 3.72 4.78a.75.75 0 0 1 0-1.06Z"></path> </svg> </button> <div aria-atomic="true" role="alert" class="js-flash-alert"> <div>{{ message }}</div> </div> </div> </div> </template> </div> <div class="application-main " data-commit-hovercards-enabled data-discussion-hovercards-enabled data-issue-and-pr-hovercards-enabled data-project-hovercards-enabled > <div itemscope itemtype="http://schema.org/SoftwareSourceCode" class=""> <main id="js-repo-pjax-container" > <div id="repository-container-header" class="pt-3 hide-full-screen" style="background-color: var(--page-header-bgColor, var(--color-page-header-bg));" data-turbo-replace> <div class="d-flex flex-nowrap flex-justify-end mb-3 px-3 px-lg-5" style="gap: 1rem;"> <div class="flex-auto min-width-0 width-fit"> <div class=" d-flex flex-wrap flex-items-center wb-break-word f3 text-normal"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-repo color-fg-muted mr-2"> <path d="M2 2.5A2.5 2.5 0 0 1 4.5 0h8.75a.75.75 0 0 1 .75.75v12.5a.75.75 0 0 1-.75.75h-2.5a.75.75 0 0 1 0-1.5h1.75v-2h-8a1 1 0 0 0-.714 1.7.75.75 0 1 1-1.072 1.05A2.495 2.495 0 0 1 2 11.5Zm10.5-1h-8a1 1 0 0 0-1 1v6.708A2.486 2.486 0 0 1 4.5 9h8ZM5 12.25a.25.25 0 0 1 .25-.25h3.5a.25.25 0 0 1 .25.25v3.25a.25.25 0 0 1-.4.2l-1.45-1.087a.249.249 0 0 0-.3 0L5.4 15.7a.25.25 0 0 1-.4-.2Z"></path> </svg> <span class="author flex-self-stretch" itemprop="author"> <a class="url fn" rel="author" data-hovercard-type="organization" data-hovercard-url="/orgs/jax-ml/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="/jax-ml"> jax-ml </a> </span> <span class="mx-1 flex-self-stretch color-fg-muted">/</span> <strong itemprop="name" class="mr-2 flex-self-stretch"> <a data-pjax="#repo-content-pjax-container" data-turbo-frame="repo-content-turbo-frame" href="/jax-ml/jax">jax</a> </strong> <span></span><span class="Label Label--secondary v-align-middle mr-1">Public</span> </div> </div> <div id="repository-details-container" class="flex-shrink-0" data-turbo-replace style="max-width: 70%;"> <ul class="pagehead-actions flex-shrink-0 d-none d-md-inline" style="padding: 2px 0;"> <li> <a href="/login?return_to=%2Fjax-ml%2Fjax" rel="nofollow" id="repository-details-watch-button" data-hydro-click="{"event_type":"authentication.click","payload":{"location_in_page":"notification subscription menu watch","repository_id":null,"auth_type":"LOG_IN","originating_url":"https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf","user_id":null}}" data-hydro-click-hmac="03275bafd858fe795a3fe6a772a49ff91487990b2059c4b6c13449a2858412fe" aria-label="You must be signed in to change notification settings" data-view-component="true" class="btn-sm btn"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-bell mr-2"> <path d="M8 16a2 2 0 0 0 1.985-1.75c.017-.137-.097-.25-.235-.25h-3.5c-.138 0-.252.113-.235.25A2 2 0 0 0 8 16ZM3 5a5 5 0 0 1 10 0v2.947c0 .05.015.098.042.139l1.703 2.555A1.519 1.519 0 0 1 13.482 13H2.518a1.516 1.516 0 0 1-1.263-2.36l1.703-2.554A.255.255 0 0 0 3 7.947Zm5-3.5A3.5 3.5 0 0 0 4.5 5v2.947c0 .346-.102.683-.294.97l-1.703 2.556a.017.017 0 0 0-.003.01l.001.006c0 .002.002.004.004.006l.006.004.007.001h10.964l.007-.001.006-.004.004-.006.001-.007a.017.017 0 0 0-.003-.01l-1.703-2.554a1.745 1.745 0 0 1-.294-.97V5A3.5 3.5 0 0 0 8 1.5Z"></path> </svg>Notifications </a> <tool-tip id="tooltip-a39f8390-fe19-4f80-9017-b622edc1c5cd" for="repository-details-watch-button" popover="manual" data-direction="s" data-type="description" data-view-component="true" class="sr-only position-absolute">You must be signed in to change notification settings</tool-tip> </li> <li> <a icon="repo-forked" id="fork-button" href="/login?return_to=%2Fjax-ml%2Fjax" rel="nofollow" data-hydro-click="{"event_type":"authentication.click","payload":{"location_in_page":"repo details fork button","repository_id":154739597,"auth_type":"LOG_IN","originating_url":"https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf","user_id":null}}" data-hydro-click-hmac="12fd675829fe07464dbf3b9a4a253856e5fd9aee584facbfe18a27ce59261f17" data-view-component="true" class="btn-sm btn"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-repo-forked mr-2"> <path d="M5 5.372v.878c0 .414.336.75.75.75h4.5a.75.75 0 0 0 .75-.75v-.878a2.25 2.25 0 1 1 1.5 0v.878a2.25 2.25 0 0 1-2.25 2.25h-1.5v2.128a2.251 2.251 0 1 1-1.5 0V8.5h-1.5A2.25 2.25 0 0 1 3.5 6.25v-.878a2.25 2.25 0 1 1 1.5 0ZM5 3.25a.75.75 0 1 0-1.5 0 .75.75 0 0 0 1.5 0Zm6.75.75a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5Zm-3 8.75a.75.75 0 1 0-1.5 0 .75.75 0 0 0 1.5 0Z"></path> </svg>Fork <span id="repo-network-counter" data-pjax-replace="true" data-turbo-replace="true" title="2,889" data-view-component="true" class="Counter">2.9k</span> </a> </li> <li> <div data-view-component="true" class="BtnGroup d-flex"> <a href="/login?return_to=%2Fjax-ml%2Fjax" rel="nofollow" data-hydro-click="{"event_type":"authentication.click","payload":{"location_in_page":"star button","repository_id":154739597,"auth_type":"LOG_IN","originating_url":"https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf","user_id":null}}" data-hydro-click-hmac="85b95d563764a29349f4f0d6cd897579876c0d52f40d4daa15940814323ee56a" aria-label="You must be signed in to star a repository" data-view-component="true" class="tooltipped tooltipped-sw btn-sm btn"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-star v-align-text-bottom d-inline-block mr-2"> <path d="M8 .25a.75.75 0 0 1 .673.418l1.882 3.815 4.21.612a.75.75 0 0 1 .416 1.279l-3.046 2.97.719 4.192a.751.751 0 0 1-1.088.791L8 12.347l-3.766 1.98a.75.75 0 0 1-1.088-.79l.72-4.194L.818 6.374a.75.75 0 0 1 .416-1.28l4.21-.611L7.327.668A.75.75 0 0 1 8 .25Zm0 2.445L6.615 5.5a.75.75 0 0 1-.564.41l-3.097.45 2.24 2.184a.75.75 0 0 1 .216.664l-.528 3.084 2.769-1.456a.75.75 0 0 1 .698 0l2.77 1.456-.53-3.084a.75.75 0 0 1 .216-.664l2.24-2.183-3.096-.45a.75.75 0 0 1-.564-.41L8 2.694Z"></path> </svg><span data-view-component="true" class="d-inline"> Star </span> <span id="repo-stars-counter-star" aria-label="31286 users starred this repository" data-singular-suffix="user starred this repository" data-plural-suffix="users starred this repository" data-turbo-replace="true" title="31,286" data-view-component="true" class="Counter js-social-count">31.3k</span> </a></div> </li> </ul> </div> </div> <div id="responsive-meta-container" data-turbo-replace> </div> <nav data-pjax="#js-repo-pjax-container" aria-label="Repository" data-view-component="true" class="js-repo-nav js-sidenav-container-pjax js-responsive-underlinenav overflow-hidden UnderlineNav px-3 px-md-4 px-lg-5"> <ul data-view-component="true" class="UnderlineNav-body list-style-none"> <li data-view-component="true" class="d-inline-flex"> <a id="code-tab" href="/jax-ml/jax" data-tab-item="i0code-tab" data-selected-links="repo_source repo_downloads repo_commits repo_releases repo_tags repo_branches repo_packages repo_deployments repo_attestations /jax-ml/jax" data-pjax="#repo-content-pjax-container" data-turbo-frame="repo-content-turbo-frame" data-hotkey="g c" data-analytics-event="{"category":"Underline navbar","action":"Click tab","label":"Code","target":"UNDERLINE_NAV.TAB"}" aria-current="page" data-view-component="true" class="UnderlineNav-item no-wrap js-responsive-underlinenav-item js-selected-navigation-item selected"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-code UnderlineNav-octicon d-none d-sm-inline"> <path d="m11.28 3.22 4.25 4.25a.75.75 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.275-.326.749.749 0 0 1 .215-.734L13.94 8l-3.72-3.72a.749.749 0 0 1 .326-1.275.749.749 0 0 1 .734.215Zm-6.56 0a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042L2.06 8l3.72 3.72a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L.47 8.53a.75.75 0 0 1 0-1.06Z"></path> </svg> <span data-content="Code">Code</span> <span id="code-repo-tab-count" data-pjax-replace="" data-turbo-replace="" title="Not available" data-view-component="true" class="Counter"></span> </a></li> <li data-view-component="true" class="d-inline-flex"> <a id="issues-tab" href="/jax-ml/jax/issues" data-tab-item="i1issues-tab" data-selected-links="repo_issues repo_labels repo_milestones /jax-ml/jax/issues" data-pjax="#repo-content-pjax-container" data-turbo-frame="repo-content-turbo-frame" data-hotkey="g i" data-analytics-event="{"category":"Underline navbar","action":"Click tab","label":"Issues","target":"UNDERLINE_NAV.TAB"}" data-view-component="true" class="UnderlineNav-item no-wrap js-responsive-underlinenav-item js-selected-navigation-item"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-issue-opened UnderlineNav-octicon d-none d-sm-inline"> <path d="M8 9.5a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Z"></path><path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM1.5 8a6.5 6.5 0 1 0 13 0 6.5 6.5 0 0 0-13 0Z"></path> </svg> <span data-content="Issues">Issues</span> <span id="issues-repo-tab-count" data-pjax-replace="" data-turbo-replace="" title="1,494" data-view-component="true" class="Counter">1.5k</span> </a></li> <li data-view-component="true" class="d-inline-flex"> <a id="pull-requests-tab" href="/jax-ml/jax/pulls" data-tab-item="i2pull-requests-tab" data-selected-links="repo_pulls checks /jax-ml/jax/pulls" data-pjax="#repo-content-pjax-container" data-turbo-frame="repo-content-turbo-frame" data-hotkey="g p" data-analytics-event="{"category":"Underline navbar","action":"Click tab","label":"Pull requests","target":"UNDERLINE_NAV.TAB"}" data-view-component="true" class="UnderlineNav-item no-wrap js-responsive-underlinenav-item js-selected-navigation-item"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-git-pull-request UnderlineNav-octicon d-none d-sm-inline"> <path d="M1.5 3.25a2.25 2.25 0 1 1 3 2.122v5.256a2.251 2.251 0 1 1-1.5 0V5.372A2.25 2.25 0 0 1 1.5 3.25Zm5.677-.177L9.573.677A.25.25 0 0 1 10 .854V2.5h1A2.5 2.5 0 0 1 13.5 5v5.628a2.251 2.251 0 1 1-1.5 0V5a1 1 0 0 0-1-1h-1v1.646a.25.25 0 0 1-.427.177L7.177 3.427a.25.25 0 0 1 0-.354ZM3.75 2.5a.75.75 0 1 0 0 1.5.75.75 0 0 0 0-1.5Zm0 9.5a.75.75 0 1 0 0 1.5.75.75 0 0 0 0-1.5Zm8.25.75a.75.75 0 1 0 1.5 0 .75.75 0 0 0-1.5 0Z"></path> </svg> <span data-content="Pull requests">Pull requests</span> <span id="pull-requests-repo-tab-count" data-pjax-replace="" data-turbo-replace="" title="414" data-view-component="true" class="Counter">414</span> </a></li> <li data-view-component="true" class="d-inline-flex"> <a id="discussions-tab" href="/jax-ml/jax/discussions" data-tab-item="i3discussions-tab" data-selected-links="repo_discussions /jax-ml/jax/discussions" data-pjax="#repo-content-pjax-container" data-turbo-frame="repo-content-turbo-frame" data-hotkey="g g" data-analytics-event="{"category":"Underline navbar","action":"Click tab","label":"Discussions","target":"UNDERLINE_NAV.TAB"}" data-view-component="true" class="UnderlineNav-item no-wrap js-responsive-underlinenav-item js-selected-navigation-item"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-comment-discussion UnderlineNav-octicon d-none d-sm-inline"> <path d="M1.75 1h8.5c.966 0 1.75.784 1.75 1.75v5.5A1.75 1.75 0 0 1 10.25 10H7.061l-2.574 2.573A1.458 1.458 0 0 1 2 11.543V10h-.25A1.75 1.75 0 0 1 0 8.25v-5.5C0 1.784.784 1 1.75 1ZM1.5 2.75v5.5c0 .138.112.25.25.25h1a.75.75 0 0 1 .75.75v2.19l2.72-2.72a.749.749 0 0 1 .53-.22h3.5a.25.25 0 0 0 .25-.25v-5.5a.25.25 0 0 0-.25-.25h-8.5a.25.25 0 0 0-.25.25Zm13 2a.25.25 0 0 0-.25-.25h-.5a.75.75 0 0 1 0-1.5h.5c.966 0 1.75.784 1.75 1.75v5.5A1.75 1.75 0 0 1 14.25 12H14v1.543a1.458 1.458 0 0 1-2.487 1.03L9.22 12.28a.749.749 0 0 1 .326-1.275.749.749 0 0 1 .734.215l2.22 2.22v-2.19a.75.75 0 0 1 .75-.75h1a.25.25 0 0 0 .25-.25Z"></path> </svg> <span data-content="Discussions">Discussions</span> <span id="discussions-repo-tab-count" data-pjax-replace="" data-turbo-replace="" title="Not available" data-view-component="true" class="Counter"></span> </a></li> <li data-view-component="true" class="d-inline-flex"> <a id="actions-tab" href="/jax-ml/jax/actions" data-tab-item="i4actions-tab" data-selected-links="repo_actions /jax-ml/jax/actions" data-pjax="#repo-content-pjax-container" data-turbo-frame="repo-content-turbo-frame" data-hotkey="g a" data-analytics-event="{"category":"Underline navbar","action":"Click tab","label":"Actions","target":"UNDERLINE_NAV.TAB"}" data-view-component="true" class="UnderlineNav-item no-wrap js-responsive-underlinenav-item js-selected-navigation-item"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-play UnderlineNav-octicon d-none d-sm-inline"> <path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM1.5 8a6.5 6.5 0 1 0 13 0 6.5 6.5 0 0 0-13 0Zm4.879-2.773 4.264 2.559a.25.25 0 0 1 0 .428l-4.264 2.559A.25.25 0 0 1 6 10.559V5.442a.25.25 0 0 1 .379-.215Z"></path> </svg> <span data-content="Actions">Actions</span> <span id="actions-repo-tab-count" data-pjax-replace="" data-turbo-replace="" title="Not available" data-view-component="true" class="Counter"></span> </a></li> <li data-view-component="true" class="d-inline-flex"> <a id="security-tab" href="/jax-ml/jax/security" data-tab-item="i5security-tab" data-selected-links="security overview alerts policy token_scanning code_scanning /jax-ml/jax/security" data-pjax="#repo-content-pjax-container" data-turbo-frame="repo-content-turbo-frame" data-hotkey="g s" data-analytics-event="{"category":"Underline navbar","action":"Click tab","label":"Security","target":"UNDERLINE_NAV.TAB"}" data-view-component="true" class="UnderlineNav-item no-wrap js-responsive-underlinenav-item js-selected-navigation-item"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-shield UnderlineNav-octicon d-none d-sm-inline"> <path d="M7.467.133a1.748 1.748 0 0 1 1.066 0l5.25 1.68A1.75 1.75 0 0 1 15 3.48V7c0 1.566-.32 3.182-1.303 4.682-.983 1.498-2.585 2.813-5.032 3.855a1.697 1.697 0 0 1-1.33 0c-2.447-1.042-4.049-2.357-5.032-3.855C1.32 10.182 1 8.566 1 7V3.48a1.75 1.75 0 0 1 1.217-1.667Zm.61 1.429a.25.25 0 0 0-.153 0l-5.25 1.68a.25.25 0 0 0-.174.238V7c0 1.358.275 2.666 1.057 3.86.784 1.194 2.121 2.34 4.366 3.297a.196.196 0 0 0 .154 0c2.245-.956 3.582-2.104 4.366-3.298C13.225 9.666 13.5 8.36 13.5 7V3.48a.251.251 0 0 0-.174-.237l-5.25-1.68ZM8.75 4.75v3a.75.75 0 0 1-1.5 0v-3a.75.75 0 0 1 1.5 0ZM9 10.5a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z"></path> </svg> <span data-content="Security">Security</span> <include-fragment src="/jax-ml/jax/security/overall-count" accept="text/fragment+html"></include-fragment> </a></li> <li data-view-component="true" class="d-inline-flex"> <a id="insights-tab" href="/jax-ml/jax/pulse" data-tab-item="i6insights-tab" data-selected-links="repo_graphs repo_contributors dependency_graph dependabot_updates pulse people community /jax-ml/jax/pulse" data-pjax="#repo-content-pjax-container" data-turbo-frame="repo-content-turbo-frame" data-analytics-event="{"category":"Underline navbar","action":"Click tab","label":"Insights","target":"UNDERLINE_NAV.TAB"}" data-view-component="true" class="UnderlineNav-item no-wrap js-responsive-underlinenav-item js-selected-navigation-item"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-graph UnderlineNav-octicon d-none d-sm-inline"> <path d="M1.5 1.75V13.5h13.75a.75.75 0 0 1 0 1.5H.75a.75.75 0 0 1-.75-.75V1.75a.75.75 0 0 1 1.5 0Zm14.28 2.53-5.25 5.25a.75.75 0 0 1-1.06 0L7 7.06 4.28 9.78a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042l3.25-3.25a.75.75 0 0 1 1.06 0L10 7.94l4.72-4.72a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042Z"></path> </svg> <span data-content="Insights">Insights</span> <span id="insights-repo-tab-count" data-pjax-replace="" data-turbo-replace="" title="Not available" data-view-component="true" class="Counter"></span> </a></li> </ul> <div style="visibility:hidden;" data-view-component="true" class="UnderlineNav-actions js-responsive-underlinenav-overflow position-absolute pr-3 pr-md-4 pr-lg-5 right-0"> <action-menu data-select-variant="none" data-view-component="true"> <focus-group direction="vertical" mnemonics retain> <button id="action-menu-579ec5bf-5993-4d07-a689-31e60b226f27-button" popovertarget="action-menu-579ec5bf-5993-4d07-a689-31e60b226f27-overlay" aria-controls="action-menu-579ec5bf-5993-4d07-a689-31e60b226f27-list" aria-haspopup="true" aria-labelledby="tooltip-10724a64-0586-4a75-a084-d4f5b6904797" type="button" data-view-component="true" class="Button Button--iconOnly Button--secondary Button--medium UnderlineNav-item"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-kebab-horizontal Button-visual"> <path d="M8 9a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3ZM1.5 9a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Zm13 0a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Z"></path> </svg> </button><tool-tip id="tooltip-10724a64-0586-4a75-a084-d4f5b6904797" for="action-menu-579ec5bf-5993-4d07-a689-31e60b226f27-button" popover="manual" data-direction="s" data-type="label" data-view-component="true" class="sr-only position-absolute">Additional navigation options</tool-tip> <anchored-position data-target="action-menu.overlay" id="action-menu-579ec5bf-5993-4d07-a689-31e60b226f27-overlay" anchor="action-menu-579ec5bf-5993-4d07-a689-31e60b226f27-button" align="start" side="outside-bottom" anchor-offset="normal" popover="auto" data-view-component="true"> <div data-view-component="true" class="Overlay Overlay--size-auto"> <div data-view-component="true" class="Overlay-body Overlay-body--paddingNone"> <action-list> <div data-view-component="true"> <ul aria-labelledby="action-menu-579ec5bf-5993-4d07-a689-31e60b226f27-button" id="action-menu-579ec5bf-5993-4d07-a689-31e60b226f27-list" role="menu" data-view-component="true" class="ActionListWrap--inset ActionListWrap"> <li hidden="hidden" data-menu-item="i0code-tab" data-targets="action-list.items" role="none" data-view-component="true" class="ActionListItem"> <a tabindex="-1" id="item-c755a066-09d2-48f5-9f5d-11ecb3bede35" href="/jax-ml/jax" role="menuitem" data-view-component="true" class="ActionListContent ActionListContent--visual16"> <span class="ActionListItem-visual ActionListItem-visual--leading"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-code"> <path d="m11.28 3.22 4.25 4.25a.75.75 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.275-.326.749.749 0 0 1 .215-.734L13.94 8l-3.72-3.72a.749.749 0 0 1 .326-1.275.749.749 0 0 1 .734.215Zm-6.56 0a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042L2.06 8l3.72 3.72a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L.47 8.53a.75.75 0 0 1 0-1.06Z"></path> </svg> </span> <span data-view-component="true" class="ActionListItem-label"> Code </span> </a> </li> <li hidden="hidden" data-menu-item="i1issues-tab" data-targets="action-list.items" role="none" data-view-component="true" class="ActionListItem"> <a tabindex="-1" id="item-60b5e642-dc90-44c7-a000-09169f4d412c" href="/jax-ml/jax/issues" role="menuitem" data-view-component="true" class="ActionListContent ActionListContent--visual16"> <span class="ActionListItem-visual ActionListItem-visual--leading"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-issue-opened"> <path d="M8 9.5a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Z"></path><path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM1.5 8a6.5 6.5 0 1 0 13 0 6.5 6.5 0 0 0-13 0Z"></path> </svg> </span> <span data-view-component="true" class="ActionListItem-label"> Issues </span> </a> </li> <li hidden="hidden" data-menu-item="i2pull-requests-tab" data-targets="action-list.items" role="none" data-view-component="true" class="ActionListItem"> <a tabindex="-1" id="item-ed6000bb-6df6-4f65-b949-77e9e049133b" href="/jax-ml/jax/pulls" role="menuitem" data-view-component="true" class="ActionListContent ActionListContent--visual16"> <span class="ActionListItem-visual ActionListItem-visual--leading"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-git-pull-request"> <path d="M1.5 3.25a2.25 2.25 0 1 1 3 2.122v5.256a2.251 2.251 0 1 1-1.5 0V5.372A2.25 2.25 0 0 1 1.5 3.25Zm5.677-.177L9.573.677A.25.25 0 0 1 10 .854V2.5h1A2.5 2.5 0 0 1 13.5 5v5.628a2.251 2.251 0 1 1-1.5 0V5a1 1 0 0 0-1-1h-1v1.646a.25.25 0 0 1-.427.177L7.177 3.427a.25.25 0 0 1 0-.354ZM3.75 2.5a.75.75 0 1 0 0 1.5.75.75 0 0 0 0-1.5Zm0 9.5a.75.75 0 1 0 0 1.5.75.75 0 0 0 0-1.5Zm8.25.75a.75.75 0 1 0 1.5 0 .75.75 0 0 0-1.5 0Z"></path> </svg> </span> <span data-view-component="true" class="ActionListItem-label"> Pull requests </span> </a> </li> <li hidden="hidden" data-menu-item="i3discussions-tab" data-targets="action-list.items" role="none" data-view-component="true" class="ActionListItem"> <a tabindex="-1" id="item-d0f0bec0-10c8-418f-8000-946d0101937f" href="/jax-ml/jax/discussions" role="menuitem" data-view-component="true" class="ActionListContent ActionListContent--visual16"> <span class="ActionListItem-visual ActionListItem-visual--leading"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-comment-discussion"> <path d="M1.75 1h8.5c.966 0 1.75.784 1.75 1.75v5.5A1.75 1.75 0 0 1 10.25 10H7.061l-2.574 2.573A1.458 1.458 0 0 1 2 11.543V10h-.25A1.75 1.75 0 0 1 0 8.25v-5.5C0 1.784.784 1 1.75 1ZM1.5 2.75v5.5c0 .138.112.25.25.25h1a.75.75 0 0 1 .75.75v2.19l2.72-2.72a.749.749 0 0 1 .53-.22h3.5a.25.25 0 0 0 .25-.25v-5.5a.25.25 0 0 0-.25-.25h-8.5a.25.25 0 0 0-.25.25Zm13 2a.25.25 0 0 0-.25-.25h-.5a.75.75 0 0 1 0-1.5h.5c.966 0 1.75.784 1.75 1.75v5.5A1.75 1.75 0 0 1 14.25 12H14v1.543a1.458 1.458 0 0 1-2.487 1.03L9.22 12.28a.749.749 0 0 1 .326-1.275.749.749 0 0 1 .734.215l2.22 2.22v-2.19a.75.75 0 0 1 .75-.75h1a.25.25 0 0 0 .25-.25Z"></path> </svg> </span> <span data-view-component="true" class="ActionListItem-label"> Discussions </span> </a> </li> <li hidden="hidden" data-menu-item="i4actions-tab" data-targets="action-list.items" role="none" data-view-component="true" class="ActionListItem"> <a tabindex="-1" id="item-220c2406-cf28-44d6-b353-a347d434e11b" href="/jax-ml/jax/actions" role="menuitem" data-view-component="true" class="ActionListContent ActionListContent--visual16"> <span class="ActionListItem-visual ActionListItem-visual--leading"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-play"> <path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM1.5 8a6.5 6.5 0 1 0 13 0 6.5 6.5 0 0 0-13 0Zm4.879-2.773 4.264 2.559a.25.25 0 0 1 0 .428l-4.264 2.559A.25.25 0 0 1 6 10.559V5.442a.25.25 0 0 1 .379-.215Z"></path> </svg> </span> <span data-view-component="true" class="ActionListItem-label"> Actions </span> </a> </li> <li hidden="hidden" data-menu-item="i5security-tab" data-targets="action-list.items" role="none" data-view-component="true" class="ActionListItem"> <a tabindex="-1" id="item-35b765d1-f292-40f8-9edf-1c266eece3e8" href="/jax-ml/jax/security" role="menuitem" data-view-component="true" class="ActionListContent ActionListContent--visual16"> <span class="ActionListItem-visual ActionListItem-visual--leading"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-shield"> <path d="M7.467.133a1.748 1.748 0 0 1 1.066 0l5.25 1.68A1.75 1.75 0 0 1 15 3.48V7c0 1.566-.32 3.182-1.303 4.682-.983 1.498-2.585 2.813-5.032 3.855a1.697 1.697 0 0 1-1.33 0c-2.447-1.042-4.049-2.357-5.032-3.855C1.32 10.182 1 8.566 1 7V3.48a1.75 1.75 0 0 1 1.217-1.667Zm.61 1.429a.25.25 0 0 0-.153 0l-5.25 1.68a.25.25 0 0 0-.174.238V7c0 1.358.275 2.666 1.057 3.86.784 1.194 2.121 2.34 4.366 3.297a.196.196 0 0 0 .154 0c2.245-.956 3.582-2.104 4.366-3.298C13.225 9.666 13.5 8.36 13.5 7V3.48a.251.251 0 0 0-.174-.237l-5.25-1.68ZM8.75 4.75v3a.75.75 0 0 1-1.5 0v-3a.75.75 0 0 1 1.5 0ZM9 10.5a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z"></path> </svg> </span> <span data-view-component="true" class="ActionListItem-label"> Security </span> </a> </li> <li hidden="hidden" data-menu-item="i6insights-tab" data-targets="action-list.items" role="none" data-view-component="true" class="ActionListItem"> <a tabindex="-1" id="item-9850013a-9bbe-475f-95a8-a5ffb12bf648" href="/jax-ml/jax/pulse" role="menuitem" data-view-component="true" class="ActionListContent ActionListContent--visual16"> <span class="ActionListItem-visual ActionListItem-visual--leading"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-graph"> <path d="M1.5 1.75V13.5h13.75a.75.75 0 0 1 0 1.5H.75a.75.75 0 0 1-.75-.75V1.75a.75.75 0 0 1 1.5 0Zm14.28 2.53-5.25 5.25a.75.75 0 0 1-1.06 0L7 7.06 4.28 9.78a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042l3.25-3.25a.75.75 0 0 1 1.06 0L10 7.94l4.72-4.72a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042Z"></path> </svg> </span> <span data-view-component="true" class="ActionListItem-label"> Insights </span> </a> </li> </ul> </div></action-list> </div> </div></anchored-position> </focus-group> </action-menu></div> </nav> </div> <turbo-frame id="repo-content-turbo-frame" target="_top" data-turbo-action="advance" class=""> <div id="repo-content-pjax-container" class="repository-content " > <react-app app-name="react-code-view" initial-path="/jax-ml/jax/tree/main/jax/experimental/jax2tf" style="display: block; min-height: calc(100vh - 64px);" data-attempted-ssr="true" data-ssr="true" data-lazy="false" data-alternate="false" data-data-router-enabled="false" > <script type="application/json" data-target="react-app.embeddedData">{"payload":{"allShortcutsEnabled":false,"path":"jax/experimental/jax2tf","repo":{"id":154739597,"defaultBranch":"main","name":"jax","ownerLogin":"jax-ml","currentUserCanPush":false,"isFork":false,"isEmpty":false,"createdAt":"2018-10-25T21:25:02.000Z","ownerAvatar":"https://avatars.githubusercontent.com/u/58486408?v=4","public":true,"private":false,"isOrgOwned":true},"currentUser":null,"refInfo":{"name":"main","listCacheKey":"v0:1739784593.0","canEdit":false,"refType":"branch","currentOid":"52f8fbeee0698862c6797ade8878f72301ae10fe"},"tree":{"items":[{"name":"examples","path":"jax/experimental/jax2tf/examples","contentType":"directory"},{"name":"g3doc","path":"jax/experimental/jax2tf/g3doc","contentType":"directory"},{"name":"tests","path":"jax/experimental/jax2tf/tests","contentType":"directory"},{"name":"BUILD","path":"jax/experimental/jax2tf/BUILD","contentType":"file"},{"name":"JAX2TF_getting_started.ipynb","path":"jax/experimental/jax2tf/JAX2TF_getting_started.ipynb","contentType":"file"},{"name":"README.md","path":"jax/experimental/jax2tf/README.md","contentType":"file"},{"name":"__init__.py","path":"jax/experimental/jax2tf/__init__.py","contentType":"file"},{"name":"call_tf.py","path":"jax/experimental/jax2tf/call_tf.py","contentType":"file"},{"name":"impl_no_xla.py","path":"jax/experimental/jax2tf/impl_no_xla.py","contentType":"file"},{"name":"jax2tf.py","path":"jax/experimental/jax2tf/jax2tf.py","contentType":"file"}],"templateDirectorySuggestionUrl":null,"readme":{"displayName":"README.md","richText":"\u003carticle class=\"markdown-body entry-content container-lg\" itemprop=\"text\"\u003e\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch1 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eJAX and TensorFlow interoperation (jax2tf/call_tf)\u003c/h1\u003e\u003ca id=\"user-content-jax-and-tensorflow-interoperation-jax2tfcall_tf\" class=\"anchor\" aria-label=\"Permalink: JAX and TensorFlow interoperation (jax2tf/call_tf)\" href=\"#jax-and-tensorflow-interoperation-jax2tfcall_tf\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\n\n\u003cp dir=\"auto\"\u003eThis package provides support for JAX native serialization and for interoperation\nbetween JAX and TensorFlow.\nThere are two interoperation directions:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003e\u003ccode\u003ejax2tf.convert\u003c/code\u003e: for calling JAX functions in a TensorFlow context, e.g.,\nfor eager or graph TensorFlow execution,\nor for serializing as a TensorFlow SavedModel; and\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003ejax2tf.call_tf\u003c/code\u003e: for calling TensorFlow functions in a JAX context, e.g.,\nto call a TensorFlow library or to reload a TensorFlow SavedModel and call\nits functions in JAX.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eThese APIs can be combined, e.g., to reload in JAX a program that\nhas been serialized from JAX to a TensorFlow SavedModel, or to save to\nTensorFlow SavedModel a JAX program that uses a TensorFlow library.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eTip: As of version 0.4.14 (July 2023) the default mode of JAX-TensorFlow\ninteroperation is by way of \u003cstrong\u003enative serialization\u003c/strong\u003e in which the target\nfunction is lowered to StableHLO\nusing standard native JAX or TensorFlow APIs, and then the StableHLO module\nis invoked from the other framework.\nThe native serialization mode has several advantages:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003esupports virtually all operations supported by native execution, e.g.,\n\u003ccode\u003eshard_map\u003c/code\u003e, \u003ccode\u003epmap\u003c/code\u003e, parallel collective operations, and all\nprimitives at all data types.\u003c/li\u003e\n\u003cli\u003euses standard native JAX code paths for lowering, and thus it is easier\nto trust that the semantics and performance stays faithful to the native\nsemantics, across platforms.\u003c/li\u003e\n\u003cli\u003ethe metadata associated with the operations, e.g., source location, is\nidentical to what native execution uses.\u003c/li\u003e\n\u003cli\u003eincludes safety checking that the serialized code is executed on\nthe platform for which it was serialized.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eAt the moment when using JAX native serialization the whole\nJAX compilation unit is wrapped with a single thin TensorFlow op,\ncalled \u003ca href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/ops/xla_ops.cc#L1318\"\u003e\u003ccode\u003eXlaCallModule\u003c/code\u003e\u003c/a\u003e,\nthat carries the serialized version of the StableHLO obtained from JAX. This\nop is supported only on TensorFlow platforms that include the XLA compiler, and\nit compiles and then invokes the embedded StableHLO.\nThe reasons we wrap the StableHLO in a TensorFlow op are:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eit allows saving the serialization in a tf.SavedModel, for use with\nmultiple mature tools for TensorFlow,\u003c/li\u003e\n\u003cli\u003eit allows composing the JAX program with TensorFlow pre-processing,\npost-processing, and host callback functions,\u003c/li\u003e\n\u003cli\u003ethe \u003ccode\u003eXlaCallModule\u003c/code\u003e contains the code that must be executed\nto deserialize, compile, and execute the JAX program, e.g., to\nhandle properly backward compatibility and to\ndo the just-in-time preprocessing needed for shape polymorphism.\u003c/li\u003e\n\u003cli\u003ethe semantics of JAX program is still preserved faithfully because it\nis entirely captured by the StableHLO serialization.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eFor backwards compatibility purposes, and for special uses,\nthe JAX-TensorFlow interoperation APIs can be used also\nin a \u003cstrong\u003egraph serialization\u003c/strong\u003e mode (the only mode available before version 0.4.7,\nand the default mode before JAX version 0.4.15),\nwithout going through StableHLO. (Starting with JAX version 0.4.31 the\ngraph serialization mode is deprecated. It will be removed in the near future).\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003e\n\u003cp dir=\"auto\"\u003eFor calling JAX functions from TensorFlow,\nit is possible to request that the JAX function be lowered with one TensorFlow\nop for each JAX primitive.\nThis can be achieved by setting \u003ccode\u003enative_serialization=False\u003c/code\u003e.\nThis enables the following:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eTensorFlow eager mode execution, e.g., for debugging,\u003c/li\u003e\n\u003cli\u003eproducing a \u003ccode\u003etf.Graph\u003c/code\u003e for consumption by tooling that understands\nTensorFlow ops but does not yet work with StableHLO,\ne.g., TFLite and TensorFlow.js.\u003c/li\u003e\n\u003cli\u003eusing the more mature support for dynamic shapes in TensorFlow.\n\u003ca href=\"https://github.com/openxla/stablehlo/blob/main/rfcs/20230704-dynamism-101.md\"\u003eStableHLO does have support for dynamic\nshapes\u003c/a\u003e,\nand in the near future we expect it will support shape polymorphism\nto the same extent as graph serialization.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eEven in the graph serialization mode the resulting TensorFlow graph\nis pretty much 1:1 with the StableHLO module\nthat would be obtained through native serialization.\u003c/p\u003e\n\u003c/li\u003e\n\u003cli\u003e\n\u003cp dir=\"auto\"\u003eFor calling TensorFlow functions from JAX, if the resulting JAX program\nis executed in op-by-op mode (i.e., not under \u003ccode\u003ejax.jit\u003c/code\u003e or \u003ccode\u003ejax.pmap\u003c/code\u003e\nand not inside \u003ccode\u003elax.cond\u003c/code\u003e or \u003ccode\u003elax.scan\u003c/code\u003e)\nthen the target TensorFlow function is executed in eager mode. This can\nbe useful if the target TensorFlow function is not lowerable to HLO, e.g.,\nis using strings.\u003c/p\u003e\n\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eTo disable native serialization, you can do the following, in decreasing\npriority order:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eset \u003ccode\u003enative_serialization=False\u003c/code\u003e, or\u003c/li\u003e\n\u003cli\u003euse the configuration flag \u003ccode\u003e--jax2tf_default_native_serialization=false\u003c/code\u003e, or\u003c/li\u003e\n\u003cli\u003euse the environment variable \u003ccode\u003eJAX2TF_DEFAULT_NATIVE_SERIALIZATION=false\u003c/code\u003e.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eWe describe below some general concepts and capabilities, first for\n\u003ccode\u003ejax2tf.convert\u003c/code\u003e and \u003ca href=\"#calling-tensorflow-functions-from-jax\"\u003elater\u003c/a\u003e\nfor \u003ccode\u003ejax2tf.call_tf\u003c/code\u003e.\nFor more involved examples, please see examples involving:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eSavedModel for archival (\u003ca href=\"#usage-saved-model\"\u003eexamples below\u003c/a\u003e), including\nsaving \u003ca href=\"#shape-polymorphic-conversion\"\u003ebatch-polymorphic functions\u003c/a\u003e,\u003c/li\u003e\n\u003cli\u003eTensorFlow.js (\u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md\"\u003eexamples\u003c/a\u003e),\u003c/li\u003e\n\u003cli\u003eTFX (\u003ca href=\"https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/README.md#instructions-for-using-flax\"\u003eexamples\u003c/a\u003e),\u003c/li\u003e\n\u003cli\u003eTensorFlow Hub and Keras (\u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md\"\u003eexamples\u003c/a\u003e).\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003e[TOC]\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eUsage: basic functions.\u003c/h2\u003e\u003ca id=\"user-content-usage-basic-functions\" class=\"anchor\" aria-label=\"Permalink: Usage: basic functions.\" href=\"#usage-basic-functions\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eAs a rule of thumb, if you can \u003ccode\u003ejax.jit\u003c/code\u003e your function then you should be able\nto use \u003ccode\u003ejax2tf.convert\u003c/code\u003e:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"from jax.experimental import jax2tf\nfrom jax import numpy as jnp\n\nimport numpy as np\nimport tensorflow as tf\n\ndef f_jax(x):\n return jnp.sin(jnp.cos(x))\n\n# jax2tf.convert is a higher-order function that returns a wrapped function with\n# the same signature as your input function but accepting TensorFlow tensors (or\n# variables) as input.\nf_tf = jax2tf.convert(f_jax)\n\n# For example you execute f_tf eagerly with valid TensorFlow inputs:\nf_tf(np.random.random(...))\n\n# Additionally you can use tools like `tf.function` to improve the execution\n# time of your function, or to stage it out to a SavedModel:\nf_tf_graph = tf.function(f_tf, autograph=False)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003efrom\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-s1\"\u003eexperimental\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eimport\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003efrom\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eimport\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enumpy\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eas\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e\n\n\u003cspan class=\"pl-k\"\u003eimport\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enumpy\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eas\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003eimport\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etensorflow\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eas\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e\n\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecos\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e))\n\n\u003cspan class=\"pl-c\"\u003e# jax2tf.convert is a higher-order function that returns a wrapped function with\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# the same signature as your input function but accepting TensorFlow tensors (or\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# variables) as input.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# For example you execute f_tf eagerly with valid TensorFlow inputs:\u003c/span\u003e\n\u003cspan class=\"pl-en\"\u003ef_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003erandom\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003erandom\u003c/span\u003e(...))\n\n\u003cspan class=\"pl-c\"\u003e# Additionally you can use tools like `tf.function` to improve the execution\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# time of your function, or to stage it out to a SavedModel:\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ef_tf_graph\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eNote that when using the default native serialization, the target JAX function\nmust be jittable (see \u003ca href=\"https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html\" rel=\"nofollow\"\u003eJAX - The Sharp Bits\u003c/a\u003e).\nIn the native serialization mode, under TensorFlow eager\nthe whole JAX function executes as one op.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe Autograph feature of \u003ccode\u003etf.function\u003c/code\u003e cannot be expected to work on\nfunctions lowered from JAX as above, so it is recommended to\nset \u003ccode\u003eautograph=False\u003c/code\u003e in order to speed up the execution\nand to avoid warnings and outright errors.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eUsage: saved model\u003c/h2\u003e\u003ca id=\"user-content-usage-saved-model\" class=\"anchor\" aria-label=\"Permalink: Usage: saved model\" href=\"#usage-saved-model\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eYou can serialize JAX program into a TensorFlow SavedModel, for use\nwith tooling that understands SavedModel. Both in native and non-native\nserialization you can count on 6 months of backwards compatibility (you\ncan load a function serialized today with tooling that will be built\nup to 6 months in the future), and 3 weeks of limited forwards compatibility\n(you can load a function serialized today with tooling that was built\nup to 3 weeks in the past, provided the model that not use any\nnew features).\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eSince jax2tf provides a regular TensorFlow function using it with SavedModel\nis trivial:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"# You can save the model just like you would with any other TensorFlow function:\nmy_model = tf.Module()\n# Save a function that can take scalar inputs.\nmy_model.f = tf.function(jax2tf.convert(f_jax), autograph=False,\n input_signature=[tf.TensorSpec([], tf.float32)])\ntf.saved_model.save(my_model, '/some/directory',\n options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))\n\n# Restoring (note: the restored model does *not* require JAX to run, just XLA).\nrestored_model = tf.saved_model.load('/some/directory')\"\u003e\u003cpre\u003e\u003cspan class=\"pl-c\"\u003e# You can save the model just like you would with any other TensorFlow function:\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eModule\u003c/span\u003e()\n\u003cspan class=\"pl-c\"\u003e# Save a function that can take scalar inputs.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ef\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003einput_signature\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eTensorSpec\u003c/span\u003e([], \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e)])\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esave\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e, \u003cspan class=\"pl-s\"\u003e'/some/directory'\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003eoptions\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eSaveOptions\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eexperimental_custom_gradients\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e))\n\n\u003cspan class=\"pl-c\"\u003e# Restoring (note: the restored model does *not* require JAX to run, just XLA).\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003erestored_model\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eload\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e'/some/directory'\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eAn important point is that in the above code snippet \u003cstrong\u003eeverything after the\njax2tf invocation is standard TensorFlow code.\nIn particular, the saving of the model is not directly part\nof the jax2tf API, and the user has full control over how to create the SavedModel\u003c/strong\u003e.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor example, just like for regular TensorFlow functions, it is possible to include in the\nSavedModel multiple versions of a function for different input shapes, by\n\"warming up\" the function on different input shapes:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"my_model.f = tf.function(jax2tf.convert(f_jax), autograph=False)\nmy_model.f(tf.ones([1, 28, 28])) # a batch size of 1\nmy_model.f(tf.ones([16, 28, 28])) # a batch size of 16\ntf.saved_model.save(my_model, '/some/directory',\n options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ef\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ef\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e([\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e28\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e28\u003c/span\u003e])) \u003cspan class=\"pl-c\"\u003e# a batch size of 1\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ef\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e([\u003cspan class=\"pl-c1\"\u003e16\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e28\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e28\u003c/span\u003e])) \u003cspan class=\"pl-c\"\u003e# a batch size of 16\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esave\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e, \u003cspan class=\"pl-s\"\u003e'/some/directory'\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003eoptions\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eSaveOptions\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eexperimental_custom_gradients\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eSaved model with parameters\u003c/h3\u003e\u003ca id=\"user-content-saved-model-with-parameters\" class=\"anchor\" aria-label=\"Permalink: Saved model with parameters\" href=\"#saved-model-with-parameters\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eSome special care is needed to ensure that the model parameters are not embedded\nas constants in the graph and are instead saved separately as variables.\nThis is useful for two reasons:\nthe parameters could be very large and exceed the 2GB limits of the\nGraphDef part of the SavedModel, or you may want to fine-tune the\nmodel and change the value of the parameters.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor example, consider the following function:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def model_jax(inputs):\n return param0 + param1 * inputs\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003emodel_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003einputs\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eparam0\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eparam1\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003einputs\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIf you just lower and save the model directly, the values of\n\u003ccode\u003eparam0\u003c/code\u003e and \u003ccode\u003eparam1\u003c/code\u003e will be embedded in the computation graph. In fact, the\nvalue of \u003ccode\u003eparam1\u003c/code\u003e is needed for the gradient computation and\nwill be embedded twice: once in the computation\ngraph for the forward computation and once for the backward computation,\nunless you turn off the staging of gradients or their saving as discussed\nfurther below (e.g., \u003ccode\u003ewith_gradient=False\u003c/code\u003e). Note also that if one\nviews the above function as an ML model parameterized by \u003ccode\u003eparam0\u003c/code\u003e and \u003ccode\u003eparam1\u003c/code\u003e\nthen the gradient function will be w.r.t. the inputs, while you probably\nwant gradients w.r.t. the parameters.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eA better way to deal with parameters (or any large constants) is to\npass them as parameters to the function to be lowered:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def model_jax(params, inputs):\n return params[0] + params[1] * inputs\n\n# Wrap the parameter constants as tf.Variables; this will signal to the model\n# saving code to save those constants as variables, separate from the\n# computation graph.\nparams_vars = tf.nest.map_structure(tf.Variable, params)\n\n# Build the prediction function by closing over the `params_vars`. If you\n# instead were to close over `params` your SavedModel would have no variables\n# and the parameters will be included in the function graph.\nprediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs)\n\nmy_model = tf.Module()\n# Tell the model saver what the variables are.\nmy_model._variables = tf.nest.flatten(params_vars)\nmy_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False)\ntf.saved_model.save(my_model)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003emodel_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eparams\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003einputs\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eparams\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eparams\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003einputs\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# Wrap the parameter constants as tf.Variables; this will signal to the model\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# saving code to save those constants as variables, separate from the\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# computation graph.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003eparams_vars\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003enest\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003emap_structure\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eparams\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# Build the prediction function by closing over the `params_vars`. If you\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# instead were to close over `params` your SavedModel would have no variables\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# and the parameters will be included in the function graph.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003eprediction_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003einputs\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003emodel_jax\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003eparams_vars\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003einputs\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eModule\u003c/span\u003e()\n\u003cspan class=\"pl-c\"\u003e# Tell the model saver what the variables are.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003e_variables\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003enest\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eflatten\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eparams_vars\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ef\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eprediction_tf\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ejit_compile\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esave\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThis strategy will avoid any copies of the large parameters in the computation\ngraph (they will be saved in a \u003ccode\u003evariables\u003c/code\u003e area of the model, which is not\nsubject to the 2GB limitation).\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor examples of how to save a Flax model as a SavedModel see the\n\u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md\"\u003eexamples directory\u003c/a\u003e.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eSaved model and differentiation\u003c/h3\u003e\u003ca id=\"user-content-saved-model-and-differentiation\" class=\"anchor\" aria-label=\"Permalink: Saved model and differentiation\" href=\"#saved-model-and-differentiation\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe code lowered from JAX supports differentiation from TensorFlow. In order to\nensure that the result of TensorFlow differentiation is identical to the\none that JAX differentiation would produce, we will\nannotate the lowered primal function with a \u003ccode\u003etf.custom_gradient\u003c/code\u003e that,\nupon TensorFlow differentiation, will lazily\ncall into JAX to compute the \u003ccode\u003ejax.vjp\u003c/code\u003e of the lowered primal function, followed by\njax2tf lowering of the gradient function.\nThis ensures that ultimately it is JAX that performs the\ndifferentiation, thus respecting any custom gradients that may be present\nin the original function.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe \u003ccode\u003ejax2tf.convert\u003c/code\u003e function has an option \u003ccode\u003ewith_gradient=False\u003c/code\u003e to skip the\ncustom gradients and wrap instead the lowered function with\n\u003ccode\u003etf.raw_ops.PreventGradient\u003c/code\u003e to generate an error in case a gradient\ncomputation is attempted.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eSavedModels enables saving custom derivative rules by using the \u003ccode\u003eexperimental_custom_gradients\u003c/code\u003e option:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)\ntf.saved_model.save(model, path, options=options)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003eoptions\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eSaveOptions\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eexperimental_custom_gradients\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esave\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003emodel\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epath\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eoptions\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003eoptions\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIf you use \u003ccode\u003ewith_gradient=True\u003c/code\u003e and forget to use the \u003ccode\u003eexperimental_custom_gradients=True\u003c/code\u003e parameter\nto \u003ccode\u003etf.saved_model.save\u003c/code\u003e when you later load the saved model you will see a warning:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"WARNING:absl:Importing a function (__inference_converted_fun_25) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003eWARNING:absl:Importing a function (__inference_converted_fun_25) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eand if you do attempt to take a gradient of the loaded model you may get an error:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"TypeError: An op outside of the function building code is being passed\na \u0026quot;Graph\u0026quot; tensor. It is possible to have Graph tensors\nleak out of the function building context by including a\ntf.init_scope in your function building code.\nFor example, the following function will fail:\n @tf.function\n def has_init_scope():\n my_constant = tf.constant(1.)\n with tf.init_scope():\n added = my_constant * 2\nThe graph tensor has name: args_0:0\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003eTypeError: An op outside of the function building code is being passed\na \"Graph\" tensor. It is possible to have Graph tensors\nleak out of the function building context by including a\ntf.init_scope in your function building code.\nFor example, the following function will fail:\n @tf.function\n def has_init_scope():\n my_constant = tf.constant(1.)\n with tf.init_scope():\n added = my_constant * 2\nThe graph tensor has name: args_0:0\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003e(We are working with the TF team to give a more explicit error in this case.)\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eSaved model for non-differentiable JAX functions\u003c/h3\u003e\u003ca id=\"user-content-saved-model-for-non-differentiable-jax-functions\" class=\"anchor\" aria-label=\"Permalink: Saved model for non-differentiable JAX functions\" href=\"#saved-model-for-non-differentiable-jax-functions\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eNote that if the JAX function is not reverse-mode differentiable, e.g., uses \u003ccode\u003elax.while_loop\u003c/code\u003e then\nattempting to save its conversion to a SavedModel will fail with:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"ValueError: Error when tracing gradients for SavedModel\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003eValueError: Error when tracing gradients for SavedModel\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eYou have two options, either pass \u003ccode\u003ewith_gradient=False\u003c/code\u003e to \u003ccode\u003ejax2tf.convert\u003c/code\u003e, or\nset \u003ccode\u003etf.saved_model.SaveOptions(experimental_custom_gradients=False)\u003c/code\u003e. In either case,\nyou will not be able to compute the gradients of the function loaded from the SavedModel.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eSupport for partitioning\u003c/h2\u003e\u003ca id=\"user-content-support-for-partitioning\" class=\"anchor\" aria-label=\"Permalink: Support for partitioning\" href=\"#support-for-partitioning\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003ejax2tf supports JAX functions that use \u003ccode\u003ejax.pjit\u003c/code\u003e and \u003ccode\u003ejax.jit\u003c/code\u003e with sharded\narguments and results, for single-host meshes.\nThe lowering is actually similar as for a \u003ccode\u003ejax.jit\u003c/code\u003e, except that the\narguments and results will be wrapped with\n\u003ccode\u003etensorflow.python.compiler.xla.experimental.xla_sharding.XlaSharding\u003c/code\u003e TensorFlow ops.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIn the default native serialization mode, if the target JAX function\nincludes sharding operations, e.g., from nested \u003ccode\u003ejax.pjit\u003c/code\u003e, then\nthere should be a top-level \u003ccode\u003ejax.pjit\u003c/code\u003e. E.g.,\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"# The following is correct\nwith mesh:\n jax2tf.convert(pjit.pjit(f_jax, in_shardings=...))(...)\n\n# The following will lead to errors because pjit is not at top-level.\ndef wrapped_pjit(x):\n ...pjit.pjit(f_jax, in_shardings=...))...\n\nwith mesh:\n jax2tf.convert(wrapped_pjit)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-c\"\u003e# The following is correct\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003emesh\u003c/span\u003e:\n \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003epjit\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003epjit\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ein_shardings\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e...))(...)\n\n\u003cspan class=\"pl-c\"\u003e# The following will lead to errors because pjit is not at top-level.\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ewrapped_pjit\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n ...\u003cspan class=\"pl-s1\"\u003epjit\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003epjit\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ein_shardings\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e...))...\n\n\u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003emesh\u003c/span\u003e:\n \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ewrapped_pjit\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eA limitation of \u003ccode\u003eXlaSharding\u003c/code\u003e is that it cannot be used in TensorFlow eager\nmode. Therefore, \u003ccode\u003ejax2tf\u003c/code\u003e will give an error when lowering a function that\nrequires sharded (not replicated) arguments or results and the lowered\nfunction is used outside a \u003ccode\u003etf.function\u003c/code\u003e context (see b/255511660).\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eAnother limitation is that today only TPUs have integrated with XLA SPMD\nsupport in serving, while CPUs and GPUs don't have e2e XLA SPMD support yet in\nTensorFlow. Executing a jax2tf converted \u003ccode\u003etf.function\u003c/code\u003e with \u003ccode\u003eXlaSharding\u003c/code\u003e ops on\nCPUs and GPUs will simply ignore all the \u003ccode\u003eXlaSharding\u003c/code\u003e ops.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eNote that when saving a model, the parameters to the model are wrapped with\n\u003ccode\u003etf.Variable\u003c/code\u003e before calling the lowered function (see \u003ca href=\"#saved_model_with_parameters\"\u003eabove\u003c/a\u003e),\ntherefore outside of the \u003ccode\u003eXlaSharding\u003c/code\u003e wrapper.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eShape-polymorphic conversion\u003c/h2\u003e\u003ca id=\"user-content-shape-polymorphic-conversion\" class=\"anchor\" aria-label=\"Permalink: Shape-polymorphic conversion\" href=\"#shape-polymorphic-conversion\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003e\u003cstrong\u003eThe shape polymorphism support is work in progress.\nPlease report any bugs you encounter.\u003c/strong\u003e\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eWe described above how to include in the SavedModel several specializations\nof a lowered function for a few specific input shapes. \u003ccode\u003ejax2tf\u003c/code\u003e can\nalso produce a shape-polymorphic TensorFlow graph that is usable with inputs\nof any shape matching\ncertain constraints. This is useful, e.g., to allow a single SavedModel\nto be used for multiple batch sizes.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe standard TensorFlow technique for producing a shape-polymorphic graph is\nto warm the \u003ccode\u003etf.function\u003c/code\u003e on partially-specified (shape-polymorphic) inputs, e.g.,\n\u003ccode\u003etf.TensorSpec([None, 28, 28], tf.float32)\u003c/code\u003e for a function that processes a\nbatch (of unspecified batch size) of 28x28 images.\nFor jax2tf it is \u003cstrong\u003eadditionally\u003c/strong\u003e necessary to specify an additional \u003ccode\u003epolymorphic_shapes\u003c/code\u003e parameter\nfor the \u003ccode\u003ejax2tf.convert\u003c/code\u003e function:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"f_tf = tf.function(jax2tf.convert(f_jax,\n polymorphic_shapes=[\u0026quot;(b, 28, 28)\u0026quot;]),\n autograph=False)\nf_tf.get_concrete_function(tf.TensorSpec([None, 28, 28], tf.float32))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(b, 28, 28)\"\u003c/span\u003e]),\n \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eget_concrete_function\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eTensorSpec\u003c/span\u003e([\u003cspan class=\"pl-c1\"\u003eNone\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e28\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e28\u003c/span\u003e], \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe \u003ccode\u003epolymorphic_shapes\u003c/code\u003e parameter, in the form of a pytree of strings corresponding\nto the pytree of positional\narguments, introduces one or more dimension variables, e.g., \u003ccode\u003eb\u003c/code\u003e, to stand for shape\ndimensions that are assumed to be unknown at JAX tracing time.\nDimension variables are assumed to range\nover all integers that are greater or equal to 1.\nIn this particular example, we can\nalso abbreviate \u003ccode\u003epolymorphic_shapes=[\"(b, _, _)\"]\u003c/code\u003e,\nbecause the \u003ccode\u003e_\u003c/code\u003e placeholders take their value\nfrom the corresponding dimension of the \u003ccode\u003etf.TensorSpec\u003c/code\u003e (which must be known).\nAs a further shortcut for a series of \u003ccode\u003e_\u003c/code\u003e at the end of a shape specification you can\nuse \u003ccode\u003e...\u003c/code\u003e: \u003ccode\u003epolymorphic_shapes=[\"(b, ...)\"]\u003c/code\u003e.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIn the example above, the \u003ccode\u003epolymorphic_shapes\u003c/code\u003e specification does\nnot convey more information than the partial \u003ccode\u003etf.TensorSpec\u003c/code\u003e,\nexcept that it gives a name to the unknown dimension, which improves\nerror messages. The real need for named shape\nvariables arises when there are\nmultiple unknown dimensions and there is a relationship between them.\nFor example,\nif the function to be lowered is also polymorphic on the size of each\nimage while requiring the images to be square,\nwe would add a dimension variable \u003ccode\u003ed\u003c/code\u003e to stand for\nthe unknown image size:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"f_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes=[\u0026quot;(b, d, d)\u0026quot;]), autograph=False)\nf_tf.get_concrete_function(tf.TensorSpec([None, None, None], tf.float32))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(b, d, d)\"\u003c/span\u003e]), \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eget_concrete_function\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eTensorSpec\u003c/span\u003e([\u003cspan class=\"pl-c1\"\u003eNone\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003eNone\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003eNone\u003c/span\u003e], \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe JAX tracing mechanism performs shape checking using the same strict rules as\nwhen the shapes are fully known. For example, given the \u003ccode\u003e\"(b, d, d)\"\u003c/code\u003e\nspecification for the argument \u003ccode\u003ex\u003c/code\u003e of a function, JAX will know that a conditional\n\u003ccode\u003ex.shape[-2] == x.shape[-1]\u003c/code\u003e is \u003ccode\u003eTrue\u003c/code\u003e, and will also know that \u003ccode\u003ex\u003c/code\u003e and \u003ccode\u003ejnp.sin(x)\u003c/code\u003e have the\nsame shape of a batch of square matrices that can be passed to \u003ccode\u003ejnp.matmul\u003c/code\u003e.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eCorrectness of shape-polymorphic tracing\u003c/h3\u003e\u003ca id=\"user-content-correctness-of-shape-polymorphic-tracing\" class=\"anchor\" aria-label=\"Permalink: Correctness of shape-polymorphic tracing\" href=\"#correctness-of-shape-polymorphic-tracing\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eWe want to trust that the lowered program produces the same results as the\noriginal JAX program. More precisely:\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor any function \u003ccode\u003ef_jax\u003c/code\u003e and any input signature \u003ccode\u003eabs_sig\u003c/code\u003e containing partially\nknown \u003ccode\u003etf.TensorSpec\u003c/code\u003e, and any concrete input \u003ccode\u003ex\u003c/code\u003e whose shape matches \u003ccode\u003eabs_sig\u003c/code\u003e:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eIf the conversion to TensorFlow succeeds: \u003ccode\u003ef_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes)).get_concrete_function(abs_sig)\u003c/code\u003e\u003c/li\u003e\n\u003cli\u003eand if the TensorFlow execution succeeds with result \u003ccode\u003ey\u003c/code\u003e: \u003ccode\u003ef_tf(x) = y\u003c/code\u003e\u003c/li\u003e\n\u003cli\u003ethen the JAX execution would produce the same result: \u003ccode\u003ef_jax(x) = y\u003c/code\u003e,\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eIt is crucial to understand that \u003ccode\u003ef_jax(x)\u003c/code\u003e has the freedom to re-invoke the JAX tracing machinery,\nand in fact it does so for each distinct concrete input shape, while the generation of \u003ccode\u003ef_tf\u003c/code\u003e\nuses JAX tracing only once, and invoking \u003ccode\u003ef_tf(x)\u003c/code\u003e does not use JAX tracing anymore. In fact,\nthe latter invocation may happen after the \u003ccode\u003ef_tf\u003c/code\u003e has been serialized\nto a SavedModel and reloaded in an environment where \u003ccode\u003ef_jax\u003c/code\u003e and the JAX\ntracing machinery are not available anymore.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eCoverage of shape-polymorphic tracing\u003c/h3\u003e\u003ca id=\"user-content-coverage-of-shape-polymorphic-tracing\" class=\"anchor\" aria-label=\"Permalink: Coverage of shape-polymorphic tracing\" href=\"#coverage-of-shape-polymorphic-tracing\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eBesides correctness, a secondary goal is to be able to lower many shape-polymorphic programs,\nbut at the very\nleast batch-size-polymorphic programs, so that one SavedModel can be used for any batch sizes.\nFor example, we want to ensure that any function written using \u003ccode\u003ejax.vmap\u003c/code\u003e at the top level can be\nlowered with the batch dimension polymorphic and the remaining dimensions concrete.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIt is reasonable to expect that there will be JAX programs for which there is a\nshape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf.\nIn general, you should expect that shape polymorphism can handle those programs for which\nall the intermediate shapes can be expressed as simple expressions in the dimension variables\nappearing in the input shapes. In particular, this does not apply to programs whose\nintermediate shapes depend on the data.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eDetails\u003c/h3\u003e\u003ca id=\"user-content-details\" class=\"anchor\" aria-label=\"Permalink: Details\" href=\"#details\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIn order to be able to use shape polymorphism effectively with jax2tf, it\nis worth considering what happens under the hood. When the lowered function\nis invoked with a \u003ccode\u003eTensorSpec\u003c/code\u003e, \u003ccode\u003ejax2tf\u003c/code\u003e will use the \u003ccode\u003epolymorphic_shapes\u003c/code\u003e parameter\nto obtain a shape abstraction for the inputs. The dimension sizes from the\n\u003ccode\u003eTensorSpec\u003c/code\u003e are used to fill in the \u003ccode\u003e_\u003c/code\u003e and \u003ccode\u003e...\u003c/code\u003e placeholders from \u003ccode\u003epolymorphic_shapes\u003c/code\u003e.\nNormally, the shape abstraction contains the dimension sizes, but in the\npresence of shape polymorphism, some dimensions may be dimension variables.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe \u003ccode\u003epolymorphic_shapes\u003c/code\u003e parameter must be either \u003ccode\u003eNone\u003c/code\u003e,\nor a pytree of shape specifiers corresponding to the pytree of arguments.\n(A value \u003ccode\u003eNone\u003c/code\u003e for \u003ccode\u003epolymorphic_shapes\u003c/code\u003e is equivalent to a list of \u003ccode\u003eNone\u003c/code\u003e.\nSee \u003ca href=\"https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees\" rel=\"nofollow\"\u003ehow optional parameters are matched to arguments\u003c/a\u003e.)\nA shape specifier is combined with a \u003ccode\u003eTensorSpec\u003c/code\u003e as follows:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003e\n\u003cp dir=\"auto\"\u003eA shape specifier of \u003ccode\u003eNone\u003c/code\u003e means that the shape is given\nby the actual argument \u003ccode\u003eTensorSpec\u003c/code\u003e, which must be fully known.\u003c/p\u003e\n\u003c/li\u003e\n\u003cli\u003e\n\u003cp dir=\"auto\"\u003eOtherwise, the specifier must be a comma-separated string of dimension specifiers: \u003ccode\u003e(dim_1, ..., dim_n)\u003c/code\u003e, denoting\nan n-dimensional array. The \u003ccode\u003eTensorSpec\u003c/code\u003e must also be of rank \u003ccode\u003en\u003c/code\u003e.\nAn \u003ccode\u003e...\u003c/code\u003e at the end of the shape specifier is expanded to a list of \u003ccode\u003e_\u003c/code\u003e or appropriate length.\nThe corresponding dimensions from the shape specifier and the \u003ccode\u003eTensorSpec\u003c/code\u003e are matched:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003ethe dimension specifier of \u003ccode\u003e_\u003c/code\u003e means that the size of the dimension is given by\nthe actual \u003ccode\u003eTensorSpec\u003c/code\u003e, which must have a known size in the corresponding dimension.\u003c/li\u003e\n\u003cli\u003ea dimension specifier can also be a lowercase identifier, denoting a dimension-size\nvariable ranging over strictly positive integers.\nThe abstract value of the dimension is going to be set to this variable.\nThe corresponding dimension in \u003ccode\u003eTensorSpec\u003c/code\u003e can be \u003ccode\u003eNone\u003c/code\u003e or can be a\nconstant.\u003c/li\u003e\n\u003cli\u003eAll occurrences of a dimension variable in any dimension\nfor any argument are assumed to be equal.\u003c/li\u003e\n\u003c/ul\u003e\n\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eNote that \u003ccode\u003epolymorphic_shapes\u003c/code\u003e controls the shape abstraction used by JAX when tracing\nthe function. The \u003ccode\u003eTensorSpec\u003c/code\u003e\ngives the shape abstraction that TensorFlow will associate with the produced\ngraph, and can be more specific.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eA few examples of shape specifications and uses:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003e\n\u003cp dir=\"auto\"\u003e\u003ccode\u003epolymorphic_shapes=[\"(b, _, _)\", None]\u003c/code\u003e can be used for a function with two arguments, the first\nhaving a batch leading dimension that should be polymorphic. The other dimensions for the\nfirst argument and the shape of the second argument are specialized based on the actual\n\u003ccode\u003eTensorSpec\u003c/code\u003e, which must be known. The lowered function can be used, e.g.,\nwith \u003ccode\u003eTensorSpec\u003c/code\u003es \u003ccode\u003e[None, 28, 28]\u003c/code\u003e and \u003ccode\u003e[28, 16]\u003c/code\u003e for the first and second argument\nrespectively. An alternative \u003ccode\u003eTensorSpec\u003c/code\u003e pair can be \u003ccode\u003e[1, 28, 28]\u003c/code\u003e and \u003ccode\u003e[28, 16]\u003c/code\u003e,\nin which case the JAX tracing is done for the same polymorphic shape given by\n\u003ccode\u003epolymorphic_shapes=[\"(b, 28, 28)\", \"(28, 16)\"]\u003c/code\u003e.\u003c/p\u003e\n\u003c/li\u003e\n\u003cli\u003e\n\u003cp dir=\"auto\"\u003e\u003ccode\u003epolymorphic_shapes=[\"(batch, _)\", \"(batch,)\"]\u003c/code\u003e: the leading dimensions of the two arguments\nmust match, and are assumed to be greater than 1.\nThe second dimension of the first argument is taken from the\nactual \u003ccode\u003eTensorSpec\u003c/code\u003e. This can be used with a \u003ccode\u003eTensorSpec\u003c/code\u003e pair \u003ccode\u003e[None, 16]\u003c/code\u003e\nand \u003ccode\u003e[None]\u003c/code\u003e. It can also be used with a pair of shapes \u003ccode\u003e[8, 16]\u003c/code\u003e and \u003ccode\u003e[8]\u003c/code\u003e.\u003c/p\u003e\n\u003c/li\u003e\n\u003c/ul\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eComputing with dimension variables\u003c/h3\u003e\u003ca id=\"user-content-computing-with-dimension-variables\" class=\"anchor\" aria-label=\"Permalink: Computing with dimension variables\" href=\"#computing-with-dimension-variables\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eJAX keeps track of the shape of all intermediate results. When those shapes depend\non dimension variables JAX computes them as symbolic expressions\ninvolving dimension variables. The symbolic expressions can represent the result\nof applying arithmetic operators (add, sub, mul, floordiv, mod,\nincluding the NumPy variants \u003ccode\u003enp.sum\u003c/code\u003e, \u003ccode\u003enp.prod\u003c/code\u003e, etc.) \u003cstrong\u003eon dimension\nvariables and integers\u003c/strong\u003e (\u003ccode\u003eint\u003c/code\u003e, \u003ccode\u003enp.int\u003c/code\u003e, or anything convertible by \u003ccode\u003eoperator.index\u003c/code\u003e).\nThese symbolic dimensions can then be used in shape-parameters of JAX primitives\nand APIs, e.g., in \u003ccode\u003ejnp.reshape\u003c/code\u003e, \u003ccode\u003ejnp.arange\u003c/code\u003e, slicing indices, etc.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor example, in the following code to flatten a 2D array, the computation\n\u003ccode\u003ex.shape[0] * x.shape[1]\u003c/code\u003e computes the symbolic dimension \u003ccode\u003e4 * b\u003c/code\u003e as the\nnew shape:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)),\n polymorphic_shapes=[\u0026quot;(b, 4)\u0026quot;])(np.ones((3, 4)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ereshape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, (\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e],)),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(b, 4)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e3\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eWhen a symbolic dimension is used in \u003cstrong\u003earithmetic operations with non-integers\u003c/strong\u003e,\ne.g., \u003ccode\u003efloat\u003c/code\u003e, \u003ccode\u003enp.float\u003c/code\u003e, \u003ccode\u003enp.ndarray\u003c/code\u003e, or JAX arrays, it is automatically\nconverted to a JAX array using \u003ccode\u003ejnp.array\u003c/code\u003e.\nFor example, in the function below all occurrences of \u003ccode\u003ex.shape[0]\u003c/code\u003e\nare converted implicitly to \u003ccode\u003ejnp.array(x.shape[0])\u003c/code\u003e because\nthey are involved in operations with non-integer scalars or with\nJAX arrays:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: (x + x.shape[0] + jnp.sin(x.shape[0]),\n 5. + x.shape[0],\n x.shape[0] - np.ones((5,), dtype=np.int32)),\n polymorphic_shapes=[\u0026quot;b\u0026quot;])(np.ones(3))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: (\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e]),\n \u003cspan class=\"pl-c1\"\u003e5.\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e],\n \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e-\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e5\u003c/span\u003e,), \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eint32\u003c/span\u003e)),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"b\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eAnother typical example is when computing averages:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],\n polymorphic_shapes=[\u0026quot;(v, _)\u0026quot;])(np.ones((3, 4)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esum\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eaxis\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e) \u003cspan class=\"pl-c1\"\u003e/\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e],\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(v, _)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e3\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIt is also possible to convert dimension polynomials explicitly\nto JAX arrays, with \u003ccode\u003ejnp.array(x.shape[0])\u003c/code\u003e or even \u003ccode\u003ejnp.array(x.shape)\u003c/code\u003e.\nThe result of these operations\ncannot be used anymore as dimension parameters and will raise a JAX error.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eErrors in presence of shape polymorphism\u003c/h3\u003e\u003ca id=\"user-content-errors-in-presence-of-shape-polymorphism\" class=\"anchor\" aria-label=\"Permalink: Errors in presence of shape polymorphism\" href=\"#errors-in-presence-of-shape-polymorphism\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eMost JAX code assumes that the shapes of JAX arrays are tuples of integers,\nbut with shape polymorphism some dimensions may be symbolic expressions.\nThis can lead to a number of errors. For example, the program:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"four_ones = np.ones((4,))\njax2tf.convert(lambda x, y: x + y,\n polymorphic_shapes=[\u0026quot;(v,)\u0026quot;, \u0026quot;(4,)\u0026quot;])(four_ones, four_ones)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003efour_ones\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e,))\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(v,)\"\u003c/span\u003e, \u003cspan class=\"pl-s\"\u003e\"(4,)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003efour_ones\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003efour_ones\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003ewith result in the error \u003ccode\u003e'add got incompatible shapes for broadcasting: (v,), (4,)'\u003c/code\u003e\nbecause the shape abstraction that JAX tracing uses is given by the\n\u003ccode\u003epolymorphic_shapes\u003c/code\u003e, even though the\nactual arguments are more specific and would actually work.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eAlso,\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: jnp.matmul(x, x),\n polymorphic_shapes=[\u0026quot;(v, 4)\u0026quot;])(np.ones((4, 4)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ematmul\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(v, 4)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003ewill result in the error \u003ccode\u003edot_general requires contracting dimensions to have the same shape, got [4] and [v]\u003c/code\u003e. What is\nhappening here is that in the process of type checking the \u003ccode\u003ematmul\u003c/code\u003e operation, JAX\nwill want to ensure the size of the two axes is the same (\u003ccode\u003ev == 4\u003c/code\u003e).\nNote that \u003ccode\u003ev\u003c/code\u003e can stand for any integer greater than 0, so the value of the\nequality expression can be true or false. Since it is not always true\nthat \u003ccode\u003ev == 4\u003c/code\u003e, the shape checking rules fail with the above error.\nSince the lowered function works only for square matrices, the correct\n\u003ccode\u003epolymorphic_shapes\u003c/code\u003e is \u003ccode\u003e[\"(v, v)\"]\u003c/code\u003e.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eAs explained above, if the dimension polynomials are used in operations with\nnon-integers, the result will be a JAX array that cannot be used as a shape\nparameter. For example, if we modify the reshape example slightly,\nto use \u003ccode\u003enp.array([x.shape[1]])\u003c/code\u003e instead of \u003ccode\u003ex.shape[1]\u003c/code\u003e:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * np.array([x.shape[1]]),)),\n polymorphic_shapes=[\u0026quot;(b, 4)\u0026quot;])(np.ones((3, 4)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ereshape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, (\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003earray\u003c/span\u003e([\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e]]),)),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(b, 4)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e3\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003ewe get an error \u003ccode\u003eShapes must be 1D sequences of concrete values of integer type, got Traced\u0026lt;...\u0026gt;\u003c/code\u003e.\nIf you get this error on JAX code that works for static shapes, it means that one operation\nthat computes shape parameters is using non-integer arguments, e.g., \u003ccode\u003enp.ndarray\u003c/code\u003e, that get\nimplicitly converted to JAX arrays.\nThe solution is to avoid \u003ccode\u003enp.array\u003c/code\u003e, \u003ccode\u003efloat\u003c/code\u003e, or JAX arrays in operations whose\nresults are used as shapes, e.g., instead of \u003ccode\u003enp.arange(n) * x.shape[0]\u003c/code\u003e write\n\u003ccode\u003e[i * x.shape[0] for i in range(n)]\u003c/code\u003e.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eDimension variables must be solvable from the input shapes\u003c/h3\u003e\u003ca id=\"user-content-dimension-variables-must-be-solvable-from-the-input-shapes\" class=\"anchor\" aria-label=\"Permalink: Dimension variables must be solvable from the input shapes\" href=\"#dimension-variables-must-be-solvable-from-the-input-shapes\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eJAX will generate code to derive the values of the dimension variables\nfrom the input shapes. This works only if the symbolic dimensions in the input shapes are linear.\nFor example, the following \u003ccode\u003epolymorphic_shapes\u003c/code\u003e will result in errors:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"polymorphic_shapes = [\u0026quot;a * a\u0026quot;] # Not a linear polynomial\npolymorphic_shapes = [\u0026quot;a + b\u0026quot;] # Too few equations to derive both `a` and `b`\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e [\u003cspan class=\"pl-s\"\u003e\"a * a\"\u003c/span\u003e] \u003cspan class=\"pl-c\"\u003e# Not a linear polynomial\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e [\u003cspan class=\"pl-s\"\u003e\"a + b\"\u003c/span\u003e] \u003cspan class=\"pl-c\"\u003e# Too few equations to derive both `a` and `b`\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe error message for the last specification above would be:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"Cannot solve for values of dimension variables {'a', 'b'}. \u0026quot;\nWe can only solve linear uni-variate constraints. \u0026quot;\nUsing the following polymorphic shapes specifications: args[0].shape = (a + b,).\nUnprocessed specifications: 'a + b' for dimension size args[0].shape[0]. \u0026quot;\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003eCannot solve for values of dimension variables {'a', 'b'}. \"\nWe can only solve linear uni-variate constraints. \"\nUsing the following polymorphic shapes specifications: args[0].shape = (a + b,).\nUnprocessed specifications: 'a + b' for dimension size args[0].shape[0]. \"\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eShape assertion errors\u003c/h3\u003e\u003ca id=\"user-content-shape-assertion-errors\" class=\"anchor\" aria-label=\"Permalink: Shape assertion errors\" href=\"#shape-assertion-errors\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eJAX assumes that dimension variables range over strictly positive integers.\nStarting with serialization version 7 these assumptions are\nchecked against the shapes of the actual arguments\nwhen the lowered code is invoked.\nFor example, given the \u003ccode\u003epolymorphic_shapes=\"(b, b, 2*d)\"\u003c/code\u003e\nspecification, we will generate code to check the following constraints when\ninvoked with actual argument \u003ccode\u003earg\u003c/code\u003e:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003e\u003ccode\u003earg.shape[0] \u0026gt;= 1\u003c/code\u003e\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003earg.shape[1] == arg.shape[0]\u003c/code\u003e\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003earg.shape[2] % 2 == 0\u003c/code\u003e\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003earg.shape[2] // 2 \u0026gt;= 1\u003c/code\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eAn example error for the third constraint above, e.g., when invoked with\nshape \u003ccode\u003e(3, 3, 5)\u003c/code\u003e, would be:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"Input shapes do not match the polymorphic shapes specification.\nDivision had remainder 1 when computing the value of 'd'.\nUsing the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d).\nObtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3).\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003eInput shapes do not match the polymorphic shapes specification.\nDivision had remainder 1 when computing the value of 'd'.\nUsing the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d).\nObtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3).\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eWhen using native serialization these are checked by the \u003ccode\u003etf.XlaCallModule\u003c/code\u003e\nop (starting with serialization\n\u003ca href=\"https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Aconfig.py+jax_serialization_version\u0026amp;type=code\"\u003eversion 7\u003c/a\u003e),\nand you will get \u003ccode\u003etf.errors.InvalidArgument\u003c/code\u003e errors.\nYou can disable this checking by including \u003ccode\u003eDisabledSafetyCheck.shape_assertions()\u003c/code\u003e\nin the \u003ccode\u003edisabled_checks\u003c/code\u003e parameter to \u003ccode\u003ejax2tf.convert\u003c/code\u003e, or by setting\nthe environment variable\n\u003ccode\u003eTF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=shape_assertions\u003c/code\u003e.\nWhen using graph serialization these are checked using \u003ccode\u003etf.debugging.assert\u003c/code\u003e,\nwhich will also result in \u003ccode\u003etf.errors.InvalidArgument\u003c/code\u003e.\nNote that due to limitations in TensorFlow, these errors are suppressed when using\n\u003ccode\u003ejit_compile=True\u003c/code\u003e and when running on TPU.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eComparison of symbolic dimensions is partially supported\u003c/h3\u003e\u003ca id=\"user-content-comparison-of-symbolic-dimensions-is-partially-supported\" class=\"anchor\" aria-label=\"Permalink: Comparison of symbolic dimensions is partially supported\" href=\"#comparison-of-symbolic-dimensions-is-partially-supported\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eInside JAX there are a number of equality and inequality comparisons\ninvolving shapes, e.g., for doing shape checking or even for choosing\nthe implementation for some primitives. Comparisons are supported\nas follows:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eequality is supported with a caveat: if the two symbolic dimensions denote the same\nvalue under all valuations for dimension variables, then equality evaluates to \u003ccode\u003eTrue\u003c/code\u003e,\ne.g., for \u003ccode\u003eb + b == 2*b\u003c/code\u003e; otherwise the equality evaluates to \u003ccode\u003eFalse\u003c/code\u003e. See below\nfor a discussion of important consequences of this behavior.\u003c/li\u003e\n\u003cli\u003edisequality is always the negation of equality.\u003c/li\u003e\n\u003cli\u003einequality is partially supported, in a similar way as partial equality.\nHowever, in this\ncase we take into consideration that dimension variables range over strictly positive\nintegers. E.g., \u003ccode\u003eb \u0026gt;= 1\u003c/code\u003e, \u003ccode\u003eb \u0026gt;= 0\u003c/code\u003e, \u003ccode\u003e2 * a + b \u0026gt;= 3\u003c/code\u003e are \u003ccode\u003eTrue\u003c/code\u003e, while \u003ccode\u003eb \u0026gt;= 2\u003c/code\u003e,\n\u003ccode\u003ea \u0026gt;= b\u003c/code\u003e, \u003ccode\u003ea - b \u0026gt;= 0\u003c/code\u003e are inconclusive and result in an exception.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eFor example, the following code raises the exception\n\u003ccode\u003ecore.InconclusiveDimensionOperation\u003c/code\u003e with the message\n\u003ccode\u003eDimension polynomial comparison 'a + 1' \u0026gt;= 'b' is inconclusive\u003c/code\u003e.\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: 0 if x.shape[0] + 1 \u0026gt;= x.shape[1] else 1,\n polymorphic_shapes=[\u0026quot;(a, b)\u0026quot;])(np.ones((3, 4)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eif\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e\u0026gt;=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e] \u003cspan class=\"pl-k\"\u003eelse\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(a, b)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e3\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIf you do get an \u003ccode\u003ecore.InconclusiveDimensionOperation\u003c/code\u003e, you can try\nseveral strategies:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eIf your code uses the built-in \u003ccode\u003emax\u003c/code\u003e or \u003ccode\u003emin\u003c/code\u003e, or the\n\u003ccode\u003enp.max\u003c/code\u003e or \u003ccode\u003enp.min\u003c/code\u003e then you can replace those with\n\u003ccode\u003ecore.max_dim\u003c/code\u003e and \u003ccode\u003ecore.min_dim\u003c/code\u003e, which have the effect\nof delaying the inequality comparison to the compilation\ntime, when shapes become known.\u003c/li\u003e\n\u003cli\u003eTry to rewrite conditionals using \u003ccode\u003ecore.max_dim\u003c/code\u003e and\n\u003ccode\u003ecore.min_dim\u003c/code\u003e, e.g., instead of \u003ccode\u003ed if d \u0026gt; 0 else 0\u003c/code\u003e\nyou can write \u003ccode\u003ecore.max_dim(d, 0)\u003c/code\u003e.\u003c/li\u003e\n\u003cli\u003eTry to rewrite the code to be less dependent on the fact\nthat dimensions should be integers, and rely on the fact\nthat symbolic dimensions duck-type as integers for most\narithmetic operations. E.g., instead of \u003ccode\u003eint(d) + 5\u003c/code\u003e write\n\u003ccode\u003ed + 5\u003c/code\u003e.\u003c/li\u003e\n\u003cli\u003eSpecify symbolic constraints, as explained below.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch4 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eUser-specified symbolic constraints\u003c/h4\u003e\u003ca id=\"user-content-user-specified-symbolic-constraints\" class=\"anchor\" aria-label=\"Permalink: User-specified symbolic constraints\" href=\"#user-specified-symbolic-constraints\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eBy default, JAX assumes that all dimension variables range\nover values greater-or-equal to 1, and it tries to derive\nother simple inequalities from that, e.g.:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003e\u003ccode\u003ea + 2 \u0026gt;= 3\u003c/code\u003e,\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003ea * 2 \u0026gt;= 1\u003c/code\u003e,\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003ea + b + c \u0026gt;= 3\u003c/code\u003e,\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003ea // 4 \u0026gt;= 0\u003c/code\u003e, \u003ccode\u003ea**2 \u0026gt;= 1\u003c/code\u003e, and so on.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eYou can avoid some inequality comparison failures if you\nchange the symbolic shape specifications to add implicit constraints\nfor dimension sizes. E.g.,\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eYou can use \u003ccode\u003e2*b\u003c/code\u003e for a dimension to constrain it to be even (and \u003ccode\u003e\u0026gt;= 2\u003c/code\u003e).\u003c/li\u003e\n\u003cli\u003eYou can use \u003ccode\u003eb + 15\u003c/code\u003e for a dimension to constrain it to\nbe at least 16. E.g., the following code would fail without\nthe \u003ccode\u003e+ 15\u003c/code\u003e part, because JAX will want to verify that slice sizes\nare at most as large as the axis size.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: x[0:16],\n polymorphic_shapes=\u0026quot;b + 15, ...\u0026quot;)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e:\u003cspan class=\"pl-c1\"\u003e16\u003c/span\u003e],\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s\"\u003e\"b + 15, ...\"\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eSuch implicit symbolic constraints are used for reasoning, and are\nchecked at compile time, as explained \u003ca href=\"#shape-assertion-errors\"\u003eabove\u003c/a\u003e.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eStarting with JAX version 0.4.24 you can also specify explicit\nsymbolic constraints:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: x[:x.shape[1], :16],\n polymorphic_shapes=\u0026quot;(a, b)\u0026quot;,\n polymorphic_constraints=(\u0026quot;a \u0026gt;= b\u0026quot;, \u0026quot;b \u0026gt;= 16\u0026quot;))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e[:\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e], :\u003cspan class=\"pl-c1\"\u003e16\u003c/span\u003e],\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s\"\u003e\"(a, b)\"\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003epolymorphic_constraints\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"a \u0026gt;= b\"\u003c/span\u003e, \u003cspan class=\"pl-s\"\u003e\"b \u0026gt;= 16\"\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe constraints form a conjunction together with the implicit\nconstraints. You can specify \u003ccode\u003e\u0026gt;=\u003c/code\u003e, \u003ccode\u003e\u0026lt;=\u003c/code\u003e, and \u003ccode\u003e==\u003c/code\u003e constraints.\nAt the moment, JAX has limited support for reasoning with\nsymbolic constraints:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eYou get most from constraints of the form\nof a variable being greater-or-equal or\nless-or-equal to a constant.\nFor example, from the constraints that\n\u003ccode\u003ea \u0026gt;= 16\u003c/code\u003e and \u003ccode\u003eb \u0026gt;= 8\u003c/code\u003e we can infer\nthat \u003ccode\u003ea + 2*b \u0026gt;= 32\u003c/code\u003e.\u003c/li\u003e\n\u003cli\u003eYou get limited power when the constraint involves\nmore complex expressions, e.g., from \u003ccode\u003ea \u0026gt;= b + 8\u003c/code\u003e we\ncan infer that \u003ccode\u003ea - b \u0026gt;= 8\u003c/code\u003e but not that \u003ccode\u003ea \u0026gt;= 9\u003c/code\u003e.\nWe plan to improve somewhat this area in the future.\u003c/li\u003e\n\u003cli\u003eEquality constraints are treated as normalization rules.\nE.g., \u003ccode\u003efloordiv(a, b) = c\u003c/code\u003e works by replacing all\noccurrences of the left-hand-side with the right-hand-side.\nYou can only have equality constraints where the left-hand-side\nis a multiplication of factors, e.g, \u003ccode\u003ea * b\u003c/code\u003e, or \u003ccode\u003e4 * a\u003c/code\u003e, or\n\u003ccode\u003efloordiv(a, b)\u003c/code\u003e. Thus, the left-hand-side cannot contain\naddition or subtraction at the top-level.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eThe symbolic constraints can also help to work around the\nlimitations in the JAX reasoning mechanisms. For example, the following\ncode would not be able to prove that the slice size fits\ninto the axis size (such examples come up when using\nstriding):\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: x[: 4*(x.shape[0] // 4)],\n polymorphic_shapes=(\u0026quot;b, ...\u0026quot;,))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e[: \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e//\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e)],\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"b, ...\"\u003c/span\u003e,))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eYou will likely see an error that the comparison\n\u003ccode\u003eb \u0026gt;= 4*floordiv(b, 4)\u003c/code\u003e is inconclusive, even though\nthe inequality always holds when \u003ccode\u003eb \u0026gt;= 1\u003c/code\u003e. One option\nhere would be to restrict the code to work only on\naxis sizes that are multiple of \u003ccode\u003e4\u003c/code\u003e (by replacing\n\u003ccode\u003eb\u003c/code\u003e with \u003ccode\u003e4*b\u003c/code\u003e in the shape specification);\nanother option is to add a symbolic constraint\nwith the exact inconclusive inequality:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: x[: 4*(x.shape[0] // 4)],\n polymorphic_shapes=(\u0026quot;b, ...\u0026quot;,),\n polymorphic_constraints=(\u0026quot;b \u0026gt;= 4*floordiv(b, 4)\u0026quot;,))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e[: \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e//\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e)],\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"b, ...\"\u003c/span\u003e,),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_constraints\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"b \u0026gt;= 4*floordiv(b, 4)\"\u003c/span\u003e,))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eAn example where an equality constraint would be useful\nis in the following code:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x, y: x + y[:y.shape[0] // 2],\n polymorphic_shapes=(\u0026quot;a\u0026quot;, \u0026quot;b\u0026quot;))(x, y)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e[:\u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e//\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e],\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"a\"\u003c/span\u003e, \u003cspan class=\"pl-s\"\u003e\"b\"\u003c/span\u003e))(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe above code would raise a \u003ccode\u003eTypeError\u003c/code\u003e because JAX cannot verify that\n\u003ccode\u003ex\u003c/code\u003e and \u003ccode\u003ey[:x.shape[0]]\u003c/code\u003e have the same shape:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"TypeError: add got incompatible shapes for broadcasting: (a,), (floordiv(b, 2),)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-v\"\u003eTypeError\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003eadd\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003egot\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eincompatible\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eshapes\u003c/span\u003e \u003cspan class=\"pl-k\"\u003efor\u003c/span\u003e \u003cspan class=\"pl-smi\"\u003ebroadcasting\u003c/span\u003e: (\u003cspan class=\"pl-s1\"\u003ea\u003c/span\u003e,), (\u003cspan class=\"pl-en\"\u003efloordiv\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eb\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e),)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eYou can fix this by adding a constraint:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x, y: x + y[:y.shape[0] // 2],\n polymorphic_shapes=(\u0026quot;a\u0026quot;, \u0026quot;b\u0026quot;),\n polymorphic_constraints=(\u0026quot;floordiv(b, 2) == a\u0026quot;,))(x, y)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e[:\u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e//\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e],\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"a\"\u003c/span\u003e, \u003cspan class=\"pl-s\"\u003e\"b\"\u003c/span\u003e),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_constraints\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"floordiv(b, 2) == a\"\u003c/span\u003e,))(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eJust like the implicit constraints, the explicit\nsymbolic constraints are checked at compile time,\nusing the same mechanism as explained \u003ca href=\"#shape-assertion-errors\"\u003eabove\u003c/a\u003e.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe symbolic constraints are stored in αn\n\u003ccode\u003eexport.SymbolicScope\u003c/code\u003e object, which is created implicitly\nfor each call to \u003ccode\u003ejax2tf.convert\u003c/code\u003e. You must be careful\nto not mix symbolic expressions that use different scopes.\nFor example,\nthe following code will fail because \u003ccode\u003ea1\u003c/code\u003e and \u003ccode\u003ea2\u003c/code\u003e\nuse different scopes (created by \u003ccode\u003eexport.symbolic_shape\u003c/code\u003e):\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"a1, = export.symbolic_shape(\u0026quot;a,\u0026quot;)\na2, = export.symbolic_shape(\u0026quot;a,\u0026quot;, constraints=(\u0026quot;a \u0026gt;= 8\u0026quot;,))\n\na1 + a2\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ea1\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eexport\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esymbolic_shape\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"a,\"\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003ea2\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eexport\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esymbolic_shape\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"a,\"\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003econstraints\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"a \u0026gt;= 8\"\u003c/span\u003e,))\n\n\u003cspan class=\"pl-s1\"\u003ea1\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ea2\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe symbolic expressions that originate from a single call\nto \u003ccode\u003eexport.symbolic_shape\u003c/code\u003e share a scope and\ncan be mixed up in arithmetic operations. The result would\nalso share the same scope.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eYou can re-use scopes:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"a, = export.symbolic_shape(\u0026quot;a,\u0026quot;, constraints=(\u0026quot;a \u0026gt;= 8\u0026quot;,))\nb, = export.symbolic_shape(\u0026quot;b,\u0026quot;, scope=a1.scope)\n\na + b # Allowed\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ea\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eexport\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esymbolic_shape\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"a,\"\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003econstraints\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"a \u0026gt;= 8\"\u003c/span\u003e,))\n\u003cspan class=\"pl-s1\"\u003eb\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eexport\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esymbolic_shape\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"b,\"\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003escope\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003ea1\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003escope\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003ea\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eb\u003c/span\u003e \u003cspan class=\"pl-c\"\u003e# Allowed\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eJAX tracing uses caches keyed partially by shapes, and\nsymbolic shapes that are printed identically will be considered\ndistinct if they use different scopes.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch4 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eCaveat for equality comparisons\u003c/h4\u003e\u003ca id=\"user-content-caveat-for-equality-comparisons\" class=\"anchor\" aria-label=\"Permalink: Caveat for equality comparisons\" href=\"#caveat-for-equality-comparisons\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe equality comparison returns \u003ccode\u003eFalse\u003c/code\u003e for \u003ccode\u003eb + 1 == b\u003c/code\u003e or \u003ccode\u003eb == 0\u003c/code\u003e\n(in which case it is certain that the dimensions are different for all valuations),\nbut also for \u003ccode\u003eb == 1\u003c/code\u003e and for \u003ccode\u003ea == b\u003c/code\u003e. This is unsound, and we\nought to raise \u003ccode\u003ecore.InconclusiveDimensionOperation\u003c/code\u003e because under\nsome valuations the result should be \u003ccode\u003eTrue\u003c/code\u003e and under other\nvaluations it should be \u003ccode\u003eFalse\u003c/code\u003e. We choose to make equality total\nthus allowing unsoundness because otherwise we may get spurious errors\nin presence of hash collisions\nwhen hashing dimension expressions or objects that include\nthem (shapes, \u003ccode\u003ecore.AbstractValue\u003c/code\u003e, \u003ccode\u003ecore.Jaxpr\u003c/code\u003e).\nBesides the hashing errors, a partial semantics of equality\nleads to errors for the following expressions \u003ccode\u003eb == a or b == b\u003c/code\u003e or \u003ccode\u003eb in [a, b]\u003c/code\u003e\neven though the error is avoided if we change the order of the comparisons.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eWe attempted to retain soundness and hashability by creating both hashable and unhashable\nkinds of symbolic dimensions \u003ca href=\"https://github.com/jax-ml/jax/pull/14200\" data-hovercard-type=\"pull_request\" data-hovercard-url=\"/jax-ml/jax/pull/14200/hovercard\"\u003ePR #14200\u003c/a\u003e,\nbut it turned out to be very hard to diagnose hashing failures in user programs because\noften hashing is implicit when using sets or memo tables.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eCode of the form \u003ccode\u003eif x.shape[0] != 1: raise NiceErrorMessage\u003c/code\u003e is sound even\nwith this treatment of equality, but code of the form \u003ccode\u003eif x.shape[0] != 1: return 1\u003c/code\u003e\nis unsound.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eDivision of symbolic dimensions is partially supported\u003c/h3\u003e\u003ca id=\"user-content-division-of-symbolic-dimensions-is-partially-supported\" class=\"anchor\" aria-label=\"Permalink: Division of symbolic dimensions is partially supported\" href=\"#division-of-symbolic-dimensions-is-partially-supported\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eJAX will attempt to simplify division and modulo operations,\ne.g., \u003ccode\u003e(a * b + a) // (b + 1) == a\u003c/code\u003e and \u003ccode\u003e(6 * a + 4) % 3 == 1\u003c/code\u003e.\nIn particular, JAX will handle the cases when either (a) there\nis no remainder, or (b) the divisor is a constant\nin which case there may be a constant remainder.\nFor example, the code below results in a division error when trying to\ncompute the inferred dimension for a \u003ccode\u003ereshape\u003c/code\u003e operation:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),\n polymorphic_shapes=[\u0026quot;(b, ...)\u0026quot;])(np.ones((4, 5, 7)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ereshape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, (\u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e-\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e)),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(b, ...)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e5\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e7\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIn this case you will see the error \u003ccode\u003eCannot divide evenly the sizes of shapes (b, 5, 7) and (2, -1)\u003c/code\u003e,\nwith a further \u003ccode\u003eDetails: Cannot divide '35*b' by '-2'\u003c/code\u003e.\nThe polynomial \u003ccode\u003e35*b\u003c/code\u003e represents the total size of the input tensor.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eNote that the following will succeed:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"## The resulting symbolic shape is (2, 15 b).\njax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),\n polymorphic_shapes=[\u0026quot;(b, ...)\u0026quot;])(np.ones((4, 5, 6)))\n\n## The resulting symbolic shape is (6 b2, b1).\njax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])),\n polymorphic_shapes=[\u0026quot;(b1, b2, ...)\u0026quot;])(np.ones((4, 5, 6)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-c\"\u003e## The resulting symbolic shape is (2, 15 b).\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ereshape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, (\u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e-\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e)),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(b, ...)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e5\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e6\u003c/span\u003e)))\n\n\u003cspan class=\"pl-c\"\u003e## The resulting symbolic shape is (6 b2, b1).\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ereshape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, (\u003cspan class=\"pl-c1\"\u003e-\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e])),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(b1, b2, ...)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e5\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e6\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eYou may also encounter division errors when working with strides, such as\nwhen computing the padding in a strided convolution.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eWhen JAX cannot simplify the result of symbolic dimension division it\nwill construct symbolic expressions of the form \u003ccode\u003efloordiv(E, N)\u003c/code\u003e and\n\u003ccode\u003emod(E, N)\u003c/code\u003e and it will use a number of heuristics to evaluate comparisons\ninvolving these. If you encounter \u003ccode\u003eInconclusiveDimensionOperation\u003c/code\u003e exceptions\nyou can specify that a dimension variable\nis a multiple of the divisor,\ne.g., \u003ccode\u003eb\u003c/code\u003e in the above example of dividing \u003ccode\u003e35*b\u003c/code\u003e by \u003ccode\u003e-2\u003c/code\u003e may\nbe known to be a multiple of \u003ccode\u003e2\u003c/code\u003e. You can specify that by replacing\n\u003ccode\u003eb\u003c/code\u003e with \u003ccode\u003e2*b\u003c/code\u003e in the polymorphic shape specification:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),\n polymorphic_shapes=[\u0026quot;(2*b, ...)\u0026quot;])(np.ones((4, 5, 7)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ereshape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, (\u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e-\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e)),\n \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"(2*b, ...)\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e5\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e7\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eNative serialization versions\u003c/h2\u003e\u003ca id=\"user-content-native-serialization-versions\" class=\"anchor\" aria-label=\"Permalink: Native serialization versions\" href=\"#native-serialization-versions\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eWe use a serialization version number to help evolve the serialization\nmechanism while allowing serialized artifacts to be used by consumers built\nat different code versions.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIf consumers use the \u003ccode\u003etf.XlaCallModule\u003c/code\u003e op, e.g. when using the TensorFlow\nSavedModel, then they support a range of serialization versions.\nSee \u003ca href=\"https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22\u0026amp;type=code\"\u003etf.XlaCallModule code\u003c/a\u003e.\nThere is also an API to get the maximum version number supported by your\ninstalled version of TensorFlow:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"from tensorflow.compiler.tf2xla.python import xla as tfxla\ntfxla.call_module_maximum_supported_version()\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003efrom tensorflow.compiler.tf2xla.python import xla as tfxla\ntfxla.call_module_maximum_supported_version()\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eFor \u003cstrong\u003ebackward compatibility\u003c/strong\u003e, we want to allow a freshly built consumer\nto load artifacts that have been serialized in the past 6 months\n(by a serializer using the latest version supported at the time). Thus,\nthe minimum supported version number should match the maximum supported\nversion number from 6 months in the past.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe serialization version used by JAX is determined by the\n\u003ccode\u003e--jax_serialization_version\u003c/code\u003e flag, or if missing, the\n\u003ccode\u003eJAX_SERIALIZATION_VERSION\u003c/code\u003e environment variable. The default value is\nspecified in the \u003ca href=\"https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Aconfig.py+JAX_SERIALIZATION_VERSION\u0026amp;type=code\"\u003e\u003ccode\u003econfig.py\u003c/code\u003e file\u003c/a\u003e.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor \u003cstrong\u003eforward compatibility\u003c/strong\u003e, we want freshly serialized artifacts to be\nloadable by consumers that have been built in the last 1 month.\nThus, we bump the default serialization version\nnumber about 1 month after the \u003ccode\u003etf.XlaCallModule\u003c/code\u003e is upgraded to a\ngiven version number.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eYou can use \u003ccode\u003e--jax_serialization_version\u003c/code\u003e to adjust the serialization version\nto your deployed consumer. We reserve the right to remove support for\ngenerating or consuming old serialization versions older than 6 months.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eSerialization version numbers\u003c/h2\u003e\u003ca id=\"user-content-serialization-version-numbers\" class=\"anchor\" aria-label=\"Permalink: Serialization version numbers\" href=\"#serialization-version-numbers\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eWe list here a history of the serialization version numbers:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eVersion 1 used MHLO \u0026amp; CHLO to serialize the code, not supported anymore.\u003c/li\u003e\n\u003cli\u003eVersion 2 supports StableHLO \u0026amp; CHLO. Used from October 2022. Not supported\nanymore.\u003c/li\u003e\n\u003cli\u003eVersion 3 supports platform checking and multiple platforms.\nUsed from February 2023. Not supported anymore.\u003c/li\u003e\n\u003cli\u003eVersion 4 supports StableHLO with compatibility guarantees.\nThis is the earliest version at the time of the JAX native serialization\nlaunch.\nUsed in JAX from March 15, 2023 (cl/516885716). Starting with\nMarch 28th, 2023 we stopped using \u003ccode\u003edim_args_spec\u003c/code\u003e (cl/520033493).\nThe support for this version was dropped on\nOctober 17th, 2023 (cl/573858283).\u003c/li\u003e\n\u003cli\u003eVersion 5 adds support for \u003ccode\u003ecall_tf_graph\u003c/code\u003e. This is currently used\nfor some specialized use cases. Used in JAX from May 3rd, 2023\n(cl/529106145).\u003c/li\u003e\n\u003cli\u003eVersion 6 adds support for the \u003ccode\u003edisabled_checks\u003c/code\u003e attribute. This version\nmandates a non-empty \u003ccode\u003eplatforms\u003c/code\u003e attribute. Supported by XlaCallModule\nsince June 7th, 2023 and available in JAX since\nJune 13th, 2023 (JAX 0.4.13).\u003c/li\u003e\n\u003cli\u003eVersion 7 adds support for \u003ccode\u003establehlo.shape_assertion\u003c/code\u003e operations and\nfor \u003ccode\u003eshape_assertions\u003c/code\u003e specified in \u003ccode\u003edisabled_checks\u003c/code\u003e.\nSee \u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism\"\u003eErrors in presence of shape polymorphism\u003c/a\u003e. Supported by XlaCallModule\nsince July 12th, 2023 (cl/547482522),\navailable in JAX serialization since July 20th, 2023 (JAX 0.4.14),\nand the default since August 12th, 2023 (JAX 0.4.15).\u003c/li\u003e\n\u003cli\u003eVersion 8 adds support for the \u003ccode\u003ejax.uses_shape_polymorphism\u003c/code\u003e module\nattribute and enables the shape refinement pass only when the\nattribute is present. Supported by XlaCallModule since July 21st, 2023\n(cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14),\nand the default since October 21st, 2023 (JAX 0.4.20).\u003c/li\u003e\n\u003cli\u003eVersion 9 adds support for effects.\nSee the docstring for \u003ccode\u003eexport.Exported\u003c/code\u003e for the precise calling convention.\nIn this serialization version we also tag the platform index and the\ndimension variables arguments with \u003ccode\u003ejax.global_constant\u003c/code\u003e attributes.\nSupported by XlaCallModule since October 27th, 2023,\navailable in JAX since October 20th, 2023 (JAX 0.4.20),\nand the default since February 1st, 2024 (JAX 0.4.24).\nThis is the only supported version as of 27th of March, 2024.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eKnown issues\u003c/h2\u003e\u003ca id=\"user-content-known-issues\" class=\"anchor\" aria-label=\"Permalink: Known issues\" href=\"#known-issues\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003e\u003ccode\u003ejax2tf\u003c/code\u003e has been in use since 2020 and the vast majority of users encounter\nno problems. However, there are a few rare corner cases\nin which the different conventions of JAX and TensorFlow result in a breakage.\nWe try to give an exhaustive list below, specifying whether the limitations\napply to the native serialization or non-native.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eDifferent 64-bit precision in JAX and TensorFlow\u003c/h3\u003e\u003ca id=\"user-content-different-64-bit-precision-in-jax-and-tensorflow\" class=\"anchor\" aria-label=\"Permalink: Different 64-bit precision in JAX and TensorFlow\" href=\"#different-64-bit-precision-in-jax-and-tensorflow\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to both native and non-native serialization.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eJAX behaves somewhat differently than TensorFlow in the handling\nof 32-bit vs. 64-bit values. However, the \u003ccode\u003ejax2tf\u003c/code\u003e lowered function\nalways behaves like the JAX function.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eJAX interprets the type of Python scalars differently based on\n\u003ccode\u003eJAX_ENABLE_X64\u003c/code\u003e flag. (See\n\u003ca href=\"https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision\" rel=\"nofollow\"\u003eJAX - The Sharp Bits: Double (64bit) precision\u003c/a\u003e.)\nIn the default configuration, the\nflag is unset, and JAX interprets Python constants as 32-bit,\ne.g., the type of \u003ccode\u003e3.14\u003c/code\u003e is \u003ccode\u003efloat32\u003c/code\u003e. This is also what\nTensorFlow always does. JAX goes further, it forces\nall explicitly-specified 64-bit values to be interpreted as\n32-bit:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"# with JAX_ENABLE_X64=0\njnp.sin(3.14) # Has type float32\ntf.math.sin(3.14) # Has type float32\n\njnp.sin(np.float64(3.14)) # Also has type float32\ntf.math.sin(np.float64(3.14)) # Has type float64\n\n# The jax2tf.convert function behaves like the JAX function.\njax2tf.convert(jnp.sin)(3.14) # Has type float32\njax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32\n\n# The following will still compute `sin` in float32 (with a tf.cast on the argument).\ntf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-c\"\u003e# with JAX_ENABLE_X64=0\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e) \u003cspan class=\"pl-c\"\u003e# Has type float32\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003emath\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e) \u003cspan class=\"pl-c\"\u003e# Has type float32\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat64\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e)) \u003cspan class=\"pl-c\"\u003e# Also has type float32\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003emath\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat64\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e)) \u003cspan class=\"pl-c\"\u003e# Has type float64\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# The jax2tf.convert function behaves like the JAX function.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e)(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e) \u003cspan class=\"pl-c\"\u003e# Has type float32\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat64\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e)) \u003cspan class=\"pl-c\"\u003e# Has type float32\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# The following will still compute `sin` in float32 (with a tf.cast on the argument).\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat64\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eWhen the \u003ccode\u003eJAX_ENABLE_X64\u003c/code\u003e flag is set, JAX uses 64-bit types\nfor Python scalars and respects the explicit 64-bit types:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"# with JAX_ENABLE_X64=1\njnp.sin(3.14) # Has type float64\ntf.math.sin(3.14) # Has type float32\n\n# The jax2tf.convert function behaves like the JAX function.\njax2tf.convert(jnp.sin)(3.14) # Has type float64\n\n# The following will compute `sin` in float64.\ntf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64))\n\n# The following will compute `sin` in float32.\ntf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-c\"\u003e# with JAX_ENABLE_X64=1\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e) \u003cspan class=\"pl-c\"\u003e# Has type float64\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003emath\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e) \u003cspan class=\"pl-c\"\u003e# Has type float32\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# The jax2tf.convert function behaves like the JAX function.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e)(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e) \u003cspan class=\"pl-c\"\u003e# Has type float64\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# The following will compute `sin` in float64.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat64\u003c/span\u003e))\n\n\u003cspan class=\"pl-c\"\u003e# The following will compute `sin` in float32.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThis is achieved by inserting \u003ccode\u003etf.cast\u003c/code\u003e operations\non the input arguments inside the lowered function,\nif necessary.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIf you want to create a \u003ccode\u003etf.Variable\u003c/code\u003e or \u003ccode\u003etf.TensorSpec\u003c/code\u003e with the\nsame dtype, you should use \u003ccode\u003ejax2tf.dtype_of_val\u003c/code\u003e:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"# The following two calls will lower jax_fun at the same dtypes\n# independently of the value of JAX_ENABLE_X64.\njax2tf.convert(jax_fun)(3.14)\njax2tf.convert(jax_fun)(tf.Variable(3.14, dtype=jax2tf.dtype_of_val(3.14)))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-c\"\u003e# The following two calls will lower jax_fun at the same dtypes\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# independently of the value of JAX_ENABLE_X64.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax_fun\u003c/span\u003e)(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax_fun\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003edtype_of_val\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e3.14\u003c/span\u003e)))\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eFunctions whose arguments and results are nested Python data structures\u003c/h3\u003e\u003ca id=\"user-content-functions-whose-arguments-and-results-are-nested-python-data-structures\" class=\"anchor\" aria-label=\"Permalink: Functions whose arguments and results are nested Python data structures\" href=\"#functions-whose-arguments-and-results-are-nested-python-data-structures\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to both native and non-native serialization.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003e\u003ccode\u003ejax2tf\u003c/code\u003e can lower functions with arguments and results that are nested\ncollections (tuples, lists, dictionaries) of numeric values or JAX arrays\n(\u003ca href=\"https://jax.readthedocs.io/en/latest/pytrees.html\" rel=\"nofollow\"\u003epytrees\u003c/a\u003e). The\nresulting TensorFlow function will take the same kind of arguments except the\nleaves can be numeric values or TensorFlow tensors (\u003ccode\u003etf.Tensor\u003c/code\u003e, \u003ccode\u003etf.TensorSpec\u003c/code\u003e, \u003ccode\u003etf.Variable\u003c/code\u003e).\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eAs long as the arguments use only standard Python containers (tuple, list, dictionaries),\nboth JAX and TensorFlow can flatten and unflatten them and you can use the lowered\nfunction in TensorFlow without limitations.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eHowever, if your JAX function takes a custom container, you can register it with\nthe JAX \u003ccode\u003etree_util\u003c/code\u003e module so that JAX will know how to operate with it, and you\ncan still lower the function to use it in TensorFlow\neager and with \u003ccode\u003etf.function\u003c/code\u003e, but you won't be able to save it to a SavedModel, nor\nwill you be able to compute gradients with TensorFlow\n(code from \u003ccode\u003ejax2tf_test.test_custom_pytree_readme\u003c/code\u003e):\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"class CustomPair:\n def __init__(self, a, b):\n self.a = a\n self.b = b\n\n# Register it with the JAX tree_util module\njax.tree_util.register_pytree_node(CustomPair,\n lambda x: ((x.a, x.b), None),\n lambda _, ab: CustomPair(*ab))\ndef f_jax(pair: CustomPair):\n return 2. * pair.a + 3. * pair.b\n\nx = CustomPair(4., 5.)\nres_jax = f_jax(x)\n# TF execution works as long as JAX can flatten the arguments\nres_tf = jax2tf.convert(f_jax)(x)\nself.assertAllClose(res_jax, res_tf.numpy())\nres_tf_2 = tf.function(jax2tf.convert(f_jax), autograph=False, jit_compile=True)(x)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003eclass\u003c/span\u003e \u003cspan class=\"pl-v\"\u003eCustomPair\u003c/span\u003e:\n \u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003e__init__\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ea\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eb\u003c/span\u003e):\n \u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ea\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ea\u003c/span\u003e\n \u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eb\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eb\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# Register it with the JAX tree_util module\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003etree_util\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eregister_pytree_node\u003c/span\u003e(\u003cspan class=\"pl-v\"\u003eCustomPair\u003c/span\u003e,\n \u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: ((\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ea\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eb\u003c/span\u003e), \u003cspan class=\"pl-c1\"\u003eNone\u003c/span\u003e),\n \u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003e_\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eab\u003c/span\u003e: \u003cspan class=\"pl-en\"\u003eCustomPair\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003eab\u003c/span\u003e))\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003epair\u003c/span\u003e: \u003cspan class=\"pl-smi\"\u003eCustomPair\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e2.\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003epair\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ea\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e3.\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003epair\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eb\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-en\"\u003eCustomPair\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e4.\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e5.\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003eres_jax\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# TF execution works as long as JAX can flatten the arguments\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003eres_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertAllClose\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eres_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eres_tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003enumpy\u003c/span\u003e())\n\u003cspan class=\"pl-s1\"\u003eres_tf_2\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ejit_compile\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIf you want to save the function in a SavedModel or compute gradients,\nyou should construct a wrapper:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\" # wrapped TF function to use only standard containers\ndef f_tf_wrapped(a, b):\n return f_tf(CustomPair(a, b))\n\n# Try to put into SavedModel\nmy_model = tf.Module()\n# Save a function that can take scalar inputs.\nmy_model.f = tf.function(f_tf_wrapped, autograph=False,\n input_signature=[tf.TensorSpec([], tf.float32),\n tf.TensorSpec([], tf.float32)])\nmodel_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(my_model)))\ntf.saved_model.save(my_model, model_dir,\n options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))\n\n# Restoring (note: the restored model does *not* require JAX to run, just XLA).\nrestored_model = tf.saved_model.load(model_dir)\ndef restored_f(pair: CustomPair):\n return restored_model.f(pair.a, pair.b)\n\nres_tf_3 = restored_f(x)\nself.assertAllClose(res_jax, res_tf_3)\ngrad_jax = jax.grad(f_jax)(x)\n\nx_v = [tf.Variable(x.a), tf.Variable(x.b)]\nwith tf.GradientTape() as tape:\n res = f_tf_wrapped(*x_v)\n grad_tf = tape.gradient(res, x_v)\n\nself.assertAllClose(grad_jax.a, grad_tf[0])\nself.assertAllClose(grad_jax.b, grad_tf[1])\"\u003e\u003cpre\u003e \u003cspan class=\"pl-c\"\u003e# wrapped TF function to use only standard containers\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf_wrapped\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ea\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eb\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf\u003c/span\u003e(\u003cspan class=\"pl-en\"\u003eCustomPair\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ea\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eb\u003c/span\u003e))\n\n\u003cspan class=\"pl-c\"\u003e# Try to put into SavedModel\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eModule\u003c/span\u003e()\n\u003cspan class=\"pl-c\"\u003e# Save a function that can take scalar inputs.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ef\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_tf_wrapped\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003einput_signature\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eTensorSpec\u003c/span\u003e([], \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e),\n \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eTensorSpec\u003c/span\u003e([], \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e)])\n\u003cspan class=\"pl-s1\"\u003emodel_dir\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003eos\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003epath\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ejoin\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eabsltest\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eget_default_test_tmpdir\u003c/span\u003e(), \u003cspan class=\"pl-en\"\u003estr\u003c/span\u003e(\u003cspan class=\"pl-en\"\u003eid\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e)))\n\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esave\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003emy_model\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003emodel_dir\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003eoptions\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eSaveOptions\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eexperimental_custom_gradients\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e))\n\n\u003cspan class=\"pl-c\"\u003e# Restoring (note: the restored model does *not* require JAX to run, just XLA).\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003erestored_model\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esaved_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eload\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003emodel_dir\u003c/span\u003e)\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003erestored_f\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003epair\u003c/span\u003e: \u003cspan class=\"pl-smi\"\u003eCustomPair\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003erestored_model\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ef\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003epair\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ea\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epair\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eb\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003eres_tf_3\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-en\"\u003erestored_f\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertAllClose\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eres_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eres_tf_3\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003egrad_jax\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egrad\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003ex_v\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e [\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ea\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eb\u003c/span\u003e)]\n\u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eGradientTape\u003c/span\u003e() \u003cspan class=\"pl-k\"\u003eas\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e:\n \u003cspan class=\"pl-s1\"\u003eres\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf_wrapped\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003ex_v\u003c/span\u003e)\n \u003cspan class=\"pl-s1\"\u003egrad_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egradient\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eres\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex_v\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertAllClose\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003egrad_jax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ea\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003egrad_tf\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e])\n\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertAllClose\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003egrad_jax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eb\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003egrad_tf\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e])\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eLowering gradients for functions with integer arguments or unused arguments\u003c/h3\u003e\u003ca id=\"user-content-lowering-gradients-for-functions-with-integer-arguments-or-unused-arguments\" class=\"anchor\" aria-label=\"Permalink: Lowering gradients for functions with integer arguments or unused arguments\" href=\"#lowering-gradients-for-functions-with-integer-arguments-or-unused-arguments\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to both native and non-native serialization.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eWhen JAX differentiates functions with integer or boolean arguments, the gradients will\nbe zero-vectors with a special \u003ccode\u003efloat0\u003c/code\u003e type (see PR 4039](\u003ca class=\"issue-link js-issue-link\" data-error-text=\"Failed to load title\" data-id=\"677939591\" data-permission-text=\"Title is private\" data-url=\"https://github.com/jax-ml/jax/issues/4039\" data-hovercard-type=\"pull_request\" data-hovercard-url=\"/jax-ml/jax/pull/4039/hovercard\" href=\"https://github.com/jax-ml/jax/pull/4039\"\u003e#4039\u003c/a\u003e)).\nThis type is translated to \u003ccode\u003eint32\u003c/code\u003e when lowering to TF.\nFor example,\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"x = np.int16(2)\ndef f_jax(x): # x: int16\n return x * 2.\n\njax.grad(f_jax, allow_int=True)(x)\n# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])\n\njax2tf.convert(jax.grad(f_jax, allow_int=True))(x)\n# returns a tf.Tensor(0, shape=(), dtype=int32)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eint16\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e)\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e): \u003cspan class=\"pl-c\"\u003e# x: int16\u003c/span\u003e\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e2.\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egrad\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eallow_int\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egrad\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003eallow_int\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e))(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# returns a tf.Tensor(0, shape=(), dtype=int32)\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eNote that this is different from how TensorFlow handles gradients\nfor integer or boolean arguments: sometimes the gradient is \u003ccode\u003eNone\u003c/code\u003e,\nsometimes it is a zero with the same dtype as the argument, and\nsometimes it is a one with the same dtype as the argument (e.g.,\nfor the identity function).\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def f_tf(x): # x: int16\n return tf.cast(x, tf.float32) * 2.\n\nxv = tf.Variable(x)\nwith tf.GradientTape(persistent=True) as tape:\n print(tape.gradient(f_tf(xv), xv))\n # returns None\n print(tape.gradient(f_tf(xv), xv,\n unconnected_gradients=tf.UnconnectedGradients.ZERO))\n # returns 0 with the same shape and dtype as x\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e): \u003cspan class=\"pl-c\"\u003e# x: int16\u003c/span\u003e\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecast\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e) \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e2.\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003exv\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eGradientTape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003epersistent\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e) \u003cspan class=\"pl-k\"\u003eas\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e:\n \u003cspan class=\"pl-en\"\u003eprint\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egradient\u003c/span\u003e(\u003cspan class=\"pl-en\"\u003ef_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003exv\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003exv\u003c/span\u003e))\n \u003cspan class=\"pl-c\"\u003e# returns None\u003c/span\u003e\n \u003cspan class=\"pl-en\"\u003eprint\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egradient\u003c/span\u003e(\u003cspan class=\"pl-en\"\u003ef_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003exv\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003exv\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003eunconnected_gradients\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eUnconnectedGradients\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eZERO\u003c/span\u003e))\n \u003cspan class=\"pl-c\"\u003e# returns 0 with the same shape and dtype as x\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eWhen differentiating functions with unused arguments, TF by default\nreturns the value \u003ccode\u003eNone\u003c/code\u003e for the corresponding gradients. The\n\u003ccode\u003etape.gradient\u003c/code\u003e function takes the option \u003ccode\u003etf.UnconnectedGradients.ZERO\u003c/code\u003e\nto ask that gradients for unused arguments be zero.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFunctions lowered with \u003ccode\u003ejax2tf.convert\u003c/code\u003e behave the same way under\n\u003ccode\u003etf.UnconnectedGradients.ZERO\u003c/code\u003e, but by default, they will return\n\u003ccode\u003eNone\u003c/code\u003e only for gradients corresponding to integer arguments.\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"# x1 and x3 are not used. x3 has integer type.\ndef fn(x0, x1, x2, x3):\n return x0 * 0. + x2 * 2.\n\nxs = [tf.Variable(x) for x in [10., 11., 12., 13]]\nwith tf.GradientTape(persistent=True) as tape:\n res = fn(*xs)\n\ng_tf_native = tape.gradient(res, xs)\n# Returns: 0., None, 2., None\n\ng_tf_native_0 = tape.gradient(res, xs,\n unconnected_gradients=tf.UnconnectedGradients.ZERO)\n# Returns: 0., 0., 2., 0\n\n# Now with jax2tf.convert\nwith tf.GradientTape() as tape:\n res = jax2tf.convert(fn, with_gradient=True)(*xs)\n\ng_jax2tf = tape.gradient(res, xs)\n# Returns: 0., 0., 2., None\n# Note that the gradient for x1 is 0.\n\ng_jax2tf_0 = tape.gradient(res, xs,\n unconnected_gradients=tf.UnconnectedGradients.ZERO)\n# Returns: 0., 0., 2., 0\n# In this case we get the same result as for TF native.\"\u003e\u003cpre\u003e\u003cspan class=\"pl-c\"\u003e# x1 and x3 are not used. x3 has integer type.\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003efn\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex0\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex1\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex2\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex3\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex0\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e0.\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex2\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e2.\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003exs\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e [\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e) \u003cspan class=\"pl-k\"\u003efor\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003ein\u003c/span\u003e [\u003cspan class=\"pl-c1\"\u003e10.\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e11.\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e12.\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e13\u003c/span\u003e]]\n\u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eGradientTape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003epersistent\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e) \u003cspan class=\"pl-k\"\u003eas\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e:\n \u003cspan class=\"pl-s1\"\u003eres\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-en\"\u003efn\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003exs\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003eg_tf_native\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egradient\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eres\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003exs\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# Returns: 0., None, 2., None\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003eg_tf_native_0\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egradient\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eres\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003exs\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003eunconnected_gradients\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eUnconnectedGradients\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eZERO\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# Returns: 0., 0., 2., 0\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# Now with jax2tf.convert\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eGradientTape\u003c/span\u003e() \u003cspan class=\"pl-k\"\u003eas\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e:\n \u003cspan class=\"pl-s1\"\u003eres\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003efn\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ewith_gradient\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e)(\u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003exs\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003eg_jax2tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egradient\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eres\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003exs\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# Returns: 0., 0., 2., None\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# Note that the gradient for x1 is 0.\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003eg_jax2tf_0\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etape\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egradient\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eres\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003exs\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003eunconnected_gradients\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eUnconnectedGradients\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eZERO\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# Returns: 0., 0., 2., 0\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# In this case we get the same result as for TF native.\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eErrors due to tf.Module magic conversion during attribute assignment\u003c/h3\u003e\u003ca id=\"user-content-errors-due-to-tfmodule-magic-conversion-during-attribute-assignment\" class=\"anchor\" aria-label=\"Permalink: Errors due to tf.Module magic conversion during attribute assignment\" href=\"#errors-due-to-tfmodule-magic-conversion-during-attribute-assignment\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to both native and non-native serialization.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003e\u003ccode\u003etf.Module\u003c/code\u003e will automatically wrap the standard Python container data types into\ntrackable classes during attribute assignment.\nPython Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper\nclasses.\nIn most situations, these Wrapper classes work exactly as the standard\nPython data types. However, the low-level pytree data structures are different\nand this can lead to errors.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIn such cases, the user can use this workaround:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"import tensorflow as tf\ninput_data = #Any data object\n\nm = tf.Module()\nflat, tree_def = jax.tree_util.tree_flatten(input_data)\nm.input_data = {\u0026quot;flat\u0026quot;: flat, \u0026quot;tree_def\u0026quot;: tree_def}\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003eimport\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etensorflow\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eas\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003einput_data\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-c\"\u003e#Any data object\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003em\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eModule\u003c/span\u003e()\n\u003cspan class=\"pl-s1\"\u003eflat\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003etree_def\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003etree_util\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003etree_flatten\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003einput_data\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003em\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003einput_data\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e {\u003cspan class=\"pl-s\"\u003e\"flat\"\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003eflat\u003c/span\u003e, \u003cspan class=\"pl-s\"\u003e\"tree_def\"\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003etree_def\u003c/span\u003e}\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eLater the user can use \u003ccode\u003etree_unflatten\u003c/code\u003e for the reverse process:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"input_data = jax.tree_util.tree_unflatten(m.input_data['tree_def'], m.input_data['flat'])\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003einput_data\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003etree_util\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003etree_unflatten\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003em\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003einput_data\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e'tree_def'\u003c/span\u003e], \u003cspan class=\"pl-s1\"\u003em\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003einput_data\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e'flat'\u003c/span\u003e])\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eLarge saved_model.pb due too many PRNG operations\u003c/h3\u003e\u003ca id=\"user-content-large-saved_modelpb-due-too-many-prng-operations\" class=\"anchor\" aria-label=\"Permalink: Large saved_model.pb due too many PRNG operations\" href=\"#large-saved_modelpb-due-too-many-prng-operations\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to both native and non-native serialization.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe default \u003ccode\u003ethreefry2x32\u003c/code\u003e PRNG is implemented in JAX with dozens\nof additions and bitwise operations. This means that a single PRNG\noperation in JAX will result in dozens of TF ops after jax2tf.\nIf the number of RPNG operations\nis large, the generated TF graph will be very large.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eTo reduce the TF graph size and the compilation time\none can use the \u003ccode\u003eunsafe_rbg\u003c/code\u003e PRNG implementation by\nsetting \u003ccode\u003ejax.config.update('jax_default_prng_impl', 'unsafe_rbg')\u003c/code\u003e.\nThe \u003ccode\u003eunsafe_rbg\u003c/code\u003e implementation will be lowered to a TF op and several\ncasts and reshapes, thus significantly reducing the number of TF ops\nper PRNG operation. The \"unsafe\" part is that it doesn't guarantee\ndeterminism across JAX/XLA versions, and the quality of random\nstreams it generates from different keys is less well understood.\nNevertheless, this should be fine for most inference/serving cases.\nSee more details in the \u003ca href=\"https://jax.readthedocs.io/en/latest/jax.random.html?highlight=unsafe_rbg#advanced-rng-configuration\" rel=\"nofollow\"\u003eJAX PRNG documentation\u003c/a\u003e.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eSavedModel supports only first-order gradients\u003c/h3\u003e\u003ca id=\"user-content-savedmodel-supports-only-first-order-gradients\" class=\"anchor\" aria-label=\"Permalink: SavedModel supports only first-order gradients\" href=\"#savedmodel-supports-only-first-order-gradients\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to both native and non-native serialization.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe \u003ccode\u003ejax2tf\u003c/code\u003e-lowered function supports higher-order gradients, but when the\nfunction is saved in a SavedModel, only the first-order gradient is saved.\nThis is primarily a limitation of the SavedModel support for custom gradients.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eNative serialization supports only select dialects\u003c/h3\u003e\u003ca id=\"user-content-native-serialization-supports-only-select-dialects\" class=\"anchor\" aria-label=\"Permalink: Native serialization supports only select dialects\" href=\"#native-serialization-supports-only-select-dialects\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eJAX native serialization checks that the code to be serialized contains\noperations only from MLIR dialects that are known to have stability guarantees,\ne.g., StableHLO, and the \"builtin\" dialect. As an exception, it also accepts\noperations from the MHLO dialect, but they are converted to corresponding\nStableHLO operations upon serialization.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eNative serialization supports only select custom calls\u003c/h3\u003e\u003ca id=\"user-content-native-serialization-supports-only-select-custom-calls\" class=\"anchor\" aria-label=\"Permalink: Native serialization supports only select custom calls\" href=\"#native-serialization-supports-only-select-custom-calls\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eJAX natively uses custom calls for lowering of certain primitives.\nThe most common example is for the implementation of PRNG on GPUs,\nwhere we get better performance with a custom call (\u003ccode\u003ecu_threefry32\u003c/code\u003e)\nthan if we use native StableHLO. Another class of examples are for\nFFT and some linear algebra primitives (e.g., QR decomposition).\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eUnlike regular StableHLO ops, the compatibility guarantees for\ncustom calls are the burden of the teams maintaining the C++\ncode that backs the custom call. For this reason, we maintain\na list of allowed custom call targets. If you try to serialize\ncode that invokes other targets you will get an error.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIf you want to disable this safety check for a specific custom call\nwith target \u003ccode\u003emy_target\u003c/code\u003e, you can add\n\u003ccode\u003ejax2tf.DisabledSafetyCheck.custom_call(\"my_target\")\u003c/code\u003e to the \u003ccode\u003edisabled_checks\u003c/code\u003e\nparameter of the \u003ccode\u003ejax2tf\u003c/code\u003e function.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eXlaCallModule not supported by some TensorFlow tools\u003c/h3\u003e\u003ca id=\"user-content-xlacallmodule-not-supported-by-some-tensorflow-tools\" class=\"anchor\" aria-label=\"Permalink: XlaCallModule not supported by some TensorFlow tools\" href=\"#xlacallmodule-not-supported-by-some-tensorflow-tools\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eJAX native serialization uses the \u003ccode\u003eXlaCallModule\u003c/code\u003e TensorFlow op to host\nthe StableHLO program obtained from JAX. This is a relatively\nnew TensorFlow op and may not be supported by some tools. In fact,\ncertain tools that need to do \u003ccode\u003etf.Graph\u003c/code\u003e inspection and transformation\ncannot work when the whole JAX program is a single TensorFlow op.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThis is the case, for example, for the TFLite and TensorFlow.js converters.\nThere is work underway to enable more tools to consume StableHLO.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eNatively serialized JAX modules are platform specific\u003c/h3\u003e\u003ca id=\"user-content-natively-serialized-jax-modules-are-platform-specific\" class=\"anchor\" aria-label=\"Permalink: Natively serialized JAX modules are platform specific\" href=\"#natively-serialized-jax-modules-are-platform-specific\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eWhen you use native serialization, JAX will record the platform for\nwhich the module was serialized, and you will get an error if you\ntry to execute the \u003ccode\u003eXlaCallModule\u003c/code\u003e TensorFlow op on another platform.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eNote that this error will only arise in native serialization; with\nnon-native serialization the lowering to TensorFlow ops is\nplatform independent, although it is only guaranteed to match the\nJAX semantics and performance behavior for TPUs.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe error has the form:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"The current platform CPU is not among the platforms required by the module [CUDA]\"\u003e\u003cpre lang=\"commandline\" class=\"notranslate\"\u003e\u003ccode\u003eThe current platform CPU is not among the platforms required by the module [CUDA]\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003ewhere \u003ccode\u003eCPU\u003c/code\u003e is the TensorFlow platform where the op is being executed\nand \u003ccode\u003eCUDA\u003c/code\u003e is the platform for which the module was serialized by JAX.\nThis probably means that JAX and TensorFlow may see different devices\nas the default device (JAX defaults to GPU and TensorFlow to CPU\nin the example error above).\nYou can check what devices TensorFlow uses:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"logging.info(\u0026quot;All TF devices: %s\u0026quot;, tf.config.list_logical_devices())\ntf_device = (tf.config.list_logical_devices(\u0026quot;TPU\u0026quot;) +\n tf.config.list_logical_devices(\u0026quot;GPU\u0026quot;) +\n tf.config.list_logical_devices())[0]\nassert jax.default_backend().upper() == tf_device.device_type\nwith tf.device(tf_device):\n ...\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003elogging\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003einfo\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"All TF devices: %s\"\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econfig\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003elist_logical_devices\u003c/span\u003e())\n\u003cspan class=\"pl-s1\"\u003etf_device\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e (\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econfig\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003elist_logical_devices\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"TPU\"\u003c/span\u003e) \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e\n \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econfig\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003elist_logical_devices\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"GPU\"\u003c/span\u003e) \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e\n \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econfig\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003elist_logical_devices\u003c/span\u003e())[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e]\n\u003cspan class=\"pl-k\"\u003eassert\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003edefault_backend\u003c/span\u003e().\u003cspan class=\"pl-c1\"\u003eupper\u003c/span\u003e() \u003cspan class=\"pl-c1\"\u003e==\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf_device\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003edevice_type\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003edevice\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf_device\u003c/span\u003e):\n ...\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eUsers should pay attention to another case, which is that they must use\n\u003ccode\u003ejit_compile=True\u003c/code\u003e in order to execute on TPU.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eBecause if \u003ccode\u003ejit_compile=False\u003c/code\u003e, TF \"executes the function without XLA\ncompilation. Set this value to False when directly running a multi-device\nfunction on TPUs (e.g. two TPU cores, one TPU core and its host CPU)\" (see\n\u003ca href=\"https://www.tensorflow.org/api_docs/python/tf/function\" rel=\"nofollow\"\u003eTF doc\u003c/a\u003e)\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eWith \u003ccode\u003ejit_compile=False\u003c/code\u003e the converted TF program will be executed on CPU\ninstead of TPU and this will result in an error message\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"Node: 'XlaCallModule'\nThe current platform CPU is not among the platforms required by the module: [TPU]\n\t [[{{node XlaCallModule}}]]\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003eNode: 'XlaCallModule'\nThe current platform CPU is not among the platforms required by the module: [TPU]\n\t [[{{node XlaCallModule}}]]\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eTo work around this on \u003ccode\u003ejit_compile=False\u003c/code\u003e, you can wrap your function with a\nnew tf.function that explicitly assigns the TPU device, like this:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"f_tf = jax2tf.convert(jnp.sin)\nx = np.float32(.5)\n\n@tf.function(autograph=False, jit_compile=False)\ndef f_tf_wrapped(x):\n with tf.device('/device:TPU:0'):\n return f_tf(x)\n\nwith tf.device('/device:TPU:0'):\n self.assertAllClose(np.sin(x), f_tf_wrapped(x))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e.5\u003c/span\u003e)\n\n\u003cspan class=\"pl-en\"\u003e@\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ejit_compile\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e)\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf_wrapped\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003edevice\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e'/device:TPU:0'\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\n\u003cspan class=\"pl-k\"\u003ewith\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003edevice\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e'/device:TPU:0'\u003c/span\u003e):\n \u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertAllClose\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e), \u003cspan class=\"pl-en\"\u003ef_tf_wrapped\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eUnsupported JAX features\u003c/h3\u003e\u003ca id=\"user-content-unsupported-jax-features\" class=\"anchor\" aria-label=\"Permalink: Unsupported JAX features\" href=\"#unsupported-jax-features\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to non-native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThere is currently no support for \u003ccode\u003epmap\u003c/code\u003e, \u003ccode\u003exmap\u003c/code\u003e, \u003ccode\u003eshard_map\u003c/code\u003e,\nnor for the collective operations, except in native serialization.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eShape polymorphism with native serialization limitations for \u003ccode\u003elax.linalg.eigh\u003c/code\u003e\u003c/h3\u003e\u003ca id=\"user-content-shape-polymorphism-with-native-serialization-limitations-for-laxlinalgeigh\" class=\"anchor\" aria-label=\"Permalink: Shape polymorphism with native serialization limitations for lax.linalg.eigh\" href=\"#shape-polymorphism-with-native-serialization-limitations-for-laxlinalgeigh\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eJAX lowers \u003ccode\u003elax.linalg.eigh\u003c/code\u003e using custom calls, and needs to call helper\nfunctions to determine the workspace size based on the non-batch dimensions.\nTherefore, dynamic dimensions are supported only for the batch dimensions\n(all but the last two dimensions).\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eAdditionally, on GPU, JAX lowering uses the \u003ccode\u003ecuSolver\u003c/code\u003e library and chooses\n\u003ccode\u003esyevj\u003c/code\u003e method (using Jacobi algorithm) for non-batch dimension size less or\nequal to 32, and the \u003ccode\u003esyevd\u003c/code\u003e method (using QR algorithm) for larger dimensions.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIn presence of shape polymorphism, JAX will always use \u003ccode\u003esyevd\u003c/code\u003e, because \u003ccode\u003esyevj\u003c/code\u003e\nrequires knowing the batch dimensions statically in order to compute\nthe workspace size. This means that the performance and the numerical behavior\nmay be slightly different for small matrices.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eSlow implementation of associative reductions for CPU\u003c/h3\u003e\u003ca id=\"user-content-slow-implementation-of-associative-reductions-for-cpu\" class=\"anchor\" aria-label=\"Permalink: Slow implementation of associative reductions for CPU\" href=\"#slow-implementation-of-associative-reductions-for-cpu\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to non-native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eOperations like \u003ccode\u003ejax.numpy.cumsum\u003c/code\u003e are lowered by JAX differently based\non the platform. For TPU, the lowering uses the \u003ca href=\"https://www.tensorflow.org/xla/operation_semantics#reducewindow\" rel=\"nofollow\"\u003eHLO ReduceWindow\u003c/a\u003e\noperation, which has an efficient implementation for the cases when the\nreduction function is associative. For CPU and GPU, JAX uses an alternative\nlowering using \u003ca href=\"https://github.com/jax-ml/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801\"\u003eassociative scans\u003c/a\u003e.\njax2tf uses the TPU lowering (because it does not support backend-specific lowering)\nand hence it can be slow in some cases on CPU and GPU.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eWe have filed a bug with the XLA:CPU compiler to improve ReduceWindow.\nMeanwhile, if you run into this problem you can use the\n\u003ccode\u003e--jax2tf_associative_scan_reductions\u003c/code\u003e flag to get the special\nassociative scan lowering.\nYou can alternatively use the \u003ccode\u003ewith jax.jax2tf_associative_scan_reductions(True)\u003c/code\u003e\naround the code that invokes the function returned by \u003ccode\u003ejax2tf.convert\u003c/code\u003e.\nUse this only if it improves the performance for your application.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eNote that this lowering may not work as well as the default one in presence\nof shape polymorphism.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eTensorFlow XLA ops\u003c/h3\u003e\u003ca id=\"user-content-tensorflow-xla-ops\" class=\"anchor\" aria-label=\"Permalink: TensorFlow XLA ops\" href=\"#tensorflow-xla-ops\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to non-native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor most JAX primitives there is a natural TensorFlow op that fits the needed semantics.\nThere are a few (listed in \u003ca href=\"/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md\"\u003eno_xla_limitations.md\u003c/a\u003e) JAX primitives\nfor which there is no single TensorFlow op with matching semantics.\nThis is not so surprising, because JAX primitives have been designed\nto be compiled to \u003ca href=\"https://www.tensorflow.org/xla/operation_semantics\" rel=\"nofollow\"\u003eHLO ops\u003c/a\u003e,\nwhile the corresponding TensorFlow ops are sometimes higher-level.\nFor the cases when there is no matching canonical TensorFlow op,\nwe use a set of special TensorFlow ops that are thin wrappers over HLO ops\n(a subset of those registered in\n\u003ca href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/ops/xla_ops.cc\"\u003etf2xla/ops/xla_ops.cc\u003c/a\u003e\nand implemented in,\ne.g.,\n\u003ca href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc\"\u003etf2xla/kernels/xla_pad_op.cc\u003c/a\u003e.)\nWe refer to these ops here as the XLA TensorFlow ops. Note that these are\nstill regular TF ops, e.g., they can be saved in a SavedModel.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThere are several drawbacks of using XLA TensorFlow ops:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003eThese ops will only be executable by a consumer that has XLA linked in.\nThis should not be a problem for TPU execution, since that requires XLA anyway.\u003c/li\u003e\n\u003cli\u003eThese ops are not yet recognized by tools that process\ntf.Graph, e.g., TensorFlow.js converter or the TensorFlow Lite converter.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eAs an experimental feature we implemented alternative conversions to avoid the XLA TensorFlow ops.\nYou can enable this with the \u003ccode\u003eenable_xla=False\u003c/code\u003e parameter to \u003ccode\u003ejax2tf.convert\u003c/code\u003e.\nFor more details see \u003ca href=\"/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md\"\u003eno_xla_limitations.md\u003c/a\u003e.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eDifferent performance characteristics\u003c/h3\u003e\u003ca id=\"user-content-different-performance-characteristics\" class=\"anchor\" aria-label=\"Permalink: Different performance characteristics\" href=\"#different-performance-characteristics\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to non-native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe lowered code may have slightly different performance characteristics than\nthe original JAX code.\nWe do expect that the performance characteristics of lowered code\nshould be the same as those of JAX when used with the XLA compiler (\u003ccode\u003etf.function(jit_compile=True)\u003c/code\u003e).\nThis is because\nduring lowering we try to generate one TensorFlow op for one JAX primitive.\nWe expect that the lowering that XLA does is similar to that done by JAX\nbefore conversion. (This is a hypothesis, we have not yet verified it extensively.)\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThere is one known case when the performance of the lowered code will be different.\nJAX programs use a \u003ca href=\"https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md\"\u003estateless\ndeterministic PRNG\u003c/a\u003e\nand it has an internal JAX primitive for it.\nThis primitive is at the moment lowered to a soup of tf.bitwise operations,\nwhich has a clear performance penalty. We plan to look into using the\nHLO \u003ca href=\"https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator\" rel=\"nofollow\"\u003eRNGBitGenerator\u003c/a\u003e\n(exposed as a TFXLA op), which does implement\nthe same basic Threefry algorithm as JAX’s PRNG, although that would\nresult in different results than JAX’s PRNG.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIn absence of TensorFlow XLA compilation,\nif one were to write the same functionality in JAX idiomatic code vs.\nnative TensorFlow idiomatic code we could end up with very different compilation paths.\nTake for example, the case of batch normalization.\nIn TensorFlow if one uses \u003ca href=\"https://www.tensorflow.org/api_docs/python/tf/nn/batch_normalization\" rel=\"nofollow\"\u003etf.nn.batch_normalization\u003c/a\u003e,\na “high-level” TensorFlow op for batch\nnormalization is generated, and in the absence of XLA, on CPU or GPU,\na custom C++ “high-level” kernel implementing batch normalization is executed.\nIn JAX, there is no primitive for batch normalization, and instead the\noperation is decomposed into low-level primitives (e.g., \u003ca href=\"https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html\" rel=\"nofollow\"\u003eflax.linen.BatchNorm\u003c/a\u003e,\nor haiku.BatchNorm).\nOnce those primitives are lowered to TensorFlow, and the resulting code is\nrun without XLA, the ensemble of the kernels executed will quite\npossibly behave differently, performance-wise or even numerically,\nthan either the TensorFlow native or JAX native batch normalization.\nA similar example is that of an LSTM cell.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eUnchecked assumption that the dimension variables take strictly positive values\u003c/h3\u003e\u003ca id=\"user-content-unchecked-assumption-that-the-dimension-variables-take-strictly-positive-values\" class=\"anchor\" aria-label=\"Permalink: Unchecked assumption that the dimension variables take strictly positive values\" href=\"#unchecked-assumption-that-the-dimension-variables-take-strictly-positive-values\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to non-native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe shape polymorphic conversion is sound with the assumption that the dimension\nvariables take non-zero values. In the following example, the function to be lowered\nhas different behavior for empty shapes. The broken assumption is caught by jax2tf if\nthe lowered function is executed eagerly, but not if it is first traced to a\nTensorFlow graph:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def f_jax(x):\n return 0 if x.shape[0] == 0 else 1\n\nx0 = np.array([], np.float32)\nself.assertEqual(0, f_jax(x0)) # JAX sees that the x.shape[0] == 0\n\n# jax2tf catches the broken assumption b \u0026gt;= 1 if the lowered function is executed\n# eagerly.\n# Raises: ValueError: Dimension variable b must have integer value \u0026gt;= 1. Found value 0 when solving b == 0\njax2tf.convert(f_jax, polymorphic_shapes=[\u0026quot;b\u0026quot;])(x0)\n\n# However, if we first trace to a TensorFlow graph, we may miss the broken assumption:\nf_tf = tf.function(\n jax2tf.convert(f_jax, polymorphic_shapes=[\u0026quot;b\u0026quot;]), autograph=False\n ).get_concrete_function(tf.TensorSpec([None], dtype=np.float32))\nself.assertEqual(1, f_tf(x0))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eif\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e==\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eelse\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003ex0\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003earray\u003c/span\u003e([], \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertEqual\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e, \u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex0\u003c/span\u003e)) \u003cspan class=\"pl-c\"\u003e# JAX sees that the x.shape[0] == 0\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# jax2tf catches the broken assumption b \u0026gt;= 1 if the lowered function is executed\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# eagerly.\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# Raises: ValueError: Dimension variable b must have integer value \u0026gt;= 1. Found value 0 when solving b == 0\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"b\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003ex0\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# However, if we first trace to a TensorFlow graph, we may miss the broken assumption:\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\n \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"b\"\u003c/span\u003e]), \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e\n ).\u003cspan class=\"pl-c1\"\u003eget_concrete_function\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eTensorSpec\u003c/span\u003e([\u003cspan class=\"pl-c1\"\u003eNone\u003c/span\u003e], \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e))\n\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertEqual\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e, \u003cspan class=\"pl-en\"\u003ef_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex0\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eAnother possible source of unsoundness is that JAX assumes that all unknown\ndimensions represented by the same dimension variable have equal size. As before,\nthis assumption is checked if the lowered function is executed eagerly, but\nit may be missed if it is first traced to a TensorFlow graph:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def f_jax(x):\n return 0 if x.shape[0] != x.shape[1] else 1\n\nx45 = np.ones((4, 5), dtype=np.float32)\nself.assertEqual(0, f_jax(x45)) # JAX seems that x.shape[0] != x.shape[1]\n\n# jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the lowered\n# function is executed eagerly.\n# Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),)\njax2tf.convert(f_jax, polymorphic_shapes=[\u0026quot;b, b\u0026quot;])(x45)\n\n# However, if we first trace to a TensorFlow graph, we may miss the broken assumption.\nf_tf = tf.function(\n jax2tf.convert(f_jax, polymorphic_shapes=[\u0026quot;b, b\u0026quot;]),\n autograph=False).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))\nself.assertEqual(1, f_tf(x45))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eif\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e!=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e] \u003cspan class=\"pl-k\"\u003eelse\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003ex45\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e((\u003cspan class=\"pl-c1\"\u003e4\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e5\u003c/span\u003e), \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e)\n\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertEqual\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e, \u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex45\u003c/span\u003e)) \u003cspan class=\"pl-c\"\u003e# JAX seems that x.shape[0] != x.shape[1]\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the lowered\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# function is executed eagerly.\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),)\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"b, b\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003ex45\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# However, if we first trace to a TensorFlow graph, we may miss the broken assumption.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ef_tf\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\n \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"b, b\"\u003c/span\u003e]),\n \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e).\u003cspan class=\"pl-c1\"\u003eget_concrete_function\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eTensorSpec\u003c/span\u003e([\u003cspan class=\"pl-c1\"\u003eNone\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003eNone\u003c/span\u003e], \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e))\n\u003cspan class=\"pl-s1\"\u003eself\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eassertEqual\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e, \u003cspan class=\"pl-en\"\u003ef_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex45\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eIncomplete TensorFlow data type coverage\u003c/h3\u003e\u003ca id=\"user-content-incomplete-tensorflow-data-type-coverage\" class=\"anchor\" aria-label=\"Permalink: Incomplete TensorFlow data type coverage\" href=\"#incomplete-tensorflow-data-type-coverage\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eApplies to non-native serialization only.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThere are a number of cases when the TensorFlow ops that are used by the\n\u003ccode\u003ejax2tf\u003c/code\u003e are not supported by TensorFlow for the same data types as in JAX.\nThere is an\n\u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md\"\u003eup-to-date list of unimplemented cases\u003c/a\u003e.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIf you try to lower and run in TensorFlow a program with partially supported primitives,\nyou may see TensorFlow errors that\na TensorFlow op is used with an unsupported data type, or that\nthere is no supported TensorFlow kernel for the op for the given\ndata type. The former case can happen even if you \u003ccode\u003ejit_compile\u003c/code\u003e\nthe TensorFlow program, and it is a priority to fit. The latter\ncase only appears in TensorFlow non-compiled mode; you can\navoid the problem if you use XLA to \u003ccode\u003ejit_compile\u003c/code\u003e (always recommended).\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eOur priority is to ensure numerical and performance accuracy for\nthe lowered program \u003cstrong\u003ewhen using XLA to compile the lowered program\u003c/strong\u003e.\nIt is always a good idea to use XLA on the lowered function.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eSometimes you cannot compile the entire TensorFlow function for your\nmodel, because in addition to the function that is lowered from JAX,\nit may include some pre-processing TensorFlow code that\nis not compilable with XLA, e.g., string parsing. Even in those situations\nyou can instruct TensorFlow to compile only the portion that originates\nfrom JAX:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def entire_tf_fun(x):\n y = preprocess_tf_fun_not_compilable(x)\n # Compile the code that is lowered from JAX\n z = tf.function(jax2tf.convert(compute_jax_fn),\n autograph=False, jit_compile=True)(y)\n return postprocess_tf_fun_not_compilable(z)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003eentire_tf_fun\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-en\"\u003epreprocess_tf_fun_not_compilable\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n \u003cspan class=\"pl-c\"\u003e# Compile the code that is lowered from JAX\u003c/span\u003e\n \u003cspan class=\"pl-s1\"\u003ez\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efunction\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ecompute_jax_fn\u003c/span\u003e),\n \u003cspan class=\"pl-s1\"\u003eautograph\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eFalse\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ejit_compile\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003eTrue\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e)\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-en\"\u003epostprocess_tf_fun_not_compilable\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ez\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eYou won't be able to compile the \u003ccode\u003eentire_tf_fun\u003c/code\u003e, but you can still execute\nit knowing that the jax2tf-lowered code is compiled. You can even save\nthe function to a SavedModel, knowing that upon restore the\njax2tf-lowered code will be compiled.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor a more elaborate example, see the test \u003ccode\u003etest_tf_mix_jax_with_uncompilable\u003c/code\u003e\nin \u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py\"\u003esavedmodel_test.py\u003c/a\u003e.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch1 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eCalling TensorFlow functions from JAX\u003c/h1\u003e\u003ca id=\"user-content-calling-tensorflow-functions-from-jax\" class=\"anchor\" aria-label=\"Permalink: Calling TensorFlow functions from JAX\" href=\"#calling-tensorflow-functions-from-jax\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe function \u003ccode\u003ecall_tf\u003c/code\u003e allows JAX functions to call\nTensorFlow functions. These functions can be called anywhere in a JAX\ncomputation, including in staging contexts \u003ccode\u003ejax.jit\u003c/code\u003e, \u003ccode\u003ejax.pmap\u003c/code\u003e, \u003ccode\u003ejax.xmap\u003c/code\u003e,\nor inside JAX's control-flow primitives. In non-staging contexts,\nthe TensorFlow function is called in eager mode.\nFor now, only reverse-mode autodiff is supported for these functions\n(no forward-mode autodiff, nor \u003ccode\u003evmap\u003c/code\u003e).\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eAs a trivial example, consider computing \u003ccode\u003esin(cos(1.))\u003c/code\u003e with \u003ccode\u003esin\u003c/code\u003e done in JAX and \u003ccode\u003ecos\u003c/code\u003e in TF:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"from jax.experimental import jax2tf\n\n# This is a TF function. It will be called with TensorFlow-compatible arguments,\n# such as `numpy.ndarray`, `tf.Tensor` or `tf.Variable`, or a pytree thereof.\n# It should return a similar result. This function will be called using\n# TensorFlow eager mode if called from outside JAX staged contexts (`jit`,\n# `pmap`, or control-flow primitives), and will be called using TensorFlow\n# compiled mode otherwise. In the latter case, the function must be compilable\n# with XLA (`tf.function(func, jit_compile=True)`)\ndef cos_tf(x):\n return tf.math.cos(x)\n\n# Compute cos with TF and sin with JAX\ndef cos_tf_sin_jax(x):\n return jax.numpy.sin(jax2tf.call_tf(cos_tf)(x))\n\n# Calls `cos_tf` in TF eager mode\nx = np.float32(1.)\ncos_tf_sin_jax(x)\n\n# Compiles `cos_tf` using TF and embeds the XLA computation into the JAX\n# XLA computation (containing `sin`). The XLA compiler may even be able to\n# fuse through JAX-TF computations.\njax.jit(cos_tf_sin_jax)(x)\n\n# Uses TF gradient for `cos_tf` and JAX gradient for `sin`\njax.grad(cos_tf_sin_jax)(x)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003efrom\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-s1\"\u003eexperimental\u003c/span\u003e \u003cspan class=\"pl-k\"\u003eimport\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e\n\n\u003cspan class=\"pl-c\"\u003e# This is a TF function. It will be called with TensorFlow-compatible arguments,\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# such as `numpy.ndarray`, `tf.Tensor` or `tf.Variable`, or a pytree thereof.\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# It should return a similar result. This function will be called using\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# TensorFlow eager mode if called from outside JAX staged contexts (`jit`,\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# `pmap`, or control-flow primitives), and will be called using TensorFlow\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# compiled mode otherwise. In the latter case, the function must be compilable\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# with XLA (`tf.function(func, jit_compile=True)`)\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ecos_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003emath\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecos\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# Compute cos with TF and sin with JAX\u003c/span\u003e\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ecos_tf_sin_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003enumpy\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecall_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ecos_tf\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e))\n\n\u003cspan class=\"pl-c\"\u003e# Calls `cos_tf` in TF eager mode\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e1.\u003c/span\u003e)\n\u003cspan class=\"pl-en\"\u003ecos_tf_sin_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# Compiles `cos_tf` using TF and embeds the XLA computation into the JAX\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# XLA computation (containing `sin`). The XLA compiler may even be able to\u003c/span\u003e\n\u003cspan class=\"pl-c\"\u003e# fuse through JAX-TF computations.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ejit\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ecos_tf_sin_jax\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# Uses TF gradient for `cos_tf` and JAX gradient for `sin`\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003egrad\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ecos_tf_sin_jax\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIf you inspect the generated HLO for \u003ccode\u003ecos_tf_sin_jax\u003c/code\u003e, you will see that the\nmain JAX computation (\u003ccode\u003eENTRY xla_computation_cos_tf_sin_jax\u003c/code\u003e) makes a call to\nthe \u003ccode\u003ea_inference_cos_tf_68__\u003c/code\u003e HLO function that was compiled by TF from \u003ccode\u003ecos_tf\u003c/code\u003e:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"HloModule xla_computation_cos_tf_sin_jax.18\n\na_inference_cos_tf_68__.4 {\n arg0.5 = f32[] parameter(0), parameter_replication={false}\n reshape.6 = f32[] reshape(arg0.5)\n cosine.7 = f32[] cosine(reshape.6)\n reshape.8 = f32[] reshape(cosine.7)\n tuple.9 = (f32[]) tuple(reshape.8)\n ROOT get-tuple-element.10 = f32[] get-tuple-element(tuple.9), index=0\n}\n\nENTRY xla_computation_cos_tf_sin_jax.18 {\n constant.2 = pred[] constant(false)\n constant.3 = pred[] constant(false)\n parameter.1 = f32[] parameter(0)\n call.11 = f32[] call(parameter.1), to_apply=a_inference_cos_tf_68__.4\n tuple.12 = (f32[]) tuple(call.11)\n get-tuple-element.13 = f32[] get-tuple-element(tuple.12), index=0\n tuple.14 = (f32[]) tuple(get-tuple-element.13)\n get-tuple-element.15 = f32[] get-tuple-element(tuple.14), index=0\n sine.16 = f32[] sine(get-tuple-element.15)\n ROOT tuple.17 = (f32[]) tuple(sine.16)\n}\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003eHloModule xla_computation_cos_tf_sin_jax.18\n\na_inference_cos_tf_68__.4 {\n arg0.5 = f32[] parameter(0), parameter_replication={false}\n reshape.6 = f32[] reshape(arg0.5)\n cosine.7 = f32[] cosine(reshape.6)\n reshape.8 = f32[] reshape(cosine.7)\n tuple.9 = (f32[]) tuple(reshape.8)\n ROOT get-tuple-element.10 = f32[] get-tuple-element(tuple.9), index=0\n}\n\nENTRY xla_computation_cos_tf_sin_jax.18 {\n constant.2 = pred[] constant(false)\n constant.3 = pred[] constant(false)\n parameter.1 = f32[] parameter(0)\n call.11 = f32[] call(parameter.1), to_apply=a_inference_cos_tf_68__.4\n tuple.12 = (f32[]) tuple(call.11)\n get-tuple-element.13 = f32[] get-tuple-element(tuple.12), index=0\n tuple.14 = (f32[]) tuple(get-tuple-element.13)\n get-tuple-element.15 = f32[] get-tuple-element(tuple.14), index=0\n sine.16 = f32[] sine(get-tuple-element.15)\n ROOT tuple.17 = (f32[]) tuple(sine.16)\n}\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eFor a more elaborate example, including round-tripping from JAX\nto TensorFlow and back through a SavedModel, with support for\ncustom gradients,\nsee the test \u003ccode\u003etest_round_trip_custom_grad_saved_model\u003c/code\u003e\nin \u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py\"\u003ecall_tf_test.py\u003c/a\u003e.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eAll the metadata inserted by TF during tracing and compilation, e.g.,\nsource location information and op names, is carried through to the\nJAX XLA computation.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe TF custom gradients are respected, since it is TF that generates the\ngradient computation.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003e\u003ccode\u003ecall_tf\u003c/code\u003e works even with shape polymorphism, but in that case\nthe user must pass the \u003ccode\u003eoutput_shape_dtype\u003c/code\u003e parameter to \u003ccode\u003ecall_tf\u003c/code\u003e to declare\nthe expected output shapes. This allows JAX tracing to know the shape and\ndtype of the results so that it can continue tracing the rest of the program.\nWhen \u003ccode\u003eoutput_shape_dtype\u003c/code\u003e is not given (the default case), \u003ccode\u003ecall_tf\u003c/code\u003e will\nform a \u003ccode\u003etf.Graph\u003c/code\u003e for the called TF function and will use the inferred\ntype and shape. However, in presence of dynamic shape the inferred TF\ntype will contain \u003ccode\u003eNone\u003c/code\u003e for the dynamic dimensions, which is not enough\ninformation for JAX shape polymorphism.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eFor example:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def fun_jax(x):\n y_shape = (x.shape[0] * 2, y.shape[1:])\n y = jax2tf.call_tf(\n lambda x: tf.concat([x, x], axis=0),\n output_shape_dype=jax.ShapeDtypeStruct(y_shape, x.dtype))(x)\n # JAX will know the y.shape\n return jnp.ones(y.shape, dtype=y.dtype) + y\n\njax2tf.convert(fun_jax, polymorphic_shapes=[\u0026quot;b, ...\u0026quot;])(x)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003efun_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-s1\"\u003ey_shape\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e (\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e*\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e:])\n \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecall_tf\u003c/span\u003e(\n \u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econcat\u003c/span\u003e([\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e], \u003cspan class=\"pl-s1\"\u003eaxis\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e),\n \u003cspan class=\"pl-s1\"\u003eoutput_shape_dype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eShapeDtypeStruct\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ey_shape\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003edtype\u003c/span\u003e))(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n \u003cspan class=\"pl-c\"\u003e# JAX will know the y.shape\u003c/span\u003e\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejnp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eones\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eshape\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003edtype\u003c/span\u003e) \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ey\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003efun_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"b, ...\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eAn even simpler example for a function that returns the same shape as the input:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def fun_jax(x):\n return jax2tf.call_tf(tf.math.sin,\n output_shape_dtype=x)\n )(x)\n\njax2tf.convert(fun_jax, polymorphic_shapes=[\u0026quot;b, ...\u0026quot;])(x)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003efun_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecall_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003emath\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e,\n \u003cspan class=\"pl-s1\"\u003eoutput_shape_dtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n )(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003efun_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"b, ...\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIf all the output shapes of the TF function are static, JAX does not need the\n\u003ccode\u003eoutput_shape_dtype\u003c/code\u003e argument:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def fun_tf(x):\n return tf.math.reduce_sum(tf.math.sin(x))\n\ndef fun_jax(x):\n return jax2tf.call_tf(fun_tf)(x)\n\n# The following will not throw an error because the output shape of fun_tf is static.\njax2tf.convert(fun_jax, polymorphic_shapes=[\u0026quot;b, ...\u0026quot;])(x)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003efun_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003emath\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ereduce_sum\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003emath\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003esin\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e))\n\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003efun_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecall_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003efun_tf\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# The following will not throw an error because the output shape of fun_tf is static.\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econvert\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003efun_jax\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003epolymorphic_shapes\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e[\u003cspan class=\"pl-s\"\u003e\"b, ...\"\u003c/span\u003e])(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe shape polymorphism support for \u003ccode\u003ecall_tf\u003c/code\u003e does not yet work for native serialization.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch3 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eLimitations of call_tf\u003c/h3\u003e\u003ca id=\"user-content-limitations-of-call_tf\" class=\"anchor\" aria-label=\"Permalink: Limitations of call_tf\" href=\"#limitations-of-call_tf\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe TF function must be compilable (\u003ccode\u003etf.function(func, jit_compile=True)\u003c/code\u003e)\nand must have static output shapes\nwhen used in a JAX staging context, e.g., \u003ccode\u003ejax.jit\u003c/code\u003e, \u003ccode\u003elax.scan\u003c/code\u003e, \u003ccode\u003elax.cond\u003c/code\u003e,\nbut may have unknown output shapes when used in a JAX op-by-op mode.\nFor example, the following\nfunction uses strings operations that are not supported by XLA:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def f_tf_non_compilable(x):\n return tf.strings.length(tf.strings.format(\u0026quot;Hello {}!\u0026quot;, [x]))\n\nf_jax = jax2tf.call_tf(f_tf_non_compilable)\n# Works in op-by-op mode\nf_jax(np.float32(42.))\n\n# Fails in jit mode\njax.jit(f_jax)(np.float(42.))\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf_non_compilable\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003estrings\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003elength\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003estrings\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eformat\u003c/span\u003e(\u003cspan class=\"pl-s\"\u003e\"Hello {}!\"\u003c/span\u003e, [\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e]))\n\n\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecall_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_tf_non_compilable\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# Works in op-by-op mode\u003c/span\u003e\n\u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat32\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e42.\u003c/span\u003e))\n\n\u003cspan class=\"pl-c\"\u003e# Fails in jit mode\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ejit\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003efloat\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e42.\u003c/span\u003e))\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eYet another unsupported situation is when the TF function\nis compilable but with dynamic output shapes:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def f_tf_dynamic_shape(x):\n return x[x[0]:5]\nx = np.array([1, 2], dtype=np.int32)\n\nf_jax = jax2tf.call_tf(f_tf_dynamic_shape)\n# Works in op-by-op mode\nf_jax(x)\n\n# Fails in jit mode\njax.jit(f_jax)(x)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf_dynamic_shape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e[\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e]:\u003cspan class=\"pl-c1\"\u003e5\u003c/span\u003e]\n\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003earray\u003c/span\u003e([\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e], \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eint32\u003c/span\u003e)\n\n\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecall_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_tf_dynamic_shape\u003c/span\u003e)\n\u003cspan class=\"pl-c\"\u003e# Works in op-by-op mode\u003c/span\u003e\n\u003cspan class=\"pl-en\"\u003ef_jax\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\n\n\u003cspan class=\"pl-c\"\u003e# Fails in jit mode\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ejit\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ef_jax\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eAnother similar example that will fail to compile:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def f_tf_dynamic_output_shape(x):\n return tf.cond(x[0] \u0026gt;= 0, lambda: x, lambda: x[1:])\n\nx = np.array([1, 2], dtype=np.int32)\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003ef_tf_dynamic_output_shape\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econd\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e] \u003cspan class=\"pl-c1\"\u003e\u0026gt;=\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e0\u003c/span\u003e, \u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-k\"\u003elambda\u003c/span\u003e: \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e[\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e:])\n\n\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003earray\u003c/span\u003e([\u003cspan class=\"pl-c1\"\u003e1\u003c/span\u003e, \u003cspan class=\"pl-c1\"\u003e2\u003c/span\u003e], \u003cspan class=\"pl-s1\"\u003edtype\u003c/span\u003e\u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e\u003cspan class=\"pl-s1\"\u003enp\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eint32\u003c/span\u003e)\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003e\u003ccode\u003ecall_tf\u003c/code\u003e works best with pure TF functions that do not capture\n\u003ccode\u003etf.Variable\u003c/code\u003es or tensors from the environment, and all such\ncontext is passed in explicitly through arguments, and if variables\nare modified, the resulting values are passed out through results.\nThere is a best-effort mechanism that can handle variable capture\nand variable updates,\nexcept in the case of a function that modifies \u003ccode\u003etf.Variable\u003c/code\u003es\nand is used in a JAX jitted context. Calling the \u003ccode\u003einpure_func_tf\u003c/code\u003e\nwill give an error:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"var1 = tf.Variable(1.)\ndef impure_func_tf(x):\n var1.write(11.) # BAD: should not write to variables\n return x + var1\n\njax2tf.call_tf(impure_func_tf)(tf.constant(2.)) # Works in eager mode\njax.jit(jax2tf.call_tf(impure_func_tf))(tf.constant(2.)) # Fails in jit mode\"\u003e\u003cpre\u003e\u003cspan class=\"pl-s1\"\u003evar1\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003eVariable\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e1.\u003c/span\u003e)\n\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-en\"\u003eimpure_func_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e):\n \u003cspan class=\"pl-s1\"\u003evar1\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ewrite\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e11.\u003c/span\u003e) \u003cspan class=\"pl-c\"\u003e# BAD: should not write to variables\u003c/span\u003e\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003evar1\u003c/span\u003e\n\n\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecall_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eimpure_func_tf\u003c/span\u003e)(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econstant\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e2.\u003c/span\u003e)) \u003cspan class=\"pl-c\"\u003e# Works in eager mode\u003c/span\u003e\n\u003cspan class=\"pl-s1\"\u003ejax\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ejit\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ejax2tf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003ecall_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003eimpure_func_tf\u003c/span\u003e))(\u003cspan class=\"pl-s1\"\u003etf\u003c/span\u003e.\u003cspan class=\"pl-c1\"\u003econstant\u003c/span\u003e(\u003cspan class=\"pl-c1\"\u003e2.\u003c/span\u003e)) \u003cspan class=\"pl-c\"\u003e# Fails in jit mode\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe error can be avoided by passing the variable explicitly:\u003c/p\u003e\n\u003cdiv class=\"highlight highlight-source-python notranslate position-relative overflow-auto\" dir=\"auto\" data-snippet-clipboard-copy-content=\"def pure_func_tf(x, var1)\n new_var1 = 11.\n return x + new_var1, new_var1\"\u003e\u003cpre\u003e\u003cspan class=\"pl-k\"\u003edef\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003epure_func_tf\u003c/span\u003e(\u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003evar1\u003c/span\u003e)\n \u003cspan class=\"pl-s1\"\u003enew_var1\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e=\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e11.\u003c/span\u003e\n \u003cspan class=\"pl-k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003ex\u003c/span\u003e \u003cspan class=\"pl-c1\"\u003e+\u003c/span\u003e \u003cspan class=\"pl-s1\"\u003enew_var1\u003c/span\u003e, \u003cspan class=\"pl-s1\"\u003enew_var1\u003c/span\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThis use case is likely to be revisited.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eNote that when the TF function captures a variable from the context, the\nTF function must be lowered for the same TF device that hosts the variable.\nBy default, the lowering will use the first TF device on the same platform\nas the embedding JAX computation, e.g., \"/device:TPU:0\" if the embedding\nJAX computation runs on TPU. This will fail if the computation captures\nvariables on some other devices. It is best to use \u003ccode\u003ecall_tf\u003c/code\u003e\nwith TF functions that do not capture variables.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eIn some rare cases your called TF function may contain ops with output\nof statically known shape, but for which the shape inference is not implemented\ncompletely and will appear to \u003ccode\u003ecall_tf\u003c/code\u003e as if they have dynamically-shaped\noutputs. In these cases you may get an error that\n\u003ccode\u003ecall_tf cannot call functions whose output has dynamic shape\u003c/code\u003e. Try using\nthe \u003ccode\u003eoutput_shape_dtype\u003c/code\u003e parameter to specify the expected output shape\n(this essentially allows you to override the shape inference for the\npurposes of \u003ccode\u003ecall_tf\u003c/code\u003e.)\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch1 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eMisc notes\u003c/h1\u003e\u003ca id=\"user-content-misc-notes\" class=\"anchor\" aria-label=\"Permalink: Misc notes\" href=\"#misc-notes\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\n\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eDebugging JAX native serialization\u003c/h2\u003e\u003ca id=\"user-content-debugging-jax-native-serialization\" class=\"anchor\" aria-label=\"Permalink: Debugging JAX native serialization\" href=\"#debugging-jax-native-serialization\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eInside Google, you can turn on logging by using the \u003ccode\u003e--vmodule\u003c/code\u003e argument to\nspecify the logging levels for different modules,\ne.g., \u003ccode\u003e--vmodule=_export=3\u003c/code\u003e. You can set \u003ccode\u003eTF_DUMP_GRAPH_PREFIX\u003c/code\u003e to\na directory where modules should be dumped, or to \u003ccode\u003e\"-\"\u003c/code\u003e to dump the\nmodules to the log.\nThe following modules are useful for debugging JAX native serialization:\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003e\u003ccode\u003e_export=3\u003c/code\u003e - will log the StableHLO module on serialization.\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003ejax2tf=3\u003c/code\u003e - will log the parameters to \u003ccode\u003eXlaCallModule\u003c/code\u003e op on serialization.\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003exla_call_module_loader=3\u003c/code\u003e - will log the StableHLO module upon loading,\nafter shape refinements, and on verification error. You can use level \u003ccode\u003e4\u003c/code\u003e to\nadd location information, and level \u003ccode\u003e5\u003c/code\u003e to also print the module before and\nafter each transformation.\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003exla_call_module_op=3\u003c/code\u003e - will log the HLO module generated after\nshape refinement and conversion from StableHLO.\u003c/li\u003e\n\u003cli\u003e\u003ccode\u003eXlaCallModule\u003c/code\u003e lowering has TensorFlow MLIR crash reproducer enabled, which\ncan be instructed to generate a crash reproducer upon MLIR pass failures by\nsetting an environment variable \u003ccode\u003eMLIR_CRASH_REPRODUCER_DIRECTORY\u003c/code\u003e.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eFor the two \u003ccode\u003exla\u003c/code\u003e modules mentioned above, you can control logging in OSS\nwith environment variables, e.g.:\u003c/p\u003e\n\u003cdiv class=\"snippet-clipboard-content notranslate position-relative overflow-auto\" data-snippet-clipboard-copy-content=\"TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=xla_call_module_loader=3 python ...\"\u003e\u003cpre class=\"notranslate\"\u003e\u003ccode\u003eTF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=xla_call_module_loader=3 python ...\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eIn addition, \u003ccode\u003eTF_DUMP_GRAPH_PREFIX\u003c/code\u003e controls where the dump will be stored, \u003ccode\u003e-\u003c/code\u003e\nfor stderr, \u003ccode\u003e${SOME_DIR}\u003c/code\u003e to store the dumps in the specified directory.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eTensorFlow versions supported\u003c/h2\u003e\u003ca id=\"user-content-tensorflow-versions-supported\" class=\"anchor\" aria-label=\"Permalink: TensorFlow versions supported\" href=\"#tensorflow-versions-supported\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe \u003ccode\u003ejax2tf.convert\u003c/code\u003e and \u003ccode\u003ecall_tf\u003c/code\u003e require fairly recent versions of TensorFlow.\nAs of today, the tests are run using \u003ccode\u003etf_nightly==2.14.0.dev20230720\u003c/code\u003e.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eRunning on GPU\u003c/h2\u003e\u003ca id=\"user-content-running-on-gpu\" class=\"anchor\" aria-label=\"Permalink: Running on GPU\" href=\"#running-on-gpu\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eTo run jax2tf on GPU, both jaxlib and TensorFlow must be installed with support\nfor CUDA. One must be mindful to install a version of CUDA that is compatible\nwith both \u003ca href=\"https://github.com/jax-ml/jax/blob/main/README.md#pip-installation\"\u003ejaxlib\u003c/a\u003e and\n\u003ca href=\"https://www.tensorflow.org/install/source#tested_build_configurations\" rel=\"nofollow\"\u003eTensorFlow\u003c/a\u003e.\u003c/p\u003e\n\u003cdiv class=\"markdown-heading\" dir=\"auto\"\u003e\u003ch2 tabindex=\"-1\" class=\"heading-element\" dir=\"auto\"\u003eUpdating the limitations documentation\u003c/h2\u003e\u003ca id=\"user-content-updating-the-limitations-documentation\" class=\"anchor\" aria-label=\"Permalink: Updating the limitations documentation\" href=\"#updating-the-limitations-documentation\"\u003e\u003csvg class=\"octicon octicon-link\" viewBox=\"0 0 16 16\" version=\"1.1\" width=\"16\" height=\"16\" aria-hidden=\"true\"\u003e\u003cpath d=\"m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z\"\u003e\u003c/path\u003e\u003c/svg\u003e\u003c/a\u003e\u003c/div\u003e\n\u003cp dir=\"auto\"\u003eThe jax2tf tests are parameterized by a set of limitations\n(see \u003ccode\u003etests/primitive_harness.py\u003c/code\u003e and \u003ccode\u003etests/jax2tf_limitations.py\u003c/code\u003e).\nThe limitations specify test harnesses that are known to fail, by\nJAX primitive, data type, device type, and TensorFlow execution mode (\u003ccode\u003eeager\u003c/code\u003e,\n\u003ccode\u003egraph\u003c/code\u003e, or \u003ccode\u003ecompiled\u003c/code\u003e). These limitations are also used\nto generate tables of limitations, e.g.,\u003c/p\u003e\n\u003cul dir=\"auto\"\u003e\n\u003cli\u003e\u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md\"\u003eList of primitives not supported in JAX\u003c/a\u003e,\ne.g., due to unimplemented cases in the XLA compiler, and\u003c/li\u003e\n\u003cli\u003e\u003ca href=\"https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md\"\u003eList of primitives not supported in jax2tf\u003c/a\u003e,\ne.g., due to unimplemented cases in TensorFlow. This list is incremental\non top of the unsupported JAX primitives.\u003c/li\u003e\n\u003c/ul\u003e\n\u003cp dir=\"auto\"\u003eThere are instructions for updating those documents at the end of each\ndocument.\u003c/p\u003e\n\u003cp dir=\"auto\"\u003eThe set of limitations is an over-approximation, in the sense that if XLA\nor TensorFlow improves and support more cases, no test will fail. Instead,\nperiodically, we check for unnecessary limitations. We do this by uncommenting\ntwo assertions (in \u003ccode\u003etests/jax_primitives_coverage_test.py\u003c/code\u003e and in\n\u003ccode\u003etests/tf_test_util.py\u003c/code\u003e) and running all the tests. With these assertions enabled\nthe tests will fail and point out unnecessary limitations. We remove limitations\nuntil the tests pass. Then we re-generate the documentation.\u003c/p\u003e\n\u003c/article\u003e","errorMessage":null,"headerInfo":{"toc":[{"level":1,"text":"JAX and TensorFlow interoperation (jax2tf/call_tf)","anchor":"jax-and-tensorflow-interoperation-jax2tfcall_tf","htmlText":"JAX and TensorFlow interoperation (jax2tf/call_tf)"},{"level":2,"text":"Usage: basic functions.","anchor":"usage-basic-functions","htmlText":"Usage: basic functions."},{"level":2,"text":"Usage: saved model","anchor":"usage-saved-model","htmlText":"Usage: saved model"},{"level":3,"text":"Saved model with parameters","anchor":"saved-model-with-parameters","htmlText":"Saved model with parameters"},{"level":3,"text":"Saved model and differentiation","anchor":"saved-model-and-differentiation","htmlText":"Saved model and differentiation"},{"level":3,"text":"Saved model for non-differentiable JAX functions","anchor":"saved-model-for-non-differentiable-jax-functions","htmlText":"Saved model for non-differentiable JAX functions"},{"level":2,"text":"Support for partitioning","anchor":"support-for-partitioning","htmlText":"Support for partitioning"},{"level":2,"text":"Shape-polymorphic conversion","anchor":"shape-polymorphic-conversion","htmlText":"Shape-polymorphic conversion"},{"level":3,"text":"Correctness of shape-polymorphic tracing","anchor":"correctness-of-shape-polymorphic-tracing","htmlText":"Correctness of shape-polymorphic tracing"},{"level":3,"text":"Coverage of shape-polymorphic tracing","anchor":"coverage-of-shape-polymorphic-tracing","htmlText":"Coverage of shape-polymorphic tracing"},{"level":3,"text":"Details","anchor":"details","htmlText":"Details"},{"level":3,"text":"Computing with dimension variables","anchor":"computing-with-dimension-variables","htmlText":"Computing with dimension variables"},{"level":3,"text":"Errors in presence of shape polymorphism","anchor":"errors-in-presence-of-shape-polymorphism","htmlText":"Errors in presence of shape polymorphism"},{"level":3,"text":"Dimension variables must be solvable from the input shapes","anchor":"dimension-variables-must-be-solvable-from-the-input-shapes","htmlText":"Dimension variables must be solvable from the input shapes"},{"level":3,"text":"Shape assertion errors","anchor":"shape-assertion-errors","htmlText":"Shape assertion errors"},{"level":3,"text":"Comparison of symbolic dimensions is partially supported","anchor":"comparison-of-symbolic-dimensions-is-partially-supported","htmlText":"Comparison of symbolic dimensions is partially supported"},{"level":4,"text":"User-specified symbolic constraints","anchor":"user-specified-symbolic-constraints","htmlText":"User-specified symbolic constraints"},{"level":4,"text":"Caveat for equality comparisons","anchor":"caveat-for-equality-comparisons","htmlText":"Caveat for equality comparisons"},{"level":3,"text":"Division of symbolic dimensions is partially supported","anchor":"division-of-symbolic-dimensions-is-partially-supported","htmlText":"Division of symbolic dimensions is partially supported"},{"level":2,"text":"Native serialization versions","anchor":"native-serialization-versions","htmlText":"Native serialization versions"},{"level":2,"text":"Serialization version numbers","anchor":"serialization-version-numbers","htmlText":"Serialization version numbers"},{"level":2,"text":"Known issues","anchor":"known-issues","htmlText":"Known issues"},{"level":3,"text":"Different 64-bit precision in JAX and TensorFlow","anchor":"different-64-bit-precision-in-jax-and-tensorflow","htmlText":"Different 64-bit precision in JAX and TensorFlow"},{"level":3,"text":"Functions whose arguments and results are nested Python data structures","anchor":"functions-whose-arguments-and-results-are-nested-python-data-structures","htmlText":"Functions whose arguments and results are nested Python data structures"},{"level":3,"text":"Lowering gradients for functions with integer arguments or unused arguments","anchor":"lowering-gradients-for-functions-with-integer-arguments-or-unused-arguments","htmlText":"Lowering gradients for functions with integer arguments or unused arguments"},{"level":3,"text":"Errors due to tf.Module magic conversion during attribute assignment","anchor":"errors-due-to-tfmodule-magic-conversion-during-attribute-assignment","htmlText":"Errors due to tf.Module magic conversion during attribute assignment"},{"level":3,"text":"Large saved_model.pb due too many PRNG operations","anchor":"large-saved_modelpb-due-too-many-prng-operations","htmlText":"Large saved_model.pb due too many PRNG operations"},{"level":3,"text":"SavedModel supports only first-order gradients","anchor":"savedmodel-supports-only-first-order-gradients","htmlText":"SavedModel supports only first-order gradients"},{"level":3,"text":"Native serialization supports only select dialects","anchor":"native-serialization-supports-only-select-dialects","htmlText":"Native serialization supports only select dialects"},{"level":3,"text":"Native serialization supports only select custom calls","anchor":"native-serialization-supports-only-select-custom-calls","htmlText":"Native serialization supports only select custom calls"},{"level":3,"text":"XlaCallModule not supported by some TensorFlow tools","anchor":"xlacallmodule-not-supported-by-some-tensorflow-tools","htmlText":"XlaCallModule not supported by some TensorFlow tools"},{"level":3,"text":"Natively serialized JAX modules are platform specific","anchor":"natively-serialized-jax-modules-are-platform-specific","htmlText":"Natively serialized JAX modules are platform specific"},{"level":3,"text":"Unsupported JAX features","anchor":"unsupported-jax-features","htmlText":"Unsupported JAX features"},{"level":3,"text":"Shape polymorphism with native serialization limitations for lax.linalg.eigh","anchor":"shape-polymorphism-with-native-serialization-limitations-for-laxlinalgeigh","htmlText":"Shape polymorphism with native serialization limitations for lax.linalg.eigh"},{"level":3,"text":"Slow implementation of associative reductions for CPU","anchor":"slow-implementation-of-associative-reductions-for-cpu","htmlText":"Slow implementation of associative reductions for CPU"},{"level":3,"text":"TensorFlow XLA ops","anchor":"tensorflow-xla-ops","htmlText":"TensorFlow XLA ops"},{"level":3,"text":"Different performance characteristics","anchor":"different-performance-characteristics","htmlText":"Different performance characteristics"},{"level":3,"text":"Unchecked assumption that the dimension variables take strictly positive values","anchor":"unchecked-assumption-that-the-dimension-variables-take-strictly-positive-values","htmlText":"Unchecked assumption that the dimension variables take strictly positive values"},{"level":3,"text":"Incomplete TensorFlow data type coverage","anchor":"incomplete-tensorflow-data-type-coverage","htmlText":"Incomplete TensorFlow data type coverage"},{"level":1,"text":"Calling TensorFlow functions from JAX","anchor":"calling-tensorflow-functions-from-jax","htmlText":"Calling TensorFlow functions from JAX"},{"level":3,"text":"Limitations of call_tf","anchor":"limitations-of-call_tf","htmlText":"Limitations of call_tf"},{"level":1,"text":"Misc notes","anchor":"misc-notes","htmlText":"Misc notes"},{"level":2,"text":"Debugging JAX native serialization","anchor":"debugging-jax-native-serialization","htmlText":"Debugging JAX native serialization"},{"level":2,"text":"TensorFlow versions supported","anchor":"tensorflow-versions-supported","htmlText":"TensorFlow versions supported"},{"level":2,"text":"Running on GPU","anchor":"running-on-gpu","htmlText":"Running on GPU"},{"level":2,"text":"Updating the limitations documentation","anchor":"updating-the-limitations-documentation","htmlText":"Updating the limitations documentation"}],"siteNavLoginPath":"/login?return_to=https%3A%2F%2Fgithub.com%2Fjax-ml%2Fjax%2Ftree%2Fmain%2Fjax%2Fexperimental%2Fjax2tf"}},"totalCount":10,"showBranchInfobar":false},"fileTree":{"jax/experimental":{"items":[{"name":"array_serialization","path":"jax/experimental/array_serialization","contentType":"directory"},{"name":"colocated_python","path":"jax/experimental/colocated_python","contentType":"directory"},{"name":"compilation_cache","path":"jax/experimental/compilation_cache","contentType":"directory"},{"name":"jax2tf","path":"jax/experimental/jax2tf","contentType":"directory"},{"name":"key_reuse","path":"jax/experimental/key_reuse","contentType":"directory"},{"name":"mosaic","path":"jax/experimental/mosaic","contentType":"directory"},{"name":"pallas","path":"jax/experimental/pallas","contentType":"directory"},{"name":"roofline","path":"jax/experimental/roofline","contentType":"directory"},{"name":"slab","path":"jax/experimental/slab","contentType":"directory"},{"name":"source_mapper","path":"jax/experimental/source_mapper","contentType":"directory"},{"name":"sparse","path":"jax/experimental/sparse","contentType":"directory"},{"name":"__init__.py","path":"jax/experimental/__init__.py","contentType":"file"},{"name":"attrs.py","path":"jax/experimental/attrs.py","contentType":"file"},{"name":"checkify.py","path":"jax/experimental/checkify.py","contentType":"file"},{"name":"compute_on.py","path":"jax/experimental/compute_on.py","contentType":"file"},{"name":"custom_dce.py","path":"jax/experimental/custom_dce.py","contentType":"file"},{"name":"custom_partitioning.py","path":"jax/experimental/custom_partitioning.py","contentType":"file"},{"name":"host_callback.py","path":"jax/experimental/host_callback.py","contentType":"file"},{"name":"jet.py","path":"jax/experimental/jet.py","contentType":"file"},{"name":"layout.py","path":"jax/experimental/layout.py","contentType":"file"},{"name":"mesh_utils.py","path":"jax/experimental/mesh_utils.py","contentType":"file"},{"name":"multihost_utils.py","path":"jax/experimental/multihost_utils.py","contentType":"file"},{"name":"ode.py","path":"jax/experimental/ode.py","contentType":"file"},{"name":"pjit.py","path":"jax/experimental/pjit.py","contentType":"file"},{"name":"profiler.py","path":"jax/experimental/profiler.py","contentType":"file"},{"name":"rnn.py","path":"jax/experimental/rnn.py","contentType":"file"},{"name":"serialize_executable.py","path":"jax/experimental/serialize_executable.py","contentType":"file"},{"name":"shard.py","path":"jax/experimental/shard.py","contentType":"file"},{"name":"shard_alike.py","path":"jax/experimental/shard_alike.py","contentType":"file"},{"name":"shard_map.py","path":"jax/experimental/shard_map.py","contentType":"file"},{"name":"topologies.py","path":"jax/experimental/topologies.py","contentType":"file"},{"name":"transfer.py","path":"jax/experimental/transfer.py","contentType":"file"},{"name":"x64_context.py","path":"jax/experimental/x64_context.py","contentType":"file"},{"name":"xla_metadata.py","path":"jax/experimental/xla_metadata.py","contentType":"file"}],"totalCount":34},"jax":{"items":[{"name":"_src","path":"jax/_src","contentType":"directory"},{"name":"example_libraries","path":"jax/example_libraries","contentType":"directory"},{"name":"experimental","path":"jax/experimental","contentType":"directory"},{"name":"extend","path":"jax/extend","contentType":"directory"},{"name":"image","path":"jax/image","contentType":"directory"},{"name":"interpreters","path":"jax/interpreters","contentType":"directory"},{"name":"lax","path":"jax/lax","contentType":"directory"},{"name":"lib","path":"jax/lib","contentType":"directory"},{"name":"nn","path":"jax/nn","contentType":"directory"},{"name":"numpy","path":"jax/numpy","contentType":"directory"},{"name":"ops","path":"jax/ops","contentType":"directory"},{"name":"scipy","path":"jax/scipy","contentType":"directory"},{"name":"tools","path":"jax/tools","contentType":"directory"},{"name":"BUILD","path":"jax/BUILD","contentType":"file"},{"name":"__init__.py","path":"jax/__init__.py","contentType":"file"},{"name":"ad_checkpoint.py","path":"jax/ad_checkpoint.py","contentType":"file"},{"name":"api_util.py","path":"jax/api_util.py","contentType":"file"},{"name":"cloud_tpu_init.py","path":"jax/cloud_tpu_init.py","contentType":"file"},{"name":"collect_profile.py","path":"jax/collect_profile.py","contentType":"file"},{"name":"core.py","path":"jax/core.py","contentType":"file"},{"name":"custom_batching.py","path":"jax/custom_batching.py","contentType":"file"},{"name":"custom_derivatives.py","path":"jax/custom_derivatives.py","contentType":"file"},{"name":"custom_transpose.py","path":"jax/custom_transpose.py","contentType":"file"},{"name":"debug.py","path":"jax/debug.py","contentType":"file"},{"name":"distributed.py","path":"jax/distributed.py","contentType":"file"},{"name":"dlpack.py","path":"jax/dlpack.py","contentType":"file"},{"name":"dtypes.py","path":"jax/dtypes.py","contentType":"file"},{"name":"errors.py","path":"jax/errors.py","contentType":"file"},{"name":"export.py","path":"jax/export.py","contentType":"file"},{"name":"ffi.py","path":"jax/ffi.py","contentType":"file"},{"name":"flatten_util.py","path":"jax/flatten_util.py","contentType":"file"},{"name":"monitoring.py","path":"jax/monitoring.py","contentType":"file"},{"name":"profiler.py","path":"jax/profiler.py","contentType":"file"},{"name":"py.typed","path":"jax/py.typed","contentType":"file"},{"name":"random.py","path":"jax/random.py","contentType":"file"},{"name":"sharding.py","path":"jax/sharding.py","contentType":"file"},{"name":"stages.py","path":"jax/stages.py","contentType":"file"},{"name":"test_util.py","path":"jax/test_util.py","contentType":"file"},{"name":"tree.py","path":"jax/tree.py","contentType":"file"},{"name":"tree_util.py","path":"jax/tree_util.py","contentType":"file"},{"name":"typing.py","path":"jax/typing.py","contentType":"file"},{"name":"util.py","path":"jax/util.py","contentType":"file"},{"name":"version.py","path":"jax/version.py","contentType":"file"}],"totalCount":43},"":{"items":[{"name":".github","path":".github","contentType":"directory"},{"name":"benchmarks","path":"benchmarks","contentType":"directory"},{"name":"build","path":"build","contentType":"directory"},{"name":"ci","path":"ci","contentType":"directory"},{"name":"cloud_tpu_colabs","path":"cloud_tpu_colabs","contentType":"directory"},{"name":"docs","path":"docs","contentType":"directory"},{"name":"examples","path":"examples","contentType":"directory"},{"name":"images","path":"images","contentType":"directory"},{"name":"jax","path":"jax","contentType":"directory"},{"name":"jax_plugins","path":"jax_plugins","contentType":"directory"},{"name":"jaxlib","path":"jaxlib","contentType":"directory"},{"name":"tests","path":"tests","contentType":"directory"},{"name":"third_party","path":"third_party","contentType":"directory"},{"name":".bazelrc","path":".bazelrc","contentType":"file"},{"name":".bazelversion","path":".bazelversion","contentType":"file"},{"name":".editorconfig","path":".editorconfig","contentType":"file"},{"name":".gitignore","path":".gitignore","contentType":"file"},{"name":".pre-commit-config.yaml","path":".pre-commit-config.yaml","contentType":"file"},{"name":".readthedocs.yml","path":".readthedocs.yml","contentType":"file"},{"name":"AUTHORS","path":"AUTHORS","contentType":"file"},{"name":"CHANGELOG.md","path":"CHANGELOG.md","contentType":"file"},{"name":"CITATION.bib","path":"CITATION.bib","contentType":"file"},{"name":"CONTRIBUTING.md","path":"CONTRIBUTING.md","contentType":"file"},{"name":"LICENSE","path":"LICENSE","contentType":"file"},{"name":"README.md","path":"README.md","contentType":"file"},{"name":"WORKSPACE","path":"WORKSPACE","contentType":"file"},{"name":"conftest.py","path":"conftest.py","contentType":"file"},{"name":"platform_mappings","path":"platform_mappings","contentType":"file"},{"name":"pyproject.toml","path":"pyproject.toml","contentType":"file"},{"name":"setup.py","path":"setup.py","contentType":"file"}],"totalCount":30}},"fileTreeProcessingTime":19.55902,"foldersToFetch":[],"treeExpanded":true,"symbolsExpanded":false,"csrf_tokens":{"/jax-ml/jax/branches":{"post":"oPspg0pdZVyXtLJwtEzTLNFx0c9HbZEdkhw0lSB45uSR-WaCJEy7yWXcU8HeuYdg7P9xqU7UlOrJlqRrs1GOGw"},"/jax-ml/jax/branches/fetch_and_merge/main":{"post":"NWMVCUm-PfdwVNGiAcu-8zaHAxyDuIgz3Lf9Sls34OpOQQYUlhBcWKWyq3pIfL_3LfPHOoxCKfKkOdu8FVHaCg"},"/jax-ml/jax/branches/fetch_and_merge/main?discard_changes=true":{"post":"DdfRoyD7j4xzu_RT1BSquq2OSzKDDGL4rS9CenyluA129cK-_1XuI6Zdjoudo6u-tvqPFIz2wznVoWSMMsOC7Q"}}},"title":"jax/jax/experimental/jax2tf at main · jax-ml/jax","appPayload":{"helpUrl":"https://docs.github.com","findFileWorkerPath":"/assets-cdn/worker/find-file-worker-9f8a877aa99f.js","findInFileWorkerPath":"/assets-cdn/worker/find-in-file-worker-96e76d5fdb2c.js","githubDevUrl":null,"enabled_features":{"code_nav_ui_events":false,"overview_shared_code_dropdown_button":false,"react_blob_overlay":false,"copilot_conversational_ux_embedding_update":false,"copilot_smell_icebreaker_ux":true,"accessible_code_button":true}}}</script> <div data-target="react-app.reactRoot"><style data-styled="true" data-styled-version="5.3.11">.hOfjFo{padding:0;}/*!sc*/ .oDGAe{max-width:100%;margin-left:auto;margin-right:auto;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-wrap:wrap;-ms-flex-wrap:wrap;flex-wrap:wrap;}/*!sc*/ .kowOcT{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex:1 1 100%;-ms-flex:1 1 100%;flex:1 1 100%;-webkit-flex-wrap:wrap;-ms-flex-wrap:wrap;flex-wrap:wrap;max-width:100%;}/*!sc*/ .gISSDQ{width:100%;}/*!sc*/ @media screen and (min-width:544px){.gISSDQ{width:100%;}}/*!sc*/ @media screen and (min-width:768px){.gISSDQ{width:auto;}}/*!sc*/ .cEmWSE{display:none;-webkit-order:1;-ms-flex-order:1;order:1;width:100%;margin-left:0;margin-right:0;-webkit-flex-direction:column-reverse;-ms-flex-direction:column-reverse;flex-direction:column-reverse;margin-bottom:0;min-width:0;}/*!sc*/ @media screen and (min-width:768px){.cEmWSE{width:auto;margin-top:0 !important;margin-bottom:0 !important;position:-webkit-sticky;position:sticky;top:0px;max-height:100vh !important;-webkit-flex-direction:row;-ms-flex-direction:row;flex-direction:row;margin-right:0;height:100vh;}}/*!sc*/ @media print,screen and (max-width:1011px) and (min-width:768px){.cEmWSE{display:none;}}/*!sc*/ .hPvFuC{margin-left:0;margin-right:0;display:none;margin-top:0;}/*!sc*/ @media screen and (min-width:768px){.hPvFuC{margin-left:0 !important;margin-right:0 !important;}}/*!sc*/ .fFSoPl{--pane-min-width:256px;--pane-max-width-diff:511px;--pane-max-width:calc(100vw - var(--pane-max-width-diff));width:100%;padding:0;}/*!sc*/ @media screen and (min-width:544px){}/*!sc*/ @media screen and (min-width:768px){.fFSoPl{width:clamp(var(--pane-min-width),var(--pane-width),var(--pane-max-width));overflow:auto;}}/*!sc*/ @media screen and (min-width:1280px){.fFSoPl{--pane-max-width-diff:959px;}}/*!sc*/ .bTBnTW{height:100%;position:relative;display:none;margin-left:0;}/*!sc*/ .bHLmSv{position:absolute;inset:0 -2px;cursor:col-resize;background-color:transparent;-webkit-transition-delay:0.1s;transition-delay:0.1s;}/*!sc*/ .bHLmSv:hover{background-color:var(--bgColor-neutral-muted,var(--color-neutral-muted,rgba(175,184,193,0.2)));}/*!sc*/ .iKqMNA{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-direction:column;-ms-flex-direction:column;flex-direction:column;-webkit-order:2;-ms-flex-order:2;order:2;-webkit-flex-basis:0;-ms-flex-preferred-size:0;flex-basis:0;-webkit-box-flex:1;-webkit-flex-grow:1;-ms-flex-positive:1;flex-grow:1;-webkit-flex-shrink:1;-ms-flex-negative:1;flex-shrink:1;min-width:1px;margin-right:auto;}/*!sc*/ @media print{.iKqMNA{display:-webkit-box !important;display:-webkit-flex !important;display:-ms-flexbox !important;display:flex !important;}}/*!sc*/ .FxAyp{width:100%;max-width:100%;margin-left:auto;margin-right:auto;-webkit-box-flex:1;-webkit-flex-grow:1;-ms-flex-positive:1;flex-grow:1;padding:0;}/*!sc*/ .leYMvG{margin-left:auto;margin-right:auto;-webkit-flex-direction:column;-ms-flex-direction:column;flex-direction:column;padding-bottom:40px;max-width:100%;margin-top:0;}/*!sc*/ .KMPzq{display:inherit;}/*!sc*/ .hfKjHv{width:100%;}/*!sc*/ .gZWyZE{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;gap:8px;-webkit-flex-direction:column;-ms-flex-direction:column;flex-direction:column;width:100%;}/*!sc*/ .dwYKDk{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-align-items:start;-webkit-box-align:start;-ms-flex-align:start;align-items:start;-webkit-box-pack:justify;-webkit-justify-content:space-between;-ms-flex-pack:justify;justify-content:space-between;gap:8px;}/*!sc*/ .ibcGmb{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-align-items:start;-webkit-box-align:start;-ms-flex-align:start;align-items:start;min-width:0;}/*!sc*/ .hKaEJF{display:block;margin-right:8px;}/*!sc*/ @media screen and (min-width:1360px){.hKaEJF{display:block;}}/*!sc*/ .XosP{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;font-size:14px;}/*!sc*/ .bCKfWo[data-size="medium"]{color:var(--fgColor-muted,var(--color-fg-muted,#656d76));padding-left:8px;padding-right:8px;display:none;}/*!sc*/ @media screen and (max-width:768px){.bCKfWo[data-size="medium"]{display:block;}}/*!sc*/ .gUkoLg{-webkit-box-pack:center;-webkit-justify-content:center;-ms-flex-pack:center;justify-content:center;}/*!sc*/ .dmxRgG[data-size="medium"]{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;}/*!sc*/ .dmxRgG[data-size="medium"] svg{color:var(--fgColor-muted,var(--color-fg-muted,#656d76));}/*!sc*/ .dmxRgG[data-size="medium"] > span{width:inherit;}/*!sc*/ .bZBlpz{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;width:100%;}/*!sc*/ .lhTYNA{margin-right:4px;color:var(--fgColor-muted,var(--color-fg-muted,#656d76));}/*!sc*/ .dbrgmi{font-size:14px;min-width:0;max-width:125px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;}/*!sc*/ .dHJiml{-webkit-align-self:center;-ms-flex-item-align:center;align-self:center;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;padding-left:8px;padding-right:8px;min-width:0;}/*!sc*/ .cEytCf{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-direction:row;-ms-flex-direction:row;flex-direction:row;font-size:16px;min-width:0;-webkit-flex-shrink:1;-ms-flex-negative:1;flex-shrink:1;-webkit-flex-wrap:wrap;-ms-flex-wrap:wrap;flex-wrap:wrap;max-width:100%;-webkit-align-items:center;-webkit-box-align:center;-ms-flex-align:center;align-items:center;}/*!sc*/ .fzFXnm{max-width:100%;}/*!sc*/ .iMnkmv{max-width:100%;list-style:none;display:inline-block;}/*!sc*/ .ghzDag{display:inline-block;max-width:100%;}/*!sc*/ .kHuKdh{font-weight:600;}/*!sc*/ .kgiVEz{font-weight:400;}/*!sc*/ .jGhzSQ{font-weight:600;display:inline-block;max-width:100%;font-size:16px;}/*!sc*/ .faNtbn{min-height:32px;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-align-items:start;-webkit-box-align:start;-ms-flex-align:start;align-items:start;}/*!sc*/ .fmQaBv{margin-left:4px;margin-right:4px;}/*!sc*/ .fGwBZA[data-size="medium"][data-no-visuals]{color:var(--fgColor-muted,var(--color-fg-muted,#656d76));}/*!sc*/ .dJxjrT{margin-left:16px;margin-right:16px;}/*!sc*/ .dzCJzi{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-direction:row;-ms-flex-direction:row;flex-direction:row;-webkit-flex-wrap:wrap;-ms-flex-wrap:wrap;flex-wrap:wrap;-webkit-box-pack:justify;-webkit-justify-content:space-between;-ms-flex-pack:justify;justify-content:space-between;-webkit-align-items:center;-webkit-box-align:center;-ms-flex-align:center;align-items:center;gap:8px;min-width:273px;padding:8px;}/*!sc*/ @media screen and (min-width:544px){.dzCJzi{-webkit-flex-wrap:nowrap;-ms-flex-wrap:nowrap;flex-wrap:nowrap;}}/*!sc*/ .bEZNi{top:0px;z-index:0;background:var(--bgColor-default,var(--color-canvas-default));position:-webkit-sticky;position:sticky;}/*!sc*/ .gjQlJX{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-direction:column;-ms-flex-direction:column;flex-direction:column;background-color:var(--bgColor-muted,var(--color-canvas-subtle,#f6f8fa));border-bottom:none;overflow:hidden;}/*!sc*/ .hqwSEx{display:none;min-width:0;padding-top:8px;padding-bottom:8px;}/*!sc*/ .bDVoEr{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-align-items:center;-webkit-box-align:center;-ms-flex-align:center;align-items:center;overflow:hidden;margin-left:8px;margin-right:8px;-webkit-flex-direction:row;-ms-flex-direction:row;flex-direction:row;-webkit-box-pack:justify;-webkit-justify-content:space-between;-ms-flex-pack:justify;justify-content:space-between;width:100%;}/*!sc*/ .kYLlPM{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-align-items:center;-webkit-box-align:center;-ms-flex-align:center;align-items:center;}/*!sc*/ .gYjEmn{margin-left:4px;margin-right:8px;}/*!sc*/ .kGqOLL{text-overflow:ellipsis;overflow:hidden;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;}/*!sc*/ .fHind{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-direction:row;-ms-flex-direction:row;flex-direction:row;font-size:14px;min-width:0;-webkit-flex-shrink:1;-ms-flex-negative:1;flex-shrink:1;-webkit-flex-wrap:wrap;-ms-flex-wrap:wrap;flex-wrap:wrap;max-width:100%;-webkit-align-items:center;-webkit-box-align:center;-ms-flex-align:center;align-items:center;}/*!sc*/ .dnZoUW{font-weight:600;display:inline-block;max-width:100%;font-size:14px;}/*!sc*/ .jRZWlf[data-size="small"]{color:var(--fgColor-default,var(--color-fg-default,#1F2328));margin-left:8px;}/*!sc*/ .vIPPs{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-direction:column;-ms-flex-direction:column;flex-direction:column;gap:16px;}/*!sc*/ .fdROMU{width:100%;border-collapse:separate;border-spacing:0;border:1px solid;border-color:var(--borderColor-default,var(--color-border-default,#d0d7de));border-radius:6px;table-layout:fixed;overflow:unset;}/*!sc*/ .eHDvEW{height:40px;}/*!sc*/ .jdgHnn{padding:16px;color:var(--fgColor-muted,var(--color-fg-muted,#656d76));font-size:12px;text-align:left;height:40px;}/*!sc*/ .jdgHnn th{padding-left:16px;background-color:var(--bgColor-muted,var(--color-canvas-subtle,#f6f8fa));}/*!sc*/ .bQivRW{width:100%;border-top-left-radius:6px;}/*!sc*/ @media screen and (min-width:544px){.bQivRW{display:none;}}/*!sc*/ .ldkMIO{width:40%;border-top-left-radius:6px;}/*!sc*/ @media screen and (max-width:543px){.ldkMIO{display:none;}}/*!sc*/ .jMbWeI{text-align:right;padding-right:16px;width:136px;border-top-right-radius:6px;}/*!sc*/ .cgFZpq{font-size:14px;height:40px;}/*!sc*/ .cgFZpq td{padding-left:16px;text-align:left;border-top-style:solid;border-top-width:1px;border-top-color:var(--borderColor-default,var(--color-border-default,#d0d7de));}/*!sc*/ .cgFZpq:hover{background-color:var(--bgColor-muted,var(--color-canvas-subtle,#f6f8fa));}/*!sc*/ .bzhFQD{font-weight:600;-webkit-text-decoration:none;text-decoration:none;cursor:pointer;}/*!sc*/ .bzhFQD:hover{-webkit-text-decoration:none;text-decoration:none;}/*!sc*/ .bzhFQD:focus:focus-visible div{outline:2px solid var(--focus-outlineColor,var(--color-accent-fg));outline-offset:-2px;}/*!sc*/ .hHzTZW{width:16px;text-align:center;-webkit-letter-spacing:2px;-moz-letter-spacing:2px;-ms-letter-spacing:2px;letter-spacing:2px;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-align-items:center;-webkit-box-align:center;-ms-flex-align:center;align-items:center;}/*!sc*/ .eNCcrz{text-align:center;vertical-align:center;height:40px;border-top:1px solid;border-color:var(--borderColor-default,var(--color-border-default,#d0d7de));}/*!sc*/ .bHTcCe{border-top:1px solid var(--borderColor-default,var(--color-border-default));cursor:pointer;}/*!sc*/ .kkSYPE{min-width:0;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-direction:row;-ms-flex-direction:row;flex-direction:row;-webkit-box-pack:justify;-webkit-justify-content:space-between;-ms-flex-pack:justify;justify-content:space-between;gap:16px;}/*!sc*/ .dYOCLB{border-color:var(--borderColor-default,var(--color-border-default,#d0d7de));border-width:1px;border-style:solid;border-radius:6px;width:100%;}/*!sc*/ .gGRoah{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-flex-direction:row;-ms-flex-direction:row;flex-direction:row;-webkit-align-items:center;-webkit-box-align:center;-ms-flex-align:center;align-items:center;gap:8px;padding-right:8px;padding-left:16px;padding-top:8px;padding-bottom:8px;border-bottom:1px solid;border-color:var(--borderColor-default,var(--color-border-default,#d0d7de));}/*!sc*/ .caoWpU{font-size:14px;-webkit-box-flex:1;-webkit-flex-grow:1;-ms-flex-positive:1;flex-grow:1;}/*!sc*/ .jARUZT{color:var(--fgColor-default,var(--color-fg-default,#1F2328));}/*!sc*/ .jARUZT:hover{color:var(--fgColor-accent,var(--color-accent-fg,#0969da));}/*!sc*/ .gbcGMH[data-size="small"][data-no-visuals]{color:var(--fgColor-muted,var(--color-fg-muted));margin-right:0;}/*!sc*/ .QkQOb{padding:32px;overflow:auto;}/*!sc*/ .cCoXib{position:fixed;top:0;right:0;height:100%;width:15px;-webkit-transition:-webkit-transform 0.3s;-webkit-transition:transform 0.3s;transition:transform 0.3s;z-index:1;}/*!sc*/ .cCoXib:hover{-webkit-transform:scaleX(1.5);-ms-transform:scaleX(1.5);transform:scaleX(1.5);}/*!sc*/ data-styled.g1[id="Box-sc-g0xbh4-0"]{content:"hOfjFo,oDGAe,kowOcT,gISSDQ,cEmWSE,hPvFuC,fFSoPl,bTBnTW,bHLmSv,iKqMNA,FxAyp,leYMvG,KMPzq,hfKjHv,gZWyZE,dwYKDk,ibcGmb,hKaEJF,XosP,bCKfWo,gUkoLg,dmxRgG,bZBlpz,lhTYNA,dbrgmi,dHJiml,cEytCf,fzFXnm,iMnkmv,ghzDag,kHuKdh,kgiVEz,jGhzSQ,faNtbn,fmQaBv,fGwBZA,dJxjrT,dzCJzi,bEZNi,gjQlJX,hqwSEx,bDVoEr,kYLlPM,gYjEmn,kGqOLL,fHind,dnZoUW,jRZWlf,vIPPs,fdROMU,eHDvEW,jdgHnn,bQivRW,ldkMIO,jMbWeI,cgFZpq,bzhFQD,hHzTZW,eNCcrz,bHTcCe,kkSYPE,dYOCLB,gGRoah,caoWpU,jARUZT,gbcGMH,QkQOb,cCoXib,"}/*!sc*/ .eMMFM{min-width:0;}/*!sc*/ .eMMFM:where([data-size='small']){font-size:var(--text-body-size-small,0.75rem);line-height:var(--text-body-lineHeight-small,1.6666);}/*!sc*/ .eMMFM:where([data-size='medium']){font-size:var(--text-body-size-medium,0.875rem);line-height:var(--text-body-lineHeight-medium,1.4285);}/*!sc*/ .eMMFM:where([data-size='large']){font-size:var(--text-body-size-large,1rem);line-height:var(--text-body-lineHeight-large,1.5);}/*!sc*/ .eMMFM:where([data-weight='light']){font-weight:var(--base-text-weight-light,300);}/*!sc*/ .eMMFM:where([data-weight='normal']){font-weight:var(--base-text-weight-normal,400);}/*!sc*/ .eMMFM:where([data-weight='medium']){font-weight:var(--base-text-weight-medium,500);}/*!sc*/ .eMMFM:where([data-weight='semibold']){font-weight:var(--base-text-weight-semibold,600);}/*!sc*/ .HlHVj{padding-left:4px;padding-right:4px;font-weight:400;color:var(--fgColor-muted,var(--color-fg-muted,#656d76));font-size:16px;}/*!sc*/ .HlHVj:where([data-size='small']){font-size:var(--text-body-size-small,0.75rem);line-height:var(--text-body-lineHeight-small,1.6666);}/*!sc*/ .HlHVj:where([data-size='medium']){font-size:var(--text-body-size-medium,0.875rem);line-height:var(--text-body-lineHeight-medium,1.4285);}/*!sc*/ .HlHVj:where([data-size='large']){font-size:var(--text-body-size-large,1rem);line-height:var(--text-body-lineHeight-large,1.5);}/*!sc*/ .HlHVj:where([data-weight='light']){font-weight:var(--base-text-weight-light,300);}/*!sc*/ .HlHVj:where([data-weight='normal']){font-weight:var(--base-text-weight-normal,400);}/*!sc*/ .HlHVj:where([data-weight='medium']){font-weight:var(--base-text-weight-medium,500);}/*!sc*/ .HlHVj:where([data-weight='semibold']){font-weight:var(--base-text-weight-semibold,600);}/*!sc*/ .lauzFl{padding-left:4px;padding-right:4px;font-weight:400;color:var(--fgColor-muted,var(--color-fg-muted,#656d76));font-size:14px;}/*!sc*/ .lauzFl:where([data-size='small']){font-size:var(--text-body-size-small,0.75rem);line-height:var(--text-body-lineHeight-small,1.6666);}/*!sc*/ .lauzFl:where([data-size='medium']){font-size:var(--text-body-size-medium,0.875rem);line-height:var(--text-body-lineHeight-medium,1.4285);}/*!sc*/ .lauzFl:where([data-size='large']){font-size:var(--text-body-size-large,1rem);line-height:var(--text-body-lineHeight-large,1.5);}/*!sc*/ .lauzFl:where([data-weight='light']){font-weight:var(--base-text-weight-light,300);}/*!sc*/ .lauzFl:where([data-weight='normal']){font-weight:var(--base-text-weight-normal,400);}/*!sc*/ .lauzFl:where([data-weight='medium']){font-weight:var(--base-text-weight-medium,500);}/*!sc*/ .lauzFl:where([data-weight='semibold']){font-weight:var(--base-text-weight-semibold,600);}/*!sc*/ data-styled.g3[id="Text__StyledText-sc-17v1xeu-0"]{content:"eMMFM,HlHVj,lauzFl,"}/*!sc*/ .jkNcAv{border:0;font-size:inherit;font-family:inherit;background-color:transparent;-webkit-appearance:none;color:inherit;width:100%;}/*!sc*/ .jkNcAv:focus{outline:0;}/*!sc*/ data-styled.g13[id="UnstyledTextInput__ToggledUnstyledTextInput-sc-14ypya-0"]{content:"jkNcAv,"}/*!sc*/ .hLzFvi{font-size:14px;line-height:var(--base-size-20);color:var(--fgColor-default,var(--color-fg-default,#1F2328));vertical-align:middle;background-color:var(--bgColor-default,var(--color-canvas-default,#ffffff));border:1px solid var(--control-borderColor-rest,var(--borderColor-default,var(--color-border-default,#d0d7de)));border-radius:6px;outline:none;box-shadow:var(--shadow-inset,var(--color-primer-shadow-inset,inset 0 1px 0 rgba(208,215,222,0.2)));display:-webkit-inline-box;display:-webkit-inline-flex;display:-ms-inline-flexbox;display:inline-flex;-webkit-align-items:stretch;-webkit-box-align:stretch;-ms-flex-align:stretch;align-items:stretch;min-height:var(--base-size-32);overflow:hidden;--inner-action-size:var(--base-size-24);}/*!sc*/ .hLzFvi input,.hLzFvi textarea{cursor:text;}/*!sc*/ .hLzFvi select{cursor:pointer;}/*!sc*/ .hLzFvi input::-webkit-input-placeholder,.hLzFvi textarea::-webkit-input-placeholder,.hLzFvi select::-webkit-input-placeholder{color:var(---control-fgColor-placeholder,var(--fgColor-muted,var(--color-fg-muted,#656d76)));}/*!sc*/ .hLzFvi input::-moz-placeholder,.hLzFvi textarea::-moz-placeholder,.hLzFvi select::-moz-placeholder{color:var(---control-fgColor-placeholder,var(--fgColor-muted,var(--color-fg-muted,#656d76)));}/*!sc*/ .hLzFvi input:-ms-input-placeholder,.hLzFvi textarea:-ms-input-placeholder,.hLzFvi select:-ms-input-placeholder{color:var(---control-fgColor-placeholder,var(--fgColor-muted,var(--color-fg-muted,#656d76)));}/*!sc*/ .hLzFvi input::placeholder,.hLzFvi textarea::placeholder,.hLzFvi select::placeholder{color:var(---control-fgColor-placeholder,var(--fgColor-muted,var(--color-fg-muted,#656d76)));}/*!sc*/ .hLzFvi:where([data-trailing-action][data-focused]),.hLzFvi:where(:not([data-trailing-action]):focus-within){border-color:var(--fgColor-accent,var(--color-accent-fg,#0969da));outline:2px solid var(--fgColor-accent,var(--color-accent-fg,#0969da));outline-offset:-1px;}/*!sc*/ .hLzFvi > textarea{padding:var(--base-size-12);}/*!sc*/ .hLzFvi:where([data-contrast]){background-color:var(--bgColor-inset,var(--color-canvas-inset,#f6f8fa));}/*!sc*/ .hLzFvi:where([data-disabled]){color:var(--fgColor-disabled,var(--color-primer-fg-disabled,#8c959f));background-color:var(--control-bgColor-disabled,var(--color-input-disabled-bg,rgba(175,184,193,0.2)));box-shadow:none;border-color:var(--control-borderColor-disabled,var(--borderColor-default,var(--color-border-default,#d0d7de)));}/*!sc*/ .hLzFvi:where([data-disabled]) input,.hLzFvi:where([data-disabled]) textarea,.hLzFvi:where([data-disabled]) select{cursor:not-allowed;}/*!sc*/ .hLzFvi:where([data-monospace]){font-family:var(--fontStack-monospace,SFMono-Regular,Consolas,"Liberation Mono",Menlo,Courier,monospace);}/*!sc*/ .hLzFvi:where([data-validation='error']){border-color:var(--borderColor-danger-emphasis,var(--color-danger-emphasis,#cf222e));}/*!sc*/ .hLzFvi:where([data-validation='error']):where([data-trailing-action][data-focused]),.hLzFvi:where([data-validation='error']):where(:not([data-trailing-action])):focus-within{border-color:var(--fgColor-accent,var(--color-accent-fg,#0969da));outline:2px solid var(--fgColor-accent,var(--color-accent-fg,#0969da));outline-offset:-1px;}/*!sc*/ .hLzFvi:where([data-validation='success']){border-color:var(--bgColor-success-emphasis,var(--color-success-emphasis,#1f883d));}/*!sc*/ .hLzFvi:where([data-block]){width:100%;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;-webkit-align-self:stretch;-ms-flex-item-align:stretch;align-self:stretch;}/*!sc*/ @media (min-width:768px){.hLzFvi{font-size:var(--text-body-size-medium);}}/*!sc*/ .hLzFvi:where([data-size='small']){--inner-action-size:var(--base-size-20);min-height:var(--base-size-28);padding-top:3px;padding-right:var(--base-size-8);padding-bottom:3px;padding-left:var(--base-size-8);font-size:var(--text-body-size-small);line-height:var(--base-size-20);}/*!sc*/ .hLzFvi:where([data-size='large']){--inner-action-size:var(--base-size-28);height:var(--base-size-40);padding-top:10px;padding-right:var(--base-size-8);padding-bottom:10px;padding-left:var(--base-size-8);}/*!sc*/ .hLzFvi:where([data-variant='small']){min-height:28px;padding-top:3px;padding-right:var(--base-size-8);padding-bottom:3px;padding-left:var(--base-size-8);font-size:(--text-body-size-small);line-height:var(--base-size-20);}/*!sc*/ .hLzFvi:where([data-variant='large']){padding-top:10px;padding-right:var(--base-size-8);padding-bottom:10px;padding-left:var(--base-size-8);font-size:var(--text-title-size-medium);}/*!sc*/ .hLzFvi{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;min-width:160px;}/*!sc*/ data-styled.g14[id="TextInputWrapper__StyledTextInputBaseWrapper-sc-1mqhpbi-0"]{content:"hLzFvi,"}/*!sc*/ .iHYdQq{background-repeat:no-repeat;background-position:right 8px center;padding-right:0;padding-left:0;}/*!sc*/ .iHYdQq > :not(:last-child){margin-right:8px;}/*!sc*/ .iHYdQq .TextInput-icon,.iHYdQq .TextInput-action{-webkit-align-self:center;-ms-flex-item-align:center;align-self:center;color:var(--fgColor-muted,var(--color-fg-muted,#656d76));-webkit-flex-shrink:0;-ms-flex-negative:0;flex-shrink:0;}/*!sc*/ .iHYdQq > input,.iHYdQq > select{padding-right:0;padding-left:0;}/*!sc*/ .iHYdQq:where([data-leading-visual]){padding-left:var(--base-size-12);}/*!sc*/ .iHYdQq:where([data-trailing-visual]:not([data-trailing-action])){padding-right:var(--base-size-12);}/*!sc*/ .iHYdQq:where(:not([data-leading-visual])) > input,.iHYdQq:where(:not([data-leading-visual])) > select{padding-left:var(--base-size-12);}/*!sc*/ .iHYdQq:where(:not([data-trailing-visual]):not([data-trailing-action])) > input,.iHYdQq:where(:not([data-trailing-visual]):not([data-trailing-action])) > select{padding-right:var(--base-size-12);}/*!sc*/ .iHYdQq{display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;min-width:160px;}/*!sc*/ data-styled.g15[id="TextInputWrapper__StyledTextInputWrapper-sc-1mqhpbi-1"]{content:"iHYdQq,"}/*!sc*/ .jOyaRH{display:none;}/*!sc*/ .jOyaRH[popover]{position:absolute;padding:0.5em 0.75em;width:-webkit-max-content;width:-moz-max-content;width:max-content;margin:auto;-webkit-clip:auto;clip:auto;white-space:normal;font:normal normal 11px/1.5 -apple-system,BlinkMacSystemFont,"Segoe UI","Noto Sans",Helvetica,Arial,sans-serif,"Apple Color Emoji","Segoe UI Emoji";-webkit-font-smoothing:subpixel-antialiased;color:var(--tooltip-fgColor,var(--fgColor-onEmphasis,var(--color-fg-on-emphasis,#ffffff)));text-align:center;word-wrap:break-word;background:var(--tooltip-bgColor,var(--bgColor-emphasis,var(--color-neutral-emphasis-plus,#24292f)));border-radius:6px;border:0;opacity:0;max-width:250px;inset:auto;overflow:visible;}/*!sc*/ .jOyaRH[popover]:popover-open{display:block;}/*!sc*/ .jOyaRH[popover].\:popover-open{display:block;}/*!sc*/ @media (forced-colors:active){.jOyaRH{outline:1px solid transparent;}}/*!sc*/ .jOyaRH::after{position:absolute;display:block;right:0;left:0;height:var(--overlay-offset,0.25rem);content:'';}/*!sc*/ .jOyaRH[data-direction='n']::after,.jOyaRH[data-direction='ne']::after,.jOyaRH[data-direction='nw']::after{top:100%;}/*!sc*/ .jOyaRH[data-direction='s']::after,.jOyaRH[data-direction='se']::after,.jOyaRH[data-direction='sw']::after{bottom:100%;}/*!sc*/ .jOyaRH[data-direction='w']::after{position:absolute;display:block;height:100%;width:8px;content:'';bottom:0;left:100%;}/*!sc*/ .jOyaRH[data-direction='e']::after{position:absolute;display:block;height:100%;width:8px;content:'';bottom:0;right:100%;margin-left:-8px;}/*!sc*/ @-webkit-keyframes tooltip-appear{from{opacity:0;}to{opacity:1;}}/*!sc*/ @keyframes tooltip-appear{from{opacity:0;}to{opacity:1;}}/*!sc*/ .jOyaRH:popover-open,.jOyaRH:popover-open::before{-webkit-animation-name:tooltip-appear;animation-name:tooltip-appear;-webkit-animation-duration:0.1s;animation-duration:0.1s;-webkit-animation-fill-mode:forwards;animation-fill-mode:forwards;-webkit-animation-timing-function:ease-in;animation-timing-function:ease-in;-webkit-animation-delay:0s;animation-delay:0s;}/*!sc*/ .jOyaRH.\:popover-open,.jOyaRH.\:popover-open::before{-webkit-animation-name:tooltip-appear;animation-name:tooltip-appear;-webkit-animation-duration:0.1s;animation-duration:0.1s;-webkit-animation-fill-mode:forwards;animation-fill-mode:forwards;-webkit-animation-timing-function:ease-in;animation-timing-function:ease-in;-webkit-animation-delay:0s;animation-delay:0s;}/*!sc*/ data-styled.g16[id="Tooltip__StyledTooltip-sc-e45c7z-0"]{content:"jOyaRH,"}/*!sc*/ .hWlpPn{position:relative;display:inline-block;}/*!sc*/ .hWlpPn::after{position:absolute;z-index:1000000;display:none;padding:0.5em 0.75em;font:normal normal 11px/1.5 -apple-system,BlinkMacSystemFont,"Segoe UI","Noto Sans",Helvetica,Arial,sans-serif,"Apple Color Emoji","Segoe UI Emoji";-webkit-font-smoothing:subpixel-antialiased;color:var(--tooltip-fgColor,var(--fgColor-onEmphasis,var(--color-fg-on-emphasis,#ffffff)));text-align:center;-webkit-text-decoration:none;text-decoration:none;text-shadow:none;text-transform:none;-webkit-letter-spacing:normal;-moz-letter-spacing:normal;-ms-letter-spacing:normal;letter-spacing:normal;word-wrap:break-word;white-space:pre;pointer-events:none;content:attr(aria-label);background:var(--tooltip-bgColor,var(--bgColor-emphasis,var(--color-neutral-emphasis-plus,#24292f)));border-radius:6px;opacity:0;}/*!sc*/ @-webkit-keyframes tooltip-appear{from{opacity:0;}to{opacity:1;}}/*!sc*/ @keyframes tooltip-appear{from{opacity:0;}to{opacity:1;}}/*!sc*/ .hWlpPn:hover::after,.hWlpPn:active::after,.hWlpPn:focus::after,.hWlpPn:focus-within::after{display:inline-block;-webkit-text-decoration:none;text-decoration:none;-webkit-animation-name:tooltip-appear;animation-name:tooltip-appear;-webkit-animation-duration:0.1s;animation-duration:0.1s;-webkit-animation-fill-mode:forwards;animation-fill-mode:forwards;-webkit-animation-timing-function:ease-in;animation-timing-function:ease-in;-webkit-animation-delay:0s;animation-delay:0s;}/*!sc*/ .hWlpPn.tooltipped-no-delay:hover::after,.hWlpPn.tooltipped-no-delay:active::after,.hWlpPn.tooltipped-no-delay:focus::after,.hWlpPn.tooltipped-no-delay:focus-within::after{-webkit-animation-delay:0s;animation-delay:0s;}/*!sc*/ .hWlpPn.tooltipped-multiline:hover::after,.hWlpPn.tooltipped-multiline:active::after,.hWlpPn.tooltipped-multiline:focus::after,.hWlpPn.tooltipped-multiline:focus-within::after{display:table-cell;}/*!sc*/ .hWlpPn.tooltipped-s::after,.hWlpPn.tooltipped-se::after,.hWlpPn.tooltipped-sw::after{top:100%;right:50%;margin-top:6px;}/*!sc*/ .hWlpPn.tooltipped-se::after{right:auto;left:50%;margin-left:-16px;}/*!sc*/ .hWlpPn.tooltipped-sw::after{margin-right:-16px;}/*!sc*/ .hWlpPn.tooltipped-n::after,.hWlpPn.tooltipped-ne::after,.hWlpPn.tooltipped-nw::after{right:50%;bottom:100%;margin-bottom:6px;}/*!sc*/ .hWlpPn.tooltipped-ne::after{right:auto;left:50%;margin-left:-16px;}/*!sc*/ .hWlpPn.tooltipped-nw::after{margin-right:-16px;}/*!sc*/ .hWlpPn.tooltipped-s::after,.hWlpPn.tooltipped-n::after{-webkit-transform:translateX(50%);-ms-transform:translateX(50%);transform:translateX(50%);}/*!sc*/ .hWlpPn.tooltipped-w::after{right:100%;bottom:50%;margin-right:6px;-webkit-transform:translateY(50%);-ms-transform:translateY(50%);transform:translateY(50%);}/*!sc*/ .hWlpPn.tooltipped-e::after{bottom:50%;left:100%;margin-left:6px;-webkit-transform:translateY(50%);-ms-transform:translateY(50%);transform:translateY(50%);}/*!sc*/ .hWlpPn.tooltipped-multiline::after{width:-webkit-max-content;width:-moz-max-content;width:max-content;max-width:250px;word-wrap:break-word;white-space:pre-line;border-collapse:separate;}/*!sc*/ .hWlpPn.tooltipped-multiline.tooltipped-s::after,.hWlpPn.tooltipped-multiline.tooltipped-n::after{right:auto;left:50%;-webkit-transform:translateX(-50%);-ms-transform:translateX(-50%);transform:translateX(-50%);}/*!sc*/ .hWlpPn.tooltipped-multiline.tooltipped-w::after,.hWlpPn.tooltipped-multiline.tooltipped-e::after{right:100%;}/*!sc*/ .hWlpPn.tooltipped-align-right-2::after{right:0;margin-right:0;}/*!sc*/ .hWlpPn.tooltipped-align-left-2::after{left:0;margin-left:0;}/*!sc*/ data-styled.g17[id="Tooltip__TooltipBase-sc-17tf59c-0"]{content:"hWlpPn,"}/*!sc*/ .liVpTx{display:inline-block;overflow:hidden;text-overflow:ellipsis;vertical-align:top;white-space:nowrap;max-width:125px;}/*!sc*/ data-styled.g19[id="Truncate__StyledTruncate-sc-23o1d2-0"]{content:"liVpTx,"}/*!sc*/ .jdVJbO{color:var(--treeViewItem-leadingVisual-iconColor-rest,var(--color-icon-directory));margin-right:10px;}/*!sc*/ data-styled.g61[id="Octicon-sc-9kayk9-0"]{content:"jdVJbO,"}/*!sc*/ body[data-page-layout-dragging="true"]{cursor:col-resize;}/*!sc*/ body[data-page-layout-dragging="true"] *{-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;}/*!sc*/ data-styled.g108[id="sc-global-gbKrvU1"]{content:"sc-global-gbKrvU1,"}/*!sc*/ </style><meta data-hydrostats="publish"/> <!-- --> <!-- --> <button hidden="" data-testid="header-permalink-button" data-hotkey-scope="read-only-cursor-text-area"></button><button hidden=""></button><div><div style="--sticky-pane-height:100vh;--spacing:var(--spacing-none)" class="Box-sc-g0xbh4-0 hOfjFo"><div class="Box-sc-g0xbh4-0 oDGAe"><div class="Box-sc-g0xbh4-0 kowOcT"><div tabindex="0" class="Box-sc-g0xbh4-0 gISSDQ"><div class="Box-sc-g0xbh4-0 cEmWSE"><div class="Box-sc-g0xbh4-0 hPvFuC"></div><div style="--pane-width:320px" class="Box-sc-g0xbh4-0 fFSoPl"></div><div class="Box-sc-g0xbh4-0 bTBnTW"><div role="slider" aria-label="Draggable pane splitter" aria-valuemin="0" aria-valuemax="0" aria-valuenow="0" aria-valuetext="Pane width 0 pixels" tabindex="0" class="Box-sc-g0xbh4-0 bHLmSv"></div></div></div></div><div class="Box-sc-g0xbh4-0 iKqMNA"><div class="Box-sc-g0xbh4-0"></div><div class="Box-sc-g0xbh4-0 FxAyp"><div data-selector="repos-split-pane-content" tabindex="0" class="Box-sc-g0xbh4-0 leYMvG"><div class="Box-sc-g0xbh4-0 KMPzq"><div class="Box-sc-g0xbh4-0 hfKjHv container"><div class="px-3 pt-3 pb-0" id="StickyHeader"><div class="Box-sc-g0xbh4-0 gZWyZE"><div class="Box-sc-g0xbh4-0 dwYKDk"><div class="Box-sc-g0xbh4-0 ibcGmb react-code-view-header-wrap--narrow"><div class="Box-sc-g0xbh4-0 hKaEJF"><h2 class="Box-sc-g0xbh4-0 XosP prc-Heading-Heading-6CmGO"><button style="--button-color:fg.muted" type="button" aria-label="Expand file tree" data-testid="expand-file-tree-button-mobile" class="Box-sc-g0xbh4-0 bCKfWo prc-Button-ButtonBase-c50BI" data-loading="false" data-size="medium" data-variant="invisible" aria-describedby=":Rld9lab:-loading-announcement"><span data-component="buttonContent" class="Box-sc-g0xbh4-0 gUkoLg prc-Button-ButtonContent-HKbr-"><span data-component="leadingVisual" class="prc-Button-Visual-2epfX prc-Button-VisualWrap-Db-eB"><svg aria-hidden="true" focusable="false" class="octicon octicon-arrow-left" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M7.78 12.53a.75.75 0 0 1-1.06 0L2.47 8.28a.75.75 0 0 1 0-1.06l4.25-4.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042L4.81 7h7.44a.75.75 0 0 1 0 1.5H4.81l2.97 2.97a.75.75 0 0 1 0 1.06Z"></path></svg></span><span data-component="text" class="prc-Button-Label-pTQ3x">Files</span></span></button><span role="tooltip" aria-label="Expand file tree" id="expand-button-file-tree-button" class="Tooltip__TooltipBase-sc-17tf59c-0 hWlpPn tooltipped-se"><button data-component="IconButton" type="button" data-testid="expand-file-tree-button" aria-controls="repos-file-tree" class="prc-Button-ButtonBase-c50BI position-relative ExpandFileTreeButton-module__expandButton--gL4is ExpandFileTreeButton-module__filesButtonBreakpoint--WfX9t fgColor-muted prc-Button-IconButton-szpyj" data-loading="false" data-no-visuals="true" data-size="medium" data-variant="invisible" aria-describedby=":R35d9lab:-loading-announcement" aria-labelledby="expand-button-file-tree-button"><svg aria-hidden="true" focusable="false" class="octicon octicon-sidebar-collapse" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M6.823 7.823a.25.25 0 0 1 0 .354l-2.396 2.396A.25.25 0 0 1 4 10.396V5.604a.25.25 0 0 1 .427-.177Z"></path><path d="M1.75 0h12.5C15.216 0 16 .784 16 1.75v12.5A1.75 1.75 0 0 1 14.25 16H1.75A1.75 1.75 0 0 1 0 14.25V1.75C0 .784.784 0 1.75 0ZM1.5 1.75v12.5c0 .138.112.25.25.25H9.5v-13H1.75a.25.25 0 0 0-.25.25ZM11 14.5h3.25a.25.25 0 0 0 .25-.25V1.75a.25.25 0 0 0-.25-.25H11Z"></path></svg></button></span><button hidden="" data-testid="" data-hotkey-scope="read-only-cursor-text-area"></button></h2></div><div class="react-code-view-header-mb--narrow mr-2"><button type="button" aria-haspopup="true" aria-expanded="false" tabindex="0" aria-label="main branch" data-testid="anchor-button" class="Box-sc-g0xbh4-0 dmxRgG prc-Button-ButtonBase-c50BI ref-selector-class" data-loading="false" data-size="medium" data-variant="default" aria-describedby="branch-picker-repos-header-ref-selector-wide-loading-announcement" id="branch-picker-repos-header-ref-selector-wide"><span data-component="buttonContent" class="Box-sc-g0xbh4-0 gUkoLg prc-Button-ButtonContent-HKbr-"><span data-component="text" class="prc-Button-Label-pTQ3x"><div class="Box-sc-g0xbh4-0 bZBlpz"><div class="Box-sc-g0xbh4-0 lhTYNA"><svg aria-hidden="true" focusable="false" class="octicon octicon-git-branch" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M9.5 3.25a2.25 2.25 0 1 1 3 2.122V6A2.5 2.5 0 0 1 10 8.5H6a1 1 0 0 0-1 1v1.128a2.251 2.251 0 1 1-1.5 0V5.372a2.25 2.25 0 1 1 1.5 0v1.836A2.493 2.493 0 0 1 6 7h4a1 1 0 0 0 1-1v-.628A2.25 2.25 0 0 1 9.5 3.25Zm-6 0a.75.75 0 1 0 1.5 0 .75.75 0 0 0-1.5 0Zm8.25-.75a.75.75 0 1 0 0 1.5.75.75 0 0 0 0-1.5ZM4.25 12a.75.75 0 1 0 0 1.5.75.75 0 0 0 0-1.5Z"></path></svg></div><div class="Box-sc-g0xbh4-0 dbrgmi ref-selector-button-text-container"><span class="Text__StyledText-sc-17v1xeu-0 eMMFM"> <!-- -->main</span></div></div></span><span data-component="trailingVisual" class="prc-Button-Visual-2epfX prc-Button-VisualWrap-Db-eB"><svg aria-hidden="true" focusable="false" class="octicon octicon-triangle-down" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="m4.427 7.427 3.396 3.396a.25.25 0 0 0 .354 0l3.396-3.396A.25.25 0 0 0 11.396 7H4.604a.25.25 0 0 0-.177.427Z"></path></svg></span></span></button><button hidden="" data-hotkey-scope="read-only-cursor-text-area"></button></div><div class="Box-sc-g0xbh4-0 dHJiml react-code-view-header-mb--narrow"><div class="Box-sc-g0xbh4-0 cEytCf"><nav data-testid="breadcrumbs" aria-labelledby="repos-header-breadcrumb-heading" id="repos-header-breadcrumb" class="Box-sc-g0xbh4-0 fzFXnm"><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading" id="repos-header-breadcrumb-heading">Breadcrumbs</h2><ol class="Box-sc-g0xbh4-0 iMnkmv"><li class="Box-sc-g0xbh4-0 ghzDag"><a class="Box-sc-g0xbh4-0 kHuKdh prc-Link-Link-85e08" sx="[object Object]" data-testid="breadcrumbs-repo-link" href="/jax-ml/jax/tree/main">jax</a></li><li class="Box-sc-g0xbh4-0 ghzDag"><span class="Text__StyledText-sc-17v1xeu-0 HlHVj" aria-hidden="true">/</span><a class="Box-sc-g0xbh4-0 kgiVEz prc-Link-Link-85e08" sx="[object Object]" href="/jax-ml/jax/tree/main/jax">jax</a></li><li class="Box-sc-g0xbh4-0 ghzDag"><span class="Text__StyledText-sc-17v1xeu-0 HlHVj" aria-hidden="true">/</span><a class="Box-sc-g0xbh4-0 kgiVEz prc-Link-Link-85e08" sx="[object Object]" href="/jax-ml/jax/tree/main/jax/experimental">experimental</a></li></ol></nav><div data-testid="breadcrumbs-filename" class="Box-sc-g0xbh4-0 ghzDag"><span class="Text__StyledText-sc-17v1xeu-0 HlHVj" aria-hidden="true">/</span><h1 class="Box-sc-g0xbh4-0 jGhzSQ prc-Heading-Heading-6CmGO" tabindex="-1" id="file-name-id">jax2tf</h1><span class="Text__StyledText-sc-17v1xeu-0 HlHVj" aria-hidden="true">/</span></div><button data-component="IconButton" type="button" class="prc-Button-ButtonBase-c50BI ml-2 prc-Button-IconButton-szpyj" data-loading="false" data-no-visuals="true" data-size="small" data-variant="invisible" aria-describedby=":Rftd9lab:-loading-announcement" aria-labelledby=":R1td9lab:"><svg aria-hidden="true" focusable="false" class="octicon octicon-copy" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path></svg></button><span class="Tooltip__StyledTooltip-sc-e45c7z-0 jOyaRH CopyToClipboardButton-module__tooltip--Dq1IB" data-direction="nw" aria-label="Copy path" aria-hidden="true" id=":R1td9lab:">Copy path</span></div></div></div><div class="react-code-view-header-element--wide"><div class="Box-sc-g0xbh4-0 faNtbn"><div class="d-flex gap-2"><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading">Directory actions</h2><div><div class="Box-sc-g0xbh4-0 fmQaBv"><span class="TextInputWrapper__StyledTextInputBaseWrapper-sc-1mqhpbi-0 hLzFvi TextInputWrapper__StyledTextInputWrapper-sc-1mqhpbi-1 iHYdQq TextInput-wrapper" data-leading-visual="true" data-trailing-visual="true" aria-busy="false"><span class="TextInput-icon" id=":R2lad9lab:" aria-hidden="true"><svg aria-hidden="true" focusable="false" class="octicon octicon-search" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M10.68 11.74a6 6 0 0 1-7.922-8.982 6 6 0 0 1 8.982 7.922l3.04 3.04a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215ZM11.5 7a4.499 4.499 0 1 0-8.997 0A4.499 4.499 0 0 0 11.5 7Z"></path></svg></span><input type="text" aria-label="Go to file" role="combobox" aria-controls="file-results-list" aria-expanded="false" aria-haspopup="dialog" autoCorrect="off" spellcheck="false" placeholder="Go to file" aria-describedby=":R2lad9lab: :R2lad9labH1:" data-component="input" class="UnstyledTextInput__ToggledUnstyledTextInput-sc-14ypya-0 jkNcAv" value=""/><span class="TextInput-icon" id=":R2lad9labH1:" aria-hidden="true"></span></span></div><button hidden="" data-testid="" data-hotkey-scope="read-only-cursor-text-area"></button><button hidden=""></button></div><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading">More options</h2><button data-component="IconButton" type="button" aria-label="More options" title="More options" data-testid="tree-overflow-menu-anchor" aria-haspopup="true" aria-expanded="false" tabindex="0" class="Box-sc-g0xbh4-0 fGwBZA prc-Button-ButtonBase-c50BI prc-Button-IconButton-szpyj" data-loading="false" data-no-visuals="true" data-size="medium" data-variant="default" aria-describedby=":Rlqd9lab:-loading-announcement" id=":Rlqd9lab:"><svg aria-hidden="true" focusable="false" class="octicon octicon-kebab-horizontal" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M8 9a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3ZM1.5 9a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Zm13 0a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Z"></path></svg></button><a class="js-github-dev-shortcut d-none prc-Link-Link-85e08"></a><a class="js-github-dev-new-tab-shortcut d-none prc-Link-Link-85e08" target="_blank"></a></div></div></div><div class="react-code-view-header-element--narrow"><div class="Box-sc-g0xbh4-0 faNtbn"><div class="d-flex gap-2"><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading">Directory actions</h2><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading">More options</h2><button data-component="IconButton" type="button" aria-label="More options" title="More options" data-testid="tree-overflow-menu-anchor" aria-haspopup="true" aria-expanded="false" tabindex="0" class="Box-sc-g0xbh4-0 fGwBZA prc-Button-ButtonBase-c50BI prc-Button-IconButton-szpyj" data-loading="false" data-no-visuals="true" data-size="medium" data-variant="default" aria-describedby=":Rlrd9lab:-loading-announcement" id=":Rlrd9lab:"><svg aria-hidden="true" focusable="false" class="octicon octicon-kebab-horizontal" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M8 9a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3ZM1.5 9a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Zm13 0a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Z"></path></svg></button><a class="js-github-dev-shortcut d-none prc-Link-Link-85e08"></a><a class="js-github-dev-new-tab-shortcut d-none prc-Link-Link-85e08" target="_blank"></a></div></div></div></div></div></div></div></div><div class="Box-sc-g0xbh4-0 dJxjrT react-code-view-bottom-padding"></div><div class="Box-sc-g0xbh4-0 dJxjrT"><div class="d-flex flex-column border rounded-2 mb-3 pl-1"><div class="Box-sc-g0xbh4-0 dzCJzi"><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading">Latest commit</h2><div style="width:120px" class="Skeleton Skeleton--text" data-testid="loading"> </div><div class="d-flex flex-shrink-0 gap-2"><div data-testid="latest-commit-details" class="d-none d-sm-flex flex-items-center"></div><div class="d-flex gap-2"><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading">History</h2><a href="/jax-ml/jax/commits/main/jax/experimental/jax2tf" class="prc-Button-ButtonBase-c50BI d-none d-lg-flex LinkButton-module__code-view-link-button--xvCGA flex-items-center fgColor-default" data-loading="false" data-size="small" data-variant="invisible" aria-describedby=":R5d6l9lab:-loading-announcement"><span data-component="buttonContent" data-align="center" class="prc-Button-ButtonContent-HKbr-"><span data-component="leadingVisual" class="prc-Button-Visual-2epfX prc-Button-VisualWrap-Db-eB"><svg aria-hidden="true" focusable="false" class="octicon octicon-history" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="m.427 1.927 1.215 1.215a8.002 8.002 0 1 1-1.6 5.685.75.75 0 1 1 1.493-.154 6.5 6.5 0 1 0 1.18-4.458l1.358 1.358A.25.25 0 0 1 3.896 6H.25A.25.25 0 0 1 0 5.75V2.104a.25.25 0 0 1 .427-.177ZM7.75 4a.75.75 0 0 1 .75.75v2.992l2.028.812a.75.75 0 0 1-.557 1.392l-2.5-1A.751.751 0 0 1 7 8.25v-3.5A.75.75 0 0 1 7.75 4Z"></path></svg></span><span data-component="text" class="prc-Button-Label-pTQ3x"><span class="fgColor-default">History</span></span></span></a><div class="d-sm-none"></div><div class="d-flex d-lg-none"><span role="tooltip" aria-label="History" id="history-icon-button-tooltip" class="Tooltip__TooltipBase-sc-17tf59c-0 hWlpPn tooltipped-n"><a href="/jax-ml/jax/commits/main/jax/experimental/jax2tf" class="prc-Button-ButtonBase-c50BI LinkButton-module__code-view-link-button--xvCGA flex-items-center fgColor-default" data-loading="false" data-size="small" data-variant="invisible" aria-describedby=":Rpd6l9lab:-loading-announcement history-icon-button-tooltip"><span data-component="buttonContent" data-align="center" class="prc-Button-ButtonContent-HKbr-"><span data-component="leadingVisual" class="prc-Button-Visual-2epfX prc-Button-VisualWrap-Db-eB"><svg aria-hidden="true" focusable="false" class="octicon octicon-history" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="m.427 1.927 1.215 1.215a8.002 8.002 0 1 1-1.6 5.685.75.75 0 1 1 1.493-.154 6.5 6.5 0 1 0 1.18-4.458l1.358 1.358A.25.25 0 0 1 3.896 6H.25A.25.25 0 0 1 0 5.75V2.104a.25.25 0 0 1 .427-.177ZM7.75 4a.75.75 0 0 1 .75.75v2.992l2.028.812a.75.75 0 0 1-.557 1.392l-2.5-1A.751.751 0 0 1 7 8.25v-3.5A.75.75 0 0 1 7.75 4Z"></path></svg></span></span></a></span></div></div></div></div></div><div class="Box-sc-g0xbh4-0 bEZNi react-blob-view-header-sticky"><div class="Box-sc-g0xbh4-0 gjQlJX"><div class="Box-sc-g0xbh4-0 hqwSEx"><div class="Box-sc-g0xbh4-0 bDVoEr"><div class="Box-sc-g0xbh4-0 kYLlPM"><div class="Box-sc-g0xbh4-0 gYjEmn"><button type="button" aria-haspopup="true" aria-expanded="false" tabindex="0" aria-label="main branch" data-testid="anchor-button" class="Box-sc-g0xbh4-0 dmxRgG prc-Button-ButtonBase-c50BI ref-selector-class" data-loading="false" data-size="medium" data-variant="default" aria-describedby="branch-picker-repos-header-ref-selector-loading-announcement" id="branch-picker-repos-header-ref-selector"><span data-component="buttonContent" class="Box-sc-g0xbh4-0 gUkoLg prc-Button-ButtonContent-HKbr-"><span data-component="text" class="prc-Button-Label-pTQ3x"><div class="Box-sc-g0xbh4-0 bZBlpz"><div class="Box-sc-g0xbh4-0 lhTYNA"><svg aria-hidden="true" focusable="false" class="octicon octicon-git-branch" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M9.5 3.25a2.25 2.25 0 1 1 3 2.122V6A2.5 2.5 0 0 1 10 8.5H6a1 1 0 0 0-1 1v1.128a2.251 2.251 0 1 1-1.5 0V5.372a2.25 2.25 0 1 1 1.5 0v1.836A2.493 2.493 0 0 1 6 7h4a1 1 0 0 0 1-1v-.628A2.25 2.25 0 0 1 9.5 3.25Zm-6 0a.75.75 0 1 0 1.5 0 .75.75 0 0 0-1.5 0Zm8.25-.75a.75.75 0 1 0 0 1.5.75.75 0 0 0 0-1.5ZM4.25 12a.75.75 0 1 0 0 1.5.75.75 0 0 0 0-1.5Z"></path></svg></div><div class="Box-sc-g0xbh4-0 dbrgmi ref-selector-button-text-container"><span class="Text__StyledText-sc-17v1xeu-0 eMMFM"> <!-- -->main</span></div></div></span><span data-component="trailingVisual" class="prc-Button-Visual-2epfX prc-Button-VisualWrap-Db-eB"><svg aria-hidden="true" focusable="false" class="octicon octicon-triangle-down" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="m4.427 7.427 3.396 3.396a.25.25 0 0 0 .354 0l3.396-3.396A.25.25 0 0 0 11.396 7H4.604a.25.25 0 0 0-.177.427Z"></path></svg></span></span></button><button hidden="" data-hotkey-scope="read-only-cursor-text-area"></button></div><div class="Box-sc-g0xbh4-0 kGqOLL"><div class="Box-sc-g0xbh4-0 fHind"><nav data-testid="breadcrumbs" aria-labelledby="sticky-breadcrumb-heading" id="sticky-breadcrumb" class="Box-sc-g0xbh4-0 fzFXnm"><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading" id="sticky-breadcrumb-heading">Breadcrumbs</h2><ol class="Box-sc-g0xbh4-0 iMnkmv"><li class="Box-sc-g0xbh4-0 ghzDag"><a class="Box-sc-g0xbh4-0 kHuKdh prc-Link-Link-85e08" sx="[object Object]" data-testid="breadcrumbs-repo-link" href="/jax-ml/jax/tree/main">jax</a></li><li class="Box-sc-g0xbh4-0 ghzDag"><span class="Text__StyledText-sc-17v1xeu-0 lauzFl" aria-hidden="true">/</span><a class="Box-sc-g0xbh4-0 kgiVEz prc-Link-Link-85e08" sx="[object Object]" href="/jax-ml/jax/tree/main/jax">jax</a></li><li class="Box-sc-g0xbh4-0 ghzDag"><span class="Text__StyledText-sc-17v1xeu-0 lauzFl" aria-hidden="true">/</span><a class="Box-sc-g0xbh4-0 kgiVEz prc-Link-Link-85e08" sx="[object Object]" href="/jax-ml/jax/tree/main/jax/experimental">experimental</a></li></ol></nav><div data-testid="breadcrumbs-filename" class="Box-sc-g0xbh4-0 ghzDag"><span class="Text__StyledText-sc-17v1xeu-0 lauzFl" aria-hidden="true">/</span><h1 class="Box-sc-g0xbh4-0 dnZoUW prc-Heading-Heading-6CmGO" tabindex="-1" id="sticky-file-name-id">jax2tf</h1><span class="Text__StyledText-sc-17v1xeu-0 HlHVj" aria-hidden="true">/</span></div></div></div></div><button style="--button-color:fg.default" type="button" class="Box-sc-g0xbh4-0 jRZWlf prc-Button-ButtonBase-c50BI" data-loading="false" data-size="small" data-variant="invisible" aria-describedby=":R2el9lab:-loading-announcement"><span data-component="buttonContent" class="Box-sc-g0xbh4-0 gUkoLg prc-Button-ButtonContent-HKbr-"><span data-component="leadingVisual" class="prc-Button-Visual-2epfX prc-Button-VisualWrap-Db-eB"><svg aria-hidden="true" focusable="false" class="octicon octicon-arrow-up" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M3.47 7.78a.75.75 0 0 1 0-1.06l4.25-4.25a.75.75 0 0 1 1.06 0l4.25 4.25a.751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018L9 4.81v7.44a.75.75 0 0 1-1.5 0V4.81L4.53 7.78a.75.75 0 0 1-1.06 0Z"></path></svg></span><span data-component="text" class="prc-Button-Label-pTQ3x">Top</span></span></button></div></div></div></div><div class="Box-sc-g0xbh4-0 vIPPs"><div data-hpc="true"><button hidden="" data-testid="focus-next-element-button" data-hotkey="j"></button><button hidden="" data-testid="focus-previous-element-button" data-hotkey="k"></button><h2 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading" id="folders-and-files">Folders and files</h2><table aria-labelledby="folders-and-files" class="Box-sc-g0xbh4-0 fdROMU"><thead class="Box-sc-g0xbh4-0 eHDvEW"><tr class="Box-sc-g0xbh4-0 jdgHnn"><th colSpan="2" class="Box-sc-g0xbh4-0 bQivRW"><span class="text-bold">Name</span></th><th colSpan="1" class="Box-sc-g0xbh4-0 ldkMIO"><span class="text-bold">Name</span></th><th class="hide-sm"><div title="Last commit message" class="Truncate__StyledTruncate-sc-23o1d2-0 liVpTx width-fit"><span class="text-bold">Last commit message</span></div></th><th colSpan="1" class="Box-sc-g0xbh4-0 jMbWeI"><div title="Last commit date" class="Truncate__StyledTruncate-sc-23o1d2-0 liVpTx width-fit"><span class="text-bold">Last commit date</span></div></th></tr></thead><tbody><tr id="folder-row-0" class="Box-sc-g0xbh4-0 cgFZpq"><td colSpan="3" class="f5 text-normal px-3"><h3 class="sr-only ScreenReaderHeading-module__userSelectNone--vW4Cq prc-Heading-Heading-6CmGO" data-testid="screen-reader-heading">parent directory</h3><a class="Box-sc-g0xbh4-0 bzhFQD prc-Link-Link-85e08" data-muted="true" aria-label="Parent directory" data-react-autofocus="true" data-testid="up-tree" muted="" rel="nofollow" sx="[object Object]" href="/jax-ml/jax/tree/main/jax/experimental"><div class="Box-sc-g0xbh4-0 hHzTZW width-full"><svg aria-hidden="true" focusable="false" class="Octicon-sc-9kayk9-0 jdVJbO" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M1.75 1A1.75 1.75 0 0 0 0 2.75v10.5C0 14.216.784 15 1.75 15h12.5A1.75 1.75 0 0 0 16 13.25v-8.5A1.75 1.75 0 0 0 14.25 3H7.5a.25.25 0 0 1-.2-.1l-.9-1.2C6.07 1.26 5.55 1 5 1H1.75Z"></path></svg>..</div></a></td></tr><tr class="react-directory-row undefined" id="folder-row-1"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="icon-directory" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M1.75 1A1.75 1.75 0 0 0 0 2.75v10.5C0 14.216.784 15 1.75 15h12.5A1.75 1.75 0 0 0 16 13.25v-8.5A1.75 1.75 0 0 0 14.25 3H7.5a.25.25 0 0 1-.2-.1l-.9-1.2C6.07 1.26 5.55 1 5 1H1.75Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="examples" aria-label="examples, (Directory)" class="Link--primary" href="/jax-ml/jax/tree/main/jax/experimental/jax2tf/examples">examples</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="icon-directory" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M1.75 1A1.75 1.75 0 0 0 0 2.75v10.5C0 14.216.784 15 1.75 15h12.5A1.75 1.75 0 0 0 16 13.25v-8.5A1.75 1.75 0 0 0 14.25 3H7.5a.25.25 0 0 1-.2-.1l-.9-1.2C6.07 1.26 5.55 1 5 1H1.75Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="examples" aria-label="examples, (Directory)" class="Link--primary" href="/jax-ml/jax/tree/main/jax/experimental/jax2tf/examples">examples</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-2"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="icon-directory" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M1.75 1A1.75 1.75 0 0 0 0 2.75v10.5C0 14.216.784 15 1.75 15h12.5A1.75 1.75 0 0 0 16 13.25v-8.5A1.75 1.75 0 0 0 14.25 3H7.5a.25.25 0 0 1-.2-.1l-.9-1.2C6.07 1.26 5.55 1 5 1H1.75Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="g3doc" aria-label="g3doc, (Directory)" class="Link--primary" href="/jax-ml/jax/tree/main/jax/experimental/jax2tf/g3doc">g3doc</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="icon-directory" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M1.75 1A1.75 1.75 0 0 0 0 2.75v10.5C0 14.216.784 15 1.75 15h12.5A1.75 1.75 0 0 0 16 13.25v-8.5A1.75 1.75 0 0 0 14.25 3H7.5a.25.25 0 0 1-.2-.1l-.9-1.2C6.07 1.26 5.55 1 5 1H1.75Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="g3doc" aria-label="g3doc, (Directory)" class="Link--primary" href="/jax-ml/jax/tree/main/jax/experimental/jax2tf/g3doc">g3doc</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-3"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="icon-directory" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M1.75 1A1.75 1.75 0 0 0 0 2.75v10.5C0 14.216.784 15 1.75 15h12.5A1.75 1.75 0 0 0 16 13.25v-8.5A1.75 1.75 0 0 0 14.25 3H7.5a.25.25 0 0 1-.2-.1l-.9-1.2C6.07 1.26 5.55 1 5 1H1.75Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="tests" aria-label="tests, (Directory)" class="Link--primary" href="/jax-ml/jax/tree/main/jax/experimental/jax2tf/tests">tests</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="icon-directory" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M1.75 1A1.75 1.75 0 0 0 0 2.75v10.5C0 14.216.784 15 1.75 15h12.5A1.75 1.75 0 0 0 16 13.25v-8.5A1.75 1.75 0 0 0 14.25 3H7.5a.25.25 0 0 1-.2-.1l-.9-1.2C6.07 1.26 5.55 1 5 1H1.75Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="tests" aria-label="tests, (Directory)" class="Link--primary" href="/jax-ml/jax/tree/main/jax/experimental/jax2tf/tests">tests</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-4"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="BUILD" aria-label="BUILD, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/BUILD">BUILD</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="BUILD" aria-label="BUILD, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/BUILD">BUILD</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-5"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="JAX2TF_getting_started.ipynb" aria-label="JAX2TF_getting_started.ipynb, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb">JAX2TF_getting_started.ipynb</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="JAX2TF_getting_started.ipynb" aria-label="JAX2TF_getting_started.ipynb, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb">JAX2TF_getting_started.ipynb</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-6"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="README.md" aria-label="README.md, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md">README.md</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="README.md" aria-label="README.md, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md">README.md</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-7"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="__init__.py" aria-label="__init__.py, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/__init__.py">__init__.py</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="__init__.py" aria-label="__init__.py, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/__init__.py">__init__.py</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-8"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="call_tf.py" aria-label="call_tf.py, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/call_tf.py">call_tf.py</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="call_tf.py" aria-label="call_tf.py, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/call_tf.py">call_tf.py</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-9"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="impl_no_xla.py" aria-label="impl_no_xla.py, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/impl_no_xla.py">impl_no_xla.py</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="impl_no_xla.py" aria-label="impl_no_xla.py, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/impl_no_xla.py">impl_no_xla.py</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="react-directory-row undefined" id="folder-row-10"><td class="react-directory-row-name-cell-small-screen" colSpan="2"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="jax2tf.py" aria-label="jax2tf.py, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/jax2tf.py">jax2tf.py</a></div></div></div></div></td><td class="react-directory-row-name-cell-large-screen" colSpan="1"><div class="react-directory-filename-column"><svg aria-hidden="true" focusable="false" class="color-fg-muted" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M2 1.75C2 .784 2.784 0 3.75 0h6.586c.464 0 .909.184 1.237.513l2.914 2.914c.329.328.513.773.513 1.237v9.586A1.75 1.75 0 0 1 13.25 16h-9.5A1.75 1.75 0 0 1 2 14.25Zm1.75-.25a.25.25 0 0 0-.25.25v12.5c0 .138.112.25.25.25h9.5a.25.25 0 0 0 .25-.25V6h-2.75A1.75 1.75 0 0 1 9 4.25V1.5Zm6.75.062V4.25c0 .138.112.25.25.25h2.688l-.011-.013-2.914-2.914-.013-.011Z"></path></svg><div class="overflow-hidden"><div class="react-directory-filename-cell"><div class="react-directory-truncate"><a title="jax2tf.py" aria-label="jax2tf.py, (File)" class="Link--primary" href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/jax2tf.py">jax2tf.py</a></div></div></div></div></td><td class="react-directory-row-commit-cell"><div class="Skeleton Skeleton--text"> </div></td><td><div class="Skeleton Skeleton--text"> </div></td></tr><tr class="Box-sc-g0xbh4-0 eNCcrz d-none" data-testid="view-all-files-row"><td colSpan="3" class="Box-sc-g0xbh4-0 bHTcCe"><div><button class="prc-Link-Link-85e08">View all files</button></div></td></tr></tbody></table></div><div class="Box-sc-g0xbh4-0 kkSYPE"><div id="readme" class="Box-sc-g0xbh4-0 dYOCLB"><div class="Box-sc-g0xbh4-0 gGRoah"><h2 class="Box-sc-g0xbh4-0 caoWpU prc-Heading-Heading-6CmGO"><a class="Box-sc-g0xbh4-0 jARUZT prc-Link-Link-85e08" href="#readme">README.md</a></h2><button data-component="IconButton" type="button" aria-label="Outline" aria-pressed="false" class="Box-sc-g0xbh4-0 gbcGMH prc-Button-ButtonBase-c50BI prc-Button-IconButton-szpyj" data-loading="false" data-no-visuals="true" data-size="small" data-variant="invisible" aria-describedby=":R6mil9lab:-loading-announcement"><svg aria-hidden="true" focusable="false" class="octicon octicon-list-unordered" viewBox="0 0 16 16" width="16" height="16" fill="currentColor" style="display:inline-block;user-select:none;vertical-align:text-bottom;overflow:visible"><path d="M5.75 2.5h8.5a.75.75 0 0 1 0 1.5h-8.5a.75.75 0 0 1 0-1.5Zm0 5h8.5a.75.75 0 0 1 0 1.5h-8.5a.75.75 0 0 1 0-1.5Zm0 5h8.5a.75.75 0 0 1 0 1.5h-8.5a.75.75 0 0 1 0-1.5ZM2 14a1 1 0 1 1 0-2 1 1 0 0 1 0 2Zm1-6a1 1 0 1 1-2 0 1 1 0 0 1 2 0ZM2 4a1 1 0 1 1 0-2 1 1 0 0 1 0 2Z"></path></svg></button></div><div class="Box-sc-g0xbh4-0 QkQOb js-snippet-clipboard-copy-unpositioned undefined" data-hpc="true"><article class="markdown-body entry-content container-lg" itemprop="text"><div class="markdown-heading" dir="auto"><h1 tabindex="-1" class="heading-element" dir="auto">JAX and TensorFlow interoperation (jax2tf/call_tf)</h1><a id="user-content-jax-and-tensorflow-interoperation-jax2tfcall_tf" class="anchor" aria-label="Permalink: JAX and TensorFlow interoperation (jax2tf/call_tf)" href="#jax-and-tensorflow-interoperation-jax2tfcall_tf"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">This package provides support for JAX native serialization and for interoperation between JAX and TensorFlow. There are two interoperation directions:</p> <ul dir="auto"> <li><code>jax2tf.convert</code>: for calling JAX functions in a TensorFlow context, e.g., for eager or graph TensorFlow execution, or for serializing as a TensorFlow SavedModel; and</li> <li><code>jax2tf.call_tf</code>: for calling TensorFlow functions in a JAX context, e.g., to call a TensorFlow library or to reload a TensorFlow SavedModel and call its functions in JAX.</li> </ul> <p dir="auto">These APIs can be combined, e.g., to reload in JAX a program that has been serialized from JAX to a TensorFlow SavedModel, or to save to TensorFlow SavedModel a JAX program that uses a TensorFlow library.</p> <p dir="auto">Tip: As of version 0.4.14 (July 2023) the default mode of JAX-TensorFlow interoperation is by way of <strong>native serialization</strong> in which the target function is lowered to StableHLO using standard native JAX or TensorFlow APIs, and then the StableHLO module is invoked from the other framework. The native serialization mode has several advantages:</p> <ul dir="auto"> <li>supports virtually all operations supported by native execution, e.g., <code>shard_map</code>, <code>pmap</code>, parallel collective operations, and all primitives at all data types.</li> <li>uses standard native JAX code paths for lowering, and thus it is easier to trust that the semantics and performance stays faithful to the native semantics, across platforms.</li> <li>the metadata associated with the operations, e.g., source location, is identical to what native execution uses.</li> <li>includes safety checking that the serialized code is executed on the platform for which it was serialized.</li> </ul> <p dir="auto">At the moment when using JAX native serialization the whole JAX compilation unit is wrapped with a single thin TensorFlow op, called <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/ops/xla_ops.cc#L1318"><code>XlaCallModule</code></a>, that carries the serialized version of the StableHLO obtained from JAX. This op is supported only on TensorFlow platforms that include the XLA compiler, and it compiles and then invokes the embedded StableHLO. The reasons we wrap the StableHLO in a TensorFlow op are:</p> <ul dir="auto"> <li>it allows saving the serialization in a tf.SavedModel, for use with multiple mature tools for TensorFlow,</li> <li>it allows composing the JAX program with TensorFlow pre-processing, post-processing, and host callback functions,</li> <li>the <code>XlaCallModule</code> contains the code that must be executed to deserialize, compile, and execute the JAX program, e.g., to handle properly backward compatibility and to do the just-in-time preprocessing needed for shape polymorphism.</li> <li>the semantics of JAX program is still preserved faithfully because it is entirely captured by the StableHLO serialization.</li> </ul> <p dir="auto">For backwards compatibility purposes, and for special uses, the JAX-TensorFlow interoperation APIs can be used also in a <strong>graph serialization</strong> mode (the only mode available before version 0.4.7, and the default mode before JAX version 0.4.15), without going through StableHLO. (Starting with JAX version 0.4.31 the graph serialization mode is deprecated. It will be removed in the near future).</p> <ul dir="auto"> <li> <p dir="auto">For calling JAX functions from TensorFlow, it is possible to request that the JAX function be lowered with one TensorFlow op for each JAX primitive. This can be achieved by setting <code>native_serialization=False</code>. This enables the following:</p> <ul dir="auto"> <li>TensorFlow eager mode execution, e.g., for debugging,</li> <li>producing a <code>tf.Graph</code> for consumption by tooling that understands TensorFlow ops but does not yet work with StableHLO, e.g., TFLite and TensorFlow.js.</li> <li>using the more mature support for dynamic shapes in TensorFlow. <a href="https://github.com/openxla/stablehlo/blob/main/rfcs/20230704-dynamism-101.md">StableHLO does have support for dynamic shapes</a>, and in the near future we expect it will support shape polymorphism to the same extent as graph serialization.</li> </ul> <p dir="auto">Even in the graph serialization mode the resulting TensorFlow graph is pretty much 1:1 with the StableHLO module that would be obtained through native serialization.</p> </li> <li> <p dir="auto">For calling TensorFlow functions from JAX, if the resulting JAX program is executed in op-by-op mode (i.e., not under <code>jax.jit</code> or <code>jax.pmap</code> and not inside <code>lax.cond</code> or <code>lax.scan</code>) then the target TensorFlow function is executed in eager mode. This can be useful if the target TensorFlow function is not lowerable to HLO, e.g., is using strings.</p> </li> </ul> <p dir="auto">To disable native serialization, you can do the following, in decreasing priority order:</p> <ul dir="auto"> <li>set <code>native_serialization=False</code>, or</li> <li>use the configuration flag <code>--jax2tf_default_native_serialization=false</code>, or</li> <li>use the environment variable <code>JAX2TF_DEFAULT_NATIVE_SERIALIZATION=false</code>.</li> </ul> <p dir="auto">We describe below some general concepts and capabilities, first for <code>jax2tf.convert</code> and <a href="#calling-tensorflow-functions-from-jax">later</a> for <code>jax2tf.call_tf</code>. For more involved examples, please see examples involving:</p> <ul dir="auto"> <li>SavedModel for archival (<a href="#usage-saved-model">examples below</a>), including saving <a href="#shape-polymorphic-conversion">batch-polymorphic functions</a>,</li> <li>TensorFlow.js (<a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md">examples</a>),</li> <li>TFX (<a href="https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/README.md#instructions-for-using-flax">examples</a>),</li> <li>TensorFlow Hub and Keras (<a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md">examples</a>).</li> </ul> <p dir="auto">[TOC]</p> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Usage: basic functions.</h2><a id="user-content-usage-basic-functions" class="anchor" aria-label="Permalink: Usage: basic functions." href="#usage-basic-functions"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">As a rule of thumb, if you can <code>jax.jit</code> your function then you should be able to use <code>jax2tf.convert</code>:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="from jax.experimental import jax2tf from jax import numpy as jnp import numpy as np import tensorflow as tf def f_jax(x): return jnp.sin(jnp.cos(x)) # jax2tf.convert is a higher-order function that returns a wrapped function with # the same signature as your input function but accepting TensorFlow tensors (or # variables) as input. f_tf = jax2tf.convert(f_jax) # For example you execute f_tf eagerly with valid TensorFlow inputs: f_tf(np.random.random(...)) # Additionally you can use tools like `tf.function` to improve the execution # time of your function, or to stage it out to a SavedModel: f_tf_graph = tf.function(f_tf, autograph=False)"><pre><span class="pl-k">from</span> <span class="pl-s1">jax</span>.<span class="pl-s1">experimental</span> <span class="pl-k">import</span> <span class="pl-s1">jax2tf</span> <span class="pl-k">from</span> <span class="pl-s1">jax</span> <span class="pl-k">import</span> <span class="pl-s1">numpy</span> <span class="pl-k">as</span> <span class="pl-s1">jnp</span> <span class="pl-k">import</span> <span class="pl-s1">numpy</span> <span class="pl-k">as</span> <span class="pl-s1">np</span> <span class="pl-k">import</span> <span class="pl-s1">tensorflow</span> <span class="pl-k">as</span> <span class="pl-s1">tf</span> <span class="pl-k">def</span> <span class="pl-en">f_jax</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>(<span class="pl-s1">jnp</span>.<span class="pl-c1">cos</span>(<span class="pl-s1">x</span>)) <span class="pl-c"># jax2tf.convert is a higher-order function that returns a wrapped function with</span> <span class="pl-c"># the same signature as your input function but accepting TensorFlow tensors (or</span> <span class="pl-c"># variables) as input.</span> <span class="pl-s1">f_tf</span> <span class="pl-c1">=</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>) <span class="pl-c"># For example you execute f_tf eagerly with valid TensorFlow inputs:</span> <span class="pl-en">f_tf</span>(<span class="pl-s1">np</span>.<span class="pl-c1">random</span>.<span class="pl-c1">random</span>(...)) <span class="pl-c"># Additionally you can use tools like `tf.function` to improve the execution</span> <span class="pl-c"># time of your function, or to stage it out to a SavedModel:</span> <span class="pl-s1">f_tf_graph</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">f_tf</span>, <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>)</pre></div> <p dir="auto">Note that when using the default native serialization, the target JAX function must be jittable (see <a href="https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html" rel="nofollow">JAX - The Sharp Bits</a>). In the native serialization mode, under TensorFlow eager the whole JAX function executes as one op.</p> <p dir="auto">The Autograph feature of <code>tf.function</code> cannot be expected to work on functions lowered from JAX as above, so it is recommended to set <code>autograph=False</code> in order to speed up the execution and to avoid warnings and outright errors.</p> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Usage: saved model</h2><a id="user-content-usage-saved-model" class="anchor" aria-label="Permalink: Usage: saved model" href="#usage-saved-model"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">You can serialize JAX program into a TensorFlow SavedModel, for use with tooling that understands SavedModel. Both in native and non-native serialization you can count on 6 months of backwards compatibility (you can load a function serialized today with tooling that will be built up to 6 months in the future), and 3 weeks of limited forwards compatibility (you can load a function serialized today with tooling that was built up to 3 weeks in the past, provided the model that not use any new features).</p> <p dir="auto">Since jax2tf provides a regular TensorFlow function using it with SavedModel is trivial:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="# You can save the model just like you would with any other TensorFlow function: my_model = tf.Module() # Save a function that can take scalar inputs. my_model.f = tf.function(jax2tf.convert(f_jax), autograph=False, input_signature=[tf.TensorSpec([], tf.float32)]) tf.saved_model.save(my_model, '/some/directory', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True)) # Restoring (note: the restored model does *not* require JAX to run, just XLA). restored_model = tf.saved_model.load('/some/directory')"><pre><span class="pl-c"># You can save the model just like you would with any other TensorFlow function:</span> <span class="pl-s1">my_model</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">Module</span>() <span class="pl-c"># Save a function that can take scalar inputs.</span> <span class="pl-s1">my_model</span>.<span class="pl-c1">f</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>, <span class="pl-s1">input_signature</span><span class="pl-c1">=</span>[<span class="pl-s1">tf</span>.<span class="pl-c1">TensorSpec</span>([], <span class="pl-s1">tf</span>.<span class="pl-c1">float32</span>)]) <span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">save</span>(<span class="pl-s1">my_model</span>, <span class="pl-s">'/some/directory'</span>, <span class="pl-s1">options</span><span class="pl-c1">=</span><span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">SaveOptions</span>(<span class="pl-s1">experimental_custom_gradients</span><span class="pl-c1">=</span><span class="pl-c1">True</span>)) <span class="pl-c"># Restoring (note: the restored model does *not* require JAX to run, just XLA).</span> <span class="pl-s1">restored_model</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">load</span>(<span class="pl-s">'/some/directory'</span>)</pre></div> <p dir="auto">An important point is that in the above code snippet <strong>everything after the jax2tf invocation is standard TensorFlow code. In particular, the saving of the model is not directly part of the jax2tf API, and the user has full control over how to create the SavedModel</strong>.</p> <p dir="auto">For example, just like for regular TensorFlow functions, it is possible to include in the SavedModel multiple versions of a function for different input shapes, by "warming up" the function on different input shapes:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="my_model.f = tf.function(jax2tf.convert(f_jax), autograph=False) my_model.f(tf.ones([1, 28, 28])) # a batch size of 1 my_model.f(tf.ones([16, 28, 28])) # a batch size of 16 tf.saved_model.save(my_model, '/some/directory', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))"><pre><span class="pl-s1">my_model</span>.<span class="pl-c1">f</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>) <span class="pl-s1">my_model</span>.<span class="pl-c1">f</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">ones</span>([<span class="pl-c1">1</span>, <span class="pl-c1">28</span>, <span class="pl-c1">28</span>])) <span class="pl-c"># a batch size of 1</span> <span class="pl-s1">my_model</span>.<span class="pl-c1">f</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">ones</span>([<span class="pl-c1">16</span>, <span class="pl-c1">28</span>, <span class="pl-c1">28</span>])) <span class="pl-c"># a batch size of 16</span> <span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">save</span>(<span class="pl-s1">my_model</span>, <span class="pl-s">'/some/directory'</span>, <span class="pl-s1">options</span><span class="pl-c1">=</span><span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">SaveOptions</span>(<span class="pl-s1">experimental_custom_gradients</span><span class="pl-c1">=</span><span class="pl-c1">True</span>))</pre></div> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Saved model with parameters</h3><a id="user-content-saved-model-with-parameters" class="anchor" aria-label="Permalink: Saved model with parameters" href="#saved-model-with-parameters"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Some special care is needed to ensure that the model parameters are not embedded as constants in the graph and are instead saved separately as variables. This is useful for two reasons: the parameters could be very large and exceed the 2GB limits of the GraphDef part of the SavedModel, or you may want to fine-tune the model and change the value of the parameters.</p> <p dir="auto">For example, consider the following function:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def model_jax(inputs): return param0 + param1 * inputs"><pre><span class="pl-k">def</span> <span class="pl-en">model_jax</span>(<span class="pl-s1">inputs</span>): <span class="pl-k">return</span> <span class="pl-s1">param0</span> <span class="pl-c1">+</span> <span class="pl-s1">param1</span> <span class="pl-c1">*</span> <span class="pl-s1">inputs</span></pre></div> <p dir="auto">If you just lower and save the model directly, the values of <code>param0</code> and <code>param1</code> will be embedded in the computation graph. In fact, the value of <code>param1</code> is needed for the gradient computation and will be embedded twice: once in the computation graph for the forward computation and once for the backward computation, unless you turn off the staging of gradients or their saving as discussed further below (e.g., <code>with_gradient=False</code>). Note also that if one views the above function as an ML model parameterized by <code>param0</code> and <code>param1</code> then the gradient function will be w.r.t. the inputs, while you probably want gradients w.r.t. the parameters.</p> <p dir="auto">A better way to deal with parameters (or any large constants) is to pass them as parameters to the function to be lowered:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def model_jax(params, inputs): return params[0] + params[1] * inputs # Wrap the parameter constants as tf.Variables; this will signal to the model # saving code to save those constants as variables, separate from the # computation graph. params_vars = tf.nest.map_structure(tf.Variable, params) # Build the prediction function by closing over the `params_vars`. If you # instead were to close over `params` your SavedModel would have no variables # and the parameters will be included in the function graph. prediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs) my_model = tf.Module() # Tell the model saver what the variables are. my_model._variables = tf.nest.flatten(params_vars) my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False) tf.saved_model.save(my_model)"><pre><span class="pl-k">def</span> <span class="pl-en">model_jax</span>(<span class="pl-s1">params</span>, <span class="pl-s1">inputs</span>): <span class="pl-k">return</span> <span class="pl-s1">params</span>[<span class="pl-c1">0</span>] <span class="pl-c1">+</span> <span class="pl-s1">params</span>[<span class="pl-c1">1</span>] <span class="pl-c1">*</span> <span class="pl-s1">inputs</span> <span class="pl-c"># Wrap the parameter constants as tf.Variables; this will signal to the model</span> <span class="pl-c"># saving code to save those constants as variables, separate from the</span> <span class="pl-c"># computation graph.</span> <span class="pl-s1">params_vars</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">nest</span>.<span class="pl-c1">map_structure</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>, <span class="pl-s1">params</span>) <span class="pl-c"># Build the prediction function by closing over the `params_vars`. If you</span> <span class="pl-c"># instead were to close over `params` your SavedModel would have no variables</span> <span class="pl-c"># and the parameters will be included in the function graph.</span> <span class="pl-s1">prediction_tf</span> <span class="pl-c1">=</span> <span class="pl-k">lambda</span> <span class="pl-s1">inputs</span>: <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">model_jax</span>)(<span class="pl-s1">params_vars</span>, <span class="pl-s1">inputs</span>) <span class="pl-s1">my_model</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">Module</span>() <span class="pl-c"># Tell the model saver what the variables are.</span> <span class="pl-s1">my_model</span>.<span class="pl-c1">_variables</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">nest</span>.<span class="pl-c1">flatten</span>(<span class="pl-s1">params_vars</span>) <span class="pl-s1">my_model</span>.<span class="pl-c1">f</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">prediction_tf</span>, <span class="pl-s1">jit_compile</span><span class="pl-c1">=</span><span class="pl-c1">True</span>, <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>) <span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">save</span>(<span class="pl-s1">my_model</span>)</pre></div> <p dir="auto">This strategy will avoid any copies of the large parameters in the computation graph (they will be saved in a <code>variables</code> area of the model, which is not subject to the 2GB limitation).</p> <p dir="auto">For examples of how to save a Flax model as a SavedModel see the <a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md">examples directory</a>.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Saved model and differentiation</h3><a id="user-content-saved-model-and-differentiation" class="anchor" aria-label="Permalink: Saved model and differentiation" href="#saved-model-and-differentiation"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">The code lowered from JAX supports differentiation from TensorFlow. In order to ensure that the result of TensorFlow differentiation is identical to the one that JAX differentiation would produce, we will annotate the lowered primal function with a <code>tf.custom_gradient</code> that, upon TensorFlow differentiation, will lazily call into JAX to compute the <code>jax.vjp</code> of the lowered primal function, followed by jax2tf lowering of the gradient function. This ensures that ultimately it is JAX that performs the differentiation, thus respecting any custom gradients that may be present in the original function.</p> <p dir="auto">The <code>jax2tf.convert</code> function has an option <code>with_gradient=False</code> to skip the custom gradients and wrap instead the lowered function with <code>tf.raw_ops.PreventGradient</code> to generate an error in case a gradient computation is attempted.</p> <p dir="auto">SavedModels enables saving custom derivative rules by using the <code>experimental_custom_gradients</code> option:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="options = tf.saved_model.SaveOptions(experimental_custom_gradients=True) tf.saved_model.save(model, path, options=options)"><pre><span class="pl-s1">options</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">SaveOptions</span>(<span class="pl-s1">experimental_custom_gradients</span><span class="pl-c1">=</span><span class="pl-c1">True</span>) <span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">save</span>(<span class="pl-s1">model</span>, <span class="pl-s1">path</span>, <span class="pl-s1">options</span><span class="pl-c1">=</span><span class="pl-s1">options</span>)</pre></div> <p dir="auto">If you use <code>with_gradient=True</code> and forget to use the <code>experimental_custom_gradients=True</code> parameter to <code>tf.saved_model.save</code> when you later load the saved model you will see a warning:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="WARNING:absl:Importing a function (__inference_converted_fun_25) with ops with unsaved custom gradients. Will likely fail if a gradient is requested."><pre class="notranslate"><code>WARNING:absl:Importing a function (__inference_converted_fun_25) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. </code></pre></div> <p dir="auto">and if you do attempt to take a gradient of the loaded model you may get an error:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="TypeError: An op outside of the function building code is being passed a "Graph" tensor. It is possible to have Graph tensors leak out of the function building context by including a tf.init_scope in your function building code. For example, the following function will fail: @tf.function def has_init_scope(): my_constant = tf.constant(1.) with tf.init_scope(): added = my_constant * 2 The graph tensor has name: args_0:0"><pre class="notranslate"><code>TypeError: An op outside of the function building code is being passed a "Graph" tensor. It is possible to have Graph tensors leak out of the function building context by including a tf.init_scope in your function building code. For example, the following function will fail: @tf.function def has_init_scope(): my_constant = tf.constant(1.) with tf.init_scope(): added = my_constant * 2 The graph tensor has name: args_0:0 </code></pre></div> <p dir="auto">(We are working with the TF team to give a more explicit error in this case.)</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Saved model for non-differentiable JAX functions</h3><a id="user-content-saved-model-for-non-differentiable-jax-functions" class="anchor" aria-label="Permalink: Saved model for non-differentiable JAX functions" href="#saved-model-for-non-differentiable-jax-functions"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Note that if the JAX function is not reverse-mode differentiable, e.g., uses <code>lax.while_loop</code> then attempting to save its conversion to a SavedModel will fail with:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="ValueError: Error when tracing gradients for SavedModel"><pre class="notranslate"><code>ValueError: Error when tracing gradients for SavedModel </code></pre></div> <p dir="auto">You have two options, either pass <code>with_gradient=False</code> to <code>jax2tf.convert</code>, or set <code>tf.saved_model.SaveOptions(experimental_custom_gradients=False)</code>. In either case, you will not be able to compute the gradients of the function loaded from the SavedModel.</p> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Support for partitioning</h2><a id="user-content-support-for-partitioning" class="anchor" aria-label="Permalink: Support for partitioning" href="#support-for-partitioning"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">jax2tf supports JAX functions that use <code>jax.pjit</code> and <code>jax.jit</code> with sharded arguments and results, for single-host meshes. The lowering is actually similar as for a <code>jax.jit</code>, except that the arguments and results will be wrapped with <code>tensorflow.python.compiler.xla.experimental.xla_sharding.XlaSharding</code> TensorFlow ops.</p> <p dir="auto">In the default native serialization mode, if the target JAX function includes sharding operations, e.g., from nested <code>jax.pjit</code>, then there should be a top-level <code>jax.pjit</code>. E.g.,</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="# The following is correct with mesh: jax2tf.convert(pjit.pjit(f_jax, in_shardings=...))(...) # The following will lead to errors because pjit is not at top-level. def wrapped_pjit(x): ...pjit.pjit(f_jax, in_shardings=...))... with mesh: jax2tf.convert(wrapped_pjit)"><pre><span class="pl-c"># The following is correct</span> <span class="pl-k">with</span> <span class="pl-s1">mesh</span>: <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">pjit</span>.<span class="pl-c1">pjit</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">in_shardings</span><span class="pl-c1">=</span>...))(...) <span class="pl-c"># The following will lead to errors because pjit is not at top-level.</span> <span class="pl-k">def</span> <span class="pl-en">wrapped_pjit</span>(<span class="pl-s1">x</span>): ...<span class="pl-s1">pjit</span>.<span class="pl-c1">pjit</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">in_shardings</span><span class="pl-c1">=</span>...))... <span class="pl-k">with</span> <span class="pl-s1">mesh</span>: <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">wrapped_pjit</span>)</pre></div> <p dir="auto">A limitation of <code>XlaSharding</code> is that it cannot be used in TensorFlow eager mode. Therefore, <code>jax2tf</code> will give an error when lowering a function that requires sharded (not replicated) arguments or results and the lowered function is used outside a <code>tf.function</code> context (see b/255511660).</p> <p dir="auto">Another limitation is that today only TPUs have integrated with XLA SPMD support in serving, while CPUs and GPUs don't have e2e XLA SPMD support yet in TensorFlow. Executing a jax2tf converted <code>tf.function</code> with <code>XlaSharding</code> ops on CPUs and GPUs will simply ignore all the <code>XlaSharding</code> ops.</p> <p dir="auto">Note that when saving a model, the parameters to the model are wrapped with <code>tf.Variable</code> before calling the lowered function (see <a href="#saved_model_with_parameters">above</a>), therefore outside of the <code>XlaSharding</code> wrapper.</p> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Shape-polymorphic conversion</h2><a id="user-content-shape-polymorphic-conversion" class="anchor" aria-label="Permalink: Shape-polymorphic conversion" href="#shape-polymorphic-conversion"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto"><strong>The shape polymorphism support is work in progress. Please report any bugs you encounter.</strong></p> <p dir="auto">We described above how to include in the SavedModel several specializations of a lowered function for a few specific input shapes. <code>jax2tf</code> can also produce a shape-polymorphic TensorFlow graph that is usable with inputs of any shape matching certain constraints. This is useful, e.g., to allow a single SavedModel to be used for multiple batch sizes.</p> <p dir="auto">The standard TensorFlow technique for producing a shape-polymorphic graph is to warm the <code>tf.function</code> on partially-specified (shape-polymorphic) inputs, e.g., <code>tf.TensorSpec([None, 28, 28], tf.float32)</code> for a function that processes a batch (of unspecified batch size) of 28x28 images. For jax2tf it is <strong>additionally</strong> necessary to specify an additional <code>polymorphic_shapes</code> parameter for the <code>jax2tf.convert</code> function:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="f_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes=["(b, 28, 28)"]), autograph=False) f_tf.get_concrete_function(tf.TensorSpec([None, 28, 28], tf.float32))"><pre><span class="pl-s1">f_tf</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(b, 28, 28)"</span>]), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>) <span class="pl-s1">f_tf</span>.<span class="pl-c1">get_concrete_function</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">TensorSpec</span>([<span class="pl-c1">None</span>, <span class="pl-c1">28</span>, <span class="pl-c1">28</span>], <span class="pl-s1">tf</span>.<span class="pl-c1">float32</span>))</pre></div> <p dir="auto">The <code>polymorphic_shapes</code> parameter, in the form of a pytree of strings corresponding to the pytree of positional arguments, introduces one or more dimension variables, e.g., <code>b</code>, to stand for shape dimensions that are assumed to be unknown at JAX tracing time. Dimension variables are assumed to range over all integers that are greater or equal to 1. In this particular example, we can also abbreviate <code>polymorphic_shapes=["(b, _, _)"]</code>, because the <code>_</code> placeholders take their value from the corresponding dimension of the <code>tf.TensorSpec</code> (which must be known). As a further shortcut for a series of <code>_</code> at the end of a shape specification you can use <code>...</code>: <code>polymorphic_shapes=["(b, ...)"]</code>.</p> <p dir="auto">In the example above, the <code>polymorphic_shapes</code> specification does not convey more information than the partial <code>tf.TensorSpec</code>, except that it gives a name to the unknown dimension, which improves error messages. The real need for named shape variables arises when there are multiple unknown dimensions and there is a relationship between them. For example, if the function to be lowered is also polymorphic on the size of each image while requiring the images to be square, we would add a dimension variable <code>d</code> to stand for the unknown image size:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="f_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes=["(b, d, d)"]), autograph=False) f_tf.get_concrete_function(tf.TensorSpec([None, None, None], tf.float32))"><pre><span class="pl-s1">f_tf</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(b, d, d)"</span>]), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>) <span class="pl-s1">f_tf</span>.<span class="pl-c1">get_concrete_function</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">TensorSpec</span>([<span class="pl-c1">None</span>, <span class="pl-c1">None</span>, <span class="pl-c1">None</span>], <span class="pl-s1">tf</span>.<span class="pl-c1">float32</span>))</pre></div> <p dir="auto">The JAX tracing mechanism performs shape checking using the same strict rules as when the shapes are fully known. For example, given the <code>"(b, d, d)"</code> specification for the argument <code>x</code> of a function, JAX will know that a conditional <code>x.shape[-2] == x.shape[-1]</code> is <code>True</code>, and will also know that <code>x</code> and <code>jnp.sin(x)</code> have the same shape of a batch of square matrices that can be passed to <code>jnp.matmul</code>.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Correctness of shape-polymorphic tracing</h3><a id="user-content-correctness-of-shape-polymorphic-tracing" class="anchor" aria-label="Permalink: Correctness of shape-polymorphic tracing" href="#correctness-of-shape-polymorphic-tracing"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">We want to trust that the lowered program produces the same results as the original JAX program. More precisely:</p> <p dir="auto">For any function <code>f_jax</code> and any input signature <code>abs_sig</code> containing partially known <code>tf.TensorSpec</code>, and any concrete input <code>x</code> whose shape matches <code>abs_sig</code>:</p> <ul dir="auto"> <li>If the conversion to TensorFlow succeeds: <code>f_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes)).get_concrete_function(abs_sig)</code></li> <li>and if the TensorFlow execution succeeds with result <code>y</code>: <code>f_tf(x) = y</code></li> <li>then the JAX execution would produce the same result: <code>f_jax(x) = y</code>,</li> </ul> <p dir="auto">It is crucial to understand that <code>f_jax(x)</code> has the freedom to re-invoke the JAX tracing machinery, and in fact it does so for each distinct concrete input shape, while the generation of <code>f_tf</code> uses JAX tracing only once, and invoking <code>f_tf(x)</code> does not use JAX tracing anymore. In fact, the latter invocation may happen after the <code>f_tf</code> has been serialized to a SavedModel and reloaded in an environment where <code>f_jax</code> and the JAX tracing machinery are not available anymore.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Coverage of shape-polymorphic tracing</h3><a id="user-content-coverage-of-shape-polymorphic-tracing" class="anchor" aria-label="Permalink: Coverage of shape-polymorphic tracing" href="#coverage-of-shape-polymorphic-tracing"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Besides correctness, a secondary goal is to be able to lower many shape-polymorphic programs, but at the very least batch-size-polymorphic programs, so that one SavedModel can be used for any batch sizes. For example, we want to ensure that any function written using <code>jax.vmap</code> at the top level can be lowered with the batch dimension polymorphic and the remaining dimensions concrete.</p> <p dir="auto">It is reasonable to expect that there will be JAX programs for which there is a shape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf. In general, you should expect that shape polymorphism can handle those programs for which all the intermediate shapes can be expressed as simple expressions in the dimension variables appearing in the input shapes. In particular, this does not apply to programs whose intermediate shapes depend on the data.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Details</h3><a id="user-content-details" class="anchor" aria-label="Permalink: Details" href="#details"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">In order to be able to use shape polymorphism effectively with jax2tf, it is worth considering what happens under the hood. When the lowered function is invoked with a <code>TensorSpec</code>, <code>jax2tf</code> will use the <code>polymorphic_shapes</code> parameter to obtain a shape abstraction for the inputs. The dimension sizes from the <code>TensorSpec</code> are used to fill in the <code>_</code> and <code>...</code> placeholders from <code>polymorphic_shapes</code>. Normally, the shape abstraction contains the dimension sizes, but in the presence of shape polymorphism, some dimensions may be dimension variables.</p> <p dir="auto">The <code>polymorphic_shapes</code> parameter must be either <code>None</code>, or a pytree of shape specifiers corresponding to the pytree of arguments. (A value <code>None</code> for <code>polymorphic_shapes</code> is equivalent to a list of <code>None</code>. See <a href="https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees" rel="nofollow">how optional parameters are matched to arguments</a>.) A shape specifier is combined with a <code>TensorSpec</code> as follows:</p> <ul dir="auto"> <li> <p dir="auto">A shape specifier of <code>None</code> means that the shape is given by the actual argument <code>TensorSpec</code>, which must be fully known.</p> </li> <li> <p dir="auto">Otherwise, the specifier must be a comma-separated string of dimension specifiers: <code>(dim_1, ..., dim_n)</code>, denoting an n-dimensional array. The <code>TensorSpec</code> must also be of rank <code>n</code>. An <code>...</code> at the end of the shape specifier is expanded to a list of <code>_</code> or appropriate length. The corresponding dimensions from the shape specifier and the <code>TensorSpec</code> are matched:</p> <ul dir="auto"> <li>the dimension specifier of <code>_</code> means that the size of the dimension is given by the actual <code>TensorSpec</code>, which must have a known size in the corresponding dimension.</li> <li>a dimension specifier can also be a lowercase identifier, denoting a dimension-size variable ranging over strictly positive integers. The abstract value of the dimension is going to be set to this variable. The corresponding dimension in <code>TensorSpec</code> can be <code>None</code> or can be a constant.</li> <li>All occurrences of a dimension variable in any dimension for any argument are assumed to be equal.</li> </ul> </li> </ul> <p dir="auto">Note that <code>polymorphic_shapes</code> controls the shape abstraction used by JAX when tracing the function. The <code>TensorSpec</code> gives the shape abstraction that TensorFlow will associate with the produced graph, and can be more specific.</p> <p dir="auto">A few examples of shape specifications and uses:</p> <ul dir="auto"> <li> <p dir="auto"><code>polymorphic_shapes=["(b, _, _)", None]</code> can be used for a function with two arguments, the first having a batch leading dimension that should be polymorphic. The other dimensions for the first argument and the shape of the second argument are specialized based on the actual <code>TensorSpec</code>, which must be known. The lowered function can be used, e.g., with <code>TensorSpec</code>s <code>[None, 28, 28]</code> and <code>[28, 16]</code> for the first and second argument respectively. An alternative <code>TensorSpec</code> pair can be <code>[1, 28, 28]</code> and <code>[28, 16]</code>, in which case the JAX tracing is done for the same polymorphic shape given by <code>polymorphic_shapes=["(b, 28, 28)", "(28, 16)"]</code>.</p> </li> <li> <p dir="auto"><code>polymorphic_shapes=["(batch, _)", "(batch,)"]</code>: the leading dimensions of the two arguments must match, and are assumed to be greater than 1. The second dimension of the first argument is taken from the actual <code>TensorSpec</code>. This can be used with a <code>TensorSpec</code> pair <code>[None, 16]</code> and <code>[None]</code>. It can also be used with a pair of shapes <code>[8, 16]</code> and <code>[8]</code>.</p> </li> </ul> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Computing with dimension variables</h3><a id="user-content-computing-with-dimension-variables" class="anchor" aria-label="Permalink: Computing with dimension variables" href="#computing-with-dimension-variables"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">JAX keeps track of the shape of all intermediate results. When those shapes depend on dimension variables JAX computes them as symbolic expressions involving dimension variables. The symbolic expressions can represent the result of applying arithmetic operators (add, sub, mul, floordiv, mod, including the NumPy variants <code>np.sum</code>, <code>np.prod</code>, etc.) <strong>on dimension variables and integers</strong> (<code>int</code>, <code>np.int</code>, or anything convertible by <code>operator.index</code>). These symbolic dimensions can then be used in shape-parameters of JAX primitives and APIs, e.g., in <code>jnp.reshape</code>, <code>jnp.arange</code>, slicing indices, etc.</p> <p dir="auto">For example, in the following code to flatten a 2D array, the computation <code>x.shape[0] * x.shape[1]</code> computes the symbolic dimension <code>4 * b</code> as the new shape:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)), polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">jnp</span>.<span class="pl-c1">reshape</span>(<span class="pl-s1">x</span>, (<span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">*</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">1</span>],)), <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(b, 4)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">3</span>, <span class="pl-c1">4</span>)))</pre></div> <p dir="auto">When a symbolic dimension is used in <strong>arithmetic operations with non-integers</strong>, e.g., <code>float</code>, <code>np.float</code>, <code>np.ndarray</code>, or JAX arrays, it is automatically converted to a JAX array using <code>jnp.array</code>. For example, in the function below all occurrences of <code>x.shape[0]</code> are converted implicitly to <code>jnp.array(x.shape[0])</code> because they are involved in operations with non-integer scalars or with JAX arrays:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: (x + x.shape[0] + jnp.sin(x.shape[0]), 5. + x.shape[0], x.shape[0] - np.ones((5,), dtype=np.int32)), polymorphic_shapes=["b"])(np.ones(3))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: (<span class="pl-s1">x</span> <span class="pl-c1">+</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">+</span> <span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>(<span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>]), <span class="pl-c1">5.</span> <span class="pl-c1">+</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>], <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">-</span> <span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">5</span>,), <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">np</span>.<span class="pl-c1">int32</span>)), <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"b"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>(<span class="pl-c1">3</span>))</pre></div> <p dir="auto">Another typical example is when computing averages:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0], polymorphic_shapes=["(v, _)"])(np.ones((3, 4)))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">jnp</span>.<span class="pl-c1">sum</span>(<span class="pl-s1">x</span>, <span class="pl-s1">axis</span><span class="pl-c1">=</span><span class="pl-c1">0</span>) <span class="pl-c1">/</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>], <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(v, _)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">3</span>, <span class="pl-c1">4</span>)))</pre></div> <p dir="auto">It is also possible to convert dimension polynomials explicitly to JAX arrays, with <code>jnp.array(x.shape[0])</code> or even <code>jnp.array(x.shape)</code>. The result of these operations cannot be used anymore as dimension parameters and will raise a JAX error.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Errors in presence of shape polymorphism</h3><a id="user-content-errors-in-presence-of-shape-polymorphism" class="anchor" aria-label="Permalink: Errors in presence of shape polymorphism" href="#errors-in-presence-of-shape-polymorphism"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Most JAX code assumes that the shapes of JAX arrays are tuples of integers, but with shape polymorphism some dimensions may be symbolic expressions. This can lead to a number of errors. For example, the program:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="four_ones = np.ones((4,)) jax2tf.convert(lambda x, y: x + y, polymorphic_shapes=["(v,)", "(4,)"])(four_ones, four_ones)"><pre><span class="pl-s1">four_ones</span> <span class="pl-c1">=</span> <span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">4</span>,)) <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>, <span class="pl-s1">y</span>: <span class="pl-s1">x</span> <span class="pl-c1">+</span> <span class="pl-s1">y</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(v,)"</span>, <span class="pl-s">"(4,)"</span>])(<span class="pl-s1">four_ones</span>, <span class="pl-s1">four_ones</span>)</pre></div> <p dir="auto">with result in the error <code>'add got incompatible shapes for broadcasting: (v,), (4,)'</code> because the shape abstraction that JAX tracing uses is given by the <code>polymorphic_shapes</code>, even though the actual arguments are more specific and would actually work.</p> <p dir="auto">Also,</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: jnp.matmul(x, x), polymorphic_shapes=["(v, 4)"])(np.ones((4, 4)))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">jnp</span>.<span class="pl-c1">matmul</span>(<span class="pl-s1">x</span>, <span class="pl-s1">x</span>), <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(v, 4)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">4</span>, <span class="pl-c1">4</span>)))</pre></div> <p dir="auto">will result in the error <code>dot_general requires contracting dimensions to have the same shape, got [4] and [v]</code>. What is happening here is that in the process of type checking the <code>matmul</code> operation, JAX will want to ensure the size of the two axes is the same (<code>v == 4</code>). Note that <code>v</code> can stand for any integer greater than 0, so the value of the equality expression can be true or false. Since it is not always true that <code>v == 4</code>, the shape checking rules fail with the above error. Since the lowered function works only for square matrices, the correct <code>polymorphic_shapes</code> is <code>["(v, v)"]</code>.</p> <p dir="auto">As explained above, if the dimension polynomials are used in operations with non-integers, the result will be a JAX array that cannot be used as a shape parameter. For example, if we modify the reshape example slightly, to use <code>np.array([x.shape[1]])</code> instead of <code>x.shape[1]</code>:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * np.array([x.shape[1]]),)), polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">jnp</span>.<span class="pl-c1">reshape</span>(<span class="pl-s1">x</span>, (<span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">*</span> <span class="pl-s1">np</span>.<span class="pl-c1">array</span>([<span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">1</span>]]),)), <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(b, 4)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">3</span>, <span class="pl-c1">4</span>)))</pre></div> <p dir="auto">we get an error <code>Shapes must be 1D sequences of concrete values of integer type, got Traced<...></code>. If you get this error on JAX code that works for static shapes, it means that one operation that computes shape parameters is using non-integer arguments, e.g., <code>np.ndarray</code>, that get implicitly converted to JAX arrays. The solution is to avoid <code>np.array</code>, <code>float</code>, or JAX arrays in operations whose results are used as shapes, e.g., instead of <code>np.arange(n) * x.shape[0]</code> write <code>[i * x.shape[0] for i in range(n)]</code>.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Dimension variables must be solvable from the input shapes</h3><a id="user-content-dimension-variables-must-be-solvable-from-the-input-shapes" class="anchor" aria-label="Permalink: Dimension variables must be solvable from the input shapes" href="#dimension-variables-must-be-solvable-from-the-input-shapes"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">JAX will generate code to derive the values of the dimension variables from the input shapes. This works only if the symbolic dimensions in the input shapes are linear. For example, the following <code>polymorphic_shapes</code> will result in errors:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="polymorphic_shapes = ["a * a"] # Not a linear polynomial polymorphic_shapes = ["a + b"] # Too few equations to derive both `a` and `b`"><pre><span class="pl-s1">polymorphic_shapes</span> <span class="pl-c1">=</span> [<span class="pl-s">"a * a"</span>] <span class="pl-c"># Not a linear polynomial</span> <span class="pl-s1">polymorphic_shapes</span> <span class="pl-c1">=</span> [<span class="pl-s">"a + b"</span>] <span class="pl-c"># Too few equations to derive both `a` and `b`</span></pre></div> <p dir="auto">The error message for the last specification above would be:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="Cannot solve for values of dimension variables {'a', 'b'}. " We can only solve linear uni-variate constraints. " Using the following polymorphic shapes specifications: args[0].shape = (a + b,). Unprocessed specifications: 'a + b' for dimension size args[0].shape[0]. " Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details."><pre class="notranslate"><code>Cannot solve for values of dimension variables {'a', 'b'}. " We can only solve linear uni-variate constraints. " Using the following polymorphic shapes specifications: args[0].shape = (a + b,). Unprocessed specifications: 'a + b' for dimension size args[0].shape[0]. " Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details. </code></pre></div> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Shape assertion errors</h3><a id="user-content-shape-assertion-errors" class="anchor" aria-label="Permalink: Shape assertion errors" href="#shape-assertion-errors"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">JAX assumes that dimension variables range over strictly positive integers. Starting with serialization version 7 these assumptions are checked against the shapes of the actual arguments when the lowered code is invoked. For example, given the <code>polymorphic_shapes="(b, b, 2*d)"</code> specification, we will generate code to check the following constraints when invoked with actual argument <code>arg</code>:</p> <ul dir="auto"> <li><code>arg.shape[0] >= 1</code></li> <li><code>arg.shape[1] == arg.shape[0]</code></li> <li><code>arg.shape[2] % 2 == 0</code></li> <li><code>arg.shape[2] // 2 >= 1</code></li> </ul> <p dir="auto">An example error for the third constraint above, e.g., when invoked with shape <code>(3, 3, 5)</code>, would be:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="Input shapes do not match the polymorphic shapes specification. Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3). Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."><pre class="notranslate"><code>Input shapes do not match the polymorphic shapes specification. Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3). Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. </code></pre></div> <p dir="auto">When using native serialization these are checked by the <code>tf.XlaCallModule</code> op (starting with serialization <a href="https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Aconfig.py+jax_serialization_version&type=code">version 7</a>), and you will get <code>tf.errors.InvalidArgument</code> errors. You can disable this checking by including <code>DisabledSafetyCheck.shape_assertions()</code> in the <code>disabled_checks</code> parameter to <code>jax2tf.convert</code>, or by setting the environment variable <code>TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=shape_assertions</code>. When using graph serialization these are checked using <code>tf.debugging.assert</code>, which will also result in <code>tf.errors.InvalidArgument</code>. Note that due to limitations in TensorFlow, these errors are suppressed when using <code>jit_compile=True</code> and when running on TPU.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Comparison of symbolic dimensions is partially supported</h3><a id="user-content-comparison-of-symbolic-dimensions-is-partially-supported" class="anchor" aria-label="Permalink: Comparison of symbolic dimensions is partially supported" href="#comparison-of-symbolic-dimensions-is-partially-supported"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Inside JAX there are a number of equality and inequality comparisons involving shapes, e.g., for doing shape checking or even for choosing the implementation for some primitives. Comparisons are supported as follows:</p> <ul dir="auto"> <li>equality is supported with a caveat: if the two symbolic dimensions denote the same value under all valuations for dimension variables, then equality evaluates to <code>True</code>, e.g., for <code>b + b == 2*b</code>; otherwise the equality evaluates to <code>False</code>. See below for a discussion of important consequences of this behavior.</li> <li>disequality is always the negation of equality.</li> <li>inequality is partially supported, in a similar way as partial equality. However, in this case we take into consideration that dimension variables range over strictly positive integers. E.g., <code>b >= 1</code>, <code>b >= 0</code>, <code>2 * a + b >= 3</code> are <code>True</code>, while <code>b >= 2</code>, <code>a >= b</code>, <code>a - b >= 0</code> are inconclusive and result in an exception.</li> </ul> <p dir="auto">For example, the following code raises the exception <code>core.InconclusiveDimensionOperation</code> with the message <code>Dimension polynomial comparison 'a + 1' >= 'b' is inconclusive</code>.</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1, polymorphic_shapes=["(a, b)"])(np.ones((3, 4)))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-c1">0</span> <span class="pl-k">if</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">+</span> <span class="pl-c1">1</span> <span class="pl-c1">>=</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">1</span>] <span class="pl-k">else</span> <span class="pl-c1">1</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(a, b)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">3</span>, <span class="pl-c1">4</span>)))</pre></div> <p dir="auto">If you do get an <code>core.InconclusiveDimensionOperation</code>, you can try several strategies:</p> <ul dir="auto"> <li>If your code uses the built-in <code>max</code> or <code>min</code>, or the <code>np.max</code> or <code>np.min</code> then you can replace those with <code>core.max_dim</code> and <code>core.min_dim</code>, which have the effect of delaying the inequality comparison to the compilation time, when shapes become known.</li> <li>Try to rewrite conditionals using <code>core.max_dim</code> and <code>core.min_dim</code>, e.g., instead of <code>d if d > 0 else 0</code> you can write <code>core.max_dim(d, 0)</code>.</li> <li>Try to rewrite the code to be less dependent on the fact that dimensions should be integers, and rely on the fact that symbolic dimensions duck-type as integers for most arithmetic operations. E.g., instead of <code>int(d) + 5</code> write <code>d + 5</code>.</li> <li>Specify symbolic constraints, as explained below.</li> </ul> <div class="markdown-heading" dir="auto"><h4 tabindex="-1" class="heading-element" dir="auto">User-specified symbolic constraints</h4><a id="user-content-user-specified-symbolic-constraints" class="anchor" aria-label="Permalink: User-specified symbolic constraints" href="#user-specified-symbolic-constraints"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">By default, JAX assumes that all dimension variables range over values greater-or-equal to 1, and it tries to derive other simple inequalities from that, e.g.:</p> <ul dir="auto"> <li><code>a + 2 >= 3</code>,</li> <li><code>a * 2 >= 1</code>,</li> <li><code>a + b + c >= 3</code>,</li> <li><code>a // 4 >= 0</code>, <code>a**2 >= 1</code>, and so on.</li> </ul> <p dir="auto">You can avoid some inequality comparison failures if you change the symbolic shape specifications to add implicit constraints for dimension sizes. E.g.,</p> <ul dir="auto"> <li>You can use <code>2*b</code> for a dimension to constrain it to be even (and <code>>= 2</code>).</li> <li>You can use <code>b + 15</code> for a dimension to constrain it to be at least 16. E.g., the following code would fail without the <code>+ 15</code> part, because JAX will want to verify that slice sizes are at most as large as the axis size.</li> </ul> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: x[0:16], polymorphic_shapes="b + 15, ...")"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">x</span>[<span class="pl-c1">0</span>:<span class="pl-c1">16</span>], <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span><span class="pl-s">"b + 15, ..."</span>)</pre></div> <p dir="auto">Such implicit symbolic constraints are used for reasoning, and are checked at compile time, as explained <a href="#shape-assertion-errors">above</a>.</p> <p dir="auto">Starting with JAX version 0.4.24 you can also specify explicit symbolic constraints:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: x[:x.shape[1], :16], polymorphic_shapes="(a, b)", polymorphic_constraints=("a >= b", "b >= 16"))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">x</span>[:<span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">1</span>], :<span class="pl-c1">16</span>], <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span><span class="pl-s">"(a, b)"</span>, <span class="pl-s1">polymorphic_constraints</span><span class="pl-c1">=</span>(<span class="pl-s">"a >= b"</span>, <span class="pl-s">"b >= 16"</span>))</pre></div> <p dir="auto">The constraints form a conjunction together with the implicit constraints. You can specify <code>>=</code>, <code><=</code>, and <code>==</code> constraints. At the moment, JAX has limited support for reasoning with symbolic constraints:</p> <ul dir="auto"> <li>You get most from constraints of the form of a variable being greater-or-equal or less-or-equal to a constant. For example, from the constraints that <code>a >= 16</code> and <code>b >= 8</code> we can infer that <code>a + 2*b >= 32</code>.</li> <li>You get limited power when the constraint involves more complex expressions, e.g., from <code>a >= b + 8</code> we can infer that <code>a - b >= 8</code> but not that <code>a >= 9</code>. We plan to improve somewhat this area in the future.</li> <li>Equality constraints are treated as normalization rules. E.g., <code>floordiv(a, b) = c</code> works by replacing all occurrences of the left-hand-side with the right-hand-side. You can only have equality constraints where the left-hand-side is a multiplication of factors, e.g, <code>a * b</code>, or <code>4 * a</code>, or <code>floordiv(a, b)</code>. Thus, the left-hand-side cannot contain addition or subtraction at the top-level.</li> </ul> <p dir="auto">The symbolic constraints can also help to work around the limitations in the JAX reasoning mechanisms. For example, the following code would not be able to prove that the slice size fits into the axis size (such examples come up when using striding):</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: x[: 4*(x.shape[0] // 4)], polymorphic_shapes=("b, ...",))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">x</span>[: <span class="pl-c1">4</span><span class="pl-c1">*</span>(<span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">//</span> <span class="pl-c1">4</span>)], <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>(<span class="pl-s">"b, ..."</span>,))</pre></div> <p dir="auto">You will likely see an error that the comparison <code>b >= 4*floordiv(b, 4)</code> is inconclusive, even though the inequality always holds when <code>b >= 1</code>. One option here would be to restrict the code to work only on axis sizes that are multiple of <code>4</code> (by replacing <code>b</code> with <code>4*b</code> in the shape specification); another option is to add a symbolic constraint with the exact inconclusive inequality:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: x[: 4*(x.shape[0] // 4)], polymorphic_shapes=("b, ...",), polymorphic_constraints=("b >= 4*floordiv(b, 4)",))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">x</span>[: <span class="pl-c1">4</span><span class="pl-c1">*</span>(<span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">//</span> <span class="pl-c1">4</span>)], <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>(<span class="pl-s">"b, ..."</span>,), <span class="pl-s1">polymorphic_constraints</span><span class="pl-c1">=</span>(<span class="pl-s">"b >= 4*floordiv(b, 4)"</span>,))</pre></div> <p dir="auto">An example where an equality constraint would be useful is in the following code:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x, y: x + y[:y.shape[0] // 2], polymorphic_shapes=("a", "b"))(x, y)"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>, <span class="pl-s1">y</span>: <span class="pl-s1">x</span> <span class="pl-c1">+</span> <span class="pl-s1">y</span>[:<span class="pl-s1">y</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">//</span> <span class="pl-c1">2</span>], <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>(<span class="pl-s">"a"</span>, <span class="pl-s">"b"</span>))(<span class="pl-s1">x</span>, <span class="pl-s1">y</span>)</pre></div> <p dir="auto">The above code would raise a <code>TypeError</code> because JAX cannot verify that <code>x</code> and <code>y[:x.shape[0]]</code> have the same shape:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="TypeError: add got incompatible shapes for broadcasting: (a,), (floordiv(b, 2),)"><pre><span class="pl-v">TypeError</span>: <span class="pl-s1">add</span> <span class="pl-s1">got</span> <span class="pl-s1">incompatible</span> <span class="pl-s1">shapes</span> <span class="pl-k">for</span> <span class="pl-smi">broadcasting</span>: (<span class="pl-s1">a</span>,), (<span class="pl-en">floordiv</span>(<span class="pl-s1">b</span>, <span class="pl-c1">2</span>),)</pre></div> <p dir="auto">You can fix this by adding a constraint:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x, y: x + y[:y.shape[0] // 2], polymorphic_shapes=("a", "b"), polymorphic_constraints=("floordiv(b, 2) == a",))(x, y)"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>, <span class="pl-s1">y</span>: <span class="pl-s1">x</span> <span class="pl-c1">+</span> <span class="pl-s1">y</span>[:<span class="pl-s1">y</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">//</span> <span class="pl-c1">2</span>], <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>(<span class="pl-s">"a"</span>, <span class="pl-s">"b"</span>), <span class="pl-s1">polymorphic_constraints</span><span class="pl-c1">=</span>(<span class="pl-s">"floordiv(b, 2) == a"</span>,))(<span class="pl-s1">x</span>, <span class="pl-s1">y</span>)</pre></div> <p dir="auto">Just like the implicit constraints, the explicit symbolic constraints are checked at compile time, using the same mechanism as explained <a href="#shape-assertion-errors">above</a>.</p> <p dir="auto">The symbolic constraints are stored in αn <code>export.SymbolicScope</code> object, which is created implicitly for each call to <code>jax2tf.convert</code>. You must be careful to not mix symbolic expressions that use different scopes. For example, the following code will fail because <code>a1</code> and <code>a2</code> use different scopes (created by <code>export.symbolic_shape</code>):</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="a1, = export.symbolic_shape("a,") a2, = export.symbolic_shape("a,", constraints=("a >= 8",)) a1 + a2"><pre><span class="pl-s1">a1</span>, <span class="pl-c1">=</span> <span class="pl-s1">export</span>.<span class="pl-c1">symbolic_shape</span>(<span class="pl-s">"a,"</span>) <span class="pl-s1">a2</span>, <span class="pl-c1">=</span> <span class="pl-s1">export</span>.<span class="pl-c1">symbolic_shape</span>(<span class="pl-s">"a,"</span>, <span class="pl-s1">constraints</span><span class="pl-c1">=</span>(<span class="pl-s">"a >= 8"</span>,)) <span class="pl-s1">a1</span> <span class="pl-c1">+</span> <span class="pl-s1">a2</span></pre></div> <p dir="auto">The symbolic expressions that originate from a single call to <code>export.symbolic_shape</code> share a scope and can be mixed up in arithmetic operations. The result would also share the same scope.</p> <p dir="auto">You can re-use scopes:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="a, = export.symbolic_shape("a,", constraints=("a >= 8",)) b, = export.symbolic_shape("b,", scope=a1.scope) a + b # Allowed"><pre><span class="pl-s1">a</span>, <span class="pl-c1">=</span> <span class="pl-s1">export</span>.<span class="pl-c1">symbolic_shape</span>(<span class="pl-s">"a,"</span>, <span class="pl-s1">constraints</span><span class="pl-c1">=</span>(<span class="pl-s">"a >= 8"</span>,)) <span class="pl-s1">b</span>, <span class="pl-c1">=</span> <span class="pl-s1">export</span>.<span class="pl-c1">symbolic_shape</span>(<span class="pl-s">"b,"</span>, <span class="pl-s1">scope</span><span class="pl-c1">=</span><span class="pl-s1">a1</span>.<span class="pl-c1">scope</span>) <span class="pl-s1">a</span> <span class="pl-c1">+</span> <span class="pl-s1">b</span> <span class="pl-c"># Allowed</span></pre></div> <p dir="auto">JAX tracing uses caches keyed partially by shapes, and symbolic shapes that are printed identically will be considered distinct if they use different scopes.</p> <div class="markdown-heading" dir="auto"><h4 tabindex="-1" class="heading-element" dir="auto">Caveat for equality comparisons</h4><a id="user-content-caveat-for-equality-comparisons" class="anchor" aria-label="Permalink: Caveat for equality comparisons" href="#caveat-for-equality-comparisons"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">The equality comparison returns <code>False</code> for <code>b + 1 == b</code> or <code>b == 0</code> (in which case it is certain that the dimensions are different for all valuations), but also for <code>b == 1</code> and for <code>a == b</code>. This is unsound, and we ought to raise <code>core.InconclusiveDimensionOperation</code> because under some valuations the result should be <code>True</code> and under other valuations it should be <code>False</code>. We choose to make equality total thus allowing unsoundness because otherwise we may get spurious errors in presence of hash collisions when hashing dimension expressions or objects that include them (shapes, <code>core.AbstractValue</code>, <code>core.Jaxpr</code>). Besides the hashing errors, a partial semantics of equality leads to errors for the following expressions <code>b == a or b == b</code> or <code>b in [a, b]</code> even though the error is avoided if we change the order of the comparisons.</p> <p dir="auto">We attempted to retain soundness and hashability by creating both hashable and unhashable kinds of symbolic dimensions <a href="https://github.com/jax-ml/jax/pull/14200" data-hovercard-type="pull_request" data-hovercard-url="/jax-ml/jax/pull/14200/hovercard">PR #14200</a>, but it turned out to be very hard to diagnose hashing failures in user programs because often hashing is implicit when using sets or memo tables.</p> <p dir="auto">Code of the form <code>if x.shape[0] != 1: raise NiceErrorMessage</code> is sound even with this treatment of equality, but code of the form <code>if x.shape[0] != 1: return 1</code> is unsound.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Division of symbolic dimensions is partially supported</h3><a id="user-content-division-of-symbolic-dimensions-is-partially-supported" class="anchor" aria-label="Permalink: Division of symbolic dimensions is partially supported" href="#division-of-symbolic-dimensions-is-partially-supported"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">JAX will attempt to simplify division and modulo operations, e.g., <code>(a * b + a) // (b + 1) == a</code> and <code>(6 * a + 4) % 3 == 1</code>. In particular, JAX will handle the cases when either (a) there is no remainder, or (b) the divisor is a constant in which case there may be a constant remainder. For example, the code below results in a division error when trying to compute the inferred dimension for a <code>reshape</code> operation:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), polymorphic_shapes=["(b, ...)"])(np.ones((4, 5, 7)))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">jnp</span>.<span class="pl-c1">reshape</span>(<span class="pl-s1">x</span>, (<span class="pl-c1">2</span>, <span class="pl-c1">-</span><span class="pl-c1">1</span>)), <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(b, ...)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">4</span>, <span class="pl-c1">5</span>, <span class="pl-c1">7</span>)))</pre></div> <p dir="auto">In this case you will see the error <code>Cannot divide evenly the sizes of shapes (b, 5, 7) and (2, -1)</code>, with a further <code>Details: Cannot divide '35*b' by '-2'</code>. The polynomial <code>35*b</code> represents the total size of the input tensor.</p> <p dir="auto">Note that the following will succeed:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="## The resulting symbolic shape is (2, 15 b). jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), polymorphic_shapes=["(b, ...)"])(np.ones((4, 5, 6))) ## The resulting symbolic shape is (6 b2, b1). jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])), polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))"><pre><span class="pl-c">## The resulting symbolic shape is (2, 15 b).</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">jnp</span>.<span class="pl-c1">reshape</span>(<span class="pl-s1">x</span>, (<span class="pl-c1">2</span>, <span class="pl-c1">-</span><span class="pl-c1">1</span>)), <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(b, ...)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">4</span>, <span class="pl-c1">5</span>, <span class="pl-c1">6</span>))) <span class="pl-c">## The resulting symbolic shape is (6 b2, b1).</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">jnp</span>.<span class="pl-c1">reshape</span>(<span class="pl-s1">x</span>, (<span class="pl-c1">-</span><span class="pl-c1">1</span>, <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>])), <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(b1, b2, ...)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">4</span>, <span class="pl-c1">5</span>, <span class="pl-c1">6</span>)))</pre></div> <p dir="auto">You may also encounter division errors when working with strides, such as when computing the padding in a strided convolution.</p> <p dir="auto">When JAX cannot simplify the result of symbolic dimension division it will construct symbolic expressions of the form <code>floordiv(E, N)</code> and <code>mod(E, N)</code> and it will use a number of heuristics to evaluate comparisons involving these. If you encounter <code>InconclusiveDimensionOperation</code> exceptions you can specify that a dimension variable is a multiple of the divisor, e.g., <code>b</code> in the above example of dividing <code>35*b</code> by <code>-2</code> may be known to be a multiple of <code>2</code>. You can specify that by replacing <code>b</code> with <code>2*b</code> in the polymorphic shape specification:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))"><pre><span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">jnp</span>.<span class="pl-c1">reshape</span>(<span class="pl-s1">x</span>, (<span class="pl-c1">2</span>, <span class="pl-c1">-</span><span class="pl-c1">1</span>)), <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"(2*b, ...)"</span>])(<span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">4</span>, <span class="pl-c1">5</span>, <span class="pl-c1">7</span>)))</pre></div> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Native serialization versions</h2><a id="user-content-native-serialization-versions" class="anchor" aria-label="Permalink: Native serialization versions" href="#native-serialization-versions"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">We use a serialization version number to help evolve the serialization mechanism while allowing serialized artifacts to be used by consumers built at different code versions.</p> <p dir="auto">If consumers use the <code>tf.XlaCallModule</code> op, e.g. when using the TensorFlow SavedModel, then they support a range of serialization versions. See <a href="https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code">tf.XlaCallModule code</a>. There is also an API to get the maximum version number supported by your installed version of TensorFlow:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="from tensorflow.compiler.tf2xla.python import xla as tfxla tfxla.call_module_maximum_supported_version()"><pre class="notranslate"><code>from tensorflow.compiler.tf2xla.python import xla as tfxla tfxla.call_module_maximum_supported_version() </code></pre></div> <p dir="auto">For <strong>backward compatibility</strong>, we want to allow a freshly built consumer to load artifacts that have been serialized in the past 6 months (by a serializer using the latest version supported at the time). Thus, the minimum supported version number should match the maximum supported version number from 6 months in the past.</p> <p dir="auto">The serialization version used by JAX is determined by the <code>--jax_serialization_version</code> flag, or if missing, the <code>JAX_SERIALIZATION_VERSION</code> environment variable. The default value is specified in the <a href="https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Aconfig.py+JAX_SERIALIZATION_VERSION&type=code"><code>config.py</code> file</a>.</p> <p dir="auto">For <strong>forward compatibility</strong>, we want freshly serialized artifacts to be loadable by consumers that have been built in the last 1 month. Thus, we bump the default serialization version number about 1 month after the <code>tf.XlaCallModule</code> is upgraded to a given version number.</p> <p dir="auto">You can use <code>--jax_serialization_version</code> to adjust the serialization version to your deployed consumer. We reserve the right to remove support for generating or consuming old serialization versions older than 6 months.</p> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Serialization version numbers</h2><a id="user-content-serialization-version-numbers" class="anchor" aria-label="Permalink: Serialization version numbers" href="#serialization-version-numbers"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">We list here a history of the serialization version numbers:</p> <ul dir="auto"> <li>Version 1 used MHLO & CHLO to serialize the code, not supported anymore.</li> <li>Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported anymore.</li> <li>Version 3 supports platform checking and multiple platforms. Used from February 2023. Not supported anymore.</li> <li>Version 4 supports StableHLO with compatibility guarantees. This is the earliest version at the time of the JAX native serialization launch. Used in JAX from March 15, 2023 (cl/516885716). Starting with March 28th, 2023 we stopped using <code>dim_args_spec</code> (cl/520033493). The support for this version was dropped on October 17th, 2023 (cl/573858283).</li> <li>Version 5 adds support for <code>call_tf_graph</code>. This is currently used for some specialized use cases. Used in JAX from May 3rd, 2023 (cl/529106145).</li> <li>Version 6 adds support for the <code>disabled_checks</code> attribute. This version mandates a non-empty <code>platforms</code> attribute. Supported by XlaCallModule since June 7th, 2023 and available in JAX since June 13th, 2023 (JAX 0.4.13).</li> <li>Version 7 adds support for <code>stablehlo.shape_assertion</code> operations and for <code>shape_assertions</code> specified in <code>disabled_checks</code>. See <a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism">Errors in presence of shape polymorphism</a>. Supported by XlaCallModule since July 12th, 2023 (cl/547482522), available in JAX serialization since July 20th, 2023 (JAX 0.4.14), and the default since August 12th, 2023 (JAX 0.4.15).</li> <li>Version 8 adds support for the <code>jax.uses_shape_polymorphism</code> module attribute and enables the shape refinement pass only when the attribute is present. Supported by XlaCallModule since July 21st, 2023 (cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14), and the default since October 21st, 2023 (JAX 0.4.20).</li> <li>Version 9 adds support for effects. See the docstring for <code>export.Exported</code> for the precise calling convention. In this serialization version we also tag the platform index and the dimension variables arguments with <code>jax.global_constant</code> attributes. Supported by XlaCallModule since October 27th, 2023, available in JAX since October 20th, 2023 (JAX 0.4.20), and the default since February 1st, 2024 (JAX 0.4.24). This is the only supported version as of 27th of March, 2024.</li> </ul> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Known issues</h2><a id="user-content-known-issues" class="anchor" aria-label="Permalink: Known issues" href="#known-issues"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto"><code>jax2tf</code> has been in use since 2020 and the vast majority of users encounter no problems. However, there are a few rare corner cases in which the different conventions of JAX and TensorFlow result in a breakage. We try to give an exhaustive list below, specifying whether the limitations apply to the native serialization or non-native.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Different 64-bit precision in JAX and TensorFlow</h3><a id="user-content-different-64-bit-precision-in-jax-and-tensorflow" class="anchor" aria-label="Permalink: Different 64-bit precision in JAX and TensorFlow" href="#different-64-bit-precision-in-jax-and-tensorflow"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to both native and non-native serialization.</p> <p dir="auto">JAX behaves somewhat differently than TensorFlow in the handling of 32-bit vs. 64-bit values. However, the <code>jax2tf</code> lowered function always behaves like the JAX function.</p> <p dir="auto">JAX interprets the type of Python scalars differently based on <code>JAX_ENABLE_X64</code> flag. (See <a href="https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision" rel="nofollow">JAX - The Sharp Bits: Double (64bit) precision</a>.) In the default configuration, the flag is unset, and JAX interprets Python constants as 32-bit, e.g., the type of <code>3.14</code> is <code>float32</code>. This is also what TensorFlow always does. JAX goes further, it forces all explicitly-specified 64-bit values to be interpreted as 32-bit:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="# with JAX_ENABLE_X64=0 jnp.sin(3.14) # Has type float32 tf.math.sin(3.14) # Has type float32 jnp.sin(np.float64(3.14)) # Also has type float32 tf.math.sin(np.float64(3.14)) # Has type float64 # The jax2tf.convert function behaves like the JAX function. jax2tf.convert(jnp.sin)(3.14) # Has type float32 jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32 # The following will still compute `sin` in float32 (with a tf.cast on the argument). tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64))"><pre><span class="pl-c"># with JAX_ENABLE_X64=0</span> <span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>(<span class="pl-c1">3.14</span>) <span class="pl-c"># Has type float32</span> <span class="pl-s1">tf</span>.<span class="pl-c1">math</span>.<span class="pl-c1">sin</span>(<span class="pl-c1">3.14</span>) <span class="pl-c"># Has type float32</span> <span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>(<span class="pl-s1">np</span>.<span class="pl-c1">float64</span>(<span class="pl-c1">3.14</span>)) <span class="pl-c"># Also has type float32</span> <span class="pl-s1">tf</span>.<span class="pl-c1">math</span>.<span class="pl-c1">sin</span>(<span class="pl-s1">np</span>.<span class="pl-c1">float64</span>(<span class="pl-c1">3.14</span>)) <span class="pl-c"># Has type float64</span> <span class="pl-c"># The jax2tf.convert function behaves like the JAX function.</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>)(<span class="pl-c1">3.14</span>) <span class="pl-c"># Has type float32</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>)(<span class="pl-s1">np</span>.<span class="pl-c1">float64</span>(<span class="pl-c1">3.14</span>)) <span class="pl-c"># Has type float32</span> <span class="pl-c"># The following will still compute `sin` in float32 (with a tf.cast on the argument).</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>)(<span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-c1">3.14</span>, <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">tf</span>.<span class="pl-c1">float64</span>))</pre></div> <p dir="auto">When the <code>JAX_ENABLE_X64</code> flag is set, JAX uses 64-bit types for Python scalars and respects the explicit 64-bit types:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="# with JAX_ENABLE_X64=1 jnp.sin(3.14) # Has type float64 tf.math.sin(3.14) # Has type float32 # The jax2tf.convert function behaves like the JAX function. jax2tf.convert(jnp.sin)(3.14) # Has type float64 # The following will compute `sin` in float64. tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64)) # The following will compute `sin` in float32. tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14))"><pre><span class="pl-c"># with JAX_ENABLE_X64=1</span> <span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>(<span class="pl-c1">3.14</span>) <span class="pl-c"># Has type float64</span> <span class="pl-s1">tf</span>.<span class="pl-c1">math</span>.<span class="pl-c1">sin</span>(<span class="pl-c1">3.14</span>) <span class="pl-c"># Has type float32</span> <span class="pl-c"># The jax2tf.convert function behaves like the JAX function.</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>)(<span class="pl-c1">3.14</span>) <span class="pl-c"># Has type float64</span> <span class="pl-c"># The following will compute `sin` in float64.</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>)(<span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-c1">3.14</span>, <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">tf</span>.<span class="pl-c1">float64</span>)) <span class="pl-c"># The following will compute `sin` in float32.</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>)(<span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-c1">3.14</span>))</pre></div> <p dir="auto">This is achieved by inserting <code>tf.cast</code> operations on the input arguments inside the lowered function, if necessary.</p> <p dir="auto">If you want to create a <code>tf.Variable</code> or <code>tf.TensorSpec</code> with the same dtype, you should use <code>jax2tf.dtype_of_val</code>:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="# The following two calls will lower jax_fun at the same dtypes # independently of the value of JAX_ENABLE_X64. jax2tf.convert(jax_fun)(3.14) jax2tf.convert(jax_fun)(tf.Variable(3.14, dtype=jax2tf.dtype_of_val(3.14)))"><pre><span class="pl-c"># The following two calls will lower jax_fun at the same dtypes</span> <span class="pl-c"># independently of the value of JAX_ENABLE_X64.</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jax_fun</span>)(<span class="pl-c1">3.14</span>) <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jax_fun</span>)(<span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-c1">3.14</span>, <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">jax2tf</span>.<span class="pl-c1">dtype_of_val</span>(<span class="pl-c1">3.14</span>)))</pre></div> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Functions whose arguments and results are nested Python data structures</h3><a id="user-content-functions-whose-arguments-and-results-are-nested-python-data-structures" class="anchor" aria-label="Permalink: Functions whose arguments and results are nested Python data structures" href="#functions-whose-arguments-and-results-are-nested-python-data-structures"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to both native and non-native serialization.</p> <p dir="auto"><code>jax2tf</code> can lower functions with arguments and results that are nested collections (tuples, lists, dictionaries) of numeric values or JAX arrays (<a href="https://jax.readthedocs.io/en/latest/pytrees.html" rel="nofollow">pytrees</a>). The resulting TensorFlow function will take the same kind of arguments except the leaves can be numeric values or TensorFlow tensors (<code>tf.Tensor</code>, <code>tf.TensorSpec</code>, <code>tf.Variable</code>).</p> <p dir="auto">As long as the arguments use only standard Python containers (tuple, list, dictionaries), both JAX and TensorFlow can flatten and unflatten them and you can use the lowered function in TensorFlow without limitations.</p> <p dir="auto">However, if your JAX function takes a custom container, you can register it with the JAX <code>tree_util</code> module so that JAX will know how to operate with it, and you can still lower the function to use it in TensorFlow eager and with <code>tf.function</code>, but you won't be able to save it to a SavedModel, nor will you be able to compute gradients with TensorFlow (code from <code>jax2tf_test.test_custom_pytree_readme</code>):</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="class CustomPair: def __init__(self, a, b): self.a = a self.b = b # Register it with the JAX tree_util module jax.tree_util.register_pytree_node(CustomPair, lambda x: ((x.a, x.b), None), lambda _, ab: CustomPair(*ab)) def f_jax(pair: CustomPair): return 2. * pair.a + 3. * pair.b x = CustomPair(4., 5.) res_jax = f_jax(x) # TF execution works as long as JAX can flatten the arguments res_tf = jax2tf.convert(f_jax)(x) self.assertAllClose(res_jax, res_tf.numpy()) res_tf_2 = tf.function(jax2tf.convert(f_jax), autograph=False, jit_compile=True)(x)"><pre><span class="pl-k">class</span> <span class="pl-v">CustomPair</span>: <span class="pl-k">def</span> <span class="pl-en">__init__</span>(<span class="pl-s1">self</span>, <span class="pl-s1">a</span>, <span class="pl-s1">b</span>): <span class="pl-s1">self</span>.<span class="pl-c1">a</span> <span class="pl-c1">=</span> <span class="pl-s1">a</span> <span class="pl-s1">self</span>.<span class="pl-c1">b</span> <span class="pl-c1">=</span> <span class="pl-s1">b</span> <span class="pl-c"># Register it with the JAX tree_util module</span> <span class="pl-s1">jax</span>.<span class="pl-c1">tree_util</span>.<span class="pl-c1">register_pytree_node</span>(<span class="pl-v">CustomPair</span>, <span class="pl-k">lambda</span> <span class="pl-s1">x</span>: ((<span class="pl-s1">x</span>.<span class="pl-c1">a</span>, <span class="pl-s1">x</span>.<span class="pl-c1">b</span>), <span class="pl-c1">None</span>), <span class="pl-k">lambda</span> <span class="pl-s1">_</span>, <span class="pl-s1">ab</span>: <span class="pl-en">CustomPair</span>(<span class="pl-c1">*</span><span class="pl-s1">ab</span>)) <span class="pl-k">def</span> <span class="pl-en">f_jax</span>(<span class="pl-s1">pair</span>: <span class="pl-smi">CustomPair</span>): <span class="pl-k">return</span> <span class="pl-c1">2.</span> <span class="pl-c1">*</span> <span class="pl-s1">pair</span>.<span class="pl-c1">a</span> <span class="pl-c1">+</span> <span class="pl-c1">3.</span> <span class="pl-c1">*</span> <span class="pl-s1">pair</span>.<span class="pl-c1">b</span> <span class="pl-s1">x</span> <span class="pl-c1">=</span> <span class="pl-en">CustomPair</span>(<span class="pl-c1">4.</span>, <span class="pl-c1">5.</span>) <span class="pl-s1">res_jax</span> <span class="pl-c1">=</span> <span class="pl-en">f_jax</span>(<span class="pl-s1">x</span>) <span class="pl-c"># TF execution works as long as JAX can flatten the arguments</span> <span class="pl-s1">res_tf</span> <span class="pl-c1">=</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>)(<span class="pl-s1">x</span>) <span class="pl-s1">self</span>.<span class="pl-c1">assertAllClose</span>(<span class="pl-s1">res_jax</span>, <span class="pl-s1">res_tf</span>.<span class="pl-c1">numpy</span>()) <span class="pl-s1">res_tf_2</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>, <span class="pl-s1">jit_compile</span><span class="pl-c1">=</span><span class="pl-c1">True</span>)(<span class="pl-s1">x</span>)</pre></div> <p dir="auto">If you want to save the function in a SavedModel or compute gradients, you should construct a wrapper:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content=" # wrapped TF function to use only standard containers def f_tf_wrapped(a, b): return f_tf(CustomPair(a, b)) # Try to put into SavedModel my_model = tf.Module() # Save a function that can take scalar inputs. my_model.f = tf.function(f_tf_wrapped, autograph=False, input_signature=[tf.TensorSpec([], tf.float32), tf.TensorSpec([], tf.float32)]) model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(my_model))) tf.saved_model.save(my_model, model_dir, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True)) # Restoring (note: the restored model does *not* require JAX to run, just XLA). restored_model = tf.saved_model.load(model_dir) def restored_f(pair: CustomPair): return restored_model.f(pair.a, pair.b) res_tf_3 = restored_f(x) self.assertAllClose(res_jax, res_tf_3) grad_jax = jax.grad(f_jax)(x) x_v = [tf.Variable(x.a), tf.Variable(x.b)] with tf.GradientTape() as tape: res = f_tf_wrapped(*x_v) grad_tf = tape.gradient(res, x_v) self.assertAllClose(grad_jax.a, grad_tf[0]) self.assertAllClose(grad_jax.b, grad_tf[1])"><pre> <span class="pl-c"># wrapped TF function to use only standard containers</span> <span class="pl-k">def</span> <span class="pl-en">f_tf_wrapped</span>(<span class="pl-s1">a</span>, <span class="pl-s1">b</span>): <span class="pl-k">return</span> <span class="pl-en">f_tf</span>(<span class="pl-en">CustomPair</span>(<span class="pl-s1">a</span>, <span class="pl-s1">b</span>)) <span class="pl-c"># Try to put into SavedModel</span> <span class="pl-s1">my_model</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">Module</span>() <span class="pl-c"># Save a function that can take scalar inputs.</span> <span class="pl-s1">my_model</span>.<span class="pl-c1">f</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">f_tf_wrapped</span>, <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>, <span class="pl-s1">input_signature</span><span class="pl-c1">=</span>[<span class="pl-s1">tf</span>.<span class="pl-c1">TensorSpec</span>([], <span class="pl-s1">tf</span>.<span class="pl-c1">float32</span>), <span class="pl-s1">tf</span>.<span class="pl-c1">TensorSpec</span>([], <span class="pl-s1">tf</span>.<span class="pl-c1">float32</span>)]) <span class="pl-s1">model_dir</span> <span class="pl-c1">=</span> <span class="pl-s1">os</span>.<span class="pl-c1">path</span>.<span class="pl-c1">join</span>(<span class="pl-s1">absltest</span>.<span class="pl-c1">get_default_test_tmpdir</span>(), <span class="pl-en">str</span>(<span class="pl-en">id</span>(<span class="pl-s1">my_model</span>))) <span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">save</span>(<span class="pl-s1">my_model</span>, <span class="pl-s1">model_dir</span>, <span class="pl-s1">options</span><span class="pl-c1">=</span><span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">SaveOptions</span>(<span class="pl-s1">experimental_custom_gradients</span><span class="pl-c1">=</span><span class="pl-c1">True</span>)) <span class="pl-c"># Restoring (note: the restored model does *not* require JAX to run, just XLA).</span> <span class="pl-s1">restored_model</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">saved_model</span>.<span class="pl-c1">load</span>(<span class="pl-s1">model_dir</span>) <span class="pl-k">def</span> <span class="pl-en">restored_f</span>(<span class="pl-s1">pair</span>: <span class="pl-smi">CustomPair</span>): <span class="pl-k">return</span> <span class="pl-s1">restored_model</span>.<span class="pl-c1">f</span>(<span class="pl-s1">pair</span>.<span class="pl-c1">a</span>, <span class="pl-s1">pair</span>.<span class="pl-c1">b</span>) <span class="pl-s1">res_tf_3</span> <span class="pl-c1">=</span> <span class="pl-en">restored_f</span>(<span class="pl-s1">x</span>) <span class="pl-s1">self</span>.<span class="pl-c1">assertAllClose</span>(<span class="pl-s1">res_jax</span>, <span class="pl-s1">res_tf_3</span>) <span class="pl-s1">grad_jax</span> <span class="pl-c1">=</span> <span class="pl-s1">jax</span>.<span class="pl-c1">grad</span>(<span class="pl-s1">f_jax</span>)(<span class="pl-s1">x</span>) <span class="pl-s1">x_v</span> <span class="pl-c1">=</span> [<span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-s1">x</span>.<span class="pl-c1">a</span>), <span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-s1">x</span>.<span class="pl-c1">b</span>)] <span class="pl-k">with</span> <span class="pl-s1">tf</span>.<span class="pl-c1">GradientTape</span>() <span class="pl-k">as</span> <span class="pl-s1">tape</span>: <span class="pl-s1">res</span> <span class="pl-c1">=</span> <span class="pl-en">f_tf_wrapped</span>(<span class="pl-c1">*</span><span class="pl-s1">x_v</span>) <span class="pl-s1">grad_tf</span> <span class="pl-c1">=</span> <span class="pl-s1">tape</span>.<span class="pl-c1">gradient</span>(<span class="pl-s1">res</span>, <span class="pl-s1">x_v</span>) <span class="pl-s1">self</span>.<span class="pl-c1">assertAllClose</span>(<span class="pl-s1">grad_jax</span>.<span class="pl-c1">a</span>, <span class="pl-s1">grad_tf</span>[<span class="pl-c1">0</span>]) <span class="pl-s1">self</span>.<span class="pl-c1">assertAllClose</span>(<span class="pl-s1">grad_jax</span>.<span class="pl-c1">b</span>, <span class="pl-s1">grad_tf</span>[<span class="pl-c1">1</span>])</pre></div> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Lowering gradients for functions with integer arguments or unused arguments</h3><a id="user-content-lowering-gradients-for-functions-with-integer-arguments-or-unused-arguments" class="anchor" aria-label="Permalink: Lowering gradients for functions with integer arguments or unused arguments" href="#lowering-gradients-for-functions-with-integer-arguments-or-unused-arguments"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to both native and non-native serialization.</p> <p dir="auto">When JAX differentiates functions with integer or boolean arguments, the gradients will be zero-vectors with a special <code>float0</code> type (see PR 4039](<a class="issue-link js-issue-link" data-error-text="Failed to load title" data-id="677939591" data-permission-text="Title is private" data-url="https://github.com/jax-ml/jax/issues/4039" data-hovercard-type="pull_request" data-hovercard-url="/jax-ml/jax/pull/4039/hovercard" href="https://github.com/jax-ml/jax/pull/4039">#4039</a>)). This type is translated to <code>int32</code> when lowering to TF. For example,</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="x = np.int16(2) def f_jax(x): # x: int16 return x * 2. jax.grad(f_jax, allow_int=True)(x) # returns a special `float0`: array((b'',), dtype=[('float0', 'V')]) jax2tf.convert(jax.grad(f_jax, allow_int=True))(x) # returns a tf.Tensor(0, shape=(), dtype=int32)"><pre><span class="pl-s1">x</span> <span class="pl-c1">=</span> <span class="pl-s1">np</span>.<span class="pl-c1">int16</span>(<span class="pl-c1">2</span>) <span class="pl-k">def</span> <span class="pl-en">f_jax</span>(<span class="pl-s1">x</span>): <span class="pl-c"># x: int16</span> <span class="pl-k">return</span> <span class="pl-s1">x</span> <span class="pl-c1">*</span> <span class="pl-c1">2.</span> <span class="pl-s1">jax</span>.<span class="pl-c1">grad</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">allow_int</span><span class="pl-c1">=</span><span class="pl-c1">True</span>)(<span class="pl-s1">x</span>) <span class="pl-c"># returns a special `float0`: array((b'',), dtype=[('float0', 'V')])</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jax</span>.<span class="pl-c1">grad</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">allow_int</span><span class="pl-c1">=</span><span class="pl-c1">True</span>))(<span class="pl-s1">x</span>) <span class="pl-c"># returns a tf.Tensor(0, shape=(), dtype=int32)</span></pre></div> <p dir="auto">Note that this is different from how TensorFlow handles gradients for integer or boolean arguments: sometimes the gradient is <code>None</code>, sometimes it is a zero with the same dtype as the argument, and sometimes it is a one with the same dtype as the argument (e.g., for the identity function).</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def f_tf(x): # x: int16 return tf.cast(x, tf.float32) * 2. xv = tf.Variable(x) with tf.GradientTape(persistent=True) as tape: print(tape.gradient(f_tf(xv), xv)) # returns None print(tape.gradient(f_tf(xv), xv, unconnected_gradients=tf.UnconnectedGradients.ZERO)) # returns 0 with the same shape and dtype as x"><pre><span class="pl-k">def</span> <span class="pl-en">f_tf</span>(<span class="pl-s1">x</span>): <span class="pl-c"># x: int16</span> <span class="pl-k">return</span> <span class="pl-s1">tf</span>.<span class="pl-c1">cast</span>(<span class="pl-s1">x</span>, <span class="pl-s1">tf</span>.<span class="pl-c1">float32</span>) <span class="pl-c1">*</span> <span class="pl-c1">2.</span> <span class="pl-s1">xv</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-s1">x</span>) <span class="pl-k">with</span> <span class="pl-s1">tf</span>.<span class="pl-c1">GradientTape</span>(<span class="pl-s1">persistent</span><span class="pl-c1">=</span><span class="pl-c1">True</span>) <span class="pl-k">as</span> <span class="pl-s1">tape</span>: <span class="pl-en">print</span>(<span class="pl-s1">tape</span>.<span class="pl-c1">gradient</span>(<span class="pl-en">f_tf</span>(<span class="pl-s1">xv</span>), <span class="pl-s1">xv</span>)) <span class="pl-c"># returns None</span> <span class="pl-en">print</span>(<span class="pl-s1">tape</span>.<span class="pl-c1">gradient</span>(<span class="pl-en">f_tf</span>(<span class="pl-s1">xv</span>), <span class="pl-s1">xv</span>, <span class="pl-s1">unconnected_gradients</span><span class="pl-c1">=</span><span class="pl-s1">tf</span>.<span class="pl-c1">UnconnectedGradients</span>.<span class="pl-c1">ZERO</span>)) <span class="pl-c"># returns 0 with the same shape and dtype as x</span></pre></div> <p dir="auto">When differentiating functions with unused arguments, TF by default returns the value <code>None</code> for the corresponding gradients. The <code>tape.gradient</code> function takes the option <code>tf.UnconnectedGradients.ZERO</code> to ask that gradients for unused arguments be zero.</p> <p dir="auto">Functions lowered with <code>jax2tf.convert</code> behave the same way under <code>tf.UnconnectedGradients.ZERO</code>, but by default, they will return <code>None</code> only for gradients corresponding to integer arguments.</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="# x1 and x3 are not used. x3 has integer type. def fn(x0, x1, x2, x3): return x0 * 0. + x2 * 2. xs = [tf.Variable(x) for x in [10., 11., 12., 13]] with tf.GradientTape(persistent=True) as tape: res = fn(*xs) g_tf_native = tape.gradient(res, xs) # Returns: 0., None, 2., None g_tf_native_0 = tape.gradient(res, xs, unconnected_gradients=tf.UnconnectedGradients.ZERO) # Returns: 0., 0., 2., 0 # Now with jax2tf.convert with tf.GradientTape() as tape: res = jax2tf.convert(fn, with_gradient=True)(*xs) g_jax2tf = tape.gradient(res, xs) # Returns: 0., 0., 2., None # Note that the gradient for x1 is 0. g_jax2tf_0 = tape.gradient(res, xs, unconnected_gradients=tf.UnconnectedGradients.ZERO) # Returns: 0., 0., 2., 0 # In this case we get the same result as for TF native."><pre><span class="pl-c"># x1 and x3 are not used. x3 has integer type.</span> <span class="pl-k">def</span> <span class="pl-en">fn</span>(<span class="pl-s1">x0</span>, <span class="pl-s1">x1</span>, <span class="pl-s1">x2</span>, <span class="pl-s1">x3</span>): <span class="pl-k">return</span> <span class="pl-s1">x0</span> <span class="pl-c1">*</span> <span class="pl-c1">0.</span> <span class="pl-c1">+</span> <span class="pl-s1">x2</span> <span class="pl-c1">*</span> <span class="pl-c1">2.</span> <span class="pl-s1">xs</span> <span class="pl-c1">=</span> [<span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-s1">x</span>) <span class="pl-k">for</span> <span class="pl-s1">x</span> <span class="pl-c1">in</span> [<span class="pl-c1">10.</span>, <span class="pl-c1">11.</span>, <span class="pl-c1">12.</span>, <span class="pl-c1">13</span>]] <span class="pl-k">with</span> <span class="pl-s1">tf</span>.<span class="pl-c1">GradientTape</span>(<span class="pl-s1">persistent</span><span class="pl-c1">=</span><span class="pl-c1">True</span>) <span class="pl-k">as</span> <span class="pl-s1">tape</span>: <span class="pl-s1">res</span> <span class="pl-c1">=</span> <span class="pl-en">fn</span>(<span class="pl-c1">*</span><span class="pl-s1">xs</span>) <span class="pl-s1">g_tf_native</span> <span class="pl-c1">=</span> <span class="pl-s1">tape</span>.<span class="pl-c1">gradient</span>(<span class="pl-s1">res</span>, <span class="pl-s1">xs</span>) <span class="pl-c"># Returns: 0., None, 2., None</span> <span class="pl-s1">g_tf_native_0</span> <span class="pl-c1">=</span> <span class="pl-s1">tape</span>.<span class="pl-c1">gradient</span>(<span class="pl-s1">res</span>, <span class="pl-s1">xs</span>, <span class="pl-s1">unconnected_gradients</span><span class="pl-c1">=</span><span class="pl-s1">tf</span>.<span class="pl-c1">UnconnectedGradients</span>.<span class="pl-c1">ZERO</span>) <span class="pl-c"># Returns: 0., 0., 2., 0</span> <span class="pl-c"># Now with jax2tf.convert</span> <span class="pl-k">with</span> <span class="pl-s1">tf</span>.<span class="pl-c1">GradientTape</span>() <span class="pl-k">as</span> <span class="pl-s1">tape</span>: <span class="pl-s1">res</span> <span class="pl-c1">=</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">fn</span>, <span class="pl-s1">with_gradient</span><span class="pl-c1">=</span><span class="pl-c1">True</span>)(<span class="pl-c1">*</span><span class="pl-s1">xs</span>) <span class="pl-s1">g_jax2tf</span> <span class="pl-c1">=</span> <span class="pl-s1">tape</span>.<span class="pl-c1">gradient</span>(<span class="pl-s1">res</span>, <span class="pl-s1">xs</span>) <span class="pl-c"># Returns: 0., 0., 2., None</span> <span class="pl-c"># Note that the gradient for x1 is 0.</span> <span class="pl-s1">g_jax2tf_0</span> <span class="pl-c1">=</span> <span class="pl-s1">tape</span>.<span class="pl-c1">gradient</span>(<span class="pl-s1">res</span>, <span class="pl-s1">xs</span>, <span class="pl-s1">unconnected_gradients</span><span class="pl-c1">=</span><span class="pl-s1">tf</span>.<span class="pl-c1">UnconnectedGradients</span>.<span class="pl-c1">ZERO</span>) <span class="pl-c"># Returns: 0., 0., 2., 0</span> <span class="pl-c"># In this case we get the same result as for TF native.</span></pre></div> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Errors due to tf.Module magic conversion during attribute assignment</h3><a id="user-content-errors-due-to-tfmodule-magic-conversion-during-attribute-assignment" class="anchor" aria-label="Permalink: Errors due to tf.Module magic conversion during attribute assignment" href="#errors-due-to-tfmodule-magic-conversion-during-attribute-assignment"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to both native and non-native serialization.</p> <p dir="auto"><code>tf.Module</code> will automatically wrap the standard Python container data types into trackable classes during attribute assignment. Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper classes. In most situations, these Wrapper classes work exactly as the standard Python data types. However, the low-level pytree data structures are different and this can lead to errors.</p> <p dir="auto">In such cases, the user can use this workaround:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="import tensorflow as tf input_data = #Any data object m = tf.Module() flat, tree_def = jax.tree_util.tree_flatten(input_data) m.input_data = {"flat": flat, "tree_def": tree_def}"><pre><span class="pl-k">import</span> <span class="pl-s1">tensorflow</span> <span class="pl-k">as</span> <span class="pl-s1">tf</span> <span class="pl-s1">input_data</span> <span class="pl-c1">=</span> <span class="pl-c">#Any data object</span> <span class="pl-s1">m</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">Module</span>() <span class="pl-s1">flat</span>, <span class="pl-s1">tree_def</span> <span class="pl-c1">=</span> <span class="pl-s1">jax</span>.<span class="pl-c1">tree_util</span>.<span class="pl-c1">tree_flatten</span>(<span class="pl-s1">input_data</span>) <span class="pl-s1">m</span>.<span class="pl-c1">input_data</span> <span class="pl-c1">=</span> {<span class="pl-s">"flat"</span>: <span class="pl-s1">flat</span>, <span class="pl-s">"tree_def"</span>: <span class="pl-s1">tree_def</span>}</pre></div> <p dir="auto">Later the user can use <code>tree_unflatten</code> for the reverse process:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="input_data = jax.tree_util.tree_unflatten(m.input_data['tree_def'], m.input_data['flat'])"><pre><span class="pl-s1">input_data</span> <span class="pl-c1">=</span> <span class="pl-s1">jax</span>.<span class="pl-c1">tree_util</span>.<span class="pl-c1">tree_unflatten</span>(<span class="pl-s1">m</span>.<span class="pl-c1">input_data</span>[<span class="pl-s">'tree_def'</span>], <span class="pl-s1">m</span>.<span class="pl-c1">input_data</span>[<span class="pl-s">'flat'</span>])</pre></div> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Large saved_model.pb due too many PRNG operations</h3><a id="user-content-large-saved_modelpb-due-too-many-prng-operations" class="anchor" aria-label="Permalink: Large saved_model.pb due too many PRNG operations" href="#large-saved_modelpb-due-too-many-prng-operations"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to both native and non-native serialization.</p> <p dir="auto">The default <code>threefry2x32</code> PRNG is implemented in JAX with dozens of additions and bitwise operations. This means that a single PRNG operation in JAX will result in dozens of TF ops after jax2tf. If the number of RPNG operations is large, the generated TF graph will be very large.</p> <p dir="auto">To reduce the TF graph size and the compilation time one can use the <code>unsafe_rbg</code> PRNG implementation by setting <code>jax.config.update('jax_default_prng_impl', 'unsafe_rbg')</code>. The <code>unsafe_rbg</code> implementation will be lowered to a TF op and several casts and reshapes, thus significantly reducing the number of TF ops per PRNG operation. The "unsafe" part is that it doesn't guarantee determinism across JAX/XLA versions, and the quality of random streams it generates from different keys is less well understood. Nevertheless, this should be fine for most inference/serving cases. See more details in the <a href="https://jax.readthedocs.io/en/latest/jax.random.html?highlight=unsafe_rbg#advanced-rng-configuration" rel="nofollow">JAX PRNG documentation</a>.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">SavedModel supports only first-order gradients</h3><a id="user-content-savedmodel-supports-only-first-order-gradients" class="anchor" aria-label="Permalink: SavedModel supports only first-order gradients" href="#savedmodel-supports-only-first-order-gradients"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to both native and non-native serialization.</p> <p dir="auto">The <code>jax2tf</code>-lowered function supports higher-order gradients, but when the function is saved in a SavedModel, only the first-order gradient is saved. This is primarily a limitation of the SavedModel support for custom gradients.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Native serialization supports only select dialects</h3><a id="user-content-native-serialization-supports-only-select-dialects" class="anchor" aria-label="Permalink: Native serialization supports only select dialects" href="#native-serialization-supports-only-select-dialects"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to native serialization only.</p> <p dir="auto">JAX native serialization checks that the code to be serialized contains operations only from MLIR dialects that are known to have stability guarantees, e.g., StableHLO, and the "builtin" dialect. As an exception, it also accepts operations from the MHLO dialect, but they are converted to corresponding StableHLO operations upon serialization.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Native serialization supports only select custom calls</h3><a id="user-content-native-serialization-supports-only-select-custom-calls" class="anchor" aria-label="Permalink: Native serialization supports only select custom calls" href="#native-serialization-supports-only-select-custom-calls"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to native serialization only.</p> <p dir="auto">JAX natively uses custom calls for lowering of certain primitives. The most common example is for the implementation of PRNG on GPUs, where we get better performance with a custom call (<code>cu_threefry32</code>) than if we use native StableHLO. Another class of examples are for FFT and some linear algebra primitives (e.g., QR decomposition).</p> <p dir="auto">Unlike regular StableHLO ops, the compatibility guarantees for custom calls are the burden of the teams maintaining the C++ code that backs the custom call. For this reason, we maintain a list of allowed custom call targets. If you try to serialize code that invokes other targets you will get an error.</p> <p dir="auto">If you want to disable this safety check for a specific custom call with target <code>my_target</code>, you can add <code>jax2tf.DisabledSafetyCheck.custom_call("my_target")</code> to the <code>disabled_checks</code> parameter of the <code>jax2tf</code> function.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">XlaCallModule not supported by some TensorFlow tools</h3><a id="user-content-xlacallmodule-not-supported-by-some-tensorflow-tools" class="anchor" aria-label="Permalink: XlaCallModule not supported by some TensorFlow tools" href="#xlacallmodule-not-supported-by-some-tensorflow-tools"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to native serialization only.</p> <p dir="auto">JAX native serialization uses the <code>XlaCallModule</code> TensorFlow op to host the StableHLO program obtained from JAX. This is a relatively new TensorFlow op and may not be supported by some tools. In fact, certain tools that need to do <code>tf.Graph</code> inspection and transformation cannot work when the whole JAX program is a single TensorFlow op.</p> <p dir="auto">This is the case, for example, for the TFLite and TensorFlow.js converters. There is work underway to enable more tools to consume StableHLO.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Natively serialized JAX modules are platform specific</h3><a id="user-content-natively-serialized-jax-modules-are-platform-specific" class="anchor" aria-label="Permalink: Natively serialized JAX modules are platform specific" href="#natively-serialized-jax-modules-are-platform-specific"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to native serialization only.</p> <p dir="auto">When you use native serialization, JAX will record the platform for which the module was serialized, and you will get an error if you try to execute the <code>XlaCallModule</code> TensorFlow op on another platform.</p> <p dir="auto">Note that this error will only arise in native serialization; with non-native serialization the lowering to TensorFlow ops is platform independent, although it is only guaranteed to match the JAX semantics and performance behavior for TPUs.</p> <p dir="auto">The error has the form:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="The current platform CPU is not among the platforms required by the module [CUDA]"><pre lang="commandline" class="notranslate"><code>The current platform CPU is not among the platforms required by the module [CUDA] </code></pre></div> <p dir="auto">where <code>CPU</code> is the TensorFlow platform where the op is being executed and <code>CUDA</code> is the platform for which the module was serialized by JAX. This probably means that JAX and TensorFlow may see different devices as the default device (JAX defaults to GPU and TensorFlow to CPU in the example error above). You can check what devices TensorFlow uses:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="logging.info("All TF devices: %s", tf.config.list_logical_devices()) tf_device = (tf.config.list_logical_devices("TPU") + tf.config.list_logical_devices("GPU") + tf.config.list_logical_devices())[0] assert jax.default_backend().upper() == tf_device.device_type with tf.device(tf_device): ..."><pre><span class="pl-s1">logging</span>.<span class="pl-c1">info</span>(<span class="pl-s">"All TF devices: %s"</span>, <span class="pl-s1">tf</span>.<span class="pl-c1">config</span>.<span class="pl-c1">list_logical_devices</span>()) <span class="pl-s1">tf_device</span> <span class="pl-c1">=</span> (<span class="pl-s1">tf</span>.<span class="pl-c1">config</span>.<span class="pl-c1">list_logical_devices</span>(<span class="pl-s">"TPU"</span>) <span class="pl-c1">+</span> <span class="pl-s1">tf</span>.<span class="pl-c1">config</span>.<span class="pl-c1">list_logical_devices</span>(<span class="pl-s">"GPU"</span>) <span class="pl-c1">+</span> <span class="pl-s1">tf</span>.<span class="pl-c1">config</span>.<span class="pl-c1">list_logical_devices</span>())[<span class="pl-c1">0</span>] <span class="pl-k">assert</span> <span class="pl-s1">jax</span>.<span class="pl-c1">default_backend</span>().<span class="pl-c1">upper</span>() <span class="pl-c1">==</span> <span class="pl-s1">tf_device</span>.<span class="pl-c1">device_type</span> <span class="pl-k">with</span> <span class="pl-s1">tf</span>.<span class="pl-c1">device</span>(<span class="pl-s1">tf_device</span>): ...</pre></div> <p dir="auto">Users should pay attention to another case, which is that they must use <code>jit_compile=True</code> in order to execute on TPU.</p> <p dir="auto">Because if <code>jit_compile=False</code>, TF "executes the function without XLA compilation. Set this value to False when directly running a multi-device function on TPUs (e.g. two TPU cores, one TPU core and its host CPU)" (see <a href="https://www.tensorflow.org/api_docs/python/tf/function" rel="nofollow">TF doc</a>)</p> <p dir="auto">With <code>jit_compile=False</code> the converted TF program will be executed on CPU instead of TPU and this will result in an error message</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="Node: 'XlaCallModule' The current platform CPU is not among the platforms required by the module: [TPU] [[{{node XlaCallModule}}]]"><pre class="notranslate"><code>Node: 'XlaCallModule' The current platform CPU is not among the platforms required by the module: [TPU] [[{{node XlaCallModule}}]] </code></pre></div> <p dir="auto">To work around this on <code>jit_compile=False</code>, you can wrap your function with a new tf.function that explicitly assigns the TPU device, like this:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="f_tf = jax2tf.convert(jnp.sin) x = np.float32(.5) @tf.function(autograph=False, jit_compile=False) def f_tf_wrapped(x): with tf.device('/device:TPU:0'): return f_tf(x) with tf.device('/device:TPU:0'): self.assertAllClose(np.sin(x), f_tf_wrapped(x))"><pre><span class="pl-s1">f_tf</span> <span class="pl-c1">=</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">jnp</span>.<span class="pl-c1">sin</span>) <span class="pl-s1">x</span> <span class="pl-c1">=</span> <span class="pl-s1">np</span>.<span class="pl-c1">float32</span>(<span class="pl-c1">.5</span>) <span class="pl-en">@<span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>, <span class="pl-s1">jit_compile</span><span class="pl-c1">=</span><span class="pl-c1">False</span>)</span> <span class="pl-k">def</span> <span class="pl-en">f_tf_wrapped</span>(<span class="pl-s1">x</span>): <span class="pl-k">with</span> <span class="pl-s1">tf</span>.<span class="pl-c1">device</span>(<span class="pl-s">'/device:TPU:0'</span>): <span class="pl-k">return</span> <span class="pl-en">f_tf</span>(<span class="pl-s1">x</span>) <span class="pl-k">with</span> <span class="pl-s1">tf</span>.<span class="pl-c1">device</span>(<span class="pl-s">'/device:TPU:0'</span>): <span class="pl-s1">self</span>.<span class="pl-c1">assertAllClose</span>(<span class="pl-s1">np</span>.<span class="pl-c1">sin</span>(<span class="pl-s1">x</span>), <span class="pl-en">f_tf_wrapped</span>(<span class="pl-s1">x</span>))</pre></div> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Unsupported JAX features</h3><a id="user-content-unsupported-jax-features" class="anchor" aria-label="Permalink: Unsupported JAX features" href="#unsupported-jax-features"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to non-native serialization only.</p> <p dir="auto">There is currently no support for <code>pmap</code>, <code>xmap</code>, <code>shard_map</code>, nor for the collective operations, except in native serialization.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Shape polymorphism with native serialization limitations for <code>lax.linalg.eigh</code></h3><a id="user-content-shape-polymorphism-with-native-serialization-limitations-for-laxlinalgeigh" class="anchor" aria-label="Permalink: Shape polymorphism with native serialization limitations for lax.linalg.eigh" href="#shape-polymorphism-with-native-serialization-limitations-for-laxlinalgeigh"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to native serialization only.</p> <p dir="auto">JAX lowers <code>lax.linalg.eigh</code> using custom calls, and needs to call helper functions to determine the workspace size based on the non-batch dimensions. Therefore, dynamic dimensions are supported only for the batch dimensions (all but the last two dimensions).</p> <p dir="auto">Additionally, on GPU, JAX lowering uses the <code>cuSolver</code> library and chooses <code>syevj</code> method (using Jacobi algorithm) for non-batch dimension size less or equal to 32, and the <code>syevd</code> method (using QR algorithm) for larger dimensions.</p> <p dir="auto">In presence of shape polymorphism, JAX will always use <code>syevd</code>, because <code>syevj</code> requires knowing the batch dimensions statically in order to compute the workspace size. This means that the performance and the numerical behavior may be slightly different for small matrices.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Slow implementation of associative reductions for CPU</h3><a id="user-content-slow-implementation-of-associative-reductions-for-cpu" class="anchor" aria-label="Permalink: Slow implementation of associative reductions for CPU" href="#slow-implementation-of-associative-reductions-for-cpu"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to non-native serialization only.</p> <p dir="auto">Operations like <code>jax.numpy.cumsum</code> are lowered by JAX differently based on the platform. For TPU, the lowering uses the <a href="https://www.tensorflow.org/xla/operation_semantics#reducewindow" rel="nofollow">HLO ReduceWindow</a> operation, which has an efficient implementation for the cases when the reduction function is associative. For CPU and GPU, JAX uses an alternative lowering using <a href="https://github.com/jax-ml/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801">associative scans</a>. jax2tf uses the TPU lowering (because it does not support backend-specific lowering) and hence it can be slow in some cases on CPU and GPU.</p> <p dir="auto">We have filed a bug with the XLA:CPU compiler to improve ReduceWindow. Meanwhile, if you run into this problem you can use the <code>--jax2tf_associative_scan_reductions</code> flag to get the special associative scan lowering. You can alternatively use the <code>with jax.jax2tf_associative_scan_reductions(True)</code> around the code that invokes the function returned by <code>jax2tf.convert</code>. Use this only if it improves the performance for your application.</p> <p dir="auto">Note that this lowering may not work as well as the default one in presence of shape polymorphism.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">TensorFlow XLA ops</h3><a id="user-content-tensorflow-xla-ops" class="anchor" aria-label="Permalink: TensorFlow XLA ops" href="#tensorflow-xla-ops"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to non-native serialization only.</p> <p dir="auto">For most JAX primitives there is a natural TensorFlow op that fits the needed semantics. There are a few (listed in <a href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md">no_xla_limitations.md</a>) JAX primitives for which there is no single TensorFlow op with matching semantics. This is not so surprising, because JAX primitives have been designed to be compiled to <a href="https://www.tensorflow.org/xla/operation_semantics" rel="nofollow">HLO ops</a>, while the corresponding TensorFlow ops are sometimes higher-level. For the cases when there is no matching canonical TensorFlow op, we use a set of special TensorFlow ops that are thin wrappers over HLO ops (a subset of those registered in <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/ops/xla_ops.cc">tf2xla/ops/xla_ops.cc</a> and implemented in, e.g., <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc">tf2xla/kernels/xla_pad_op.cc</a>.) We refer to these ops here as the XLA TensorFlow ops. Note that these are still regular TF ops, e.g., they can be saved in a SavedModel.</p> <p dir="auto">There are several drawbacks of using XLA TensorFlow ops:</p> <ul dir="auto"> <li>These ops will only be executable by a consumer that has XLA linked in. This should not be a problem for TPU execution, since that requires XLA anyway.</li> <li>These ops are not yet recognized by tools that process tf.Graph, e.g., TensorFlow.js converter or the TensorFlow Lite converter.</li> </ul> <p dir="auto">As an experimental feature we implemented alternative conversions to avoid the XLA TensorFlow ops. You can enable this with the <code>enable_xla=False</code> parameter to <code>jax2tf.convert</code>. For more details see <a href="/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md">no_xla_limitations.md</a>.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Different performance characteristics</h3><a id="user-content-different-performance-characteristics" class="anchor" aria-label="Permalink: Different performance characteristics" href="#different-performance-characteristics"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to non-native serialization only.</p> <p dir="auto">The lowered code may have slightly different performance characteristics than the original JAX code. We do expect that the performance characteristics of lowered code should be the same as those of JAX when used with the XLA compiler (<code>tf.function(jit_compile=True)</code>). This is because during lowering we try to generate one TensorFlow op for one JAX primitive. We expect that the lowering that XLA does is similar to that done by JAX before conversion. (This is a hypothesis, we have not yet verified it extensively.)</p> <p dir="auto">There is one known case when the performance of the lowered code will be different. JAX programs use a <a href="https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md">stateless deterministic PRNG</a> and it has an internal JAX primitive for it. This primitive is at the moment lowered to a soup of tf.bitwise operations, which has a clear performance penalty. We plan to look into using the HLO <a href="https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator" rel="nofollow">RNGBitGenerator</a> (exposed as a TFXLA op), which does implement the same basic Threefry algorithm as JAX’s PRNG, although that would result in different results than JAX’s PRNG.</p> <p dir="auto">In absence of TensorFlow XLA compilation, if one were to write the same functionality in JAX idiomatic code vs. native TensorFlow idiomatic code we could end up with very different compilation paths. Take for example, the case of batch normalization. In TensorFlow if one uses <a href="https://www.tensorflow.org/api_docs/python/tf/nn/batch_normalization" rel="nofollow">tf.nn.batch_normalization</a>, a “high-level” TensorFlow op for batch normalization is generated, and in the absence of XLA, on CPU or GPU, a custom C++ “high-level” kernel implementing batch normalization is executed. In JAX, there is no primitive for batch normalization, and instead the operation is decomposed into low-level primitives (e.g., <a href="https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html" rel="nofollow">flax.linen.BatchNorm</a>, or haiku.BatchNorm). Once those primitives are lowered to TensorFlow, and the resulting code is run without XLA, the ensemble of the kernels executed will quite possibly behave differently, performance-wise or even numerically, than either the TensorFlow native or JAX native batch normalization. A similar example is that of an LSTM cell.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Unchecked assumption that the dimension variables take strictly positive values</h3><a id="user-content-unchecked-assumption-that-the-dimension-variables-take-strictly-positive-values" class="anchor" aria-label="Permalink: Unchecked assumption that the dimension variables take strictly positive values" href="#unchecked-assumption-that-the-dimension-variables-take-strictly-positive-values"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to non-native serialization only.</p> <p dir="auto">The shape polymorphic conversion is sound with the assumption that the dimension variables take non-zero values. In the following example, the function to be lowered has different behavior for empty shapes. The broken assumption is caught by jax2tf if the lowered function is executed eagerly, but not if it is first traced to a TensorFlow graph:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def f_jax(x): return 0 if x.shape[0] == 0 else 1 x0 = np.array([], np.float32) self.assertEqual(0, f_jax(x0)) # JAX sees that the x.shape[0] == 0 # jax2tf catches the broken assumption b >= 1 if the lowered function is executed # eagerly. # Raises: ValueError: Dimension variable b must have integer value >= 1. Found value 0 when solving b == 0 jax2tf.convert(f_jax, polymorphic_shapes=["b"])(x0) # However, if we first trace to a TensorFlow graph, we may miss the broken assumption: f_tf = tf.function( jax2tf.convert(f_jax, polymorphic_shapes=["b"]), autograph=False ).get_concrete_function(tf.TensorSpec([None], dtype=np.float32)) self.assertEqual(1, f_tf(x0))"><pre><span class="pl-k">def</span> <span class="pl-en">f_jax</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-c1">0</span> <span class="pl-k">if</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">==</span> <span class="pl-c1">0</span> <span class="pl-k">else</span> <span class="pl-c1">1</span> <span class="pl-s1">x0</span> <span class="pl-c1">=</span> <span class="pl-s1">np</span>.<span class="pl-c1">array</span>([], <span class="pl-s1">np</span>.<span class="pl-c1">float32</span>) <span class="pl-s1">self</span>.<span class="pl-c1">assertEqual</span>(<span class="pl-c1">0</span>, <span class="pl-en">f_jax</span>(<span class="pl-s1">x0</span>)) <span class="pl-c"># JAX sees that the x.shape[0] == 0</span> <span class="pl-c"># jax2tf catches the broken assumption b >= 1 if the lowered function is executed</span> <span class="pl-c"># eagerly.</span> <span class="pl-c"># Raises: ValueError: Dimension variable b must have integer value >= 1. Found value 0 when solving b == 0</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"b"</span>])(<span class="pl-s1">x0</span>) <span class="pl-c"># However, if we first trace to a TensorFlow graph, we may miss the broken assumption:</span> <span class="pl-s1">f_tf</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>( <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"b"</span>]), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span> ).<span class="pl-c1">get_concrete_function</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">TensorSpec</span>([<span class="pl-c1">None</span>], <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">np</span>.<span class="pl-c1">float32</span>)) <span class="pl-s1">self</span>.<span class="pl-c1">assertEqual</span>(<span class="pl-c1">1</span>, <span class="pl-en">f_tf</span>(<span class="pl-s1">x0</span>))</pre></div> <p dir="auto">Another possible source of unsoundness is that JAX assumes that all unknown dimensions represented by the same dimension variable have equal size. As before, this assumption is checked if the lowered function is executed eagerly, but it may be missed if it is first traced to a TensorFlow graph:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def f_jax(x): return 0 if x.shape[0] != x.shape[1] else 1 x45 = np.ones((4, 5), dtype=np.float32) self.assertEqual(0, f_jax(x45)) # JAX seems that x.shape[0] != x.shape[1] # jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the lowered # function is executed eagerly. # Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),) jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])(x45) # However, if we first trace to a TensorFlow graph, we may miss the broken assumption. f_tf = tf.function( jax2tf.convert(f_jax, polymorphic_shapes=["b, b"]), autograph=False).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) self.assertEqual(1, f_tf(x45))"><pre><span class="pl-k">def</span> <span class="pl-en">f_jax</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-c1">0</span> <span class="pl-k">if</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">!=</span> <span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">1</span>] <span class="pl-k">else</span> <span class="pl-c1">1</span> <span class="pl-s1">x45</span> <span class="pl-c1">=</span> <span class="pl-s1">np</span>.<span class="pl-c1">ones</span>((<span class="pl-c1">4</span>, <span class="pl-c1">5</span>), <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">np</span>.<span class="pl-c1">float32</span>) <span class="pl-s1">self</span>.<span class="pl-c1">assertEqual</span>(<span class="pl-c1">0</span>, <span class="pl-en">f_jax</span>(<span class="pl-s1">x45</span>)) <span class="pl-c"># JAX seems that x.shape[0] != x.shape[1]</span> <span class="pl-c"># jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the lowered</span> <span class="pl-c"># function is executed eagerly.</span> <span class="pl-c"># Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),)</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"b, b"</span>])(<span class="pl-s1">x45</span>) <span class="pl-c"># However, if we first trace to a TensorFlow graph, we may miss the broken assumption.</span> <span class="pl-s1">f_tf</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>( <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">f_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"b, b"</span>]), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>).<span class="pl-c1">get_concrete_function</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">TensorSpec</span>([<span class="pl-c1">None</span>, <span class="pl-c1">None</span>], <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">np</span>.<span class="pl-c1">float32</span>)) <span class="pl-s1">self</span>.<span class="pl-c1">assertEqual</span>(<span class="pl-c1">1</span>, <span class="pl-en">f_tf</span>(<span class="pl-s1">x45</span>))</pre></div> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Incomplete TensorFlow data type coverage</h3><a id="user-content-incomplete-tensorflow-data-type-coverage" class="anchor" aria-label="Permalink: Incomplete TensorFlow data type coverage" href="#incomplete-tensorflow-data-type-coverage"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Applies to non-native serialization only.</p> <p dir="auto">There are a number of cases when the TensorFlow ops that are used by the <code>jax2tf</code> are not supported by TensorFlow for the same data types as in JAX. There is an <a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md">up-to-date list of unimplemented cases</a>.</p> <p dir="auto">If you try to lower and run in TensorFlow a program with partially supported primitives, you may see TensorFlow errors that a TensorFlow op is used with an unsupported data type, or that there is no supported TensorFlow kernel for the op for the given data type. The former case can happen even if you <code>jit_compile</code> the TensorFlow program, and it is a priority to fit. The latter case only appears in TensorFlow non-compiled mode; you can avoid the problem if you use XLA to <code>jit_compile</code> (always recommended).</p> <p dir="auto">Our priority is to ensure numerical and performance accuracy for the lowered program <strong>when using XLA to compile the lowered program</strong>. It is always a good idea to use XLA on the lowered function.</p> <p dir="auto">Sometimes you cannot compile the entire TensorFlow function for your model, because in addition to the function that is lowered from JAX, it may include some pre-processing TensorFlow code that is not compilable with XLA, e.g., string parsing. Even in those situations you can instruct TensorFlow to compile only the portion that originates from JAX:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def entire_tf_fun(x): y = preprocess_tf_fun_not_compilable(x) # Compile the code that is lowered from JAX z = tf.function(jax2tf.convert(compute_jax_fn), autograph=False, jit_compile=True)(y) return postprocess_tf_fun_not_compilable(z)"><pre><span class="pl-k">def</span> <span class="pl-en">entire_tf_fun</span>(<span class="pl-s1">x</span>): <span class="pl-s1">y</span> <span class="pl-c1">=</span> <span class="pl-en">preprocess_tf_fun_not_compilable</span>(<span class="pl-s1">x</span>) <span class="pl-c"># Compile the code that is lowered from JAX</span> <span class="pl-s1">z</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">function</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">compute_jax_fn</span>), <span class="pl-s1">autograph</span><span class="pl-c1">=</span><span class="pl-c1">False</span>, <span class="pl-s1">jit_compile</span><span class="pl-c1">=</span><span class="pl-c1">True</span>)(<span class="pl-s1">y</span>) <span class="pl-k">return</span> <span class="pl-en">postprocess_tf_fun_not_compilable</span>(<span class="pl-s1">z</span>)</pre></div> <p dir="auto">You won't be able to compile the <code>entire_tf_fun</code>, but you can still execute it knowing that the jax2tf-lowered code is compiled. You can even save the function to a SavedModel, knowing that upon restore the jax2tf-lowered code will be compiled.</p> <p dir="auto">For a more elaborate example, see the test <code>test_tf_mix_jax_with_uncompilable</code> in <a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py">savedmodel_test.py</a>.</p> <div class="markdown-heading" dir="auto"><h1 tabindex="-1" class="heading-element" dir="auto">Calling TensorFlow functions from JAX</h1><a id="user-content-calling-tensorflow-functions-from-jax" class="anchor" aria-label="Permalink: Calling TensorFlow functions from JAX" href="#calling-tensorflow-functions-from-jax"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">The function <code>call_tf</code> allows JAX functions to call TensorFlow functions. These functions can be called anywhere in a JAX computation, including in staging contexts <code>jax.jit</code>, <code>jax.pmap</code>, <code>jax.xmap</code>, or inside JAX's control-flow primitives. In non-staging contexts, the TensorFlow function is called in eager mode. For now, only reverse-mode autodiff is supported for these functions (no forward-mode autodiff, nor <code>vmap</code>).</p> <p dir="auto">As a trivial example, consider computing <code>sin(cos(1.))</code> with <code>sin</code> done in JAX and <code>cos</code> in TF:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="from jax.experimental import jax2tf # This is a TF function. It will be called with TensorFlow-compatible arguments, # such as `numpy.ndarray`, `tf.Tensor` or `tf.Variable`, or a pytree thereof. # It should return a similar result. This function will be called using # TensorFlow eager mode if called from outside JAX staged contexts (`jit`, # `pmap`, or control-flow primitives), and will be called using TensorFlow # compiled mode otherwise. In the latter case, the function must be compilable # with XLA (`tf.function(func, jit_compile=True)`) def cos_tf(x): return tf.math.cos(x) # Compute cos with TF and sin with JAX def cos_tf_sin_jax(x): return jax.numpy.sin(jax2tf.call_tf(cos_tf)(x)) # Calls `cos_tf` in TF eager mode x = np.float32(1.) cos_tf_sin_jax(x) # Compiles `cos_tf` using TF and embeds the XLA computation into the JAX # XLA computation (containing `sin`). The XLA compiler may even be able to # fuse through JAX-TF computations. jax.jit(cos_tf_sin_jax)(x) # Uses TF gradient for `cos_tf` and JAX gradient for `sin` jax.grad(cos_tf_sin_jax)(x)"><pre><span class="pl-k">from</span> <span class="pl-s1">jax</span>.<span class="pl-s1">experimental</span> <span class="pl-k">import</span> <span class="pl-s1">jax2tf</span> <span class="pl-c"># This is a TF function. It will be called with TensorFlow-compatible arguments,</span> <span class="pl-c"># such as `numpy.ndarray`, `tf.Tensor` or `tf.Variable`, or a pytree thereof.</span> <span class="pl-c"># It should return a similar result. This function will be called using</span> <span class="pl-c"># TensorFlow eager mode if called from outside JAX staged contexts (`jit`,</span> <span class="pl-c"># `pmap`, or control-flow primitives), and will be called using TensorFlow</span> <span class="pl-c"># compiled mode otherwise. In the latter case, the function must be compilable</span> <span class="pl-c"># with XLA (`tf.function(func, jit_compile=True)`)</span> <span class="pl-k">def</span> <span class="pl-en">cos_tf</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">tf</span>.<span class="pl-c1">math</span>.<span class="pl-c1">cos</span>(<span class="pl-s1">x</span>) <span class="pl-c"># Compute cos with TF and sin with JAX</span> <span class="pl-k">def</span> <span class="pl-en">cos_tf_sin_jax</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">jax</span>.<span class="pl-c1">numpy</span>.<span class="pl-c1">sin</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">call_tf</span>(<span class="pl-s1">cos_tf</span>)(<span class="pl-s1">x</span>)) <span class="pl-c"># Calls `cos_tf` in TF eager mode</span> <span class="pl-s1">x</span> <span class="pl-c1">=</span> <span class="pl-s1">np</span>.<span class="pl-c1">float32</span>(<span class="pl-c1">1.</span>) <span class="pl-en">cos_tf_sin_jax</span>(<span class="pl-s1">x</span>) <span class="pl-c"># Compiles `cos_tf` using TF and embeds the XLA computation into the JAX</span> <span class="pl-c"># XLA computation (containing `sin`). The XLA compiler may even be able to</span> <span class="pl-c"># fuse through JAX-TF computations.</span> <span class="pl-s1">jax</span>.<span class="pl-c1">jit</span>(<span class="pl-s1">cos_tf_sin_jax</span>)(<span class="pl-s1">x</span>) <span class="pl-c"># Uses TF gradient for `cos_tf` and JAX gradient for `sin`</span> <span class="pl-s1">jax</span>.<span class="pl-c1">grad</span>(<span class="pl-s1">cos_tf_sin_jax</span>)(<span class="pl-s1">x</span>)</pre></div> <p dir="auto">If you inspect the generated HLO for <code>cos_tf_sin_jax</code>, you will see that the main JAX computation (<code>ENTRY xla_computation_cos_tf_sin_jax</code>) makes a call to the <code>a_inference_cos_tf_68__</code> HLO function that was compiled by TF from <code>cos_tf</code>:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="HloModule xla_computation_cos_tf_sin_jax.18 a_inference_cos_tf_68__.4 { arg0.5 = f32[] parameter(0), parameter_replication={false} reshape.6 = f32[] reshape(arg0.5) cosine.7 = f32[] cosine(reshape.6) reshape.8 = f32[] reshape(cosine.7) tuple.9 = (f32[]) tuple(reshape.8) ROOT get-tuple-element.10 = f32[] get-tuple-element(tuple.9), index=0 } ENTRY xla_computation_cos_tf_sin_jax.18 { constant.2 = pred[] constant(false) constant.3 = pred[] constant(false) parameter.1 = f32[] parameter(0) call.11 = f32[] call(parameter.1), to_apply=a_inference_cos_tf_68__.4 tuple.12 = (f32[]) tuple(call.11) get-tuple-element.13 = f32[] get-tuple-element(tuple.12), index=0 tuple.14 = (f32[]) tuple(get-tuple-element.13) get-tuple-element.15 = f32[] get-tuple-element(tuple.14), index=0 sine.16 = f32[] sine(get-tuple-element.15) ROOT tuple.17 = (f32[]) tuple(sine.16) }"><pre class="notranslate"><code>HloModule xla_computation_cos_tf_sin_jax.18 a_inference_cos_tf_68__.4 { arg0.5 = f32[] parameter(0), parameter_replication={false} reshape.6 = f32[] reshape(arg0.5) cosine.7 = f32[] cosine(reshape.6) reshape.8 = f32[] reshape(cosine.7) tuple.9 = (f32[]) tuple(reshape.8) ROOT get-tuple-element.10 = f32[] get-tuple-element(tuple.9), index=0 } ENTRY xla_computation_cos_tf_sin_jax.18 { constant.2 = pred[] constant(false) constant.3 = pred[] constant(false) parameter.1 = f32[] parameter(0) call.11 = f32[] call(parameter.1), to_apply=a_inference_cos_tf_68__.4 tuple.12 = (f32[]) tuple(call.11) get-tuple-element.13 = f32[] get-tuple-element(tuple.12), index=0 tuple.14 = (f32[]) tuple(get-tuple-element.13) get-tuple-element.15 = f32[] get-tuple-element(tuple.14), index=0 sine.16 = f32[] sine(get-tuple-element.15) ROOT tuple.17 = (f32[]) tuple(sine.16) } </code></pre></div> <p dir="auto">For a more elaborate example, including round-tripping from JAX to TensorFlow and back through a SavedModel, with support for custom gradients, see the test <code>test_round_trip_custom_grad_saved_model</code> in <a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py">call_tf_test.py</a>.</p> <p dir="auto">All the metadata inserted by TF during tracing and compilation, e.g., source location information and op names, is carried through to the JAX XLA computation.</p> <p dir="auto">The TF custom gradients are respected, since it is TF that generates the gradient computation.</p> <p dir="auto"><code>call_tf</code> works even with shape polymorphism, but in that case the user must pass the <code>output_shape_dtype</code> parameter to <code>call_tf</code> to declare the expected output shapes. This allows JAX tracing to know the shape and dtype of the results so that it can continue tracing the rest of the program. When <code>output_shape_dtype</code> is not given (the default case), <code>call_tf</code> will form a <code>tf.Graph</code> for the called TF function and will use the inferred type and shape. However, in presence of dynamic shape the inferred TF type will contain <code>None</code> for the dynamic dimensions, which is not enough information for JAX shape polymorphism.</p> <p dir="auto">For example:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def fun_jax(x): y_shape = (x.shape[0] * 2, y.shape[1:]) y = jax2tf.call_tf( lambda x: tf.concat([x, x], axis=0), output_shape_dype=jax.ShapeDtypeStruct(y_shape, x.dtype))(x) # JAX will know the y.shape return jnp.ones(y.shape, dtype=y.dtype) + y jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)"><pre><span class="pl-k">def</span> <span class="pl-en">fun_jax</span>(<span class="pl-s1">x</span>): <span class="pl-s1">y_shape</span> <span class="pl-c1">=</span> (<span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">0</span>] <span class="pl-c1">*</span> <span class="pl-c1">2</span>, <span class="pl-s1">y</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">1</span>:]) <span class="pl-s1">y</span> <span class="pl-c1">=</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">call_tf</span>( <span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">tf</span>.<span class="pl-c1">concat</span>([<span class="pl-s1">x</span>, <span class="pl-s1">x</span>], <span class="pl-s1">axis</span><span class="pl-c1">=</span><span class="pl-c1">0</span>), <span class="pl-s1">output_shape_dype</span><span class="pl-c1">=</span><span class="pl-s1">jax</span>.<span class="pl-c1">ShapeDtypeStruct</span>(<span class="pl-s1">y_shape</span>, <span class="pl-s1">x</span>.<span class="pl-c1">dtype</span>))(<span class="pl-s1">x</span>) <span class="pl-c"># JAX will know the y.shape</span> <span class="pl-k">return</span> <span class="pl-s1">jnp</span>.<span class="pl-c1">ones</span>(<span class="pl-s1">y</span>.<span class="pl-c1">shape</span>, <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">y</span>.<span class="pl-c1">dtype</span>) <span class="pl-c1">+</span> <span class="pl-s1">y</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">fun_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"b, ..."</span>])(<span class="pl-s1">x</span>)</pre></div> <p dir="auto">An even simpler example for a function that returns the same shape as the input:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def fun_jax(x): return jax2tf.call_tf(tf.math.sin, output_shape_dtype=x) )(x) jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)"><pre><span class="pl-k">def</span> <span class="pl-en">fun_jax</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">call_tf</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">math</span>.<span class="pl-c1">sin</span>, <span class="pl-s1">output_shape_dtype</span><span class="pl-c1">=</span><span class="pl-s1">x</span>) )(<span class="pl-s1">x</span>) <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">fun_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"b, ..."</span>])(<span class="pl-s1">x</span>)</pre></div> <p dir="auto">If all the output shapes of the TF function are static, JAX does not need the <code>output_shape_dtype</code> argument:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def fun_tf(x): return tf.math.reduce_sum(tf.math.sin(x)) def fun_jax(x): return jax2tf.call_tf(fun_tf)(x) # The following will not throw an error because the output shape of fun_tf is static. jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)"><pre><span class="pl-k">def</span> <span class="pl-en">fun_tf</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">tf</span>.<span class="pl-c1">math</span>.<span class="pl-c1">reduce_sum</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">math</span>.<span class="pl-c1">sin</span>(<span class="pl-s1">x</span>)) <span class="pl-k">def</span> <span class="pl-en">fun_jax</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">call_tf</span>(<span class="pl-s1">fun_tf</span>)(<span class="pl-s1">x</span>) <span class="pl-c"># The following will not throw an error because the output shape of fun_tf is static.</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">convert</span>(<span class="pl-s1">fun_jax</span>, <span class="pl-s1">polymorphic_shapes</span><span class="pl-c1">=</span>[<span class="pl-s">"b, ..."</span>])(<span class="pl-s1">x</span>)</pre></div> <p dir="auto">The shape polymorphism support for <code>call_tf</code> does not yet work for native serialization.</p> <div class="markdown-heading" dir="auto"><h3 tabindex="-1" class="heading-element" dir="auto">Limitations of call_tf</h3><a id="user-content-limitations-of-call_tf" class="anchor" aria-label="Permalink: Limitations of call_tf" href="#limitations-of-call_tf"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">The TF function must be compilable (<code>tf.function(func, jit_compile=True)</code>) and must have static output shapes when used in a JAX staging context, e.g., <code>jax.jit</code>, <code>lax.scan</code>, <code>lax.cond</code>, but may have unknown output shapes when used in a JAX op-by-op mode. For example, the following function uses strings operations that are not supported by XLA:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def f_tf_non_compilable(x): return tf.strings.length(tf.strings.format("Hello {}!", [x])) f_jax = jax2tf.call_tf(f_tf_non_compilable) # Works in op-by-op mode f_jax(np.float32(42.)) # Fails in jit mode jax.jit(f_jax)(np.float(42.))"><pre><span class="pl-k">def</span> <span class="pl-en">f_tf_non_compilable</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">tf</span>.<span class="pl-c1">strings</span>.<span class="pl-c1">length</span>(<span class="pl-s1">tf</span>.<span class="pl-c1">strings</span>.<span class="pl-c1">format</span>(<span class="pl-s">"Hello {}!"</span>, [<span class="pl-s1">x</span>])) <span class="pl-s1">f_jax</span> <span class="pl-c1">=</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">call_tf</span>(<span class="pl-s1">f_tf_non_compilable</span>) <span class="pl-c"># Works in op-by-op mode</span> <span class="pl-en">f_jax</span>(<span class="pl-s1">np</span>.<span class="pl-c1">float32</span>(<span class="pl-c1">42.</span>)) <span class="pl-c"># Fails in jit mode</span> <span class="pl-s1">jax</span>.<span class="pl-c1">jit</span>(<span class="pl-s1">f_jax</span>)(<span class="pl-s1">np</span>.<span class="pl-c1">float</span>(<span class="pl-c1">42.</span>))</pre></div> <p dir="auto">Yet another unsupported situation is when the TF function is compilable but with dynamic output shapes:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def f_tf_dynamic_shape(x): return x[x[0]:5] x = np.array([1, 2], dtype=np.int32) f_jax = jax2tf.call_tf(f_tf_dynamic_shape) # Works in op-by-op mode f_jax(x) # Fails in jit mode jax.jit(f_jax)(x)"><pre><span class="pl-k">def</span> <span class="pl-en">f_tf_dynamic_shape</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">x</span>[<span class="pl-s1">x</span>[<span class="pl-c1">0</span>]:<span class="pl-c1">5</span>] <span class="pl-s1">x</span> <span class="pl-c1">=</span> <span class="pl-s1">np</span>.<span class="pl-c1">array</span>([<span class="pl-c1">1</span>, <span class="pl-c1">2</span>], <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">np</span>.<span class="pl-c1">int32</span>) <span class="pl-s1">f_jax</span> <span class="pl-c1">=</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">call_tf</span>(<span class="pl-s1">f_tf_dynamic_shape</span>) <span class="pl-c"># Works in op-by-op mode</span> <span class="pl-en">f_jax</span>(<span class="pl-s1">x</span>) <span class="pl-c"># Fails in jit mode</span> <span class="pl-s1">jax</span>.<span class="pl-c1">jit</span>(<span class="pl-s1">f_jax</span>)(<span class="pl-s1">x</span>)</pre></div> <p dir="auto">Another similar example that will fail to compile:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def f_tf_dynamic_output_shape(x): return tf.cond(x[0] >= 0, lambda: x, lambda: x[1:]) x = np.array([1, 2], dtype=np.int32)"><pre><span class="pl-k">def</span> <span class="pl-en">f_tf_dynamic_output_shape</span>(<span class="pl-s1">x</span>): <span class="pl-k">return</span> <span class="pl-s1">tf</span>.<span class="pl-c1">cond</span>(<span class="pl-s1">x</span>[<span class="pl-c1">0</span>] <span class="pl-c1">>=</span> <span class="pl-c1">0</span>, <span class="pl-k">lambda</span>: <span class="pl-s1">x</span>, <span class="pl-k">lambda</span>: <span class="pl-s1">x</span>[<span class="pl-c1">1</span>:]) <span class="pl-s1">x</span> <span class="pl-c1">=</span> <span class="pl-s1">np</span>.<span class="pl-c1">array</span>([<span class="pl-c1">1</span>, <span class="pl-c1">2</span>], <span class="pl-s1">dtype</span><span class="pl-c1">=</span><span class="pl-s1">np</span>.<span class="pl-c1">int32</span>)</pre></div> <p dir="auto"><code>call_tf</code> works best with pure TF functions that do not capture <code>tf.Variable</code>s or tensors from the environment, and all such context is passed in explicitly through arguments, and if variables are modified, the resulting values are passed out through results. There is a best-effort mechanism that can handle variable capture and variable updates, except in the case of a function that modifies <code>tf.Variable</code>s and is used in a JAX jitted context. Calling the <code>inpure_func_tf</code> will give an error:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="var1 = tf.Variable(1.) def impure_func_tf(x): var1.write(11.) # BAD: should not write to variables return x + var1 jax2tf.call_tf(impure_func_tf)(tf.constant(2.)) # Works in eager mode jax.jit(jax2tf.call_tf(impure_func_tf))(tf.constant(2.)) # Fails in jit mode"><pre><span class="pl-s1">var1</span> <span class="pl-c1">=</span> <span class="pl-s1">tf</span>.<span class="pl-c1">Variable</span>(<span class="pl-c1">1.</span>) <span class="pl-k">def</span> <span class="pl-en">impure_func_tf</span>(<span class="pl-s1">x</span>): <span class="pl-s1">var1</span>.<span class="pl-c1">write</span>(<span class="pl-c1">11.</span>) <span class="pl-c"># BAD: should not write to variables</span> <span class="pl-k">return</span> <span class="pl-s1">x</span> <span class="pl-c1">+</span> <span class="pl-s1">var1</span> <span class="pl-s1">jax2tf</span>.<span class="pl-c1">call_tf</span>(<span class="pl-s1">impure_func_tf</span>)(<span class="pl-s1">tf</span>.<span class="pl-c1">constant</span>(<span class="pl-c1">2.</span>)) <span class="pl-c"># Works in eager mode</span> <span class="pl-s1">jax</span>.<span class="pl-c1">jit</span>(<span class="pl-s1">jax2tf</span>.<span class="pl-c1">call_tf</span>(<span class="pl-s1">impure_func_tf</span>))(<span class="pl-s1">tf</span>.<span class="pl-c1">constant</span>(<span class="pl-c1">2.</span>)) <span class="pl-c"># Fails in jit mode</span></pre></div> <p dir="auto">The error can be avoided by passing the variable explicitly:</p> <div class="highlight highlight-source-python notranslate position-relative overflow-auto" dir="auto" data-snippet-clipboard-copy-content="def pure_func_tf(x, var1) new_var1 = 11. return x + new_var1, new_var1"><pre><span class="pl-k">def</span> <span class="pl-s1">pure_func_tf</span>(<span class="pl-s1">x</span>, <span class="pl-s1">var1</span>) <span class="pl-s1">new_var1</span> <span class="pl-c1">=</span> <span class="pl-c1">11.</span> <span class="pl-k">return</span> <span class="pl-s1">x</span> <span class="pl-c1">+</span> <span class="pl-s1">new_var1</span>, <span class="pl-s1">new_var1</span></pre></div> <p dir="auto">This use case is likely to be revisited.</p> <p dir="auto">Note that when the TF function captures a variable from the context, the TF function must be lowered for the same TF device that hosts the variable. By default, the lowering will use the first TF device on the same platform as the embedding JAX computation, e.g., "/device:TPU:0" if the embedding JAX computation runs on TPU. This will fail if the computation captures variables on some other devices. It is best to use <code>call_tf</code> with TF functions that do not capture variables.</p> <p dir="auto">In some rare cases your called TF function may contain ops with output of statically known shape, but for which the shape inference is not implemented completely and will appear to <code>call_tf</code> as if they have dynamically-shaped outputs. In these cases you may get an error that <code>call_tf cannot call functions whose output has dynamic shape</code>. Try using the <code>output_shape_dtype</code> parameter to specify the expected output shape (this essentially allows you to override the shape inference for the purposes of <code>call_tf</code>.)</p> <div class="markdown-heading" dir="auto"><h1 tabindex="-1" class="heading-element" dir="auto">Misc notes</h1><a id="user-content-misc-notes" class="anchor" aria-label="Permalink: Misc notes" href="#misc-notes"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Debugging JAX native serialization</h2><a id="user-content-debugging-jax-native-serialization" class="anchor" aria-label="Permalink: Debugging JAX native serialization" href="#debugging-jax-native-serialization"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">Inside Google, you can turn on logging by using the <code>--vmodule</code> argument to specify the logging levels for different modules, e.g., <code>--vmodule=_export=3</code>. You can set <code>TF_DUMP_GRAPH_PREFIX</code> to a directory where modules should be dumped, or to <code>"-"</code> to dump the modules to the log. The following modules are useful for debugging JAX native serialization:</p> <ul dir="auto"> <li><code>_export=3</code> - will log the StableHLO module on serialization.</li> <li><code>jax2tf=3</code> - will log the parameters to <code>XlaCallModule</code> op on serialization.</li> <li><code>xla_call_module_loader=3</code> - will log the StableHLO module upon loading, after shape refinements, and on verification error. You can use level <code>4</code> to add location information, and level <code>5</code> to also print the module before and after each transformation.</li> <li><code>xla_call_module_op=3</code> - will log the HLO module generated after shape refinement and conversion from StableHLO.</li> <li><code>XlaCallModule</code> lowering has TensorFlow MLIR crash reproducer enabled, which can be instructed to generate a crash reproducer upon MLIR pass failures by setting an environment variable <code>MLIR_CRASH_REPRODUCER_DIRECTORY</code>.</li> </ul> <p dir="auto">For the two <code>xla</code> modules mentioned above, you can control logging in OSS with environment variables, e.g.:</p> <div class="snippet-clipboard-content notranslate position-relative overflow-auto" data-snippet-clipboard-copy-content="TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=xla_call_module_loader=3 python ..."><pre class="notranslate"><code>TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=xla_call_module_loader=3 python ... </code></pre></div> <p dir="auto">In addition, <code>TF_DUMP_GRAPH_PREFIX</code> controls where the dump will be stored, <code>-</code> for stderr, <code>${SOME_DIR}</code> to store the dumps in the specified directory.</p> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">TensorFlow versions supported</h2><a id="user-content-tensorflow-versions-supported" class="anchor" aria-label="Permalink: TensorFlow versions supported" href="#tensorflow-versions-supported"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">The <code>jax2tf.convert</code> and <code>call_tf</code> require fairly recent versions of TensorFlow. As of today, the tests are run using <code>tf_nightly==2.14.0.dev20230720</code>.</p> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Running on GPU</h2><a id="user-content-running-on-gpu" class="anchor" aria-label="Permalink: Running on GPU" href="#running-on-gpu"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">To run jax2tf on GPU, both jaxlib and TensorFlow must be installed with support for CUDA. One must be mindful to install a version of CUDA that is compatible with both <a href="https://github.com/jax-ml/jax/blob/main/README.md#pip-installation">jaxlib</a> and <a href="https://www.tensorflow.org/install/source#tested_build_configurations" rel="nofollow">TensorFlow</a>.</p> <div class="markdown-heading" dir="auto"><h2 tabindex="-1" class="heading-element" dir="auto">Updating the limitations documentation</h2><a id="user-content-updating-the-limitations-documentation" class="anchor" aria-label="Permalink: Updating the limitations documentation" href="#updating-the-limitations-documentation"><svg class="octicon octicon-link" viewBox="0 0 16 16" version="1.1" width="16" height="16" aria-hidden="true"><path d="m7.775 3.275 1.25-1.25a3.5 3.5 0 1 1 4.95 4.95l-2.5 2.5a3.5 3.5 0 0 1-4.95 0 .751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018 1.998 1.998 0 0 0 2.83 0l2.5-2.5a2.002 2.002 0 0 0-2.83-2.83l-1.25 1.25a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042Zm-4.69 9.64a1.998 1.998 0 0 0 2.83 0l1.25-1.25a.751.751 0 0 1 1.042.018.751.751 0 0 1 .018 1.042l-1.25 1.25a3.5 3.5 0 1 1-4.95-4.95l2.5-2.5a3.5 3.5 0 0 1 4.95 0 .751.751 0 0 1-.018 1.042.751.751 0 0 1-1.042.018 1.998 1.998 0 0 0-2.83 0l-2.5 2.5a1.998 1.998 0 0 0 0 2.83Z"></path></svg></a></div> <p dir="auto">The jax2tf tests are parameterized by a set of limitations (see <code>tests/primitive_harness.py</code> and <code>tests/jax2tf_limitations.py</code>). The limitations specify test harnesses that are known to fail, by JAX primitive, data type, device type, and TensorFlow execution mode (<code>eager</code>, <code>graph</code>, or <code>compiled</code>). These limitations are also used to generate tables of limitations, e.g.,</p> <ul dir="auto"> <li><a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md">List of primitives not supported in JAX</a>, e.g., due to unimplemented cases in the XLA compiler, and</li> <li><a href="https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md">List of primitives not supported in jax2tf</a>, e.g., due to unimplemented cases in TensorFlow. This list is incremental on top of the unsupported JAX primitives.</li> </ul> <p dir="auto">There are instructions for updating those documents at the end of each document.</p> <p dir="auto">The set of limitations is an over-approximation, in the sense that if XLA or TensorFlow improves and support more cases, no test will fail. Instead, periodically, we check for unnecessary limitations. We do this by uncommenting two assertions (in <code>tests/jax_primitives_coverage_test.py</code> and in <code>tests/tf_test_util.py</code>) and running all the tests. With these assertions enabled the tests will fail and point out unnecessary limitations. We remove limitations until the tests pass. Then we re-generate the documentation.</p> </article></div></div></div></div></div></div></div><div class="Box-sc-g0xbh4-0"></div></div></div></div></div><div id="find-result-marks-container" class="Box-sc-g0xbh4-0 cCoXib"></div><button hidden="" data-testid="" data-hotkey-scope="read-only-cursor-text-area"></button><button hidden=""></button></div> <!-- --> <!-- --> <script type="application/json" id="__PRIMER_DATA_:R0:__">{"resolvedServerColorMode":"day"}</script></div> </react-app> </turbo-frame> </div> </turbo-frame> </main> </div> </div> <footer class="footer pt-8 pb-6 f6 color-fg-muted p-responsive" role="contentinfo" > <h2 class='sr-only'>Footer</h2> <div class="d-flex flex-justify-center flex-items-center flex-column-reverse flex-lg-row flex-wrap flex-lg-nowrap"> <div class="d-flex flex-items-center flex-shrink-0 mx-2"> <a aria-label="Homepage" title="GitHub" class="footer-octicon mr-2" href="https://github.com"> <svg aria-hidden="true" height="24" viewBox="0 0 24 24" version="1.1" width="24" data-view-component="true" class="octicon octicon-mark-github"> <path d="M12.5.75C6.146.75 1 5.896 1 12.25c0 5.089 3.292 9.387 7.863 10.91.575.101.79-.244.79-.546 0-.273-.014-1.178-.014-2.142-2.889.532-3.636-.704-3.866-1.35-.13-.331-.69-1.352-1.18-1.625-.402-.216-.977-.748-.014-.762.906-.014 1.553.834 1.769 1.179 1.035 1.74 2.688 1.25 3.349.948.1-.747.402-1.25.733-1.538-2.559-.287-5.232-1.279-5.232-5.678 0-1.25.445-2.285 1.178-3.09-.115-.288-.517-1.467.115-3.048 0 0 .963-.302 3.163 1.179.92-.259 1.897-.388 2.875-.388.977 0 1.955.13 2.875.388 2.2-1.495 3.162-1.179 3.162-1.179.633 1.581.23 2.76.115 3.048.733.805 1.179 1.825 1.179 3.09 0 4.413-2.688 5.39-5.247 5.678.417.36.776 1.05.776 2.128 0 1.538-.014 2.774-.014 3.162 0 .302.216.662.79.547C20.709 21.637 24 17.324 24 12.25 24 5.896 18.854.75 12.5.75Z"></path> </svg> </a> <span> © 2025 GitHub, Inc. </span> </div> <nav aria-label="Footer"> <h3 class="sr-only" id="sr-footer-heading">Footer navigation</h3> <ul class="list-style-none d-flex flex-justify-center flex-wrap mb-2 mb-lg-0" aria-labelledby="sr-footer-heading"> <li class="mx-2"> <a data-analytics-event="{"category":"Footer","action":"go to Terms","label":"text:terms"}" href="https://docs.github.com/site-policy/github-terms/github-terms-of-service" data-view-component="true" class="Link--secondary Link">Terms</a> </li> <li class="mx-2"> <a data-analytics-event="{"category":"Footer","action":"go to privacy","label":"text:privacy"}" href="https://docs.github.com/site-policy/privacy-policies/github-privacy-statement" data-view-component="true" class="Link--secondary Link">Privacy</a> </li> <li class="mx-2"> <a data-analytics-event="{"category":"Footer","action":"go to security","label":"text:security"}" href="https://github.com/security" data-view-component="true" class="Link--secondary Link">Security</a> </li> <li class="mx-2"> <a data-analytics-event="{"category":"Footer","action":"go to status","label":"text:status"}" href="https://www.githubstatus.com/" data-view-component="true" class="Link--secondary Link">Status</a> </li> <li class="mx-2"> <a data-analytics-event="{"category":"Footer","action":"go to docs","label":"text:docs"}" href="https://docs.github.com/" data-view-component="true" class="Link--secondary Link">Docs</a> </li> <li class="mx-2"> <a data-analytics-event="{"category":"Footer","action":"go to contact","label":"text:contact"}" href="https://support.github.com?tags=dotcom-footer" data-view-component="true" class="Link--secondary Link">Contact</a> </li> <li class="mx-2" > <cookie-consent-link> <button type="button" class="Link--secondary underline-on-hover border-0 p-0 color-bg-transparent" data-action="click:cookie-consent-link#showConsentManagement" data-analytics-event="{"location":"footer","action":"cookies","context":"subfooter","tag":"link","label":"cookies_link_subfooter_footer"}" > Manage cookies </button> </cookie-consent-link> </li> <li class="mx-2"> <cookie-consent-link> <button type="button" class="Link--secondary underline-on-hover border-0 p-0 color-bg-transparent" data-action="click:cookie-consent-link#showConsentManagement" data-analytics-event="{"location":"footer","action":"dont_share_info","context":"subfooter","tag":"link","label":"dont_share_info_link_subfooter_footer"}" > Do not share my personal information </button> </cookie-consent-link> </li> </ul> </nav> </div> </footer> <ghcc-consent id="ghcc" class="position-fixed bottom-0 left-0" style="z-index: 999999" data-initial-cookie-consent-allowed="" data-cookie-consent-required="false"></ghcc-consent> <div id="ajax-error-message" class="ajax-error-message flash flash-error" hidden> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-alert"> <path d="M6.457 1.047c.659-1.234 2.427-1.234 3.086 0l6.082 11.378A1.75 1.75 0 0 1 14.082 15H1.918a1.75 1.75 0 0 1-1.543-2.575Zm1.763.707a.25.25 0 0 0-.44 0L1.698 13.132a.25.25 0 0 0 .22.368h12.164a.25.25 0 0 0 .22-.368Zm.53 3.996v2.5a.75.75 0 0 1-1.5 0v-2.5a.75.75 0 0 1 1.5 0ZM9 11a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z"></path> </svg> <button type="button" class="flash-close js-ajax-error-dismiss" aria-label="Dismiss error"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-x"> <path d="M3.72 3.72a.75.75 0 0 1 1.06 0L8 6.94l3.22-3.22a.749.749 0 0 1 1.275.326.749.749 0 0 1-.215.734L9.06 8l3.22 3.22a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L8 9.06l-3.22 3.22a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042L6.94 8 3.72 4.78a.75.75 0 0 1 0-1.06Z"></path> </svg> </button> You can’t perform that action at this time. </div> <template id="site-details-dialog"> <details class="details-reset details-overlay details-overlay-dark lh-default color-fg-default hx_rsm" open> <summary role="button" aria-label="Close dialog"></summary> <details-dialog class="Box Box--overlay d-flex flex-column anim-fade-in fast hx_rsm-dialog hx_rsm-modal"> <button class="Box-btn-octicon m-0 btn-octicon position-absolute right-0 top-0" type="button" aria-label="Close dialog" data-close-dialog> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-x"> <path d="M3.72 3.72a.75.75 0 0 1 1.06 0L8 6.94l3.22-3.22a.749.749 0 0 1 1.275.326.749.749 0 0 1-.215.734L9.06 8l3.22 3.22a.749.749 0 0 1-.326 1.275.749.749 0 0 1-.734-.215L8 9.06l-3.22 3.22a.751.751 0 0 1-1.042-.018.751.751 0 0 1-.018-1.042L6.94 8 3.72 4.78a.75.75 0 0 1 0-1.06Z"></path> </svg> </button> <div class="octocat-spinner my-6 js-details-dialog-spinner"></div> </details-dialog> </details> </template> <div class="Popover js-hovercard-content position-absolute" style="display: none; outline: none;"> <div class="Popover-message Popover-message--bottom-left Popover-message--large Box color-shadow-large" style="width:360px;"> </div> </div> <template id="snippet-clipboard-copy-button"> <div class="zeroclipboard-container position-absolute right-0 top-0"> <clipboard-copy aria-label="Copy" class="ClipboardButton btn js-clipboard-copy m-2 p-0" data-copy-feedback="Copied!" data-tooltip-direction="w"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-copy js-clipboard-copy-icon m-2"> <path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path> </svg> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-check js-clipboard-check-icon color-fg-success d-none m-2"> <path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path> </svg> </clipboard-copy> </div> </template> <template id="snippet-clipboard-copy-button-unpositioned"> <div class="zeroclipboard-container"> <clipboard-copy aria-label="Copy" class="ClipboardButton btn btn-invisible js-clipboard-copy m-2 p-0 d-flex flex-justify-center flex-items-center" data-copy-feedback="Copied!" data-tooltip-direction="w"> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-copy js-clipboard-copy-icon"> <path d="M0 6.75C0 5.784.784 5 1.75 5h1.5a.75.75 0 0 1 0 1.5h-1.5a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-1.5a.75.75 0 0 1 1.5 0v1.5A1.75 1.75 0 0 1 9.25 16h-7.5A1.75 1.75 0 0 1 0 14.25Z"></path><path d="M5 1.75C5 .784 5.784 0 6.75 0h7.5C15.216 0 16 .784 16 1.75v7.5A1.75 1.75 0 0 1 14.25 11h-7.5A1.75 1.75 0 0 1 5 9.25Zm1.75-.25a.25.25 0 0 0-.25.25v7.5c0 .138.112.25.25.25h7.5a.25.25 0 0 0 .25-.25v-7.5a.25.25 0 0 0-.25-.25Z"></path> </svg> <svg aria-hidden="true" height="16" viewBox="0 0 16 16" version="1.1" width="16" data-view-component="true" class="octicon octicon-check js-clipboard-check-icon color-fg-success d-none"> <path d="M13.78 4.22a.75.75 0 0 1 0 1.06l-7.25 7.25a.75.75 0 0 1-1.06 0L2.22 9.28a.751.751 0 0 1 .018-1.042.751.751 0 0 1 1.042-.018L6 10.94l6.72-6.72a.75.75 0 0 1 1.06 0Z"></path> </svg> </clipboard-copy> </div> </template> </div> <div id="js-global-screen-reader-notice" class="sr-only mt-n1" aria-live="polite" aria-atomic="true" ></div> <div id="js-global-screen-reader-notice-assertive" class="sr-only mt-n1" aria-live="assertive" aria-atomic="true"></div> </body> </html>