Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: GenAI - Add cancel, delete, list methods in BatchPredictionJob #3762

Merged
1 commit merged into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
120 changes: 116 additions & 4 deletions tests/unit/vertexai/test_batch_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,36 @@ def complete_bq_uri_mock():


@pytest.fixture
def get_batch_prediction_job_mock():
def get_batch_prediction_job_with_bq_output_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "get_batch_prediction_job"
) as get_job_mock:
get_job_mock.return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
state=_TEST_JOB_STATE_SUCCESS,
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
bigquery_output_table=_TEST_BQ_OUTPUT_PREFIX
),
)
yield get_job_mock


@pytest.fixture
def get_batch_prediction_job_with_gcs_output_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "get_batch_prediction_job"
) as get_job_mock:
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
state=_TEST_JOB_STATE_SUCCESS,
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
),
)
yield get_job_mock


Expand All @@ -120,6 +145,39 @@ def create_batch_prediction_job_mock():
yield create_job_mock


@pytest.fixture
def cancel_batch_prediction_job_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "cancel_batch_prediction_job"
) as cancel_job_mock:
yield cancel_job_mock


@pytest.fixture
def delete_batch_prediction_job_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "delete_batch_prediction_job"
) as delete_job_mock:
yield delete_job_mock


@pytest.fixture
def list_batch_prediction_jobs_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "list_batch_prediction_jobs"
) as list_jobs_mock:
list_jobs_mock.return_value = [
_TEST_GAPIC_BATCH_PREDICTION_JOB,
gca_batch_prediction_job_compat.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_PALM_MODEL_RESOURCE_NAME,
state=_TEST_JOB_STATE_SUCCESS,
),
]
yield list_jobs_mock


@pytest.mark.usefixtures(
"google_auth_mock", "generate_display_name_mock", "complete_bq_uri_mock"
)
Expand All @@ -138,10 +196,12 @@ def setup_method(self):
def teardown_method(self):
aiplatform_initializer.global_pool.shutdown(wait=True)

def test_init_batch_prediction_job(self, get_batch_prediction_job_mock):
def test_init_batch_prediction_job(
self, get_batch_prediction_job_with_gcs_output_mock
):
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)

get_batch_prediction_job_mock.assert_called_once_with(
get_batch_prediction_job_with_gcs_output_mock.assert_called_once_with(
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
)

Expand All @@ -157,6 +217,7 @@ def test_init_batch_prediction_job_invalid_model(self):
):
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)

@pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock")
def test_submit_batch_prediction_job_with_gcs_input(
self, create_batch_prediction_job_mock
):
Expand All @@ -167,6 +228,15 @@ def test_submit_batch_prediction_job_with_gcs_input(
)

assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
assert job.state == _TEST_JOB_STATE_RUNNING
assert not job.has_ended
assert not job.has_succeeded

job.refresh()
assert job.state == _TEST_JOB_STATE_SUCCESS
assert job.has_ended
assert job.has_succeeded
assert job.output_location == _TEST_GCS_OUTPUT_PREFIX

expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
display_name=_TEST_DISPLAY_NAME,
Expand All @@ -188,6 +258,7 @@ def test_submit_batch_prediction_job_with_gcs_input(
timeout=None,
)

@pytest.mark.usefixtures("get_batch_prediction_job_with_bq_output_mock")
def test_submit_batch_prediction_job_with_bq_input(
self, create_batch_prediction_job_mock
):
Expand All @@ -198,6 +269,15 @@ def test_submit_batch_prediction_job_with_bq_input(
)

assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
assert job.state == _TEST_JOB_STATE_RUNNING
assert not job.has_ended
assert not job.has_succeeded

job.refresh()
assert job.state == _TEST_JOB_STATE_SUCCESS
assert job.has_ended
assert job.has_succeeded
assert job.output_location == _TEST_BQ_OUTPUT_PREFIX

expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
display_name=_TEST_DISPLAY_NAME,
Expand Down Expand Up @@ -349,3 +429,35 @@ def test_submit_batch_prediction_job_without_output_uri_prefix_and_bucket(self):
source_model=_TEST_GEMINI_MODEL_NAME,
input_dataset=_TEST_GCS_INPUT_URI,
)

@pytest.mark.usefixtures("create_batch_prediction_job_mock")
def test_cancel_batch_prediction_job(self, cancel_batch_prediction_job_mock):
job = batch_prediction.BatchPredictionJob.submit(
source_model=_TEST_GEMINI_MODEL_NAME,
input_dataset=_TEST_GCS_INPUT_URI,
output_uri_prefix=_TEST_GCS_OUTPUT_PREFIX,
)
job.cancel()

cancel_batch_prediction_job_mock.assert_called_once_with(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
)

@pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock")
def test_delete_batch_prediction_job(self, delete_batch_prediction_job_mock):
job = batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
job.delete()

delete_batch_prediction_job_mock.assert_called_once_with(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
)

def tes_list_batch_prediction_jobs(self, list_batch_prediction_jobs_mock):
jobs = batch_prediction.BatchPredictionJob.list()

assert len(jobs) == 1
assert jobs[0].gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB

list_batch_prediction_jobs_mock.assert_called_once_with(
request={"parent": _TEST_PARENT}
)
87 changes: 81 additions & 6 deletions vertexai/batch_prediction/_batch_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform import utils as aiplatform_utils
from google.cloud.aiplatform_v1 import types as gca_types
from vertexai import generative_models

from google.rpc import status_pb2


_LOGGER = aiplatform_base.Logger(__name__)

Expand All @@ -37,7 +40,6 @@ class BatchPredictionJob(aiplatform_base._VertexAiResourceNounPlus):
_resource_noun = "batchPredictionJobs"
_getter_method = "get_batch_prediction_job"
_list_method = "list_batch_prediction_jobs"
_cancel_method = "cancel_batch_prediction_job"
_delete_method = "delete_batch_prediction_job"
_job_type = "batch-predictions"
_parse_resource_name_method = "parse_batch_prediction_job_path"
Expand All @@ -63,13 +65,46 @@ def __init__(self, batch_prediction_job_name: str):
resource_name=batch_prediction_job_name
)
# TODO(b/338452508) Support tuned GenAI models.
if not re.search(_GEMINI_MODEL_PATTERN, self._gca_resource.model):
if not re.search(_GEMINI_MODEL_PATTERN, self.model_name):
raise ValueError(
f"BatchPredictionJob '{batch_prediction_job_name}' "
f"runs with the model '{self._gca_resource.model}', "
f"runs with the model '{self.model_name}', "
"which is not a GenAI model."
)

@property
def model_name(self) -> str:
"""Returns the model name used for this batch prediction job."""
return self._gca_resource.model

@property
def state(self) -> gca_types.JobState:
"""Returns the state of this batch prediction job."""
return self._gca_resource.state

@property
def has_ended(self) -> bool:
"""Returns true if this batch prediction job has ended."""
return self.state in jobs._JOB_COMPLETE_STATES

@property
def has_succeeded(self) -> bool:
"""Returns true if this batch prediction job has succeeded."""
return self.state == gca_types.JobState.JOB_STATE_SUCCEEDED

@property
def error(self) -> Optional[status_pb2.Status]:
"""Returns detailed error info for this Job resource."""
return self._gca_resource.error

@property
def output_location(self) -> str:
"""Returns the output location of this batch prediction job."""
return (
self._gca_resource.output_info.gcs_output_directory
or self._gca_resource.output_info.bigquery_output_table
)

@classmethod
def submit(
cls,
Expand Down Expand Up @@ -178,14 +213,54 @@ def submit(
_LOGGER.log_create_complete(
cls, job._gca_resource, "job", module_name="batch_prediction"
)
_LOGGER.info(
"View Batch Prediction Job:\n%s" % aiplatform_job._dashboard_uri()
)
_LOGGER.info("View Batch Prediction Job:\n%s" % job._dashboard_uri())

return job
finally:
logging.getLogger("google.cloud.aiplatform.jobs").disabled = False

def refresh(self) -> "BatchPredictionJob":
"""Refreshes the batch prediction job from the service."""
self._sync_gca_resource()
return self

def cancel(self):
"""Cancels this BatchPredictionJob.

Success of cancellation is not guaranteed. Use `job.refresh()` and
`job.state` to verify if cancellation was successful.
"""
_LOGGER.log_action_start_against_resource("Cancelling", "run", self)
self.api_client.cancel_batch_prediction_job(name=self.resource_name)

def delete(self):
"""Deletes this BatchPredictionJob resource.

WARNING: This deletion is permanent.
"""
self._delete()

@classmethod
def list(cls, filter=None) -> List["BatchPredictionJob"]:
"""Lists all BatchPredictionJob instances that run with GenAI models."""
return cls._list(
cls_filter=lambda gca_resource: re.search(
_GEMINI_MODEL_PATTERN, gca_resource.model
),
filter=filter,
)

def _dashboard_uri(self) -> Optional[str]:
"""Returns the Google Cloud console URL where job can be viewed."""
fields = self._parse_resource_name(self.resource_name)
location = fields.pop("location")
project = fields.pop("project")
job = list(fields.values())[0]
return (
"https://console.cloud.google.com/ai/platform/locations/"
f"{location}/{self._job_type}/{job}?project={project}"
)

@classmethod
def _reconcile_model_name(cls, model_name: str) -> str:
"""Reconciles model name to a publisher model resource name."""
Expand Down