Skip to content

Commit 898109d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Vertex AI Model Garden open model export SDK Public Preview
PiperOrigin-RevId: 742429430
1 parent 273e341 commit 898109d

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

tests/unit/vertexai/model_garden/test_model_garden.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"publishers/hf-meta-llama/models/llama-3.3-70b-instruct@001"
4949
)
5050
_TEST_HUGGING_FACE_ACCESS_TOKEN = "test-access-token"
51-
51+
_TEST_GCS_URI = "gs://some-bucket/some-model"
5252
_TEST_ENDPOINT_NAME = "projects/test-project/locations/us-central1/endpoints/1234567890"
5353
_TEST_MODEL_NAME = "projects/test-project/locations/us-central1/models/9876543210"
5454
_TEST_MODEL_CONTAINER_SPEC = types.ModelContainerSpec(
@@ -85,6 +85,22 @@ def google_auth_mock():
8585
yield google_auth_mock
8686

8787

88+
@pytest.fixture
89+
def export_publisher_model_mock():
90+
"""Mocks the export_publisher_model method."""
91+
with mock.patch.object(
92+
model_garden_service.ModelGardenServiceClient,
93+
"export_publisher_model",
94+
) as export_publisher_model:
95+
mock_export_lro = mock.Mock(ga_operation.Operation)
96+
mock_export_lro.result.return_value = types.ExportPublisherModelResponse(
97+
publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
98+
destination_uri=_TEST_GCS_URI,
99+
)
100+
export_publisher_model.return_value = mock_export_lro
101+
yield export_publisher_model
102+
103+
88104
@pytest.fixture
89105
def deploy_mock():
90106
"""Mocks the deploy method."""
@@ -338,6 +354,7 @@ def list_publisher_models_mock():
338354
"deploy_mock",
339355
"get_publisher_model_mock",
340356
"list_publisher_models_mock",
357+
"export_publisher_model_mock",
341358
)
342359
class TestModelGarden:
343360
"""Test cases for ModelGarden class."""
@@ -350,6 +367,54 @@ def setup_method(self):
350367
def teardown_method(self):
351368
aiplatform.initializer.global_pool.shutdown(wait=True)
352369

370+
def test_export_full_resource_name_success(self, export_publisher_model_mock):
371+
aiplatform.init(
372+
project=_TEST_PROJECT,
373+
location=_TEST_LOCATION,
374+
)
375+
model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
376+
model.export(_TEST_GCS_URI)
377+
export_publisher_model_mock.assert_called_once_with(
378+
types.ExportPublisherModelRequest(
379+
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
380+
name=_TEST_MODEL_FULL_RESOURCE_NAME,
381+
destination=types.GcsDestination(output_uri_prefix=_TEST_GCS_URI),
382+
),
383+
metadata=[("x-goog-user-project", "test-project")],
384+
)
385+
386+
def test_export_simplified_resource_name_success(self, export_publisher_model_mock):
387+
aiplatform.init(
388+
project=_TEST_PROJECT,
389+
location=_TEST_LOCATION,
390+
)
391+
model = model_garden.OpenModel(model_name=_TEST_MODEL_SIMPLIFIED_RESOURCE_NAME)
392+
model.export(_TEST_GCS_URI)
393+
export_publisher_model_mock.assert_called_once_with(
394+
types.ExportPublisherModelRequest(
395+
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
396+
name=_TEST_MODEL_FULL_RESOURCE_NAME,
397+
destination=types.GcsDestination(output_uri_prefix=_TEST_GCS_URI),
398+
),
399+
metadata=[("x-goog-user-project", "test-project")],
400+
)
401+
402+
def test_export_hugging_face_id_success(self, export_publisher_model_mock):
403+
aiplatform.init(
404+
project=_TEST_PROJECT,
405+
location=_TEST_LOCATION,
406+
)
407+
model = model_garden.OpenModel(model_name=_TEST_MODEL_HUGGING_FACE_ID)
408+
model.export(_TEST_GCS_URI)
409+
export_publisher_model_mock.assert_called_once_with(
410+
types.ExportPublisherModelRequest(
411+
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
412+
name=_TEST_HUGGING_FACE_MODEL_FULL_RESOURCE_NAME,
413+
destination=types.GcsDestination(output_uri_prefix=_TEST_GCS_URI),
414+
),
415+
metadata=[("x-goog-user-project", "test-project")],
416+
)
417+
353418
def test_deploy_full_resource_name_success(self, deploy_mock):
354419
aiplatform.init(
355420
project=_TEST_PROJECT,

vertexai/model_garden/_model_garden.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
_LOGGER = base.Logger(__name__)
3737
_DEFAULT_VERSION = compat.V1BETA1
3838
_DEFAULT_TIMEOUT = 2 * 60 * 60 # 2 hours, same as UI one-click deployment.
39+
_DEFAULT_EXPORT_TIMEOUT = 1 * 60 * 60 # 1 hour.
3940
_HF_WILDCARD_FILTER = "is_hf_wildcard(true)"
4041
_NATIVE_MODEL_FILTER = "is_hf_wildcard(false)"
4142
_VERIFIED_DEPLOYMENT_FILTER = (
@@ -308,6 +309,53 @@ def _us_central1_model_garden_client(
308309
location_override="us-central1",
309310
)
310311

312+
def export(
313+
self,
314+
target_gcs_path: str = "",
315+
export_request_timeout: Optional[float] = None,
316+
) -> str:
317+
"""Exports an Open Model to a google cloud storage bucket.
318+
319+
Args:
320+
target_gcs_path: target gcs path.
321+
export_request_timeout: The timeout for the deploy request. Default is 2
322+
hours.
323+
324+
Returns:
325+
str: the target gcs bucket where the model weights are downloaded to
326+
327+
328+
Raises:
329+
ValueError: If ``target_gcs_path`` is not specified
330+
"""
331+
if not target_gcs_path:
332+
raise ValueError("target_gcs_path is required.")
333+
334+
request = types.ExportPublisherModelRequest(
335+
parent=f"projects/{self._project}/locations/{self._location}",
336+
name=self._publisher_model_name,
337+
destination=types.GcsDestination(output_uri_prefix=target_gcs_path),
338+
)
339+
request_headers = [
340+
("x-goog-user-project", "{}".format(initializer.global_config.project)),
341+
]
342+
343+
_LOGGER.info(f"Exporting model weights: {self._model_name}")
344+
345+
operation_future = self._model_garden_client.export_publisher_model(
346+
request, metadata=request_headers
347+
)
348+
_LOGGER.info(f"LRO: {operation_future.operation.name}")
349+
350+
_LOGGER.info(f"Start time: {datetime.datetime.now()}")
351+
export_publisher_model_response = operation_future.result(
352+
timeout=export_request_timeout or _DEFAULT_EXPORT_TIMEOUT
353+
)
354+
_LOGGER.info(f"End time: {datetime.datetime.now()}")
355+
_LOGGER.info(f"Response: {export_publisher_model_response}")
356+
357+
return target_gcs_path
358+
311359
def deploy(
312360
self,
313361
accept_eula: bool = False,

0 commit comments

Comments
 (0)