/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.model;

import java.time.Instant;
import java.util.Collection;
import java.util.HashSet;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
import org.opensearch.remote.metadata.client.GetDataObjectResponse;
import org.opensearch.remote.metadata.client.PutDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.transport.client.Client;

public class MLModelGroupManager {
    @Generated
    private static final Logger log = LogManager.getLogger(MLModelGroupManager.class);
    private final MLIndicesHandler mlIndicesHandler;
    private final Client client;
    private final SdkClient sdkClient;
    ClusterService clusterService;
    ModelAccessControlHelper modelAccessControlHelper;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public MLModelGroupManager(MLIndicesHandler mlIndicesHandler, Client client, SdkClient sdkClient, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        this.mlIndicesHandler = mlIndicesHandler;
        this.client = client;
        this.sdkClient = sdkClient;
        this.clusterService = clusterService;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<String> listener) {
        try {
            String modelName = input.getName();
            User user = RestActionUtils.getUserContext(this.client);
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                ActionListener wrappedListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
                this.validateUniqueModelGroupName(input.getName(), input.getTenantId(), (ActionListener<SearchResponse>)ActionListener.wrap(modelGroups -> {
                    if (modelGroups != null && modelGroups.getHits().getTotalHits() != null && modelGroups.getHits().getTotalHits().value() != 0L) {
                        for (SearchHit documentFields : modelGroups.getHits()) {
                            String id = documentFields.getId();
                            wrappedListener.onFailure((Exception)new IllegalArgumentException("The name you provided is already being used by a model group with ID: " + id + "."));
                        }
                    } else {
                        MLModelGroup mlModelGroup;
                        MLModelGroup.MLModelGroupBuilder builder = MLModelGroup.builder();
                        if (this.modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
                            this.validateRequestForAccessControl(input, user);
                            builder = builder.access(input.getModelAccessMode().getValue());
                            if (Boolean.TRUE.equals(input.getIsAddAllBackendRoles())) {
                                input.setBackendRoles(user.getBackendRoles());
                            }
                            mlModelGroup = builder.name(modelName).description(input.getDescription()).backendRoles(input.getBackendRoles()).owner(user).createdTime(Instant.now()).lastUpdatedTime(Instant.now()).tenantId(input.getTenantId()).build();
                        } else {
                            this.validateSecurityDisabledOrModelAccessControlDisabled(input);
                            mlModelGroup = builder.name(modelName).description(input.getDescription()).access(AccessMode.PUBLIC.getValue()).createdTime(Instant.now()).lastUpdatedTime(Instant.now()).tenantId(input.getTenantId()).build();
                        }
                        this.mlIndicesHandler.initModelGroupIndexIfAbsent(ActionListener.wrap(res -> this.sdkClient.putDataObjectAsync(((PutDataObjectRequest.Builder)((PutDataObjectRequest.Builder)PutDataObjectRequest.builder().tenantId(mlModelGroup.getTenantId())).index(".plugins-ml-model-group")).dataObject((ToXContentObject)mlModelGroup).build()).whenComplete((r, throwable) -> {
                            if (throwable != null) {
                                Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                                log.error("Failed to index model group", (Throwable)cause);
                                wrappedListener.onFailure(cause);
                            } else {
                                try {
                                    IndexResponse indexResponse = r.indexResponse();
                                    log.info("Model group creation result: {}, model group id: {}", (Object)indexResponse.getResult(), (Object)indexResponse.getId());
                                    wrappedListener.onResponse((Object)r.id());
                                }
                                catch (Exception e) {
                                    wrappedListener.onFailure(e);
                                }
                            }
                        }), ex -> {
                            log.error("Failed to init model group index", (Throwable)ex);
                            wrappedListener.onFailure(ex);
                        }));
                    }
                }, e -> {
                    log.error("Failed to search model group index", (Throwable)e);
                    wrappedListener.onFailure(e);
                }));
            }
            catch (Exception e2) {
                log.error("Failed to create model group doc", (Throwable)e2);
                listener.onFailure(e2);
            }
        }
        catch (Exception e3) {
            log.error("Failed to init model group index", (Throwable)e3);
            listener.onFailure(e3);
        }
    }

    private void validateRequestForAccessControl(MLRegisterModelGroupInput input, User user) {
        AccessMode modelAccessMode = input.getModelAccessMode();
        Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles();
        if (modelAccessMode == null) {
            if (!CollectionUtils.isEmpty((Collection)input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
                throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
            }
            if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty((Collection)input.getBackendRoles())) {
                input.setModelAccessMode(AccessMode.RESTRICTED);
                modelAccessMode = AccessMode.RESTRICTED;
            } else {
                input.setModelAccessMode(AccessMode.PRIVATE);
            }
        }
        if (!(AccessMode.PUBLIC != modelAccessMode && AccessMode.PRIVATE != modelAccessMode || CollectionUtils.isEmpty((Collection)input.getBackendRoles()) && !Boolean.TRUE.equals(isAddAllBackendRoles))) {
            throw new IllegalArgumentException("You can specify backend roles only for a model group with the restricted access mode.");
        }
        if (AccessMode.RESTRICTED == modelAccessMode) {
            if (this.modelAccessControlHelper.isAdmin(user) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
                throw new IllegalArgumentException("Admin users cannot add all backend roles to a model group.");
            }
            if (!this.modelAccessControlHelper.isAdmin(user) && CollectionUtils.isEmpty((Collection)user.getBackendRoles())) {
                throw new IllegalArgumentException("You must have at least one backend role to register a restricted model group.");
            }
            if (CollectionUtils.isEmpty((Collection)input.getBackendRoles()) && !Boolean.TRUE.equals(isAddAllBackendRoles)) {
                throw new IllegalArgumentException("You must specify one or more backend roles or add all backend roles to register a restricted model group.");
            }
            if (!CollectionUtils.isEmpty((Collection)input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
                throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
            }
            if (!(this.modelAccessControlHelper.isAdmin(user) || Boolean.TRUE.equals(isAddAllBackendRoles) || new HashSet(user.getBackendRoles()).containsAll(input.getBackendRoles()))) {
                throw new IllegalArgumentException("You don't have the backend roles specified.");
            }
        }
    }

    public void validateUniqueModelGroupName(String name, String tenantId, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            BoolQueryBuilder query = new BoolQueryBuilder();
            query.filter((QueryBuilder)new TermQueryBuilder("name.keyword", name));
            SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query((QueryBuilder)query);
            SearchRequest searchRequest = new SearchRequest(new String[]{".plugins-ml-model-group"}).source(searchSourceBuilder);
            SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest.builder().indices(searchRequest.indices()).searchSourceBuilder(searchRequest.source()).tenantId(tenantId).build();
            this.sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
                if (throwable != null) {
                    Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                    if (ExceptionsHelper.unwrap((Throwable)throwable, (Class[])new Class[]{IndexNotFoundException.class}) != null) {
                        log.debug("Model group index does not exist");
                        listener.onResponse(null);
                    } else {
                        log.error("Failed to search model group index", (Throwable)cause);
                        listener.onFailure(cause);
                    }
                } else {
                    try {
                        SearchResponse searchResponse = r.searchResponse();
                        log.info("Model group search complete: {}", (Object)searchResponse.getHits().getTotalHits());
                        listener.onResponse((Object)searchResponse);
                    }
                    catch (Exception e) {
                        log.error("Failed to parse search response", (Throwable)e);
                        listener.onFailure((Exception)new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]));
                    }
                }
            });
        }
        catch (Exception e) {
            log.error("Failed to search model group index", (Throwable)e);
            listener.onFailure(e);
        }
    }

    public void getModelGroupResponse(SdkClient sdkClient, String modelGroupId, ActionListener<GetResponse> listener) {
        GetDataObjectRequest getRequest = this.buildGetModelGroupRequest(modelGroupId);
        sdkClient.getDataObjectAsync(getRequest).whenComplete((response, throwable) -> {
            if (throwable != null) {
                this.handleError((Throwable)throwable, listener);
                return;
            }
            this.processModelGroupResponse((GetDataObjectResponse)response, modelGroupId, listener);
        });
    }

    private GetDataObjectRequest buildGetModelGroupRequest(String modelGroupId) {
        return ((GetDataObjectRequest.Builder)((GetDataObjectRequest.Builder)GetDataObjectRequest.builder().index(".plugins-ml-model-group")).id(modelGroupId)).build();
    }

    private void handleError(Throwable throwable, ActionListener<GetResponse> listener) {
        Exception exception = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
        listener.onFailure(exception);
    }

    private void processModelGroupResponse(GetDataObjectResponse response, String modelGroupId, ActionListener<GetResponse> listener) {
        try {
            GetResponse getResponse = response.getResponse();
            if (getResponse == null || !getResponse.isExists()) {
                listener.onFailure((Exception)new MLResourceNotFoundException("Failed to find model group with ID: " + modelGroupId));
                return;
            }
            this.parseAndRespond(getResponse, listener);
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private void parseAndRespond(GetResponse getResponse, ActionListener<GetResponse> listener) {
        try (XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, Strings.toString((MediaType)MediaTypeRegistry.JSON, (ToXContent)getResponse));){
            listener.onResponse((Object)GetResponse.fromXContent((XContentParser)parser));
        }
        catch (Exception e) {
            log.error("Failed to parse model group response: {}", (Object)getResponse.getId(), (Object)e);
            listener.onFailure(e);
        }
    }

    private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) {
        if (input.getModelAccessMode() != null || input.getIsAddAllBackendRoles() != null || !CollectionUtils.isEmpty((Collection)input.getBackendRoles())) {
            throw new IllegalArgumentException("You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster.");
        }
    }
}

